Shyrma cudnn (#192)

* - implementation of cudnn batchnorm_bp op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - testing and fixing bugs in batchnorm_bp based on cudnn api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - move pooling mkl code and delete some unnecessary files

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementation and testing cudnn pooling2d ops (avg/max, ff/bp)

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementation and testing cudnn pooling 3d (ff/bp) ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide ff step in case of cudnn maxpool3d_bp op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove half type from set of supported types in mkl dpethwise conv op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - bring back cudaStreamSynchronize in batchnorm and pooling cudnn ops

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2020-01-28 17:23:07 +02:00 committed by raver119
parent 2f08af3166
commit 7a7ee4b021
31 changed files with 3521 additions and 2242 deletions

View File

@ -197,8 +197,7 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
// ***** calculations ***** //
// notations:
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
// g = dLdO
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO
// stdInv = 1 / (v + eps)^0.5
// N - batch size (product of spatial dimensions)

View File

@ -31,31 +31,28 @@ namespace ops {
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
auto input = INPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", input->rankOf());
auto output = OUTPUT_VARIABLE(0);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
int pH = INT_ARG(4);
int pW = INT_ARG(5);
auto pH = INT_ARG(4);
auto pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto isSameMode = static_cast<bool>(INT_ARG(8));
const auto extraParam0 = INT_ARG(9);
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int oH = 0;
int oW = 0;
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
@ -207,7 +204,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
}
return Status::OK();
}
DECLARE_SHAPE_FN(avgpool2d_bp) {

View File

@ -51,14 +51,14 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
if(!isNCDHW) {
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) {
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]

View File

@ -32,6 +32,7 @@ namespace ops {
//////////////////////////////////////////////////////////////////////////
// maxpool2d corresponds to poolingMode=0
CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
auto input = INPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());

View File

@ -0,0 +1,138 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include "cudnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool2d, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
auto pH = INT_ARG(4);
auto pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto paddingMode = static_cast<bool>(INT_ARG(8));
const auto extraParam0 = INT_ARG(9);
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int oH = 0;
int oW = 0;
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
if (paddingMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(avgpool2d, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
return goodType && input->dataType() == output->dataType();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
const auto kH = INT_ARG(0); // filter(kernel) height
const auto kW = INT_ARG(1); // filter(kernel) width
const auto sH = INT_ARG(2); // strides height
const auto sW = INT_ARG(3); // strides width
auto pH = INT_ARG(4); // paddings height
auto pW = INT_ARG(5); // paddings width
const auto dH = INT_ARG(6); // dilations height
const auto dW = INT_ARG(7); // dilations width
const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
const auto extraParam0 = INT_ARG(9);
const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(paddingMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode);
return Status::OK();
}
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
return goodType && (input->dataType() == gradO->dataType())
&& (input->dataType() == gradI->dataType())
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
}
}
}
}

View File

@ -0,0 +1,144 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include "cudnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
int kW = INT_ARG(2); // filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
int extraParam0 = INT_ARG(13);
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
if(paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(avgpool3dnew, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
return goodType && input->dataType() == output->dataType();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const int kD = INT_ARG(0); // filter(kernel) depth
const int kH = INT_ARG(1); // filter(kernel) height
const int kW = INT_ARG(2); // filter(kernel) width
const int sD = INT_ARG(3); // strides depth
const int sH = INT_ARG(4); // strides height
const int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
const int dD = INT_ARG(9); // dilations depth
const int dH = INT_ARG(10); // dilations height
const int dW = INT_ARG(11); // dilations width
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode);
return Status::OK();
}
PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
return goodType && (input->dataType() == gradO->dataType())
&& (input->dataType() == gradI->dataType())
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
}
}
}
}

View File

