Shyrma cudnn (#192)
* - implementation of cudnn batchnorm_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in batchnorm_bp based on cudnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - move pooling mkl code and delete some unnecessary files Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing cudnn pooling2d ops (avg/max, ff/bp) Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing cudnn pooling 3d (ff/bp) ops Signed-off-by: Yurii <iuriish@yahoo.com> * - provide ff step in case of cudnn maxpool3d_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - remove half type from set of supported types in mkl dpethwise conv op Signed-off-by: Yurii <iuriish@yahoo.com> * - bring back cudaStreamSynchronize in batchnorm and pooling cudnn ops Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
2f08af3166
commit
7a7ee4b021
|
@ -197,8 +197,7 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
|
||||||
// ***** calculations ***** //
|
// ***** calculations ***** //
|
||||||
|
|
||||||
// notations:
|
// notations:
|
||||||
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
|
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO
|
||||||
// g = dLdO
|
|
||||||
// stdInv = 1 / (v + eps)^0.5
|
// stdInv = 1 / (v + eps)^0.5
|
||||||
// N - batch size (product of spatial dimensions)
|
// N - batch size (product of spatial dimensions)
|
||||||
|
|
||||||
|
|
|
@ -31,31 +31,28 @@ namespace ops {
|
||||||
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", input->rankOf());
|
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
auto argI = *(block.getIArguments());
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const auto kH = INT_ARG(0);
|
const auto kH = INT_ARG(0);
|
||||||
const auto kW = INT_ARG(1);
|
const auto kW = INT_ARG(1);
|
||||||
const auto sH = INT_ARG(2);
|
const auto sH = INT_ARG(2);
|
||||||
const auto sW = INT_ARG(3);
|
const auto sW = INT_ARG(3);
|
||||||
int pH = INT_ARG(4);
|
auto pH = INT_ARG(4);
|
||||||
int pW = INT_ARG(5);
|
auto pW = INT_ARG(5);
|
||||||
const auto dH = INT_ARG(6);
|
const auto dH = INT_ARG(6);
|
||||||
const auto dW = INT_ARG(7);
|
const auto dW = INT_ARG(7);
|
||||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||||
const auto extraParam0 = INT_ARG(9);
|
const auto extraParam0 = INT_ARG(9);
|
||||||
|
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
int oH = 0;
|
int oH = 0;
|
||||||
int oW = 0;
|
int oW = 0;
|
||||||
|
|
||||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
|
||||||
|
|
||||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
|
|
||||||
|
@ -207,7 +204,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(avgpool2d_bp) {
|
DECLARE_SHAPE_FN(avgpool2d_bp) {
|
||||||
|
|
|
@ -51,14 +51,14 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
|
|
@ -32,6 +32,7 @@ namespace ops {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// maxpool2d corresponds to poolingMode=0
|
// maxpool2d corresponds to poolingMode=0
|
||||||
CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
|
|
@ -0,0 +1,138 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include "cudnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(avgpool2d, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
const auto kH = INT_ARG(0);
|
||||||
|
const auto kW = INT_ARG(1);
|
||||||
|
const auto sH = INT_ARG(2);
|
||||||
|
const auto sW = INT_ARG(3);
|
||||||
|
auto pH = INT_ARG(4);
|
||||||
|
auto pW = INT_ARG(5);
|
||||||
|
const auto dH = INT_ARG(6);
|
||||||
|
const auto dW = INT_ARG(7);
|
||||||
|
const auto paddingMode = static_cast<bool>(INT_ARG(8));
|
||||||
|
const auto extraParam0 = INT_ARG(9);
|
||||||
|
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
int oH = 0;
|
||||||
|
int oW = 0;
|
||||||
|
|
||||||
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
|
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||||
|
|
||||||
|
if (paddingMode)
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||||
|
|
||||||
|
pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(avgpool2d, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||||
|
|
||||||
|
return goodType && input->dataType() == output->dataType();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
const auto kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
const auto kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
const auto sH = INT_ARG(2); // strides height
|
||||||
|
const auto sW = INT_ARG(3); // strides width
|
||||||
|
auto pH = INT_ARG(4); // paddings height
|
||||||
|
auto pW = INT_ARG(5); // paddings width
|
||||||
|
const auto dH = INT_ARG(6); // dilations height
|
||||||
|
const auto dW = INT_ARG(7); // dilations width
|
||||||
|
const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
const auto extraParam0 = INT_ARG(9);
|
||||||
|
const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||||
|
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
if(paddingMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||||
|
|
||||||
|
pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||||
|
|
||||||
|
return goodType && (input->dataType() == gradO->dataType())
|
||||||
|
&& (input->dataType() == gradI->dataType())
|
||||||
|
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,144 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include "cudnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
int extraParam0 = INT_ARG(13);
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
|
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||||
|
|
||||||
|
if(paddingMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||||
|
|
||||||
|
pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(avgpool3dnew, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||||
|
|
||||||
|
return goodType && input->dataType() == output->dataType();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
|
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
const int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
const int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
const int sD = INT_ARG(3); // strides depth
|
||||||
|
const int sH = INT_ARG(4); // strides height
|
||||||
|
const int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
const int dD = INT_ARG(9); // dilations depth
|
||||||
|
const int dH = INT_ARG(10); // dilations height
|
||||||
|
const int dW = INT_ARG(11); // dilations width
|
||||||
|
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
|
||||||
|
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
|
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
if(isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||||
|
|
||||||
|
pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
|
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||||
|
|
||||||
|
return goodType && (input->dataType() == gradO->dataType())
|
||||||
|
&& (input->dataType() == gradI->dataType())
|
||||||
|
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -97,9 +97,6 @@ static void batchnormCUDNN(const LaunchContext* context,
|
||||||
err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data());
|
err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data());
|
||||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err);
|
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err);
|
||||||
|
|
||||||
|
|
||||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetConvolutionNdDescriptor failed", err);
|
|
||||||
|
|
||||||
// provide scaling parameters
|
// provide scaling parameters
|
||||||
const float alpha32(1), beta32(0);
|
const float alpha32(1), beta32(0);
|
||||||
const double alpha64(1), beta64(0);
|
const double alpha64(1), beta64(0);
|
||||||
|
@ -114,20 +111,127 @@ static void batchnormCUDNN(const LaunchContext* context,
|
||||||
x, input->getSpecialBuffer(),
|
x, input->getSpecialBuffer(),
|
||||||
z, output->getSpecialBuffer(),
|
z, output->getSpecialBuffer(),
|
||||||
params,
|
params,
|
||||||
gamma ? gamma->getSpecialBuffer(): nullptr,
|
gamma->getSpecialBuffer(), beta->getSpecialBuffer(),
|
||||||
beta ? beta->getSpecialBuffer() : nullptr,
|
|
||||||
mean->getSpecialBuffer(), variance->getSpecialBuffer(), epsilon);
|
mean->getSpecialBuffer(), variance->getSpecialBuffer(), epsilon);
|
||||||
|
|
||||||
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err);
|
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err);
|
||||||
|
|
||||||
// cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||||
// if (cudaErr != 0)
|
if (cudaErr != 0)
|
||||||
// throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||||
|
|
||||||
|
|
||||||
NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta});
|
NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void batchnormBpCUDNN(const LaunchContext* context,
|
||||||
|
const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* gradO,
|
||||||
|
NDArray* gradI, NDArray* gradG, NDArray* gradB,
|
||||||
|
const double epsilon, const bool isSpatialMode) {
|
||||||
|
|
||||||
|
// input, gradO, gradI -> 4D:nchw, 5D:ncdhw
|
||||||
|
// mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode
|
||||||
|
// -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode
|
||||||
|
|
||||||
|
const cudnnDataType_t dataType = cudnnDataType(input->dataType());
|
||||||
|
|
||||||
|
const int xRank = input->rankOf();
|
||||||
|
|
||||||
|
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||||
|
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: can't set stream for cuDNN", err);
|
||||||
|
|
||||||
|
const std::vector<int> xShape = input->getShapeAsVectorInt(); // input and output have same shapes
|
||||||
|
|
||||||
|
std::vector<int> paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes
|
||||||
|
if(isSpatialMode) { // 1xCx1x1
|
||||||
|
const int iC = mean->lengthOf();
|
||||||
|
const int stride0 = mean->strideAt(0);
|
||||||
|
paramsShape = xRank == 4 ? std::vector<int>({1, iC, 1, 1}) : std::vector<int>({1, iC, 1, 1, 1});
|
||||||
|
paramsStrides = xRank == 4 ? std::vector<int>({iC*stride0, stride0, 1, 1}) : std::vector<int>({iC*stride0, stride0, 1, 1, 1});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
paramsShape = mean->getShapeAsVectorInt();
|
||||||
|
paramsStrides = xRank == 4 ? std::vector<int>({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector<int>({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)};
|
||||||
|
std::vector<int> dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3)};
|
||||||
|
std::vector<int> dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3)};
|
||||||
|
|
||||||
|
if(xRank > 4) { // 5D
|
||||||
|
xStrides.push_back((int)input->strideAt(4));
|
||||||
|
dxStrides.push_back((int)gradI->strideAt(4));
|
||||||
|
dzStrides.push_back((int)gradO->strideAt(4));
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW;
|
||||||
|
|
||||||
|
// input descriptor
|
||||||
|
cudnnTensorDescriptor_t x;
|
||||||
|
cudnnCreateTensorDescriptor(&x);
|
||||||
|
if(input->ews() == 1)
|
||||||
|
err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data());
|
||||||
|
else
|
||||||
|
err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err);
|
||||||
|
|
||||||
|
// gradO descriptor
|
||||||
|
cudnnTensorDescriptor_t dz;
|
||||||
|
cudnnCreateTensorDescriptor(&dz);
|
||||||
|
if(gradO->ews() == 1)
|
||||||
|
err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, xShape.data());
|
||||||
|
else
|
||||||
|
err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), dzStrides.data());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err);
|
||||||
|
|
||||||
|
// gradI descriptor
|
||||||
|
cudnnTensorDescriptor_t dx;
|
||||||
|
cudnnCreateTensorDescriptor(&dx);
|
||||||
|
if(input->ews() == 1)
|
||||||
|
err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, xShape.data());
|
||||||
|
else
|
||||||
|
err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), dxStrides.data());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err);
|
||||||
|
|
||||||
|
// mean, variance, gamma, gradG and gradB descriptor, the same descriptor for all of them
|
||||||
|
cudnnTensorDescriptor_t params;
|
||||||
|
cudnnCreateTensorDescriptor(¶ms);
|
||||||
|
if(mean->ews() == 1)
|
||||||
|
err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data());
|
||||||
|
else
|
||||||
|
err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/gradG/gradB failed", err);
|
||||||
|
|
||||||
|
// provide scaling parameters
|
||||||
|
const float alpha32(1), beta32(0);
|
||||||
|
double alpha64(1), beta64(0);
|
||||||
|
const void* ptrAlpha = input->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||||
|
const void* ptrBeta = input->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO});
|
||||||
|
|
||||||
|
// calculations
|
||||||
|
// TODO: we can use cache here
|
||||||
|
err = cudnnBatchNormalizationBackward(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION,
|
||||||
|
ptrAlpha, ptrBeta, ptrAlpha, ptrBeta,
|
||||||
|
x, input->getSpecialBuffer(),
|
||||||
|
dz, gradO->getSpecialBuffer(),
|
||||||
|
dx, gradI->getSpecialBuffer(),
|
||||||
|
params,
|
||||||
|
gamma->getSpecialBuffer(), gradG->getSpecialBuffer(), gradB->getSpecialBuffer(),
|
||||||
|
epsilon,
|
||||||
|
nullptr/*mean->getSpecialBuffer()*/, nullptr/*variance->getSpecialBuffer()*/);
|
||||||
|
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err);
|
||||||
|
|
||||||
|
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||||
|
if (cudaErr != 0)
|
||||||
|
throw cuda_exception::build("batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
|
PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
|
||||||
|
@ -189,11 +293,21 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
|
||||||
const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1);
|
const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1);
|
||||||
|
|
||||||
if(needPermut) { // if NHWC
|
if(needPermut) { // if NHWC
|
||||||
std::vector<int> perm = {0, 3, 1, 2}; // NHWC -> NCHW
|
std::vector<int> perm = inRank == 4 ? std::vector<int>({0, 3, 1, 2}) : std::vector<int>({0, 4, 1, 2, 3}); // NHWC -> NCHW
|
||||||
input = new NDArray(input->permute(perm));
|
input = new NDArray(input->permute(perm));
|
||||||
output = new NDArray(output->permute(perm));
|
output = new NDArray(output->permute(perm));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cudnn requires gamma and beta to be non-nullptr
|
||||||
|
if(!applyScale) {
|
||||||
|
gamma = new NDArray(mean);
|
||||||
|
*gamma = 1;
|
||||||
|
}
|
||||||
|
if(!applyOffset) {
|
||||||
|
beta = new NDArray(mean);
|
||||||
|
*beta = 0;
|
||||||
|
}
|
||||||
|
|
||||||
// calculations
|
// calculations
|
||||||
batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1);
|
batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1);
|
||||||
|
|
||||||
|
@ -202,6 +316,12 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
|
||||||
delete output;
|
delete output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(!applyScale)
|
||||||
|
delete gamma;
|
||||||
|
|
||||||
|
if(!applyOffset)
|
||||||
|
delete beta;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,9 +340,6 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) {
|
||||||
const int numOfIntArgs = block.getIArguments()->size();
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
const int xRank = input->rankOf();
|
const int xRank = input->rankOf();
|
||||||
|
|
||||||
// disable cudnn batchnorm so far
|
|
||||||
return false;
|
|
||||||
|
|
||||||
// *********************************** //
|
// *********************************** //
|
||||||
if(xRank != 4 && xRank != 5)
|
if(xRank != 4 && xRank != 5)
|
||||||
return false;
|
return false;
|
||||||
|
@ -269,6 +386,182 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(batchnorm_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
NDArray* input = INPUT_VARIABLE(0);
|
||||||
|
NDArray* mean = INPUT_VARIABLE(1);
|
||||||
|
NDArray* variance = INPUT_VARIABLE(2);
|
||||||
|
NDArray* gamma = nullptr;
|
||||||
|
NDArray* beta = nullptr;
|
||||||
|
NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon
|
||||||
|
|
||||||
|
NDArray* gradI = OUTPUT_VARIABLE(0);
|
||||||
|
NDArray* gradM = OUTPUT_VARIABLE(1);
|
||||||
|
NDArray* gradV = OUTPUT_VARIABLE(2);
|
||||||
|
NDArray* gradG = nullptr;
|
||||||
|
NDArray* gradB = nullptr;
|
||||||
|
|
||||||
|
const bool applyScale = (bool)INT_ARG(0);
|
||||||
|
const bool applyOffset = (bool)INT_ARG(1);
|
||||||
|
const float epsilon = T_ARG(0);
|
||||||
|
|
||||||
|
if(applyScale) {
|
||||||
|
gamma = INPUT_VARIABLE(3);
|
||||||
|
gradG = OUTPUT_VARIABLE(3);
|
||||||
|
}
|
||||||
|
if(applyOffset) {
|
||||||
|
beta = INPUT_VARIABLE(3 + (int)applyScale);
|
||||||
|
gradB = OUTPUT_VARIABLE(3 + (int)applyScale);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
|
const int inRank = input->rankOf();
|
||||||
|
|
||||||
|
// get axes args to normalize input array over
|
||||||
|
std::vector<int> axes;
|
||||||
|
if(numOfIntArgs > 2)
|
||||||
|
for(int i = 2; i < numOfIntArgs; ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
|
||||||
|
|
||||||
|
const int numOfAxes = axes.size();
|
||||||
|
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
|
||||||
|
|
||||||
|
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
|
||||||
|
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
|
||||||
|
std::vector<Nd4jLong> expShape;
|
||||||
|
if(numOfAxes == 1)
|
||||||
|
expShape.push_back(input->sizeAt(axes[0]));
|
||||||
|
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
|
||||||
|
expShape = std::vector<Nd4jLong>(inRank, 1);
|
||||||
|
for(uint i = 0; i < numOfAxes; ++i)
|
||||||
|
expShape[axes[i]] = input->sizeAt(axes[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str());
|
||||||
|
REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str());
|
||||||
|
if(gamma)
|
||||||
|
REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str());
|
||||||
|
if(beta)
|
||||||
|
REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str());
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->isSameShape(gradO), 0, "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
|
||||||
|
// types of all input arrays should be the same (except gradO)
|
||||||
|
for(int i = 1; i < block.width() - 2; ++i)
|
||||||
|
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP CUDNN op: types of arrays (input, mean, variance, gamma, beta) should be the same !");
|
||||||
|
|
||||||
|
// cudnn supports NCHW format only
|
||||||
|
const bool needPermut = axes.size() == 1 && mean->lengthOf() != input->sizeAt(1);
|
||||||
|
|
||||||
|
if(needPermut) { // if NHWC
|
||||||
|
std::vector<int> perm = inRank == 4 ? std::vector<int>({0, 3, 1, 2}) : std::vector<int>({0, 4, 1, 2, 3}); // NHWC -> NCHW
|
||||||
|
input = new NDArray(input->permute(perm));
|
||||||
|
gradO = new NDArray(gradO->permute(perm));
|
||||||
|
gradI = new NDArray(gradI->permute(perm));
|
||||||
|
}
|
||||||
|
|
||||||
|
// cudnn requires gamma, gradG, gradB to be non-nullptr
|
||||||
|
if(!applyScale) {
|
||||||
|
gamma = new NDArray(mean);
|
||||||
|
gradG = new NDArray(mean);
|
||||||
|
*gamma = 1;
|
||||||
|
}
|
||||||
|
if(!applyOffset)
|
||||||
|
gradB = new NDArray(mean);
|
||||||
|
|
||||||
|
// calculations
|
||||||
|
batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, gradI, gradG, gradB, epsilon, axes.size() == 1);
|
||||||
|
|
||||||
|
*gradM = 0; // put zeros so far
|
||||||
|
*gradV = 0; // put zeros so far
|
||||||
|
|
||||||
|
if(needPermut) {
|
||||||
|
delete input;
|
||||||
|
delete gradO;
|
||||||
|
delete gradI;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!applyScale) {
|
||||||
|
delete gamma;
|
||||||
|
delete gradG;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!applyOffset)
|
||||||
|
delete gradB;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(batchnorm_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
NDArray* input = INPUT_VARIABLE(0);
|
||||||
|
NDArray* mean = INPUT_VARIABLE(1);
|
||||||
|
NDArray* variance = INPUT_VARIABLE(2);
|
||||||
|
NDArray* gamma = nullptr;
|
||||||
|
NDArray* beta = nullptr;
|
||||||
|
NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon
|
||||||
|
|
||||||
|
NDArray* gradI = OUTPUT_VARIABLE(0);
|
||||||
|
NDArray* gradM = OUTPUT_VARIABLE(1);
|
||||||
|
NDArray* gradV = OUTPUT_VARIABLE(2);
|
||||||
|
NDArray* gradG = nullptr;
|
||||||
|
NDArray* gradB = nullptr;
|
||||||
|
|
||||||
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
|
const int xRank = input->rankOf();
|
||||||
|
|
||||||
|
// *********************************** //
|
||||||
|
if(xRank != 4 && xRank != 5)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// *********************************** //
|
||||||
|
const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF;
|
||||||
|
if(badType)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// *********************************** //
|
||||||
|
// get axes args to normalize input array over
|
||||||
|
std::vector<int> axes;
|
||||||
|
if(numOfIntArgs > 2)
|
||||||
|
for(int i = 2; i < numOfIntArgs; ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(xRank-1); // default dimension to reduce along is last dimension
|
||||||
|
|
||||||
|
if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// *********************************** //
|
||||||
|
bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->getShapeInfo(), variance->getShapeInfo());
|
||||||
|
if(gamma)
|
||||||
|
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gamma->getShapeInfo());
|
||||||
|
if(gradG)
|
||||||
|
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradG->getShapeInfo());
|
||||||
|
if(gradB)
|
||||||
|
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradB->getShapeInfo());
|
||||||
|
|
||||||
|
if(!allParamsHaveSameShapeAndStrides)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// *********************************** //
|
||||||
|
bool isFormatGood = false;
|
||||||
|
if(axes.size() == 1)
|
||||||
|
isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C]
|
||||||
|
else {
|
||||||
|
auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4]
|
||||||
|
inputShapeModif[0] = 1;
|
||||||
|
isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4]
|
||||||
|
}
|
||||||
|
if(!isFormatGood)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,412 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include "cudnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||||
|
const int iH, const int iW,
|
||||||
|
const int oH, const int oW,
|
||||||
|
const int kH, const int kW,
|
||||||
|
const int sH, const int sW,
|
||||||
|
const int pH, const int pW,
|
||||||
|
const int dH, const int dW,
|
||||||
|
const bool isNCHW) {
|
||||||
|
|
||||||
|
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
|
||||||
|
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
|
||||||
|
|
||||||
|
const bool isPHasymm = pH != (pHsum - pH);
|
||||||
|
const bool isPWasymm = pW != (pWsum - pW);
|
||||||
|
|
||||||
|
if(!isPHasymm && !isPWasymm)
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
|
||||||
|
|
||||||
|
const int iHposition = isNCHW ? 2 : 1;
|
||||||
|
|
||||||
|
if(isPHasymm)
|
||||||
|
newShape[iHposition] += 1;
|
||||||
|
if(isPWasymm)
|
||||||
|
newShape[iHposition + 1] += 1;
|
||||||
|
|
||||||
|
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
|
||||||
|
|
||||||
|
if(isNCHW)
|
||||||
|
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input);
|
||||||
|
else
|
||||||
|
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input);
|
||||||
|
|
||||||
|
input = newInput;
|
||||||
|
|
||||||
|
if(gradI != nullptr)
|
||||||
|
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||||
|
const int iD, const int iH, const int iW,
|
||||||
|
const int oD, const int oH, const int oW,
|
||||||
|
const int kD, const int kH, const int kW,
|
||||||
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int dD, const int dH, const int dW,
|
||||||
|
const bool isNCDHW) {
|
||||||
|
|
||||||
|
const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD);
|
||||||
|
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
|
||||||
|
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
|
||||||
|
|
||||||
|
const bool isPDasymm = pD != (pDsum - pD);
|
||||||
|
const bool isPHasymm = pH != (pHsum - pH);
|
||||||
|
const bool isPWasymm = pW != (pWsum - pW);
|
||||||
|
|
||||||
|
if(!isPDasymm && !isPHasymm && !isPWasymm)
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
|
||||||
|
|
||||||
|
const int iDposition = isNCDHW ? 2 : 1;
|
||||||
|
|
||||||
|
if(isPDasymm)
|
||||||
|
newShape[iDposition] += 1;
|
||||||
|
if(isPHasymm)
|
||||||
|
newShape[iDposition + 1] += 1;
|
||||||
|
if(isPWasymm)
|
||||||
|
newShape[iDposition + 2] += 1;
|
||||||
|
|
||||||
|
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
|
||||||
|
|
||||||
|
if(isNCDHW)
|
||||||
|
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input);
|
||||||
|
else
|
||||||
|
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input);
|
||||||
|
|
||||||
|
input = newInput;
|
||||||
|
|
||||||
|
if(gradI != nullptr)
|
||||||
|
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void pooling2dCUDNN(const LaunchContext* context,
|
||||||
|
const NDArray* input, NDArray* output,
|
||||||
|
const int kH, const int kW,
|
||||||
|
const int sH, const int sW,
|
||||||
|
const int pH, const int pW,
|
||||||
|
const int dH, const int dW,
|
||||||
|
const bool isNCHW, const cudnnPoolingMode_t mode) {
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||||
|
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: can't set stream for cuDNN", err);
|
||||||
|
|
||||||
|
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||||
|
|
||||||
|
// input descriptor
|
||||||
|
cudnnTensorDescriptor_t x;
|
||||||
|
cudnnCreateTensorDescriptor(&x);
|
||||||
|
if(input->ews() == 1)
|
||||||
|
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||||
|
else
|
||||||
|
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err);
|
||||||
|
|
||||||
|
// output descriptor
|
||||||
|
cudnnTensorDescriptor_t z;
|
||||||
|
cudnnCreateTensorDescriptor(&z);
|
||||||
|
if(output->ews() == 1)
|
||||||
|
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
|
||||||
|
else
|
||||||
|
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err);
|
||||||
|
|
||||||
|
// description of pooling
|
||||||
|
cudnnPoolingDescriptor_t pooling;
|
||||||
|
cudnnCreatePoolingDescriptor(&pooling);
|
||||||
|
err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW);
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err);
|
||||||
|
|
||||||
|
// provide scaling parameters
|
||||||
|
const float alpha32(1), beta32(0);
|
||||||
|
const double alpha64(1), beta64(0);
|
||||||
|
const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||||
|
const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
|
||||||
|
// run calculation
|
||||||
|
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnPoolingForward failed", err);
|
||||||
|
|
||||||
|
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||||
|
if (cudaErr != 0)
|
||||||
|
throw cuda_exception::build("pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void pooling2dBpCUDNN(const LaunchContext* context,
|
||||||
|
const NDArray* input, const NDArray* gradO,
|
||||||
|
NDArray* gradI,
|
||||||
|
const int kH, const int kW,
|
||||||
|
const int sH, const int sW,
|
||||||
|
const int pH, const int pW,
|
||||||
|
const int dH, const int dW,
|
||||||
|
const bool isNCHW, const cudnnPoolingMode_t mode) {
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||||
|
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: can't set stream for cuDNN", err);
|
||||||
|
|
||||||
|
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||||
|
|
||||||
|
// input and gradI descriptor
|
||||||
|
cudnnTensorDescriptor_t x;
|
||||||
|
cudnnCreateTensorDescriptor(&x);
|
||||||
|
if(input->ews() == 1)
|
||||||
|
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
|
||||||
|
else
|
||||||
|
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input/gradI failed", err);
|
||||||
|
|
||||||
|
// gradO descriptor
|
||||||
|
cudnnTensorDescriptor_t dz;
|
||||||
|
cudnnCreateTensorDescriptor(&dz);
|
||||||
|
if(gradO->ews() == 1)
|
||||||
|
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
|
||||||
|
else
|
||||||
|
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err);
|
||||||
|
|
||||||
|
// description of pooling
|
||||||
|
cudnnPoolingDescriptor_t pooling;
|
||||||
|
cudnnCreatePoolingDescriptor(&pooling);
|
||||||
|
err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW);
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err);
|
||||||
|
|
||||||
|
// provide scaling parameters
|
||||||
|
const float alpha32(1), beta32(0);
|
||||||
|
const double alpha64(1), beta64(0);
|
||||||
|
const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||||
|
const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({gradI}, {input, gradO});
|
||||||
|
|
||||||
|
// run calculation for gradI
|
||||||
|
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
|
||||||
|
|
||||||
|
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||||
|
if (cudaErr != 0)
|
||||||
|
throw cuda_exception::build("pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({gradI}, {input, gradO});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void pooling3dCUDNN(const LaunchContext* context,
|
||||||
|
const NDArray* input, NDArray* output,
|
||||||
|
const int kD, const int kH, const int kW,
|
||||||
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int dD, const int dH, const int dW,
|
||||||
|
const bool isNCDHW, const cudnnPoolingMode_t mode) {
|
||||||
|
|
||||||
|
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||||
|
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err);
|
||||||
|
printf("fffffffffff\n");
|
||||||
|
const int numDims = 5;
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
const int pSizes[] = {pD, pH, pW};
|
||||||
|
const int sSizes[] = {sD, sH, sW};
|
||||||
|
const int kSizes[] = {kD, kH, kW};
|
||||||
|
|
||||||
|
const int xShape[] = {bS, iC, iD, iH, iW};
|
||||||
|
const int zShape[] = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)};
|
||||||
|
const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)};
|
||||||
|
|
||||||
|
cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||||
|
|
||||||
|
// input descriptor
|
||||||
|
cudnnTensorDescriptor_t x;
|
||||||
|
cudnnCreateTensorDescriptor(&x);
|
||||||
|
if(input->ews() == 1)
|
||||||
|
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
|
||||||
|
else
|
||||||
|
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err);
|
||||||
|
|
||||||
|
// output descriptor
|
||||||
|
cudnnTensorDescriptor_t z;
|
||||||
|
cudnnCreateTensorDescriptor(&z);
|
||||||
|
if(output->ews() == 1)
|
||||||
|
err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape);
|
||||||
|
else
|
||||||
|
err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides);
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err);
|
||||||
|
|
||||||
|
// description of pooling
|
||||||
|
cudnnPoolingDescriptor_t pooling;
|
||||||
|
cudnnCreatePoolingDescriptor(&pooling);
|
||||||
|
err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes);
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err);
|
||||||
|
|
||||||
|
// provide scaling parameters
|
||||||
|
const float alpha32(1), beta32(0);
|
||||||
|
const double alpha64(1), beta64(0);
|
||||||
|
const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||||
|
const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
|
||||||
|
// run calculation
|
||||||
|
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err);
|
||||||
|
|
||||||
|
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||||
|
if (cudaErr != 0)
|
||||||
|
throw cuda_exception::build("pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void pooling3dBpCUDNN(const LaunchContext* context,
|
||||||
|
const NDArray* input, const NDArray* gradO,
|
||||||
|
NDArray* gradI,
|
||||||
|
const int kD, const int kH, const int kW,
|
||||||
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int dD, const int dH, const int dW,
|
||||||
|
const bool isNCDHW, const cudnnPoolingMode_t mode) {
|
||||||
|
|
||||||
|
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||||
|
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: can't set stream for cuDNN", err);
|
||||||
|
|
||||||
|
const int numDims = 5;
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
const int pSizes[] = {pD, pH, pW};
|
||||||
|
const int sSizes[] = {sD, sH, sW};
|
||||||
|
const int kSizes[] = {kD, kH, kW};
|
||||||
|
|
||||||
|
const int xShape[] = {bS, iC, iD, iH, iW};
|
||||||
|
const int dzShape[] = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)};
|
||||||
|
const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)};
|
||||||
|
|
||||||
|
cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
|
||||||
|
|
||||||
|
// input and gradI descriptor
|
||||||
|
cudnnTensorDescriptor_t x;
|
||||||
|
cudnnCreateTensorDescriptor(&x);
|
||||||
|
if(input->ews() == 1)
|
||||||
|
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
|
||||||
|
else
|
||||||
|
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input/gradI failed", err);
|
||||||
|
|
||||||
|
// gradO descriptor
|
||||||
|
cudnnTensorDescriptor_t dz;
|
||||||
|
cudnnCreateTensorDescriptor(&dz);
|
||||||
|
if(gradO->ews() == 1)
|
||||||
|
err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape);
|
||||||
|
else
|
||||||
|
err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides);
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err);
|
||||||
|
|
||||||
|
// description of pooling
|
||||||
|
cudnnPoolingDescriptor_t pooling;
|
||||||
|
cudnnCreatePoolingDescriptor(&pooling);
|
||||||
|
err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes);
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err);
|
||||||
|
|
||||||
|
// provide scaling parameters
|
||||||
|
const float alpha32(1), beta32(0);
|
||||||
|
const double alpha64(1), beta64(0);
|
||||||
|
const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
|
||||||
|
const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
|
||||||
|
|
||||||
|
// cudnn maxpool2d_bp api requires ff output as one of input arguments
|
||||||
|
if(mode == CUDNN_POOLING_MAX) {
|
||||||
|
|
||||||
|
NDArray temp(gradO);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp});
|
||||||
|
|
||||||
|
// run ff calculation
|
||||||
|
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, dz, temp.specialBuffer());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err);
|
||||||
|
|
||||||
|
// run bp calculation for gradI
|
||||||
|
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, temp.getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({gradI}, {input, gradO, &temp});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({gradI}, {input, gradO});
|
||||||
|
|
||||||
|
// run bp calculation for gradI
|
||||||
|
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
|
||||||
|
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({gradI}, {input, gradO});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
|
||||||
|
if (cudaErr != 0)
|
||||||
|
throw cuda_exception::build("pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -30,8 +30,8 @@
|
||||||
|
|
||||||
#include <cudnn.h>
|
#include <cudnn.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
DECLARE_PLATFORM(conv2d, ENGINE_CUDA);
|
DECLARE_PLATFORM(conv2d, ENGINE_CUDA);
|
||||||
|
@ -46,6 +46,18 @@ namespace platforms {
|
||||||
DECLARE_PLATFORM(batchnorm, ENGINE_CUDA);
|
DECLARE_PLATFORM(batchnorm, ENGINE_CUDA);
|
||||||
DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA);
|
DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA);
|
||||||
|
DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA);
|
||||||
|
DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA);
|
||||||
|
DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA);
|
||||||
|
DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) {
|
FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) {
|
||||||
switch (dataType) {
|
switch (dataType) {
|
||||||
|
@ -65,91 +77,62 @@ FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
FORCEINLINE void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||||
const int iH, const int iW,
|
const int iH, const int iW,
|
||||||
const int oH, const int oW,
|
const int oH, const int oW,
|
||||||
const int kH, const int kW,
|
const int kH, const int kW,
|
||||||
const int sH, const int sW,
|
const int sH, const int sW,
|
||||||
const int pH, const int pW,
|
const int pH, const int pW,
|
||||||
const int dH, const int dW,
|
const int dH, const int dW,
|
||||||
const bool isNCHW) {
|
const bool isNCHW);
|
||||||
|
|
||||||
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
|
|
||||||
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
|
|
||||||
|
|
||||||
const bool isPHasymm = pH != (pHsum - pH);
|
|
||||||
const bool isPWasymm = pW != (pWsum - pW);
|
|
||||||
|
|
||||||
if(!isPHasymm && !isPWasymm)
|
|
||||||
return;
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
|
|
||||||
|
|
||||||
const int iHposition = isNCHW ? 2 : 1;
|
|
||||||
|
|
||||||
if(isPHasymm)
|
|
||||||
newShape[iHposition] += 1;
|
|
||||||
if(isPWasymm)
|
|
||||||
newShape[iHposition + 1] += 1;
|
|
||||||
|
|
||||||
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
|
|
||||||
|
|
||||||
if(isNCHW)
|
|
||||||
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input);
|
|
||||||
else
|
|
||||||
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input);
|
|
||||||
|
|
||||||
input = newInput;
|
|
||||||
|
|
||||||
if(gradI != nullptr)
|
|
||||||
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
FORCEINLINE void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
|
||||||
const int iD, const int iH, const int iW,
|
const int iD, const int iH, const int iW,
|
||||||
const int oD, const int oH, const int oW,
|
const int oD, const int oH, const int oW,
|
||||||
const int kD, const int kH, const int kW,
|
const int kD, const int kH, const int kW,
|
||||||
const int sD, const int sH, const int sW,
|
const int sD, const int sH, const int sW,
|
||||||
const int pD, const int pH, const int pW,
|
const int pD, const int pH, const int pW,
|
||||||
const int dD, const int dH, const int dW,
|
const int dD, const int dH, const int dW,
|
||||||
const bool isNCDHW) {
|
const bool isNCDHW);
|
||||||
|
|
||||||
const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD);
|
//////////////////////////////////////////////////////////////////////////
|
||||||
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
|
void pooling2dCUDNN(const LaunchContext* context,
|
||||||
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
|
const NDArray* input, NDArray* output,
|
||||||
|
const int kH, const int kW,
|
||||||
|
const int sH, const int sW,
|
||||||
|
const int pH, const int pW,
|
||||||
|
const int dH, const int dW,
|
||||||
|
const bool isNCHW, const cudnnPoolingMode_t mode);
|
||||||
|
|
||||||
const bool isPDasymm = pD != (pDsum - pD);
|
//////////////////////////////////////////////////////////////////////////
|
||||||
const bool isPHasymm = pH != (pHsum - pH);
|
void pooling2dBpCUDNN(const LaunchContext* context,
|
||||||
const bool isPWasymm = pW != (pWsum - pW);
|
const NDArray* input, const NDArray* gradO,
|
||||||
|
NDArray* gradI,
|
||||||
|
const int kH, const int kW,
|
||||||
|
const int sH, const int sW,
|
||||||
|
const int pH, const int pW,
|
||||||
|
const int dH, const int dW,
|
||||||
|
const bool isNCHW, const cudnnPoolingMode_t mode);
|
||||||
|
|
||||||
if(!isPDasymm && !isPHasymm && !isPWasymm)
|
//////////////////////////////////////////////////////////////////////////
|
||||||
return;
|
void pooling3dCUDNN(const LaunchContext* context,
|
||||||
|
const NDArray* input, NDArray* output,
|
||||||
|
const int kD, const int kH, const int kW,
|
||||||
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int dD, const int dH, const int dW,
|
||||||
|
const bool isNCDHW, const cudnnPoolingMode_t mode);
|
||||||
|
|
||||||
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void pooling3dBpCUDNN(const LaunchContext* context,
|
||||||
const int iDposition = isNCDHW ? 2 : 1;
|
const NDArray* input, const NDArray* gradO,
|
||||||
|
NDArray* gradI,
|
||||||
if(isPDasymm)
|
const int kD, const int kH, const int kW,
|
||||||
newShape[iDposition] += 1;
|
const int sD, const int sH, const int sW,
|
||||||
if(isPHasymm)
|
const int pD, const int pH, const int pW,
|
||||||
newShape[iDposition + 1] += 1;
|
const int dD, const int dH, const int dW,
|
||||||
if(isPWasymm)
|
const bool isNCDHW, const cudnnPoolingMode_t mode);
|
||||||
newShape[iDposition + 2] += 1;
|
|
||||||
|
|
||||||
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
|
|
||||||
|
|
||||||
if(isNCDHW)
|
|
||||||
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input);
|
|
||||||
else
|
|
||||||
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input);
|
|
||||||
|
|
||||||
input = newInput;
|
|
||||||
|
|
||||||
if(gradI != nullptr)
|
|
||||||
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include "cudnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(maxpool2d, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee;
|
||||||
|
const auto kH = INT_ARG(0);
|
||||||
|
const auto kW = INT_ARG(1);
|
||||||
|
const auto sH = INT_ARG(2);
|
||||||
|
const auto sW = INT_ARG(3);
|
||||||
|
auto pH = INT_ARG(4);
|
||||||
|
auto pW = INT_ARG(5);
|
||||||
|
const auto dH = INT_ARG(6);
|
||||||
|
const auto dW = INT_ARG(7);
|
||||||
|
const auto paddingMode = static_cast<bool>(INT_ARG(8));
|
||||||
|
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
int oH = 0;
|
||||||
|
int oW = 0;
|
||||||
|
|
||||||
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
|
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||||
|
|
||||||
|
if (paddingMode)
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(maxpool2d, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||||
|
|
||||||
|
return goodType && input->dataType() == output->dataType();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
const auto kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
const auto kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
const auto sH = INT_ARG(2); // strides height
|
||||||
|
const auto sW = INT_ARG(3); // strides width
|
||||||
|
auto pH = INT_ARG(4); // paddings height
|
||||||
|
auto pW = INT_ARG(5); // paddings width
|
||||||
|
const auto dH = INT_ARG(6); // dilations height
|
||||||
|
const auto dW = INT_ARG(7); // dilations width
|
||||||
|
const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||||
|
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
if(paddingMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||||
|
|
||||||
|
return goodType && (input->dataType() == gradO->dataType())
|
||||||
|
&& (input->dataType() == gradI->dataType())
|
||||||
|
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,140 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#include "cudnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
// int extraParam0 = INT_ARG(13);
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
|
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||||
|
|
||||||
|
if(paddingMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(maxpool3dnew, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||||
|
|
||||||
|
return goodType && input->dataType() == output->dataType();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
|
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
const int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
const int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
const int sD = INT_ARG(3); // strides depth
|
||||||
|
const int sH = INT_ARG(4); // strides height
|
||||||
|
const int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
const int dD = INT_ARG(9); // dilations depth
|
||||||
|
const int dH = INT_ARG(10); // dilations height
|
||||||
|
const int dW = INT_ARG(11); // dilations width
|
||||||
|
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
// const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
|
||||||
|
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
|
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
if(isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CUDA) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
|
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
|
||||||
|
|
||||||
|
return goodType && (input->dataType() == gradO->dataType())
|
||||||
|
&& (input->dataType() == gradI->dataType())
|
||||||
|
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -30,111 +30,231 @@
|
||||||
using namespace dnnl;
|
using namespace dnnl;
|
||||||
using namespace samediff;
|
using namespace samediff;
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
//////////////////////////////////////////////////////////////////////////
|
||||||
input->rankOf());
|
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||||
auto argI = *(block.getIArguments());
|
input->rankOf());
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const auto kH = INT_ARG(0);
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
const auto kW = INT_ARG(1);
|
auto argI = *(block.getIArguments());
|
||||||
const auto sH = INT_ARG(2);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
const auto sW = INT_ARG(3);
|
|
||||||
int pH = INT_ARG(4);
|
|
||||||
int pW = INT_ARG(5);
|
|
||||||
const auto dH = INT_ARG(6);
|
|
||||||
const auto dW = INT_ARG(7);
|
|
||||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
|
||||||
const auto extraParam0 = INT_ARG(9);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
const auto kH = INT_ARG(0);
|
||||||
dH, dW);
|
const auto kW = INT_ARG(1);
|
||||||
|
const auto sH = INT_ARG(2);
|
||||||
|
const auto sW = INT_ARG(3);
|
||||||
|
int pH = INT_ARG(4);
|
||||||
|
int pW = INT_ARG(5);
|
||||||
|
const auto dH = INT_ARG(6);
|
||||||
|
const auto dW = INT_ARG(7);
|
||||||
|
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||||
|
const auto extraParam0 = INT_ARG(9);
|
||||||
|
|
||||||
int oH = 0;
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||||
int oW = 0;
|
dH, dW);
|
||||||
|
|
||||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
int oH = 0;
|
||||||
|
int oW = 0;
|
||||||
|
|
||||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
input = new NDArray(
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
output = new NDArray(
|
|
||||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
if (!isNCHW) {
|
||||||
|
input = new NDArray(
|
||||||
if (isSameMode)
|
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
output = new NDArray(
|
||||||
|
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
const int bS = input->sizeAt(0);
|
|
||||||
const int iC = input->sizeAt(1);
|
|
||||||
const int oC = output->sizeAt(1);
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::AVG_POOL;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
|
||||||
true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
|
||||||
&user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
|
||||||
pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
}
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory}});
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
|
||||||
}
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
//streams[0].submitAndWait();
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
if (isSameMode)
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
const int bS = input->sizeAt(0);
|
||||||
|
const int iC = input->sizeAt(1);
|
||||||
|
const int oC = output->sizeAt(1);
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::AVG_POOL;
|
||||||
|
|
||||||
|
dnnl_memory_desc_t empty;
|
||||||
|
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||||
|
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||||
|
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
dnnl::algorithm algorithm;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||||
|
true,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||||
|
&user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||||
|
pool_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
auto pool_dst_memory = user_dst_memory;
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||||
|
{DNNL_ARG_DST, pool_dst_memory}});
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
//streams[0].submitAndWait();
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(
|
||||||
|
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int extraParam0 = INT_ARG(9);
|
||||||
|
int isNCHW =
|
||||||
|
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||||
|
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||||
|
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||||
|
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||||
|
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||||
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||||
|
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||||
|
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
input = new NDArray(input->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::AVG_POOL;
|
||||||
|
|
||||||
|
dnnl_memory_desc_t empty;
|
||||||
|
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||||
|
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||||
|
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
dnnl::algorithm algorithm;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||||
|
true,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||||
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||||
|
&user_diff_src_md, &user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||||
|
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||||
|
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||||
|
pool_padding_r);
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
||||||
|
pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||||
|
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||||
|
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||||
|
auto poolB_src_memory = userB_src_memory;
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
auto poolB_dst_memory = userB_dst_memory;
|
||||||
|
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||||
|
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||||
|
}
|
||||||
|
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||||
|
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||||
|
}
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,149 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author saudet
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
|
||||||
#include <ops/declarable/OpRegistrator.h>
|
|
||||||
#include <platform_boilerplate.h>
|
|
||||||
|
|
||||||
#include <helpers/MKLDNNStream.h>
|
|
||||||
#include "mkldnnUtils.h"
|
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
|
||||||
|
|
||||||
using namespace dnnl;
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace platforms {
|
|
||||||
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
||||||
auto gradO = INPUT_VARIABLE(
|
|
||||||
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
|
||||||
int sH = INT_ARG(2); // strides height
|
|
||||||
int sW = INT_ARG(3); // strides width
|
|
||||||
int pH = INT_ARG(4); // paddings height
|
|
||||||
int pW = INT_ARG(5); // paddings width
|
|
||||||
int dH = INT_ARG(6); // dilations height
|
|
||||||
int dW = INT_ARG(7); // dilations width
|
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
|
||||||
int extraParam0 = INT_ARG(9);
|
|
||||||
int isNCHW =
|
|
||||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
|
||||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
|
||||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
|
||||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
|
||||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
|
||||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
|
||||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
|
||||||
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
input = new NDArray(input->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::AVG_POOL;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
|
||||||
true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
|
||||||
&user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
|
||||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
|
||||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
|
||||||
pool_padding_r);
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
|
||||||
pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
|
||||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
|
||||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
|
||||||
}
|
|
||||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
|
||||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
|
||||||
}
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -34,24 +34,23 @@ namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||||
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
||||||
const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode,
|
const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode,
|
||||||
const int isNCHW) {
|
const int isNCHW) {
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
||||||
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
dnnl_memory_desc_t empty;
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||||
empty);
|
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
|
||||||
empty);
|
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||||
bias, output,
|
bias, output,
|
||||||
|
@ -61,13 +60,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
||||||
&user_bias_md, &user_dst_md,
|
&user_bias_md, &user_dst_md,
|
||||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||||
|
|
||||||
auto conv_desc = bias != nullptr
|
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward,
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
algorithm::convolution_auto, conv_src_md,
|
||||||
conv_weights_md, conv_bias_md,
|
conv_weights_md, conv_bias_md,
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
conv_padding_r)
|
conv_padding_r)
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
algorithm::convolution_auto, conv_src_md,
|
algorithm::convolution_auto, conv_src_md,
|
||||||
conv_weights_md,
|
conv_weights_md,
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
|
@ -112,6 +110,135 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void conv2dBpMKLDNN(nd4j::graph::Context &block,
|
||||||
|
const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||||
|
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||||
|
const int kH, const int kW, const int sH,const int sW, int pH, int pW, const int dH, const int dW,
|
||||||
|
const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
|
|
||||||
|
dnnl_memory_desc_t empty;
|
||||||
|
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||||
|
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
|
|
||||||
|
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
||||||
|
gradB, gradO,
|
||||||
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||||
|
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
||||||
|
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||||
|
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||||
|
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||||
|
auto conv_desc = gradB != nullptr
|
||||||
|
? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||||
|
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine()));
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
if (gradW != nullptr) {
|
||||||
|
auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||||
|
: convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
|
||||||
|
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
|
||||||
|
|
||||||
|
auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||||
|
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||||
|
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convW_src_memory = userW_src_memory;
|
||||||
|
|
||||||
|
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||||
|
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
||||||
|
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,convW_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_weights_memory = userW_weights_memory;
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_dst_memory = userW_dst_memory;
|
||||||
|
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||||
|
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gradB != nullptr) {
|
||||||
|
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer());
|
||||||
|
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{DNNL_ARG_SRC, convW_src_memory},
|
||||||
|
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||||
|
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{DNNL_ARG_SRC, convW_src_memory},
|
||||||
|
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
||||||
|
userW_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gradI != nullptr) {
|
||||||
|
|
||||||
|
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
|
||||||
|
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
|
||||||
|
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
|
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,const_cast<NDArray *>(weights)->buffer());
|
||||||
|
auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convI_src_memory = userI_src_memory;
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_weights_memory = userI_weights_memory;
|
||||||
|
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||||
|
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
||||||
|
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_dst_memory = userI_dst_memory;
|
||||||
|
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||||
|
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||||
|
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
|
||||||
|
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
||||||
|
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
||||||
|
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
@ -132,7 +259,7 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||||
|
|
||||||
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -152,6 +279,7 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
@ -172,158 +300,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
||||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf());
|
||||||
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
|
REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf());
|
||||||
input->rankOf());
|
REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf());
|
||||||
REQUIRE_TRUE(weights->rankOf() == 4, 0,
|
|
||||||
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
|
|
||||||
weights->rankOf());
|
|
||||||
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
|
|
||||||
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
|
|
||||||
gradO->rankOf());
|
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
|
||||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
||||||
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
|
||||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
|
||||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
|
||||||
gradB, gradO,
|
|
||||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
|
||||||
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
|
||||||
&user_src_md, &user_diff_src_md, &user_weights_md,
|
|
||||||
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
|
||||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
|
||||||
auto conv_desc = gradB != nullptr
|
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
|
||||||
conv_padding_r)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
|
||||||
conv_padding_r);
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
|
||||||
LaunchContext::defaultContext()->engine()));
|
|
||||||
if (gradW != nullptr) {
|
|
||||||
auto convW_desc = gradB != nullptr
|
|
||||||
? convolution_backward_weights::desc(
|
|
||||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
|
||||||
: convolution_backward_weights::desc(
|
|
||||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
|
||||||
conv_prim_desc);
|
|
||||||
auto userW_src_memory = dnnl::memory(user_src_md, engine,
|
|
||||||
const_cast<NDArray *>(input)->buffer());
|
|
||||||
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
|
||||||
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
|
|
||||||
const_cast<NDArray *>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convW_src_memory = userW_src_memory;
|
|
||||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
|
||||||
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
|
||||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
|
||||||
convW_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_weights_memory = userW_weights_memory;
|
|
||||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
|
||||||
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_dst_memory = userW_dst_memory;
|
|
||||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
|
||||||
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
|
||||||
convW_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradB != nullptr) {
|
|
||||||
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
|
|
||||||
gradB->buffer());
|
|
||||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
|
||||||
{{DNNL_ARG_SRC, convW_src_memory},
|
|
||||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
|
||||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
|
||||||
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
|
||||||
} else {
|
|
||||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
|
||||||
{{DNNL_ARG_SRC, convW_src_memory},
|
|
||||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
|
||||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
|
||||||
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
|
||||||
userW_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradI != nullptr) {
|
|
||||||
auto convI_desc =
|
|
||||||
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
|
|
||||||
conv_weights_md, conv_dst_md, conv_strides, conv_dilation,
|
|
||||||
conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
|
||||||
conv_prim_desc);
|
|
||||||
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
|
||||||
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
|
|
||||||
const_cast<NDArray *>(weights)->buffer());
|
|
||||||
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
|
|
||||||
const_cast<NDArray *>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convI_src_memory = userI_src_memory;
|
|
||||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
|
||||||
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_weights_memory = userI_weights_memory;
|
|
||||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
|
||||||
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
|
||||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
|
||||||
convI_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_dst_memory = userI_dst_memory;
|
|
||||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
|
||||||
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
|
||||||
convI_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
|
||||||
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
|
|
||||||
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
|
||||||
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
|
||||||
|
|
||||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
|
||||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
|
||||||
userI_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
};
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,62 +34,23 @@ namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
static void conv3dMKLDNN(nd4j::graph::Context &block,
|
||||||
auto input = INPUT_VARIABLE(
|
const NDArray *input, const NDArray *weights, const NDArray *bias,
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
NDArray *output,
|
||||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW,
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
const int paddingMode, const int isNCDHW) {
|
||||||
auto output = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
|
||||||
"CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !",
|
|
||||||
input->rankOf());
|
|
||||||
REQUIRE_TRUE(weights->rankOf() == 5, 0,
|
|
||||||
"CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !",
|
|
||||||
weights->rankOf());
|
|
||||||
|
|
||||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
|
||||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
|
||||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
|
||||||
int sD = INT_ARG(3); // strides depth
|
|
||||||
int sH = INT_ARG(4); // strides height
|
|
||||||
int sW = INT_ARG(5); // strides width
|
|
||||||
int pD = INT_ARG(6); // paddings depth
|
|
||||||
int pH = INT_ARG(7); // paddings height
|
|
||||||
int pW = INT_ARG(8); // paddings width
|
|
||||||
int dD = INT_ARG(9); // dilations depth
|
|
||||||
int dH = INT_ARG(10); // dilations height
|
|
||||||
int dW = INT_ARG(11); // dilations width
|
|
||||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
|
||||||
int isNCDHW =
|
|
||||||
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
|
||||||
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
|
|
||||||
"CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !",
|
|
||||||
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
|
||||||
if (bias)
|
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
|
|
||||||
"CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
|
|
||||||
oC, bias->rankOf(), bias->lengthOf());
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
|
||||||
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
dnnl_memory_desc_t empty;
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty);
|
||||||
empty);
|
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty);
|
||||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
|
||||||
empty);
|
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode,
|
||||||
isNCDHW,
|
isNCDHW,
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
|
||||||
nullptr, bias, output,
|
nullptr, bias, output,
|
||||||
|
@ -98,151 +59,73 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
||||||
&user_src_md, nullptr, &user_weights_md, nullptr,
|
&user_src_md, nullptr, &user_weights_md, nullptr,
|
||||||
&user_bias_md, &user_dst_md,
|
&user_bias_md, &user_dst_md,
|
||||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||||
auto conv_desc = bias != nullptr
|
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
|
||||||
conv_padding_r)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
|
||||||
conv_padding_r);
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
dnnl::stream stream(engine);
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||||
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
|
auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast<NDArray *>(weights)->buffer());
|
||||||
const_cast<NDArray *>(weights)->buffer());
|
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||||
|
|
||||||
auto conv_src_memory = user_src_memory;
|
auto conv_src_memory = user_src_memory;
|
||||||
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
|
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
|
||||||
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
|
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto conv_weights_memory = user_weights_memory;
|
auto conv_weights_memory = user_weights_memory;
|
||||||
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
|
||||||
conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
|
conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
|
||||||
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
|
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory);
|
||||||
conv_weights_memory);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto conv_dst_memory = user_dst_memory;
|
auto conv_dst_memory = user_dst_memory;
|
||||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
|
conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (bias != nullptr) {
|
if (bias != nullptr) {
|
||||||
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
|
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->getBuffer());
|
||||||
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
||||||
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
||||||
{DNNL_ARG_BIAS, conv_bias_memory},
|
{DNNL_ARG_BIAS, conv_bias_memory},
|
||||||
{DNNL_ARG_DST, conv_dst_memory}});
|
{DNNL_ARG_DST, conv_dst_memory}});
|
||||||
} else {
|
}
|
||||||
|
else {
|
||||||
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
|
||||||
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
{DNNL_ARG_WEIGHTS, conv_weights_memory},
|
||||||
{DNNL_ARG_DST, conv_dst_memory}});
|
{DNNL_ARG_DST, conv_dst_memory}});
|
||||||
}
|
}
|
||||||
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
|
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc())
|
||||||
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
|
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
|
||||||
}
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
if (::optimalLevel() < 2)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
auto output = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
static void conv3dBpMKLDNN(nd4j::graph::Context &block,
|
||||||
auto input = INPUT_VARIABLE(
|
const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||||
auto weights = INPUT_VARIABLE(
|
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW,
|
||||||
1); // [kD, kH, kW, iC, oC] always
|
const int paddingMode, const int isNCDHW) {
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
|
||||||
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kD, kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
|
|
||||||
input->rankOf());
|
|
||||||
REQUIRE_TRUE(weights->rankOf() == 5, 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
|
|
||||||
weights->rankOf());
|
|
||||||
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
|
|
||||||
gradO->rankOf());
|
|
||||||
|
|
||||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
|
||||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
|
||||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
|
||||||
int sD = INT_ARG(3); // strides depth
|
|
||||||
int sH = INT_ARG(4); // strides height
|
|
||||||
int sW = INT_ARG(5); // strides width
|
|
||||||
int pD = INT_ARG(6); // paddings depth
|
|
||||||
int pH = INT_ARG(7); // paddings height
|
|
||||||
int pW = INT_ARG(8); // paddings width
|
|
||||||
int dD = INT_ARG(9); // dilations depth
|
|
||||||
int dH = INT_ARG(10); // dilations height
|
|
||||||
int dW = INT_ARG(11); // dilations width
|
|
||||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
|
||||||
int isNDHWC =
|
|
||||||
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
|
||||||
|
|
||||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
|
||||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
|
|
||||||
dW, iD, iH, iW, isSameMode);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
|
||||||
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
|
|
||||||
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
|
||||||
if (bias)
|
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
|
|
||||||
oC, bias->rankOf(), bias->lengthOf());
|
|
||||||
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
dnnl_memory_desc_t empty;
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
|
||||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
|
||||||
isNDHWC,
|
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode,
|
||||||
|
isNCDHW,
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
|
||||||
gradW, gradB, gradO,
|
gradW, gradB, gradO,
|
||||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||||
|
@ -250,43 +133,30 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||||
&user_src_md, &user_diff_src_md, &user_weights_md,
|
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||||
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||||
auto conv_desc = gradB != nullptr
|
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
|
||||||
conv_padding_r)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
|
||||||
conv_padding_r);
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
|
||||||
LaunchContext::defaultContext()->engine()));
|
|
||||||
if (gradW != nullptr) {
|
|
||||||
auto convW_desc = gradB != nullptr
|
|
||||||
? convolution_backward_weights::desc(
|
|
||||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
|
||||||
: convolution_backward_weights::desc(
|
|
||||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto conv_desc = gradB != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||||
dnnl::stream stream(engine);
|
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
|
||||||
conv_prim_desc);
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()));
|
||||||
auto userW_src_memory = dnnl::memory(user_src_md, engine,
|
|
||||||
const_cast<NDArray *>(input)->buffer());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
if (gradW != nullptr) {
|
||||||
|
|
||||||
|
auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||||
|
: convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
|
||||||
|
|
||||||
|
auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||||
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||||
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
|
auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
|
||||||
const_cast<NDArray *>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convW_src_memory = userW_src_memory;
|
auto convW_src_memory = userW_src_memory;
|
||||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||||
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
|
||||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory);
|
||||||
convW_src_memory);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto convW_weights_memory = userW_weights_memory;
|
auto convW_weights_memory = userW_weights_memory;
|
||||||
|
@ -297,65 +167,53 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||||
auto convW_dst_memory = userW_dst_memory;
|
auto convW_dst_memory = userW_dst_memory;
|
||||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||||
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory);
|
||||||
convW_dst_memory);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gradB != nullptr) {
|
if (gradB != nullptr) {
|
||||||
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
|
|
||||||
gradB->buffer());
|
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer());
|
||||||
|
|
||||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
{{DNNL_ARG_SRC, convW_src_memory},
|
{{DNNL_ARG_SRC, convW_src_memory},
|
||||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||||
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||||
} else {
|
}
|
||||||
|
else {
|
||||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
{{DNNL_ARG_SRC, convW_src_memory},
|
{{DNNL_ARG_SRC, convW_src_memory},
|
||||||
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
{DNNL_ARG_DIFF_DST, convW_dst_memory},
|
||||||
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc())
|
||||||
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory);
|
||||||
userW_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
if (gradI != nullptr) {
|
if (gradI != nullptr) {
|
||||||
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
|
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||||
conv_diff_src_md, conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
|
||||||
conv_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
|
||||||
conv_prim_desc);
|
|
||||||
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
|
auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast<NDArray *>(weights)->buffer());
|
||||||
const_cast<NDArray *>(weights)->buffer());
|
auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
|
||||||
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
|
|
||||||
const_cast<NDArray *>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convI_src_memory = userI_src_memory;
|
auto convI_src_memory = userI_src_memory;
|
||||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc())
|
||||||
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_weights_memory = userI_weights_memory;
|
auto convI_weights_memory = userI_weights_memory;
|
||||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||||
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
|
||||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory);
|
||||||
convI_weights_memory);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto convI_dst_memory = userI_dst_memory;
|
auto convI_dst_memory = userI_dst_memory;
|
||||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||||
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory);
|
||||||
convI_dst_memory);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||||
|
@ -363,30 +221,128 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||||
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
{DNNL_ARG_WEIGHTS, convI_weights_memory},
|
||||||
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
|
||||||
|
|
||||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc())
|
||||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
|
||||||
userI_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||||
|
|
||||||
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||||
|
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
if (paddingMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf());
|
||||||
|
|
||||||
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
if(paddingMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||||
|
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
|
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kD, kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
|
||||||
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
1); // [kD, kH, kW, iC, oC] always
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
return block.isUseMKLDNN() &&
|
return block.isUseMKLDNN() &&
|
||||||
|
|
|
@ -177,7 +177,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
const int paddingMode) {
|
const int paddingMode) {
|
||||||
|
|
||||||
|
@ -492,7 +492,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
|
||||||
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
||||||
|
|
||||||
delete weights;
|
delete weights;
|
||||||
delete gradW;
|
delete gradW;
|
||||||
|
|
|
@ -421,7 +421,7 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
|
||||||
return block.isUseMKLDNN() && mC == 1 &&
|
return block.isUseMKLDNN() && mC == 1 &&
|
||||||
(
|
(
|
||||||
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF) ||
|
(xType==DataType::BFLOAT16 && wType==DataType::BFLOAT16 && bType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) ||
|
||||||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,117 +29,258 @@
|
||||||
|
|
||||||
using namespace dnnl;
|
using namespace dnnl;
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
//////////////////////////////////////////////////////////////////////////
|
||||||
input->rankOf());
|
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
||||||
auto argI = *(block.getIArguments());
|
input->rankOf());
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const auto kH = INT_ARG(0);
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
const auto kW = INT_ARG(1);
|
auto argI = *(block.getIArguments());
|
||||||
const auto sH = INT_ARG(2);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
const auto sW = INT_ARG(3);
|
|
||||||
int pH = INT_ARG(4);
|
|
||||||
int pW = INT_ARG(5);
|
|
||||||
const auto dH = INT_ARG(6);
|
|
||||||
const auto dW = INT_ARG(7);
|
|
||||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
|
||||||
|
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
const auto kH = INT_ARG(0);
|
||||||
dH, dW);
|
const auto kW = INT_ARG(1);
|
||||||
|
const auto sH = INT_ARG(2);
|
||||||
|
const auto sW = INT_ARG(3);
|
||||||
|
int pH = INT_ARG(4);
|
||||||
|
int pW = INT_ARG(5);
|
||||||
|
const auto dH = INT_ARG(6);
|
||||||
|
const auto dW = INT_ARG(7);
|
||||||
|
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
||||||
|
|
||||||
int oH = 0;
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
||||||
int oW = 0;
|
dH, dW);
|
||||||
|
|
||||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
int oH = 0;
|
||||||
|
int oW = 0;
|
||||||
|
|
||||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
input = new NDArray(
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
output = new NDArray(
|
|
||||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
if (!isNCHW) {
|
||||||
|
input = new NDArray(
|
||||||
if (isSameMode)
|
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
output = new NDArray(
|
||||||
|
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
const int bS = input->sizeAt(0);
|
|
||||||
const int iC = input->sizeAt(1);
|
|
||||||
const int oC = output->sizeAt(1);
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::MAX_POOL;
|
|
||||||
int extraParam0 = 1;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
|
||||||
true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
|
||||||
&user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
|
||||||
pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory}});
|
|
||||||
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
if (isSameMode)
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
const int bS = input->sizeAt(0);
|
||||||
|
const int iC = input->sizeAt(1);
|
||||||
|
const int oC = output->sizeAt(1);
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::MAX_POOL;
|
||||||
|
int extraParam0 = 1;
|
||||||
|
|
||||||
|
dnnl_memory_desc_t empty;
|
||||||
|
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||||
|
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||||
|
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
dnnl::algorithm algorithm;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||||
|
true,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||||
|
&user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||||
|
pool_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||||
|
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool_dst_memory = user_dst_memory;
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||||
|
{DNNL_ARG_DST, pool_dst_memory}});
|
||||||
|
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int extraParam0 = INT_ARG(9);
|
||||||
|
int isNCHW =
|
||||||
|
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||||
|
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
||||||
|
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||||
|
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
||||||
|
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||||
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
||||||
|
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||||
|
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
input = new NDArray(input->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute(
|
||||||
|
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::MAX_POOL;
|
||||||
|
|
||||||
|
dnnl_memory_desc_t empty;
|
||||||
|
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||||
|
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||||
|
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
dnnl::algorithm algorithm;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
||||||
|
true,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
||||||
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||||
|
&user_diff_src_md, &user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
||||||
|
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
||||||
|
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||||
|
pool_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
|
||||||
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||||
|
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
||||||
|
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||||
|
|
||||||
|
auto poolB_src_memory = userB_src_memory;
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto poolB_dst_memory = userB_dst_memory;
|
||||||
|
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||||
|
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||||
|
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||||
|
{DNNL_ARG_DST, pool_dst_memory},
|
||||||
|
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||||
|
// probably wrong, fix that
|
||||||
|
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||||
|
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||||
|
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||||
|
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,174 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author saudet
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
|
||||||
#include <ops/declarable/OpRegistrator.h>
|
|
||||||
#include <platform_boilerplate.h>
|
|
||||||
|
|
||||||
#include <helpers/MKLDNNStream.h>
|
|
||||||
#include "mkldnnUtils.h"
|
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
|
||||||
|
|
||||||
using namespace dnnl;
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace platforms {
|
|
||||||
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
||||||
auto gradO = INPUT_VARIABLE(
|
|
||||||
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
|
||||||
int sH = INT_ARG(2); // strides height
|
|
||||||
int sW = INT_ARG(3); // strides width
|
|
||||||
int pH = INT_ARG(4); // paddings height
|
|
||||||
int pW = INT_ARG(5); // paddings width
|
|
||||||
int dH = INT_ARG(6); // dilations height
|
|
||||||
int dW = INT_ARG(7); // dilations width
|
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
|
||||||
int extraParam0 = INT_ARG(9);
|
|
||||||
int isNCHW =
|
|
||||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
|
||||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
|
||||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
|
||||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
|
||||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
|
||||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
|
||||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
|
||||||
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
input = new NDArray(input->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::MAX_POOL;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
|
||||||
true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
|
||||||
&user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
|
||||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
|
||||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
|
||||||
pool_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
|
||||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
|
||||||
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
|
||||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
|
||||||
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory},
|
|
||||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
|
||||||
// probably wrong, fix that
|
|
||||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
|
||||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
|
||||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
|
||||||
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -28,124 +28,273 @@
|
||||||
|
|
||||||
using namespace dnnl;
|
using namespace dnnl;
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
auto output = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
|
||||||
|
|
||||||
int kD = INT_ARG(0); // filter(kernel) depth
|
//////////////////////////////////////////////////////////////////////////
|
||||||
int kH = INT_ARG(1); // filter(kernel) height
|
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
||||||
int kW = INT_ARG(2); // filter(kernel) width
|
auto input = INPUT_VARIABLE(
|
||||||
int sD = INT_ARG(3); // strides depth
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
int sH = INT_ARG(4); // strides height
|
auto output = OUTPUT_VARIABLE(
|
||||||
int sW = INT_ARG(5); // strides width
|
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
int pD = INT_ARG(6); // paddings depth
|
|
||||||
int pH = INT_ARG(7); // paddings height
|
|
||||||
int pW = INT_ARG(8); // paddings width
|
|
||||||
int dD = INT_ARG(9); // dilations depth
|
|
||||||
int dH = INT_ARG(10); // dilations height
|
|
||||||
int dW = INT_ARG(11); // dilations width
|
|
||||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
|
||||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
input->rankOf());
|
int kW = INT_ARG(2); // filter(kernel) width
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
int sD = INT_ARG(3); // strides depth
|
||||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
input->rankOf());
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||||
|
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
|
||||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
|
||||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
input = new NDArray(
|
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
||||||
output = new NDArray(
|
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
||||||
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||||
}
|
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||||
|
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
if (!isNCDHW) {
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
input = new NDArray(
|
||||||
dW);
|
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
output = new NDArray(
|
||||||
|
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
auto poolingMode = PoolingType::MAX_POOL;
|
|
||||||
auto extraParam0 = 1;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
|
||||||
extraParam0, true,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
|
||||||
&user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
|
||||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
|
||||||
pool_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory}});
|
|
||||||
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||||
|
dW);
|
||||||
|
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::MAX_POOL;
|
||||||
|
auto extraParam0 = 1;
|
||||||
|
|
||||||
|
dnnl_memory_desc_t empty;
|
||||||
|
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
||||||
|
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
||||||
|
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
dnnl::algorithm algorithm;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||||
|
extraParam0, true,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
||||||
|
&user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
||||||
|
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
||||||
|
pool_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||||
|
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
||||||
|
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool_dst_memory = user_dst_memory;
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||||
|
{DNNL_ARG_DST, pool_dst_memory}});
|
||||||
|
|
||||||
|
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
|
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
const int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
const int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
const int sD = INT_ARG(3); // strides depth
|
||||||
|
const int sH = INT_ARG(4); // strides height
|
||||||
|
const int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
const int dD = INT_ARG(9); // dilations depth
|
||||||
|
const int dH = INT_ARG(10); // dilations height
|
||||||
|
const int dW = INT_ARG(11); // dilations width
|
||||||
|
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
|
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
||||||
|
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
||||||
|
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
||||||
|
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
input = new NDArray(input->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute(
|
||||||
|
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
||||||
|
dW);
|
||||||
|
|
||||||
|
|
||||||
|
auto poolingMode = PoolingType::MAX_POOL;
|
||||||
|
auto extraParam0 = 1;
|
||||||
|
|
||||||
|
dnnl_memory_desc_t empty;
|
||||||
|
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
||||||
|
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
||||||
|
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
||||||
|
dnnl::algorithm algorithm;
|
||||||
|
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
||||||
|
extraParam0, true,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
||||||
|
algorithm,
|
||||||
|
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
||||||
|
&user_diff_src_md, &user_dst_md,
|
||||||
|
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
// input is sometimes null, so we can't rely on pool_src_md being valid
|
||||||
|
if (input->buffer() == nullptr) {
|
||||||
|
pool_src_md = pool_diff_src_md;
|
||||||
|
user_src_md = user_diff_src_md;
|
||||||
|
}
|
||||||
|
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
||||||
|
|
||||||
|
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
||||||
|
|
||||||
|
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
||||||
|
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
|
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
||||||
|
|
||||||
|
auto poolB_src_memory = userB_src_memory;
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto poolB_dst_memory = userB_dst_memory;
|
||||||
|
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
||||||
|
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
||||||
|
|
||||||
|
auto pool_src_memory = user_src_memory;
|
||||||
|
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
|
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
||||||
|
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
||||||
|
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
||||||
|
|
||||||
|
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
||||||
|
{DNNL_ARG_DST, pool_dst_memory},
|
||||||
|
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
||||||
|
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
||||||
|
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
||||||
|
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
||||||
|
|
||||||
|
|
||||||
|
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
||||||
|
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
if (!isNCDHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,181 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
|
||||||
#include <ops/declarable/OpRegistrator.h>
|
|
||||||
#include <platform_boilerplate.h>
|
|
||||||
|
|
||||||
#include <helpers/MKLDNNStream.h>
|
|
||||||
#include "mkldnnUtils.h"
|
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
|
||||||
|
|
||||||
using namespace dnnl;
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace platforms {
|
|
||||||
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
auto gradO = INPUT_VARIABLE(
|
|
||||||
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
|
||||||
|
|
||||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
|
||||||
const int kH = INT_ARG(1); // filter(kernel) height
|
|
||||||
const int kW = INT_ARG(2); // filter(kernel) width
|
|
||||||
const int sD = INT_ARG(3); // strides depth
|
|
||||||
const int sH = INT_ARG(4); // strides height
|
|
||||||
const int sW = INT_ARG(5); // strides width
|
|
||||||
int pD = INT_ARG(6); // paddings depth
|
|
||||||
int pH = INT_ARG(7); // paddings height
|
|
||||||
int pW = INT_ARG(8); // paddings width
|
|
||||||
const int dD = INT_ARG(9); // dilations depth
|
|
||||||
const int dH = INT_ARG(10); // dilations height
|
|
||||||
const int dW = INT_ARG(11); // dilations width
|
|
||||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
|
||||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
|
||||||
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
|
||||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
|
||||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
|
||||||
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
|
||||||
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
input = new NDArray(input->permute(
|
|
||||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute(
|
|
||||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute(
|
|
||||||
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
|
||||||
dW);
|
|
||||||
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::MAX_POOL;
|
|
||||||
auto extraParam0 = 1;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
|
||||||
extraParam0, true,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
|
||||||
&user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
|
||||||
if (input->buffer() == nullptr) {
|
|
||||||
pool_src_md = pool_diff_src_md;
|
|
||||||
user_src_md = user_diff_src_md;
|
|
||||||
}
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
|
||||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
|
||||||
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
|
||||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
|
||||||
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory},
|
|
||||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
|
||||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
|
||||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
|
||||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
|
||||||
|
|
||||||
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -23,383 +23,388 @@
|
||||||
|
|
||||||
using namespace dnnl;
|
using namespace dnnl;
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace mkldnnUtils {
|
namespace mkldnnUtils {
|
||||||
void getMKLDNNMemoryDescPool2d(
|
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
|
||||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
|
||||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
|
||||||
dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
|
||||||
dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
|
||||||
|
|
||||||
pool_strides = { sH, sW };
|
//////////////////////////////////////////////////////////////////////////
|
||||||
pool_kernel = { kH, kW };
|
void getMKLDNNMemoryDescPool2d(
|
||||||
pool_padding = { pH, pW };
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||||
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||||
(oW - 1) * sW - iW + kW - pW };
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||||
|
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||||
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||||
|
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
||||||
|
dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
||||||
|
dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
||||||
|
|
||||||
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
pool_strides = { sH, sW };
|
||||||
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
pool_kernel = { kH, kW };
|
||||||
: algorithm::pooling_avg_include_padding;
|
pool_padding = { pH, pW };
|
||||||
auto type = dnnl::memory::data_type::f32;
|
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
(oW - 1) * sW - iW + kW - pW };
|
||||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||||
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||||
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
: algorithm::pooling_avg_include_padding;
|
||||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
auto type = dnnl::memory::data_type::f32;
|
||||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||||
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||||
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
|
||||||
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
|
||||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
void getMKLDNNMemoryDescPool3d(
|
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
|
||||||
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
|
||||||
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
|
||||||
dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
|
||||||
dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
|
||||||
|
|
||||||
pool_strides = { sD, sH, sW };
|
|
||||||
pool_kernel = { kD, kH, kW };
|
|
||||||
pool_padding = { pD, pH, pW };
|
|
||||||
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
|
||||||
(oH - 1) * sH - iH + kH - pH,
|
|
||||||
(oW - 1) * sW - iW + kW - pW };
|
|
||||||
|
|
||||||
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
|
||||||
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
|
||||||
: algorithm::pooling_avg_include_padding;
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
|
||||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
|
||||||
auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
|
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
|
||||||
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
|
||||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
|
||||||
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
|
||||||
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
|
||||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void getMKLDNNMemoryDescConv2d(
|
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
||||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
|
||||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
|
||||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
|
||||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
|
||||||
dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
|
||||||
dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
|
||||||
dnnl::memory::dims conv_bias_tz = { oC };
|
|
||||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
|
||||||
|
|
||||||
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
|
||||||
|
|
||||||
conv_strides = { sH, sW };
|
|
||||||
conv_padding = { pH, pW };
|
|
||||||
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
|
||||||
conv_dilation = { dH-1, dW-1};
|
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
|
||||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
|
||||||
auto formatw = dnnl::memory::format_tag::hwio;
|
|
||||||
|
|
||||||
if (src != nullptr && conv_src_md != nullptr) {
|
|
||||||
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
|
||||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
|
||||||
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
|
||||||
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
|
||||||
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
|
||||||
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
|
||||||
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
|
||||||
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
|
||||||
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
|
||||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void getMKLDNNMemoryDescConv3d(
|
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW,
|
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
|
||||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
|
||||||
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
|
||||||
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
|
||||||
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
|
||||||
dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
|
||||||
dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
|
||||||
dnnl::memory::dims conv_bias_tz = { oC };
|
|
||||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
|
||||||
|
|
||||||
conv_strides = { sD, sH, sW };
|
|
||||||
conv_padding = { pD, pH, pW };
|
|
||||||
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
|
||||||
conv_dilation = { dD-1, dH-1, dW-1};
|
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
|
||||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
|
||||||
auto formatw = dnnl::memory::format_tag::dhwio;
|
|
||||||
|
|
||||||
if (src != nullptr && conv_src_md != nullptr) {
|
|
||||||
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
|
||||||
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
|
||||||
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weights != nullptr && conv_weights_md != nullptr) {
|
|
||||||
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
|
||||||
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
|
||||||
user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
|
||||||
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
|
||||||
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
|
||||||
user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bias != nullptr && conv_bias_md != nullptr) {
|
|
||||||
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && conv_dst_md != nullptr) {
|
|
||||||
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
|
||||||
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
|
||||||
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
|
||||||
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
|
||||||
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
|
||||||
// const Nd4jLong* shape = src->getShapeInfo();
|
|
||||||
// Nd4jLong rank = shape[0];
|
|
||||||
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
|
||||||
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
|
||||||
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
|
||||||
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
|
||||||
|
|
||||||
// auto type = dnnl::memory::data_type::f32;
|
|
||||||
// auto format = dnnl::memory::format_tag::nchw;
|
|
||||||
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
|
||||||
|
|
||||||
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
|
||||||
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
|
||||||
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
|
||||||
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
|
||||||
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
|
||||||
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
|
||||||
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
|
||||||
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
|
||||||
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
|
||||||
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
|
||||||
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
|
||||||
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
|
||||||
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
|
|
||||||
|
|
||||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
|
||||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
|
||||||
const Nd4jLong* shape = src->getShapeInfo();
|
|
||||||
long rank = shape[0];
|
|
||||||
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
|
||||||
long dim2 = axis >= 2 ? 1 : 2;
|
|
||||||
long dim3 = axis >= 3 ? 2 : 3;
|
|
||||||
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
|
||||||
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
|
||||||
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
|
||||||
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
|
||||||
user_src_md->data.format_kind = dnnl_blocked;
|
|
||||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
|
||||||
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
|
||||||
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
|
||||||
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
|
||||||
user_diff_src_md->data.format_kind = dnnl_blocked;
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
|
||||||
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
|
||||||
user_dst_md->data.format_kind = dnnl_blocked;
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
|
||||||
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
dnnl::engine& getEngine(void *ptr) {
|
|
||||||
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
|
||||||
return *eng;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||||
|
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||||
|
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void getMKLDNNMemoryDescPool3d(
|
||||||
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||||
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||||
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
|
||||||
|
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
|
||||||
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
|
||||||
|
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
|
||||||
|
dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
||||||
|
dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
||||||
|
|
||||||
|
pool_strides = { sD, sH, sW };
|
||||||
|
pool_kernel = { kD, kH, kW };
|
||||||
|
pool_padding = { pD, pH, pW };
|
||||||
|
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||||
|
(oH - 1) * sH - iH + kH - pH,
|
||||||
|
(oW - 1) * sW - iW + kW - pW };
|
||||||
|
|
||||||
|
algorithm = poolingMode == 0 ? algorithm::pooling_max
|
||||||
|
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
|
||||||
|
: algorithm::pooling_avg_include_padding;
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
|
auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
|
||||||
|
|
||||||
|
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||||
|
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||||
|
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||||
|
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void getMKLDNNMemoryDescConv2d(
|
||||||
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
||||||
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||||
|
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||||
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||||
|
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||||
|
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
||||||
|
dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
||||||
|
dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
||||||
|
dnnl::memory::dims conv_bias_tz = { oC };
|
||||||
|
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||||
|
|
||||||
|
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||||
|
|
||||||
|
conv_strides = { sH, sW };
|
||||||
|
conv_padding = { pH, pW };
|
||||||
|
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||||
|
conv_dilation = { dH-1, dW-1};
|
||||||
|
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
auto formatw = dnnl::memory::format_tag::hwio;
|
||||||
|
|
||||||
|
if (src != nullptr && conv_src_md != nullptr) {
|
||||||
|
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||||
|
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||||
|
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||||
|
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||||
|
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||||
|
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||||
|
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||||
|
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void getMKLDNNMemoryDescConv3d(
|
||||||
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW,
|
||||||
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||||
|
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
|
||||||
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
|
||||||
|
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
|
||||||
|
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
|
||||||
|
dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
||||||
|
dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
||||||
|
dnnl::memory::dims conv_bias_tz = { oC };
|
||||||
|
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||||
|
|
||||||
|
conv_strides = { sD, sH, sW };
|
||||||
|
conv_padding = { pD, pH, pW };
|
||||||
|
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||||
|
conv_dilation = { dD-1, dH-1, dW-1};
|
||||||
|
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
|
auto formatw = dnnl::memory::format_tag::dhwio;
|
||||||
|
|
||||||
|
if (src != nullptr && conv_src_md != nullptr) {
|
||||||
|
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
|
||||||
|
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weights != nullptr && conv_weights_md != nullptr) {
|
||||||
|
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||||
|
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
|
||||||
|
user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
|
||||||
|
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
|
||||||
|
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
|
||||||
|
user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bias != nullptr && conv_bias_md != nullptr) {
|
||||||
|
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && conv_dst_md != nullptr) {
|
||||||
|
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
|
||||||
|
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||||
|
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||||
|
// const Nd4jLong* shape = src->getShapeInfo();
|
||||||
|
// Nd4jLong rank = shape[0];
|
||||||
|
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||||
|
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||||
|
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||||
|
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||||
|
|
||||||
|
// auto type = dnnl::memory::data_type::f32;
|
||||||
|
// auto format = dnnl::memory::format_tag::nchw;
|
||||||
|
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||||
|
|
||||||
|
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||||
|
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
|
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||||
|
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||||
|
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||||
|
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||||
|
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
|
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||||
|
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||||
|
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||||
|
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||||
|
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
|
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||||
|
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||||
|
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||||
|
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||||
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||||
|
const Nd4jLong* shape = src->getShapeInfo();
|
||||||
|
long rank = shape[0];
|
||||||
|
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||||
|
long dim2 = axis >= 2 ? 1 : 2;
|
||||||
|
long dim3 = axis >= 3 ? 2 : 3;
|
||||||
|
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||||
|
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
||||||
|
|
||||||
|
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
||||||
|
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = dnnl_blocked;
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
||||||
|
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = dnnl_blocked;
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
||||||
|
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = dnnl_blocked;
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
dnnl::engine& getEngine(void *ptr) {
|
||||||
|
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
||||||
|
return *eng;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
File diff suppressed because one or more lines are too long
|
@ -970,7 +970,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) {
|
||||||
|
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::maxpool2d op;
|
nd4j::ops::maxpool2d op;
|
||||||
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1});
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1});
|
||||||
|
|
||||||
|
@ -991,7 +990,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) {
|
||||||
|
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::maxpool2d op;
|
nd4j::ops::maxpool2d op;
|
||||||
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1});
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1});
|
||||||
|
|
||||||
|
@ -1012,7 +1010,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) {
|
||||||
|
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::maxpool2d op;
|
nd4j::ops::maxpool2d op;
|
||||||
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0});
|
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0});
|
||||||
|
|
||||||
|
@ -1467,11 +1464,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) {
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||||||
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
||||||
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f,
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f,
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f,
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f,
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f,
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f,
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f,
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f,
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f});
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f});
|
||||||
|
|
||||||
input.linspace(1.);
|
input.linspace(1.);
|
||||||
gradO.linspace(0.1, 0.1);
|
gradO.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,17 @@ TEST_F(CuDnnTests, helpers_includer) {
|
||||||
nd4j::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d;
|
nd4j::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d;
|
||||||
nd4j::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp;
|
nd4j::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp;
|
||||||
nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm;
|
nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm;
|
||||||
|
nd4j::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp;
|
||||||
|
nd4j::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d;
|
||||||
|
nd4j::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp;
|
||||||
|
nd4j::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d;
|
||||||
|
nd4j::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp;
|
||||||
|
nd4j::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew;
|
||||||
|
nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp;
|
||||||
|
nd4j::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew;
|
||||||
|
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
printer({&conv2d});
|
printer({&conv2d});
|
||||||
printer({&conv2d_bp});
|
printer({&conv2d_bp});
|
||||||
|
@ -65,6 +76,15 @@ TEST_F(CuDnnTests, helpers_includer) {
|
||||||
printer({&depthwise_conv2d});
|
printer({&depthwise_conv2d});
|
||||||
printer({&depthwise_conv2d_bp});
|
printer({&depthwise_conv2d_bp});
|
||||||
printer({&batchnorm});
|
printer({&batchnorm});
|
||||||
|
printer({&batchnorm_bp});
|
||||||
|
printer({&avgpool2d});
|
||||||
|
printer({&avgpool2d_bp});
|
||||||
|
printer({&maxpool2d});
|
||||||
|
printer({&maxpool2d_bp});
|
||||||
|
printer({&avgpool3dnew});
|
||||||
|
printer({&avgpool3dnew_bp});
|
||||||
|
printer({&maxpool3dnew});
|
||||||
|
printer({&maxpool3dnew_bp});
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <ops/ops.h>
|
#include <ops/ops.h>
|
||||||
#include <GradCheck.h>
|
#include <GradCheck.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
|
|
||||||
|
@ -2247,3 +2248,525 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) {
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
variance.assign(0.46666667);
|
||||||
|
gamma.assign(1.2);
|
||||||
|
beta.assign(1.); // has no effect on gradient calculations
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747,
|
||||||
|
0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978,
|
||||||
|
-0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
// beta.assign(1.); // has no effect on gradient calculations
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002,
|
||||||
|
0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000,
|
||||||
|
-0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
// beta.assign(1.); // has no effect on gradient calculations
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test5) {
|
||||||
|
|
||||||
|
#if defined(HAVE_CUDNN)
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243,
|
||||||
|
-1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118,
|
||||||
|
-0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test6) {
|
||||||
|
|
||||||
|
#if defined(HAVE_CUDNN)
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295,
|
||||||
|
0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295,
|
||||||
|
-0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test7) {
|
||||||
|
|
||||||
|
#if defined(HAVE_CUDNN)
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142,
|
||||||
|
-43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662,
|
||||||
|
15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032,
|
||||||
|
-15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788,
|
||||||
|
-27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
// dLdI->printBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test8) {
|
||||||
|
|
||||||
|
#if defined(HAVE_CUDNN)
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301,
|
||||||
|
32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767,
|
||||||
|
-27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526,
|
||||||
|
30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773,
|
||||||
|
31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
// dLdI->printBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,4,2,2}, {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818,
|
||||||
|
-0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063,
|
||||||
|
-0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, 0.036168, 0.040426}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(1,0.01);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
// calculate mean and variance of input
|
||||||
|
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
|
||||||
|
std::vector<int> dimensions = {0,2,3};
|
||||||
|
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
|
||||||
|
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions);
|
||||||
|
NDArray::prepareSpecialUse({&variance}, {&input});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
|
||||||
|
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
|
||||||
|
manager.synchronize();
|
||||||
|
NDArray::registerSpecialUse({&variance}, {&input});
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,2,2,4}, {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, 0.027213, -0.029205, 0.013996, -0.015192, 0.016343,
|
||||||
|
-0.017519, 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, -0.005452, 0.005824, -0.013974, 0.015171,
|
||||||
|
-0.016325, 0.017508, -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, -0.038118, 0.040878}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(1,0.01);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
// calculate mean and variance of input
|
||||||
|
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
|
||||||
|
std::vector<int> dimensions = {0,1,2};
|
||||||
|
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
|
||||||
|
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions);
|
||||||
|
NDArray::prepareSpecialUse({&variance}, {&input});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
|
||||||
|
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
|
||||||
|
manager.synchronize();
|
||||||
|
NDArray::registerSpecialUse({&variance}, {&input});
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,3,4,5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,3,4,5}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,3,4,5}, {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, 0.003001, 0.002837,
|
||||||
|
0.002670, 0.002505, 0.002337, 0.002167, 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, 0.000996, 0.000830, 0.000664, 0.000498,
|
||||||
|
0.000332, 0.000166, -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, -0.001003, -0.001168, -0.001337, -0.001502, -0.001670,
|
||||||
|
-0.001838, -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821,
|
||||||
|
-0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, -0.003985,
|
||||||
|
-0.003836, -0.003661, -0.003505, -0.003338, -0.003171, -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, -0.002003, -0.001835,
|
||||||
|
-0.001664, -0.001499, -0.001328, -0.001162, -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, 0.0, 0.000166, 0.000334,
|
||||||
|
0.000500, 0.000668, 0.000834, 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, 0.002005, 0.002172, 0.002330, 0.002496, 0.002669,
|
||||||
|
0.002836, 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {1,3,4,5}, {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498,
|
||||||
|
8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, 8.999504, 8.999503,
|
||||||
|
8.999504, 8.999504, 8.999499, 8.999500, 8.999497, 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, 8.999496, 8.999496, 8.999501,
|
||||||
|
8.999501, 8.999499, 8.999499, 8.999499, 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, 8.999501, 8.999495, 8.999495, 8.999497}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {1,3,4,5}, {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, 14.4, 14.7, 15.0,
|
||||||
|
15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, 21.6, 21.9, 22.2, 22.5,
|
||||||
|
22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(1,0.01);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
gamma.linspace(-3, 0.1);
|
||||||
|
|
||||||
|
// calculate mean and variance of input
|
||||||
|
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
|
||||||
|
std::vector<int> dimensions = {0};
|
||||||
|
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
|
||||||
|
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions, true);
|
||||||
|
NDArray::prepareSpecialUse({&variance}, {&input});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
|
||||||
|
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
|
||||||
|
manager.synchronize();
|
||||||
|
NDArray::registerSpecialUse({&variance}, {&input});
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
|
@ -72,71 +72,6 @@ TEST_F(DeclarableOpsTests15, Test_Half_assign_1) {
|
||||||
ASSERT_EQ(10, x.sumNumber().e<int>(0));
|
ASSERT_EQ(10, x.sumNumber().e<int>(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) {
|
|
||||||
int inOutH = 5;// 35;
|
|
||||||
int inOutW = 5;// 35;
|
|
||||||
int inOutC = 10;// 192;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
|
|
||||||
x.linspace(1.0);
|
|
||||||
|
|
||||||
nd4j::ops::avgpool2d op;
|
|
||||||
auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH;
|
|
||||||
int padTop = totalPadHeight / 2;
|
|
||||||
int padBottom = totalPadHeight - totalPadHeight / 2;
|
|
||||||
|
|
||||||
int k = 3;
|
|
||||||
|
|
||||||
auto m = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
|
|
||||||
auto c = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
|
|
||||||
|
|
||||||
for (int h = 0; h < inOutH; h++) {
|
|
||||||
for (int w = 0; w < inOutW; w++) {
|
|
||||||
int hFrom = h - padTop;
|
|
||||||
int wFrom = w - padBottom;
|
|
||||||
|
|
||||||
int hTo = hFrom + k;
|
|
||||||
int wTo = wFrom + k;
|
|
||||||
|
|
||||||
hFrom = nd4j::math::nd4j_max<int>(0, hFrom);
|
|
||||||
wFrom = nd4j::math::nd4j_max<int>(0, wFrom);
|
|
||||||
|
|
||||||
hTo = nd4j::math::nd4j_min<int>(inOutH, hTo);
|
|
||||||
wTo = nd4j::math::nd4j_min<int>(inOutW, wTo);
|
|
||||||
|
|
||||||
int idxOut[4];
|
|
||||||
int idxIn[4];
|
|
||||||
for (int ch = 0; ch < inOutC; ch++) {
|
|
||||||
idxOut[1] = h;
|
|
||||||
idxOut[2] = w;
|
|
||||||
idxOut[3] = ch;
|
|
||||||
idxIn[3] = ch;
|
|
||||||
|
|
||||||
for (int kh = hFrom; kh < hTo; kh++) {
|
|
||||||
for (int kw = wFrom; kw < wTo; kw++) {
|
|
||||||
idxIn[1] = kh;
|
|
||||||
idxIn[2] = kw;
|
|
||||||
|
|
||||||
auto inVal = x.e<double>(0, kh, kw, ch);
|
|
||||||
m.p(0, h, w, ch, inVal + m.e<double>(0, h, w, ch));
|
|
||||||
c.p(0, h, w, ch, 1 + c.e<int>(0, h, w, ch));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m /= c;
|
|
||||||
|
|
||||||
ASSERT_EQ(m, *z);
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_standarize_1) {
|
TEST_F(DeclarableOpsTests15, Test_standarize_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
auto e = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
auto e = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
@ -1097,7 +1032,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) {
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1106,7 +1041,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) {
|
||||||
// rank 2
|
// rank 2
|
||||||
NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32);
|
NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32);
|
||||||
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
auto result = op.execute({ &rgbs }, {}, { 0 });
|
auto result = op.execute({ &rgbs }, {}, { 0 });
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
@ -1170,7 +1105,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) {
|
||||||
// rank 3
|
// rank 3
|
||||||
NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32);
|
NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32);
|
||||||
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::rgb_to_yuv op;
|
nd4j::ops::rgb_to_yuv op;
|
||||||
auto result = op.execute({ &rgbs }, {}, {});
|
auto result = op.execute({ &rgbs }, {}, {});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
|
@ -1210,7 +1145,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) {
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1484,7 +1419,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) {
|
||||||
auto Y = NDArrayFactory::create<float>(2.f);
|
auto Y = NDArrayFactory::create<float>(2.f);
|
||||||
NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
||||||
NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
dLdzC.linspace(0.1, 0.1);
|
dLdzC.linspace(0.1, 0.1);
|
||||||
x = 4.f;
|
x = 4.f;
|
||||||
|
|
||||||
|
|
|
@ -883,22 +883,6 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_AvgPool_1) {
|
|
||||||
auto x= NDArrayFactory::create<float>('c', {2, 10, 10, 3});
|
|
||||||
x.linspace(1);
|
|
||||||
|
|
||||||
nd4j::ops::avgpool2d op;
|
|
||||||
// kY kX sY sX pY pX dY dX M P
|
|
||||||
auto result = op.execute({&x}, {}, {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1});
|
|
||||||
// 0 1 2 3 4 5 6 7 8 9 10
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
// z->printShapeInfo("z shape");
|
|
||||||
// z->printIndexedBuffer("z buffr");
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
|
TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
|
||||||
auto x= NDArrayFactory::create<double>('c', {1, 3}, {2, 2, 2});
|
auto x= NDArrayFactory::create<double>('c', {1, 3}, {2, 2, 2});
|
||||||
auto y= NDArrayFactory::create<double>('c', {1, 3}, {4, 6, 8});
|
auto y= NDArrayFactory::create<double>('c', {1, 3}, {4, 6, 8});
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -2459,42 +2459,6 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests8, avgpool2d_test13) {
|
|
||||||
|
|
||||||
int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oH=4, oW=4;
|
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID
|
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NDHW
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5,
|
|
||||||
182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5,
|
|
||||||
317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5,
|
|
||||||
482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5,
|
|
||||||
617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5,
|
|
||||||
782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5,
|
|
||||||
917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5,
|
|
||||||
1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5});
|
|
||||||
input.linspace(1.);
|
|
||||||
input.syncToDevice();
|
|
||||||
|
|
||||||
nd4j::ops::avgpool2d op;
|
|
||||||
auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat});
|
|
||||||
auto output = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
|
|
||||||
//output->printIndexedBuffer("output");
|
|
||||||
//expected.printIndexedBuffer("expected");
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) {
|
TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) {
|
||||||
|
|
||||||
|
|
|
@ -2894,344 +2894,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
|
|
||||||
|
|
||||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
variance.assign(0.46666667);
|
|
||||||
gamma.assign(1.2);
|
|
||||||
beta.assign(1.); // has no effect on gradient calculations
|
|
||||||
gradO.linspace(-0.9, 0.15);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto dLdI = results->at(0);
|
|
||||||
auto dLdG = results->at(3);
|
|
||||||
auto dLdB = results->at(4);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
|
||||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
|
||||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
|
||||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
|
|
||||||
|
|
||||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray beta ('c', {3}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747,
|
|
||||||
0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978,
|
|
||||||
-0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
// beta.assign(1.); // has no effect on gradient calculations
|
|
||||||
gradO.linspace(-0.9, 0.15);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto dLdI = results->at(0);
|
|
||||||
auto dLdG = results->at(3);
|
|
||||||
auto dLdB = results->at(4);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
|
||||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
|
||||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
|
||||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
|
|
||||||
|
|
||||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002,
|
|
||||||
0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000,
|
|
||||||
-0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
// beta.assign(1.); // has no effect on gradient calculations
|
|
||||||
gradO.linspace(-0.9, 0.15);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto dLdI = results->at(0);
|
|
||||||
auto dLdG = results->at(3);
|
|
||||||
auto dLdB = results->at(4);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
|
||||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
|
||||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
|
||||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
|
|
||||||
|
|
||||||
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
gradO.linspace(-0.9, 0.15);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto dLdI = results->at(0);
|
|
||||||
auto dLdG = results->at(3);
|
|
||||||
auto dLdB = results->at(4);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
|
||||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
|
||||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
|
||||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
|
|
||||||
|
|
||||||
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243,
|
|
||||||
-1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118,
|
|
||||||
-0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
gradO.linspace(-0.9, 0.15);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto dLdI = results->at(0);
|
|
||||||
auto dLdG = results->at(3);
|
|
||||||
auto dLdB = results->at(4);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
|
||||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
|
||||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
|
||||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
|
|
||||||
|
|
||||||
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295,
|
|
||||||
0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295,
|
|
||||||
-0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
gradO.linspace(-0.9, 0.15);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto dLdI = results->at(0);
|
|
||||||
auto dLdG = results->at(3);
|
|
||||||
auto dLdB = results->at(4);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
|
||||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
|
||||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
|
||||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) {
|
|
||||||
|
|
||||||
NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142,
|
|
||||||
-43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662,
|
|
||||||
15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032,
|
|
||||||
-15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788,
|
|
||||||
-27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
gradO.linspace(-0.9, 0.15);
|
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto dLdI = results->at(0);
|
|
||||||
auto dLdG = results->at(3);
|
|
||||||
auto dLdB = results->at(4);
|
|
||||||
|
|
||||||
// dLdI->printBuffer();
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
|
||||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
|
||||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
|
||||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
|
|
||||||
|
|
||||||
NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301,
|
|
||||||
32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767,
|
|
||||||
-27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526,
|
|
||||||
30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773,
|
|
||||||
31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
gradO.linspace(-0.9, 0.15);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm_bp op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto dLdI = results->at(0);
|
|
||||||
auto dLdG = results->at(3);
|
|
||||||
auto dLdB = results->at(4);
|
|
||||||
|
|
||||||
// dLdI->printBuffer();
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
|
|
||||||
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
|
|
||||||
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
|
||||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
/*
|
/*
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
||||||
|
|
Loading…
Reference in New Issue