@ -97,9 +97,6 @@ static void batchnormCUDNN(const LaunchContext* context,
err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data());
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err);
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetConvolutionNdDescriptor failed", err);
// provide scaling parameters
const float alpha32(1), beta32(0);
const double alpha64(1), beta64(0);
@ -114,20 +111,127 @@ static void batchnormCUDNN(const LaunchContext* context,
x, input->getSpecialBuffer(),
z, output->getSpecialBuffer(),
params,
gamma ? gamma->getSpecialBuffer(): nullptr,
beta ? beta->getSpecialBuffer() : nullptr,
gamma->getSpecialBuffer(), beta->getSpecialBuffer(),
mean->getSpecialBuffer(), variance->getSpecialBuffer(), epsilon);
if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err);
// cudaErr = cudaStreamSynchronize(*context->getCudaStream());
// if (cudaErr != 0)
// throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr);
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
if (cudaErr != 0)
throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr);
NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta});
}
//////////////////////////////////////////////////////////////////////////
static void batchnormBpCUDNN(const LaunchContext* context,
const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* gradO,
NDArray* gradI, NDArray* gradG, NDArray* gradB,
const double epsilon, const bool isSpatialMode) {
// input, gradO, gradI -> 4D:nchw, 5D:ncdhw
// mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode
// -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode
const cudnnDataType_t dataType = cudnnDataType(input->dataType());
const int xRank = input->rankOf();
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: can't set stream for cuDNN", err);
const std::vector<int> xShape = input->getShapeAsVectorInt(); // input and output have same shapes
std::vector<int> paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes
if(isSpatialMode) { // 1xCx1x1
const int iC = mean->lengthOf();
const int stride0 = mean->strideAt(0);
paramsShape = xRank == 4 ? std::vector<int>({1, iC, 1, 1}) : std::vector<int>({1, iC, 1, 1, 1});
paramsStrides = xRank == 4 ? std::vector<int>({iC*stride0, stride0, 1, 1}) : std::vector<int>({iC*stride0, stride0, 1, 1, 1});
}
else {
paramsShape = mean->getShapeAsVectorInt();
paramsStrides = xRank == 4 ? std::vector<int>({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector<int>({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)});
}
std::vector<int> xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)};
std::vector<int> dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3)};
std::vector<int> dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3)};
if(xRank > 4) { // 5D
xStrides.push_back((int)input->strideAt(4));
dxStrides.push_back((int)gradI->strideAt(4));
dzStrides.push_back((int)gradO->strideAt(4));
}
cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW;
// input descriptor
cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1)
err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data());
else
err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data());
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err);
// gradO descriptor
cudnnTensorDescriptor_t dz;
cudnnCreateTensorDescriptor(&dz);
if(gradO->ews() == 1)
err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, xShape.data());
else
err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), dzStrides.data());
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err);
// gradI descriptor
cudnnTensorDescriptor_t dx;
cudnnCreateTensorDescriptor(&dx);
if(input->ews() == 1)
err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, xShape.data());
else
err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), dxStrides.data());
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err);
// mean, variance, gamma, gradG and gradB descriptor, the same descriptor for all of them
cudnnTensorDescriptor_t params;
cudnnCreateTensorDescriptor(&params);
if(mean->ews() == 1)
err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data());
else
err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data());
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/gradG/gradB failed", err);
// provide scaling parameters
const float alpha32(1), beta32(0);
double alpha64(1), beta64(0);
const void* ptrAlpha = input->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
const void* ptrBeta = input->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
NDArray::prepareSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO});
// calculations
// TODO: we can use cache here
err = cudnnBatchNormalizationBackward(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION,
ptrAlpha, ptrBeta, ptrAlpha, ptrBeta,
x, input->getSpecialBuffer(),
dz, gradO->getSpecialBuffer(),
dx, gradI->getSpecialBuffer(),
params,
gamma->getSpecialBuffer(), gradG->getSpecialBuffer(), gradB->getSpecialBuffer(),
epsilon,
nullptr/*mean->getSpecialBuffer()*/, nullptr/*variance->getSpecialBuffer()*/);
if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err);
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
if (cudaErr != 0)
throw cuda_exception::build("batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
NDArray::registerSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO});
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
@ -189,11 +293,21 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1);
if(needPermut) { // if NHWC
std::vector<int> perm = {0, 3, 1, 2}; // NHWC -> NCHW
std::vector<int> perm = inRank == 4 ? std::vector<int>({0, 3, 1, 2}) : std::vector<int>({0, 4, 1, 2, 3}); // NHWC -> NCHW
input = new NDArray(input->permute(perm));
output = new NDArray(output->permute(perm));
}
// cudnn requires gamma and beta to be non-nullptr
if(!applyScale) {
gamma = new NDArray(mean);
*gamma = 1;
}
if(!applyOffset) {
beta = new NDArray(mean);
*beta = 0;
}
// calculations
batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1);
@ -202,6 +316,12 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) {
delete output;
}
if(!applyScale)
delete gamma;
if(!applyOffset)
delete beta;
return Status::OK();
}
@ -220,9 +340,6 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) {
const int numOfIntArgs = block.getIArguments()->size();
const int xRank = input->rankOf();
// disable cudnn batchnorm so far
return false;
// *********************************** //
if(xRank != 4 && xRank != 5)
return false;
@ -269,6 +386,182 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) {
return true;
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(batchnorm_bp, ENGINE_CUDA) {
NDArray* input = INPUT_VARIABLE(0);
NDArray* mean = INPUT_VARIABLE(1);
NDArray* variance = INPUT_VARIABLE(2);
NDArray* gamma = nullptr;
NDArray* beta = nullptr;
NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon
NDArray* gradI = OUTPUT_VARIABLE(0);
NDArray* gradM = OUTPUT_VARIABLE(1);
NDArray* gradV = OUTPUT_VARIABLE(2);
NDArray* gradG = nullptr;
NDArray* gradB = nullptr;
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
const float epsilon = T_ARG(0);
if(applyScale) {
gamma = INPUT_VARIABLE(3);
gradG = OUTPUT_VARIABLE(3);
}
if(applyOffset) {
beta = INPUT_VARIABLE(3 + (int)applyScale);
gradB = OUTPUT_VARIABLE(3 + (int)applyScale);
}
const int numOfIntArgs = block.getIArguments()->size();
const int inRank = input->rankOf();
// get axes args to normalize input array over
std::vector<int> axes;
if(numOfIntArgs > 2)
for(int i = 2; i < numOfIntArgs; ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
const int numOfAxes = axes.size();
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
std::vector<Nd4jLong> expShape;
if(numOfAxes == 1)
expShape.push_back(input->sizeAt(axes[0]));
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
expShape = std::vector<Nd4jLong>(inRank, 1);
for(uint i = 0; i < numOfAxes; ++i)
expShape[axes[i]] = input->sizeAt(axes[i]);
}
REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str());
if(gamma)
REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str());
if(beta)
REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str());
REQUIRE_TRUE(input->isSameShape(gradO), 0, "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
// types of all input arrays should be the same (except gradO)
for(int i = 1; i < block.width() - 2; ++i)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP CUDNN op: types of arrays (input, mean, variance, gamma, beta) should be the same !");
// cudnn supports NCHW format only
const bool needPermut = axes.size() == 1 && mean->lengthOf() != input->sizeAt(1);
if(needPermut) { // if NHWC
std::vector<int> perm = inRank == 4 ? std::vector<int>({0, 3, 1, 2}) : std::vector<int>({0, 4, 1, 2, 3}); // NHWC -> NCHW
input = new NDArray(input->permute(perm));
gradO = new NDArray(gradO->permute(perm));
gradI = new NDArray(gradI->permute(perm));
}
// cudnn requires gamma, gradG, gradB to be non-nullptr
if(!applyScale) {
gamma = new NDArray(mean);
gradG = new NDArray(mean);
*gamma = 1;
}
if(!applyOffset)
gradB = new NDArray(mean);
// calculations
batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, gradI, gradG, gradB, epsilon, axes.size() == 1);
*gradM = 0; // put zeros so far
*gradV = 0; // put zeros so far
if(needPermut) {
delete input;
delete gradO;
delete gradI;
}
if(!applyScale) {
delete gamma;
delete gradG;
}
if(!applyOffset)
delete gradB;
return Status::OK();
}
PLATFORM_CHECK(batchnorm_bp, ENGINE_CUDA) {
NDArray* input = INPUT_VARIABLE(0);
NDArray* mean = INPUT_VARIABLE(1);
NDArray* variance = INPUT_VARIABLE(2);
NDArray* gamma = nullptr;
NDArray* beta = nullptr;
NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon
NDArray* gradI = OUTPUT_VARIABLE(0);
NDArray* gradM = OUTPUT_VARIABLE(1);
NDArray* gradV = OUTPUT_VARIABLE(2);
NDArray* gradG = nullptr;
NDArray* gradB = nullptr;
const int numOfIntArgs = block.getIArguments()->size();
const int xRank = input->rankOf();
// *********************************** //
if(xRank != 4 && xRank != 5)
return false;
// *********************************** //
const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF;
if(badType)
return false;
// *********************************** //
// get axes args to normalize input array over
std::vector<int> axes;
if(numOfIntArgs > 2)
for(int i = 2; i < numOfIntArgs; ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(xRank-1); // default dimension to reduce along is last dimension
if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4)
return false;
// *********************************** //
bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->getShapeInfo(), variance->getShapeInfo());
if(gamma)
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gamma->getShapeInfo());
if(gradG)
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradG->getShapeInfo());
if(gradB)
allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradB->getShapeInfo());
if(!allParamsHaveSameShapeAndStrides)
return false;
// *********************************** //
bool isFormatGood = false;
if(axes.size() == 1)
isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C]
else {
auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4]
inputShapeModif[0] = 1;
isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4]
}
if(!isFormatGood)
return false;
return true;
}
}
}

View File

@ -0,0 +1,412 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include "cudnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
const int iH, const int iW,
const int oH, const int oW,
const int kH, const int kW,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW,
const bool isNCHW) {
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
const bool isPHasymm = pH != (pHsum - pH);
const bool isPWasymm = pW != (pWsum - pW);
if(!isPHasymm && !isPWasymm)
return;
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
const int iHposition = isNCHW ? 2 : 1;
if(isPHasymm)
newShape[iHposition] += 1;
if(isPWasymm)
newShape[iHposition + 1] += 1;
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
if(isNCHW)
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input);
else
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input);
input = newInput;
if(gradI != nullptr)
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
}
//////////////////////////////////////////////////////////////////////////
void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
const int iD, const int iH, const int iW,
const int oD, const int oH, const int oW,
const int kD, const int kH, const int kW,
const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW,
const bool isNCDHW) {
const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD);
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
const bool isPDasymm = pD != (pDsum - pD);
const bool isPHasymm = pH != (pHsum - pH);
const bool isPWasymm = pW != (pWsum - pW);
if(!isPDasymm && !isPHasymm && !isPWasymm)
return;
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
const int iDposition = isNCDHW ? 2 : 1;
if(isPDasymm)
newShape[iDposition] += 1;
if(isPHasymm)
newShape[iDposition + 1] += 1;
if(isPWasymm)
newShape[iDposition + 2] += 1;
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
if(isNCDHW)
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input);
else
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input);
input = newInput;
if(gradI != nullptr)
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
}
//////////////////////////////////////////////////////////////////////////
void pooling2dCUDNN(const LaunchContext* context,
const NDArray* input, NDArray* output,
const int kH, const int kW,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW,
const bool isNCHW, const cudnnPoolingMode_t mode) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: can't set stream for cuDNN", err);
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
// input descriptor
cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1)
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
else
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err);
// output descriptor
cudnnTensorDescriptor_t z;
cudnnCreateTensorDescriptor(&z);
if(output->ews() == 1)
err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW);
else
err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1));
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err);
// description of pooling
cudnnPoolingDescriptor_t pooling;
cudnnCreatePoolingDescriptor(&pooling);
err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW);
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err);
// provide scaling parameters
const float alpha32(1), beta32(0);
const double alpha64(1), beta64(0);
const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
NDArray::prepareSpecialUse({output}, {input});
// run calculation
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer());
if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnPoolingForward failed", err);
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
if (cudaErr != 0)
throw cuda_exception::build("pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr);
NDArray::registerSpecialUse({output}, {input});
}
//////////////////////////////////////////////////////////////////////////
void pooling2dBpCUDNN(const LaunchContext* context,
const NDArray* input, const NDArray* gradO,
NDArray* gradI,
const int kH, const int kW,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW,
const bool isNCHW, const cudnnPoolingMode_t mode) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: can't set stream for cuDNN", err);
cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
// input and gradI descriptor
cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1)
err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW);
else
err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1));
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input/gradI failed", err);
// gradO descriptor
cudnnTensorDescriptor_t dz;
cudnnCreateTensorDescriptor(&dz);
if(gradO->ews() == 1)
err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW);
else
err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1));
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err);
// description of pooling
cudnnPoolingDescriptor_t pooling;
cudnnCreatePoolingDescriptor(&pooling);
err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW);
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err);
// provide scaling parameters
const float alpha32(1), beta32(0);
const double alpha64(1), beta64(0);
const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
NDArray::prepareSpecialUse({gradI}, {input, gradO});
// run calculation for gradI
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
if (cudaErr != 0)
throw cuda_exception::build("pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
NDArray::registerSpecialUse({gradI}, {input, gradO});
}
//////////////////////////////////////////////////////////////////////////
void pooling3dCUDNN(const LaunchContext* context,
const NDArray* input, NDArray* output,
const int kD, const int kH, const int kW,
const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW,
const bool isNCDHW, const cudnnPoolingMode_t mode) {
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err);
printf("fffffffffff\n");
const int numDims = 5;
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
const int pSizes[] = {pD, pH, pW};
const int sSizes[] = {sD, sH, sW};
const int kSizes[] = {kD, kH, kW};
const int xShape[] = {bS, iC, iD, iH, iW};
const int zShape[] = {bS, oC, oD, oH, oW};
const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)};
const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)};
cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
// input descriptor
cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1)
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
else
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err);
// output descriptor
cudnnTensorDescriptor_t z;
cudnnCreateTensorDescriptor(&z);
if(output->ews() == 1)
err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape);
else
err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides);
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err);
// description of pooling
cudnnPoolingDescriptor_t pooling;
cudnnCreatePoolingDescriptor(&pooling);
err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes);
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err);
// provide scaling parameters
const float alpha32(1), beta32(0);
const double alpha64(1), beta64(0);
const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
NDArray::prepareSpecialUse({output}, {input});
// run calculation
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer());
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err);
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
if (cudaErr != 0)
throw cuda_exception::build("pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr);
NDArray::registerSpecialUse({output}, {input});
}
//////////////////////////////////////////////////////////////////////////
void pooling3dBpCUDNN(const LaunchContext* context,
const NDArray* input, const NDArray* gradO,
NDArray* gradI,
const int kD, const int kH, const int kW,
const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW,
const bool isNCDHW, const cudnnPoolingMode_t mode) {
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: can't set stream for cuDNN", err);
const int numDims = 5;
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
const int pSizes[] = {pD, pH, pW};
const int sSizes[] = {sD, sH, sW};
const int kSizes[] = {kD, kH, kW};
const int xShape[] = {bS, iC, iD, iH, iW};
const int dzShape[] = {bS, oC, oD, oH, oW};
const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)};
const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)};
cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
// input and gradI descriptor
cudnnTensorDescriptor_t x;
cudnnCreateTensorDescriptor(&x);
if(input->ews() == 1)
err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape);
else
err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides);
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input/gradI failed", err);
// gradO descriptor
cudnnTensorDescriptor_t dz;
cudnnCreateTensorDescriptor(&dz);
if(gradO->ews() == 1)
err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape);
else
err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides);
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err);
// description of pooling
cudnnPoolingDescriptor_t pooling;
cudnnCreatePoolingDescriptor(&pooling);
err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes);
if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err);
// provide scaling parameters
const float alpha32(1), beta32(0);
const double alpha64(1), beta64(0);
const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&alpha32) : reinterpret_cast<const void*>(&alpha64);
const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast<const void*>(&beta32) : reinterpret_cast<const void*>(&beta64);
// cudnn maxpool2d_bp api requires ff output as one of input arguments
if(mode == CUDNN_POOLING_MAX) {
NDArray temp(gradO);
NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp});
// run ff calculation
err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, dz, temp.specialBuffer());
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err);
// run bp calculation for gradI
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, temp.getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
NDArray::registerSpecialUse({gradI}, {input, gradO, &temp});
}
else {
NDArray::prepareSpecialUse({gradI}, {input, gradO});
// run bp calculation for gradI
err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer());
if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err);
NDArray::registerSpecialUse({gradI}, {input, gradO});
}
auto cudaErr = cudaStreamSynchronize(*context->getCudaStream());
if (cudaErr != 0)
throw cuda_exception::build("pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr);
}
}
}
}

View File

@ -30,8 +30,8 @@
#include <cudnn.h>
namespace nd4j {
namespace ops {
namespace nd4j {
namespace ops {
namespace platforms {
DECLARE_PLATFORM(conv2d, ENGINE_CUDA);
@ -46,6 +46,18 @@ namespace platforms {
DECLARE_PLATFORM(batchnorm, ENGINE_CUDA);
DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA);
DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA);
DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA);
DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA);
DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA);
DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA);
DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA);
DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA);
DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA);
//////////////////////////////////////////////////////////////////////////
FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) {
switch (dataType) {
@ -65,91 +77,62 @@ FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) {
}
//////////////////////////////////////////////////////////////////////////
FORCEINLINE void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
const int iH, const int iW,
const int oH, const int oW,
const int kH, const int kW,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW,
const bool isNCHW) {
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
const bool isPHasymm = pH != (pHsum - pH);
const bool isPWasymm = pW != (pWsum - pW);
if(!isPHasymm && !isPWasymm)
return;
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
const int iHposition = isNCHW ? 2 : 1;
if(isPHasymm)
newShape[iHposition] += 1;
if(isPWasymm)
newShape[iHposition + 1] += 1;
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
if(isNCHW)
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input);
else
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input);
input = newInput;
if(gradI != nullptr)
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
}
void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
const int iH, const int iW,
const int oH, const int oW,
const int kH, const int kW,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW,
const bool isNCHW);
//////////////////////////////////////////////////////////////////////////
FORCEINLINE void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
const int iD, const int iH, const int iW,
const int oD, const int oH, const int oW,
const int kD, const int kH, const int kW,
const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW,
const bool isNCDHW) {
void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI,
const int iD, const int iH, const int iW,
const int oD, const int oH, const int oW,
const int kD, const int kH, const int kW,
const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW,
const bool isNCDHW);
const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD);
const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH);
const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW);
//////////////////////////////////////////////////////////////////////////
void pooling2dCUDNN(const LaunchContext* context,
const NDArray* input, NDArray* output,
const int kH, const int kW,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW,
const bool isNCHW, const cudnnPoolingMode_t mode);
const bool isPDasymm = pD != (pDsum - pD);
const bool isPHasymm = pH != (pHsum - pH);
const bool isPWasymm = pW != (pWsum - pW);
//////////////////////////////////////////////////////////////////////////
void pooling2dBpCUDNN(const LaunchContext* context,
const NDArray* input, const NDArray* gradO,
NDArray* gradI,
const int kH, const int kW,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW,
const bool isNCHW, const cudnnPoolingMode_t mode);
if(!isPDasymm && !isPHasymm && !isPWasymm)
return;
//////////////////////////////////////////////////////////////////////////
void pooling3dCUDNN(const LaunchContext* context,
const NDArray* input, NDArray* output,
const int kD, const int kH, const int kW,
const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW,
const bool isNCDHW, const cudnnPoolingMode_t mode);
std::vector<Nd4jLong> newShape = input->getShapeAsVector();
const int iDposition = isNCDHW ? 2 : 1;
if(isPDasymm)
newShape[iDposition] += 1;
if(isPHasymm)
newShape[iDposition + 1] += 1;
if(isPWasymm)
newShape[iDposition + 2] += 1;
NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext());
if(isNCDHW)
(*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input);
else
(*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input);
input = newInput;
if(gradI != nullptr)
gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext());
}
//////////////////////////////////////////////////////////////////////////
void pooling3dBpCUDNN(const LaunchContext* context,
const NDArray* input, const NDArray* gradO,
NDArray* gradI,
const int kD, const int kH, const int kW,
const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW,
const int dD, const int dH, const int dW,
const bool isNCDHW, const cudnnPoolingMode_t mode);
}
}

View File

@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include "cudnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool2d, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee;
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
auto pH = INT_ARG(4);
auto pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto paddingMode = static_cast<bool>(INT_ARG(8));
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int oH = 0;
int oW = 0;
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
if (paddingMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(maxpool2d, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
return goodType && input->dataType() == output->dataType();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
const auto kH = INT_ARG(0); // filter(kernel) height
const auto kW = INT_ARG(1); // filter(kernel) width
const auto sH = INT_ARG(2); // strides height
const auto sW = INT_ARG(3); // strides width
auto pH = INT_ARG(4); // paddings height
auto pW = INT_ARG(5); // paddings width
const auto dH = INT_ARG(6); // dilations height
const auto dW = INT_ARG(7); // dilations width
const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(paddingMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX);
return Status::OK();
}
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
return goodType && (input->dataType() == gradO->dataType())
&& (input->dataType() == gradI->dataType())
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
}
}
}
}

View File

@ -0,0 +1,140 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include "cudnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
int kW = INT_ARG(2); // filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13);
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
if(paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(maxpool3dnew, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
return goodType && input->dataType() == output->dataType();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const int kD = INT_ARG(0); // filter(kernel) depth
const int kH = INT_ARG(1); // filter(kernel) height
const int kW = INT_ARG(2); // filter(kernel) width
const int sD = INT_ARG(3); // strides depth
const int sH = INT_ARG(4); // strides height
const int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
const int dD = INT_ARG(9); // dilations depth
const int dH = INT_ARG(10); // dilations height
const int dW = INT_ARG(11); // dilations width
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX);
return Status::OK();
}
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CUDA) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32;
return goodType && (input->dataType() == gradO->dataType())
&& (input->dataType() == gradI->dataType())
&& shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo());
}
}
}
}

View File

@ -30,111 +30,231 @@
using namespace dnnl;
using namespace samediff;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
namespace nd4j {
namespace ops {
namespace platforms {
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
input->rankOf());
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
input->rankOf());
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
int pH = INT_ARG(4);
int pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto isSameMode = static_cast<bool>(INT_ARG(8));
const auto extraParam0 = INT_ARG(9);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
dH, dW);
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
int pH = INT_ARG(4);
int pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto isSameMode = static_cast<bool>(INT_ARG(8));
const auto extraParam0 = INT_ARG(9);
int oH = 0;
int oW = 0;
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
dH, dW);
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
int oH = 0;
int oW = 0;
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
if (!isNCHW) {
input = new NDArray(
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
if (isSameMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
const int bS = input->sizeAt(0);
const int iC = input->sizeAt(1);
const int oC = output->sizeAt(1);
auto poolingMode = PoolingType::AVG_POOL;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
dnnl::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
//streams[0].submitAndWait();
if (!isNCHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
if (!isNCHW) {
input = new NDArray(
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
if (isSameMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
const int bS = input->sizeAt(0);
const int iC = input->sizeAt(1);
const int oC = output->sizeAt(1);
auto poolingMode = PoolingType::AVG_POOL;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
dnnl::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
//streams[0].submitAndWait();
if (!isNCHW) {
delete input;
delete output;
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int extraParam0 = INT_ARG(9);
int isNCHW =
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCHW) {
input = new NDArray(input->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
auto poolingMode = PoolingType::AVG_POOL;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
dnnl::stream stream(engine);
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -1,149 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace dnnl;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int extraParam0 = INT_ARG(9);
int isNCHW =
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCHW) {
input = new NDArray(input->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
auto poolingMode = PoolingType::AVG_POOL;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
dnnl::stream stream(engine);
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -34,24 +34,23 @@ namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode,
const int isNCHW) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
bias, output,
@ -61,13 +60,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
&user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
@ -112,6 +110,135 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
stream.wait();
}
//////////////////////////////////////////////////////////////////////
static void conv2dBpMKLDNN(nd4j::graph::Context &block,
const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
NDArray *gradI, NDArray *gradW, NDArray *gradB,
const int kH, const int kW, const int sH,const int sW, int pH, int pW, const int dH, const int dW,
const int paddingMode, const int isNCHW) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine()));
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
}
else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
{DNNL_ARG_WEIGHTS, convI_weights_memory},
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
}
stream.wait();
}
}
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
@ -132,7 +259,7 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
return Status::OK();
}
@ -152,6 +279,7 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) {
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
@ -172,158 +300,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
gradO->rankOf());
REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf());
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
LaunchContext::defaultContext()->engine()));
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = dnnl::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc =
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
conv_weights_md, conv_dst_md, conv_strides, conv_dilation,
conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{DNNL_ARG_DIFF_DST, convI_dst_memory},
{DNNL_ARG_WEIGHTS, convI_weights_memory},
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
userI_src_memory);
}
stream.wait();
};
conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
return Status::OK();
}

View File

@ -34,62 +34,23 @@ namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
REQUIRE_TRUE(input->rankOf() == 5, 0,
"CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0,
"CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !",
weights->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW =
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
static void conv3dMKLDNN(nd4j::graph::Context &block,
const NDArray *input, const NDArray *weights, const NDArray *bias,
NDArray *output,
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW,
const int paddingMode, const int isNCDHW) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
"CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !",
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
"CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
oC, bias->rankOf(), bias->lengthOf());
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty);
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty);
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode,
isNCDHW,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
nullptr, bias, output,
@ -98,151 +59,73 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
&user_src_md, nullptr, &user_weights_md, nullptr,
&user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->getBuffer());
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{DNNL_ARG_WEIGHTS, conv_weights_memory},
{DNNL_ARG_BIAS, conv_bias_memory},
{DNNL_ARG_DST, conv_dst_memory}});
} else {
}
else {
convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory},
{DNNL_ARG_WEIGHTS, conv_weights_memory},
{DNNL_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc())
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
}
stream.wait();
return Status::OK();
}
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
}
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
gradO->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNDHWC =
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
static void conv3dBpMKLDNN(nd4j::graph::Context &block,
const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
NDArray *gradI, NDArray *gradW, NDArray *gradB,
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW,
const int paddingMode, const int isNCDHW) {
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
dW, iD, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNDHWC,
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode,
isNCDHW,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
gradW, gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
@ -250,43 +133,30 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
LaunchContext::defaultContext()->engine()));
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = dnnl::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto conv_desc = gradB != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()));
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc);
auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
@ -297,65 +167,53 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory},
{DNNL_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
}
else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{DNNL_ARG_SRC, convW_src_memory},
{DNNL_ARG_DIFF_DST, convW_dst_memory},
{DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc())
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory);
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
conv_diff_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc);
auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = dnnl::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = dnnl::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc())
convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
@ -363,30 +221,128 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
{DNNL_ARG_WEIGHTS, convI_weights_memory},
{DNNL_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
userI_src_memory);
}
stream.wait();
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc())
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
}
}
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if (paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
return Status::OK();
}
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
}
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
if(paddingMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
return Status::OK();
}
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
return block.isUseMKLDNN() &&

View File

@ -177,7 +177,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
}
//////////////////////////////////////////////////////////////////////////
static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int paddingMode) {
@ -492,7 +492,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
}
deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
delete weights;
delete gradW;

View File

@ -421,7 +421,7 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
return block.isUseMKLDNN() && mC == 1 &&
(
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF) ||
(xType==DataType::BFLOAT16 && wType==DataType::BFLOAT16 && bType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
);
}

View File

@ -29,117 +29,258 @@
using namespace dnnl;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
namespace nd4j {
namespace ops {
namespace platforms {
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
input->rankOf());
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
input->rankOf());
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
int pH = INT_ARG(4);
int pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto isSameMode = static_cast<bool>(INT_ARG(8));
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto argI = *(block.getIArguments());
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
dH, dW);
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
int pH = INT_ARG(4);
int pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto isSameMode = static_cast<bool>(INT_ARG(8));
int oH = 0;
int oW = 0;
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
dH, dW);
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
int oH = 0;
int oW = 0;
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
if (!isNCHW) {
input = new NDArray(
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
if (isSameMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
const int bS = input->sizeAt(0);
const int iC = input->sizeAt(1);
const int oC = output->sizeAt(1);
auto poolingMode = PoolingType::MAX_POOL;
int extraParam0 = 1;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
dnnl::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
if (!isNCHW) {
input = new NDArray(
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
if (isSameMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
const int bS = input->sizeAt(0);
const int iC = input->sizeAt(1);
const int oC = output->sizeAt(1);
auto poolingMode = PoolingType::MAX_POOL;
int extraParam0 = 1;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
dnnl::stream stream(engine);
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete output;
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int extraParam0 = INT_ARG(9);
int isNCHW =
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCHW) {
input = new NDArray(input->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
auto poolingMode = PoolingType::MAX_POOL;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
// input is sometimes null, so we can't rely on pool_src_md being valid
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
// probably wrong, fix that
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -1,174 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace dnnl;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int extraParam0 = INT_ARG(9);
int isNCHW =
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCHW) {
input = new NDArray(input->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
auto poolingMode = PoolingType::MAX_POOL;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
true,
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
// input is sometimes null, so we can't rely on pool_src_md being valid
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
// probably wrong, fix that
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -28,124 +28,273 @@
using namespace dnnl;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
namespace nd4j {
namespace ops {
namespace platforms {
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
int kW = INT_ARG(2); // filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
REQUIRE_TRUE(input->rankOf() == 5, 0,
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
int kW = INT_ARG(2); // filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
REQUIRE_TRUE(input->rankOf() == 5, 0,
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
if (!isNCDHW) {
input = new NDArray(
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = new NDArray(
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
dW);
auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
if (!isNCDHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
if (!isNCDHW) {
input = new NDArray(
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = new NDArray(
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
dW);
auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
algorithm,
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
&user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
pool_dst_md, pool_strides, pool_kernel, pool_padding,
pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = user_dst_memory;
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
}
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory}});
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
}
stream.wait();
if (!isNCDHW) {
delete input;
delete output;
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const int kD = INT_ARG(0); // filter(kernel) depth
const int kH = INT_ARG(1); // filter(kernel) height
const int kW = INT_ARG(2); // filter(kernel) width
const int sD = INT_ARG(3); // strides depth
const int sH = INT_ARG(4); // strides height
const int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
const int dD = INT_ARG(9); // dilations depth
const int dH = INT_ARG(10); // dilations height
const int dW = INT_ARG(11); // dilations width
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0,
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCDHW) {
input = new NDArray(input->permute(
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
dW);
auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
// input is sometimes null, so we can't rely on pool_src_md being valid
if (input->buffer() == nullptr) {
pool_src_md = pool_diff_src_md;
user_src_md = user_diff_src_md;
}
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCDHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -1,181 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace dnnl;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const int kD = INT_ARG(0); // filter(kernel) depth
const int kH = INT_ARG(1); // filter(kernel) height
const int kW = INT_ARG(2); // filter(kernel) width
const int sD = INT_ARG(3); // strides depth
const int sH = INT_ARG(4); // strides height
const int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
const int dD = INT_ARG(9); // dilations depth
const int dH = INT_ARG(10); // dilations height
const int dW = INT_ARG(11); // dilations width
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0,
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if (!isNCDHW) {
input = new NDArray(input->permute(
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute(
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = new NDArray(gradO->permute(
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
dW);
auto poolingMode = PoolingType::MAX_POOL;
auto extraParam0 = 1;
dnnl_memory_desc_t empty;
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
dnnl::algorithm algorithm;
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
extraParam0, true,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
algorithm,
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
&user_diff_src_md, &user_dst_md,
pool_strides, pool_kernel, pool_padding, pool_padding_r);
// input is sometimes null, so we can't rely on pool_src_md being valid
if (input->buffer() == nullptr) {
pool_src_md = pool_diff_src_md;
user_src_md = user_diff_src_md;
}
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
auto poolB_src_memory = userB_src_memory;
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
}
auto poolB_dst_memory = userB_dst_memory;
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
}
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
auto pool_src_memory = user_src_memory;
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
}
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
{DNNL_ARG_DST, pool_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
}
stream.wait();
if (!isNCDHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
}
}
}
}

View File

@ -23,383 +23,388 @@
using namespace dnnl;
namespace nd4j {
namespace mkldnnUtils {
void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
namespace nd4j {
namespace mkldnnUtils {
pool_strides = { sH, sW };
pool_kernel = { kH, kW };
pool_padding = { pH, pW };
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
//////////////////////////////////////////////////////////////////////////
void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW };
dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW };
algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding;
auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
pool_strides = { sH, sW };
pool_kernel = { kH, kW };
pool_padding = { pH, pW };
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
}
algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding;
auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
}
};
void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
pool_strides = { sD, sH, sW };
pool_kernel = { kD, kH, kW };
pool_padding = { pD, pH, pW };
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
(oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding;
auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
}
};
void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
conv_dilation = { dH-1, dW-1};
auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto formatw = dnnl::memory::format_tag::hwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
}
}
void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
conv_strides = { sD, sH, sW };
conv_padding = { pD, pH, pW };
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
conv_dilation = { dD-1, dH-1, dW-1};
auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
auto formatw = dnnl::memory::format_tag::dhwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2];
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2];
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
}
};
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
// const Nd4jLong* shape = src->getShapeInfo();
// Nd4jLong rank = shape[0];
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
// auto type = dnnl::memory::data_type::f32;
// auto format = dnnl::memory::format_tag::nchw;
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
// }
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
// }
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
// }
// };
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
long rank = shape[0];
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
long dim2 = axis >= 2 ? 1 : 2;
long dim3 = axis >= 3 ? 2 : 3;
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = dnnl::memory::data_type::f32;
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto supposed_to_be_any_format = format; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked;
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked;
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
}
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked;
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
}
}
dnnl::engine& getEngine(void *ptr) {
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
return *eng;
}
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
}
};
//////////////////////////////////////////////////////////////////////////
void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm,
dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) {
dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
pool_strides = { sD, sH, sW };
pool_kernel = { kD, kH, kW };
pool_padding = { pD, pH, pW };
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
(oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
algorithm = poolingMode == 0 ? algorithm::pooling_max
: extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding
: algorithm::pooling_avg_include_padding;
auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
*pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
*pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
}
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
*pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
}
};
//////////////////////////////////////////////////////////////////////////
void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW };
dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW };
dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
conv_dilation = { dH-1, dW-1};
auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto formatw = dnnl::memory::format_tag::hwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2];
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2];
}
}
//////////////////////////////////////////////////////////////////////////
void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md,
dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md,
dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) {
dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
conv_strides = { sD, sH, sW };
conv_padding = { pD, pH, pW };
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
conv_dilation = { dD-1, dH-1, dW-1};
auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
auto formatw = dnnl::memory::format_tag::dhwio;
if (src != nullptr && conv_src_md != nullptr) {
*conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4];
user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1];
user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2];
user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3];
}
if (diff_src != nullptr && conv_diff_src_md != nullptr) {
*conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any);
*user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
}
if (weights != nullptr && conv_weights_md != nullptr) {
*conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4];
user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3];
user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0];
user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1];
user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2];
}
if (diff_weights != nullptr && conv_diff_weights_md != nullptr) {
*conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any);
*user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw);
user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio"
user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4];
user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3];
user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0];
user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1];
user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2];
}
if (bias != nullptr && conv_bias_md != nullptr) {
*conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any);
*user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x);
}
if (dst != nullptr && conv_dst_md != nullptr) {
*conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any);
*user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4];
user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1];
user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2];
user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3];
}
};
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
// const Nd4jLong* shape = src->getShapeInfo();
// Nd4jLong rank = shape[0];
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
// auto type = dnnl::memory::data_type::f32;
// auto format = dnnl::memory::format_tag::nchw;
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
// }
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
// }
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
// }
// };
//////////////////////////////////////////////////////////////////////////
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
long rank = shape[0];
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
long dim2 = axis >= 2 ? 1 : 2;
long dim3 = axis >= 3 ? 2 : 3;
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = dnnl::memory::data_type::f32;
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
auto supposed_to_be_any_format = format; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_src_md->data.format_kind = dnnl_blocked;
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
}
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_diff_src_md->data.format_kind = dnnl_blocked;
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
}
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
user_dst_md->data.format_kind = dnnl_blocked;
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
}
}
//////////////////////////////////////////////////////////////////////////
dnnl::engine& getEngine(void *ptr) {
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
return *eng;
}
}
}

File diff suppressed because one or more lines are too long

View File

@ -970,7 +970,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) {
x.linspace(1);
nd4j::ops::maxpool2d op;
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1});
@ -991,7 +990,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) {
x.linspace(1);
nd4j::ops::maxpool2d op;
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1});
@ -1012,7 +1010,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) {
x.linspace(1);
nd4j::ops::maxpool2d op;
auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0});
@ -1467,11 +1464,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f});
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f});
input.linspace(1.);
gradO.linspace(0.1, 0.1);

View File

@ -57,6 +57,17 @@ TEST_F(CuDnnTests, helpers_includer) {
nd4j::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d;
nd4j::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp;
nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm;
nd4j::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp;
nd4j::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d;
nd4j::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp;
nd4j::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d;
nd4j::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp;
nd4j::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew;
nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp;
nd4j::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew;
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp;
printer({&conv2d});
printer({&conv2d_bp});
@ -65,6 +76,15 @@ TEST_F(CuDnnTests, helpers_includer) {
printer({&depthwise_conv2d});
printer({&depthwise_conv2d_bp});
printer({&batchnorm});
printer({&batchnorm_bp});
printer({&avgpool2d});
printer({&avgpool2d_bp});
printer({&maxpool2d});
printer({&maxpool2d_bp});
printer({&avgpool3dnew});
printer({&avgpool3dnew_bp});
printer({&maxpool3dnew});
printer({&maxpool3dnew_bp});
#endif
}

View File

@ -25,6 +25,7 @@
#include <ops/ops.h>
#include <GradCheck.h>
#include <memory>
#include <PointersManager.h>
using namespace nd4j;
@ -2247,3 +2248,525 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) {
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) {
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
variance.assign(0.46666667);
gamma.assign(1.2);
beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) {
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32);
NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {3}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747,
0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978,
-0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
// beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) {
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002,
0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000,
-0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
// beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) {
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test5) {
#if defined(HAVE_CUDNN)
return;
#endif
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243,
-1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118,
-0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test6) {
#if defined(HAVE_CUDNN)
return;
#endif
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295,
0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295,
-0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test7) {
#if defined(HAVE_CUDNN)
return;
#endif
NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142,
-43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662,
15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032,
-15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788,
-27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
// dLdI->printBuffer();
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test8) {
#if defined(HAVE_CUDNN)
return;
#endif
NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301,
32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767,
-27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526,
30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773,
31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
// dLdI->printBuffer();
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) {
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2}, {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818,
-0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063,
-0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, 0.036168, 0.040426}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
input.linspace(1,0.01);
gradO.linspace(-0.9, 0.15);
// calculate mean and variance of input
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
std::vector<int> dimensions = {0,2,3};
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions);
NDArray::prepareSpecialUse({&variance}, {&input});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
manager.synchronize();
NDArray::registerSpecialUse({&variance}, {&input});
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) {
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,4}, {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, 0.027213, -0.029205, 0.013996, -0.015192, 0.016343,
-0.017519, 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, -0.005452, 0.005824, -0.013974, 0.015171,
-0.016325, 0.017508, -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, -0.038118, 0.040878}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, nd4j::DataType::FLOAT32);
input.linspace(1,0.01);
gradO.linspace(-0.9, 0.15);
// calculate mean and variance of input
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
std::vector<int> dimensions = {0,1,2};
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions);
NDArray::prepareSpecialUse({&variance}, {&input});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
manager.synchronize();
NDArray::registerSpecialUse({&variance}, {&input});
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) {
NDArray input ('c', {2,3,4,5}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
NDArray variance('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {1,3,4,5}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4,5}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4,5}, {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, 0.003001, 0.002837,
0.002670, 0.002505, 0.002337, 0.002167, 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, 0.000996, 0.000830, 0.000664, 0.000498,
0.000332, 0.000166, -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, -0.001003, -0.001168, -0.001337, -0.001502, -0.001670,
-0.001838, -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821,
-0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, -0.003985,
-0.003836, -0.003661, -0.003505, -0.003338, -0.003171, -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, -0.002003, -0.001835,
-0.001664, -0.001499, -0.001328, -0.001162, -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, 0.0, 0.000166, 0.000334,
0.000500, 0.000668, 0.000834, 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, 0.002005, 0.002172, 0.002330, 0.002496, 0.002669,
0.002836, 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {1,3,4,5}, {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498,
8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, 8.999504, 8.999503,
8.999504, 8.999504, 8.999499, 8.999500, 8.999497, 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, 8.999496, 8.999496, 8.999501,
8.999501, 8.999499, 8.999499, 8.999499, 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, 8.999501, 8.999495, 8.999495, 8.999497}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {1,3,4,5}, {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, 14.4, 14.7, 15.0,
15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, 21.6, 21.9, 22.2, 22.5,
22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, nd4j::DataType::FLOAT32);
input.linspace(1,0.01);
gradO.linspace(-0.9, 0.15);
gamma.linspace(-3, 0.1);
// calculate mean and variance of input
PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9");
std::vector<int> dimensions = {0};
int* dims = reinterpret_cast<int*>(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)));
input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions, true);
NDArray::prepareSpecialUse({&variance}, {&input});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false);
manager.synchronize();
NDArray::registerSpecialUse({&variance}, {&input});
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}

View File

@ -72,71 +72,6 @@ TEST_F(DeclarableOpsTests15, Test_Half_assign_1) {
ASSERT_EQ(10, x.sumNumber().e<int>(0));
}
TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) {
int inOutH = 5;// 35;
int inOutW = 5;// 35;
int inOutC = 10;// 192;
auto x = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
x.linspace(1.0);
nd4j::ops::avgpool2d op;
auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH;
int padTop = totalPadHeight / 2;
int padBottom = totalPadHeight - totalPadHeight / 2;
int k = 3;
auto m = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
auto c = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
for (int h = 0; h < inOutH; h++) {
for (int w = 0; w < inOutW; w++) {
int hFrom = h - padTop;
int wFrom = w - padBottom;
int hTo = hFrom + k;
int wTo = wFrom + k;
hFrom = nd4j::math::nd4j_max<int>(0, hFrom);
wFrom = nd4j::math::nd4j_max<int>(0, wFrom);
hTo = nd4j::math::nd4j_min<int>(inOutH, hTo);
wTo = nd4j::math::nd4j_min<int>(inOutW, wTo);
int idxOut[4];
int idxIn[4];
for (int ch = 0; ch < inOutC; ch++) {
idxOut[1] = h;
idxOut[2] = w;
idxOut[3] = ch;
idxIn[3] = ch;
for (int kh = hFrom; kh < hTo; kh++) {
for (int kw = wFrom; kw < wTo; kw++) {
idxIn[1] = kh;
idxIn[2] = kw;
auto inVal = x.e<double>(0, kh, kw, ch);
m.p(0, h, w, ch, inVal + m.e<double>(0, h, w, ch));
c.p(0, h, w, ch, 1 + c.e<int>(0, h, w, ch));
}
}
}
}
}
m /= c;
ASSERT_EQ(m, *z);
delete result;
}
TEST_F(DeclarableOpsTests15, Test_standarize_1) {
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
auto e = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
@ -1097,7 +1032,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) {
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
@ -1106,7 +1041,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) {
// rank 2
NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, { 0 });
auto output = result->at(0);
@ -1170,7 +1105,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) {
// rank 3
NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32);
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {});
auto output = result->at(0);
@ -1210,7 +1145,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) {
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
@ -1484,7 +1419,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) {
auto Y = NDArrayFactory::create<float>(2.f);
NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
dLdzC.linspace(0.1, 0.1);
x = 4.f;

View File

@ -883,22 +883,6 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) {
delete result;
}
TEST_F(DeclarableOpsTests3, Test_AvgPool_1) {
auto x= NDArrayFactory::create<float>('c', {2, 10, 10, 3});
x.linspace(1);
nd4j::ops::avgpool2d op;
// kY kX sY sX pY pX dY dX M P
auto result = op.execute({&x}, {}, {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1});
// 0 1 2 3 4 5 6 7 8 9 10
auto z = result->at(0);
// z->printShapeInfo("z shape");
// z->printIndexedBuffer("z buffr");
delete result;
}
TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
auto x= NDArrayFactory::create<double>('c', {1, 3}, {2, 2, 2});
auto y= NDArrayFactory::create<double>('c', {1, 3}, {4, 6, 8});

File diff suppressed because one or more lines are too long

View File

@ -2459,42 +2459,6 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) {
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests8, avgpool2d_test13) {
int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1;
int oH=4, oW=4;
int paddingMode = 1; // 1-SAME, 0-VALID
int dataFormat = 1; // 1-NHWC, 0-NDHW
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
auto expected = NDArrayFactory::create<double>('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5,
182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5,
317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5,
482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5,
617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5,
782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5,
917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5,
1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5});
input.linspace(1.);
input.syncToDevice();
nd4j::ops::avgpool2d op;
auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
//output->printIndexedBuffer("output");
//expected.printIndexedBuffer("expected");
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) {

View File

@ -2894,344 +2894,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
variance.assign(0.46666667);
gamma.assign(1.2);
beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32);
NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {3}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747,
0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978,
-0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
// beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002,
0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000,
-0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
// beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243,
-1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118,
-0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295,
0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295,
-0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) {
NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142,
-43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662,
15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032,
-15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788,
-27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
// dLdI->printBuffer();
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301,
32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767,
-27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526,
30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773,
31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
// dLdI->printBuffer();
ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
/*
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {