Shyrma casual conv1d (#90)
* - add causal mode of padding to convolutions Signed-off-by: Yurii <iuriish@yahoo.com> * - add additional tests for causal conv1d Signed-off-by: Yurii <iuriish@yahoo.com> * - add causal mode for cuda conv kernels Signed-off-by: Yurii <iuriish@yahoo.com> * Java side of Conv1D changes Signed-off-by: raver119 <raver119@gmail.com> * Add Conv1DDerivative op Signed-off-by: Alex Black <blacka101@gmail.com> * Causal Conv1D gradient checks Signed-off-by: Alex Black <blacka101@gmail.com> * Tweaks Signed-off-by: Alex Black <blacka101@gmail.com> * - add causal padding mode to conv2d_bp Signed-off-by: Yurii <iuriish@yahoo.com> * More thorough causal conv1d tests Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
5e07998e59
commit
d19eeaec52
|
@ -31,7 +31,7 @@ namespace ops {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
|
CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
|
auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
|
||||||
|
@ -42,8 +42,9 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
|
||||||
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
||||||
int sW = INT_ARG(1); // strides width
|
int sW = INT_ARG(1); // strides width
|
||||||
int pW = INT_ARG(2); // paddings width
|
int pW = INT_ARG(2); // paddings width
|
||||||
int isSameMode = INT_ARG(3); // 0-VALID, 1-SAME
|
int dW = INT_ARG(3); // dilations width
|
||||||
int isNCW = block.getIArguments()->size() > 4 ? !INT_ARG(4) : 1; // INT_ARG(4): 0-NCW, 1-NWC
|
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL
|
||||||
|
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC
|
||||||
|
|
||||||
const int rank = 3;
|
const int rank = 3;
|
||||||
REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf());
|
||||||
|
@ -81,7 +82,12 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
|
||||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
nd4j::ops::conv2d conv2d;
|
||||||
|
const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||||
|
if (status != ND4J_STATUS_OK)
|
||||||
|
return status;
|
||||||
|
|
||||||
|
// ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -96,8 +102,9 @@ DECLARE_SHAPE_FN(conv1d) {
|
||||||
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) width
|
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) width
|
||||||
int sW = INT_ARG(1); // strides width
|
int sW = INT_ARG(1); // strides width
|
||||||
int pW = INT_ARG(2); // paddings width
|
int pW = INT_ARG(2); // paddings width
|
||||||
int isSameMode = INT_ARG(3); // 0-VALID, 1-SAME
|
int dW = INT_ARG(3); // dilations width
|
||||||
int isNCW = block.getIArguments()->size() > 4 ? !INT_ARG(4) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME
|
||||||
|
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
||||||
|
|
||||||
int indIOioC, indIiW, indWoC(2);
|
int indIOioC, indIiW, indWoC(2);
|
||||||
if(!isNCW) {
|
if(!isNCW) {
|
||||||
|
@ -122,7 +129,7 @@ DECLARE_SHAPE_FN(conv1d) {
|
||||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||||
|
|
||||||
int oH, oW; // output height, width
|
int oH, oW; // output height, width
|
||||||
ConvolutionUtils::calcOutSizePool2D(oH,oW, 1,kW, 1,sW, 0,pW, 1,1, 1,iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(oH,oW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
|
||||||
|
|
||||||
Nd4jLong* outputShapeInfo = nullptr;
|
Nd4jLong* outputShapeInfo = nullptr;
|
||||||
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||||
|
@ -153,7 +160,7 @@ DECLARE_TYPES(conv1d) {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
|
CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
|
auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
|
||||||
|
@ -167,8 +174,9 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
|
||||||
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
||||||
int sW = INT_ARG(1); // strides width
|
int sW = INT_ARG(1); // strides width
|
||||||
int pW = INT_ARG(2); // paddings width
|
int pW = INT_ARG(2); // paddings width
|
||||||
int isSameMode = INT_ARG(3); // 0-VALID, 1-SAME
|
int dW = INT_ARG(3); // dilations width
|
||||||
int isNCW = block.getIArguments()->size() > 4 ? !INT_ARG(4) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL
|
||||||
|
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
||||||
|
|
||||||
const int rank = 3;
|
const int rank = 3;
|
||||||
REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf());
|
||||||
|
@ -188,7 +196,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
|
||||||
const int oC = weights->sizeAt(indWoC); // output channels
|
const int oC = weights->sizeAt(indWoC); // output channels
|
||||||
|
|
||||||
int trueoH, trueoW; // true output height, width
|
int trueoH, trueoW; // true output height, width
|
||||||
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,1, 1,iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}));
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}));
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
|
||||||
|
@ -213,7 +221,12 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
nd4j::ops::conv2d_bp conv2dBP;
|
||||||
|
const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||||
|
if (status != ND4J_STATUS_OK)
|
||||||
|
return status;
|
||||||
|
|
||||||
|
// ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -234,8 +247,9 @@ DECLARE_SHAPE_FN(conv1d_bp) {
|
||||||
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) width
|
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) width
|
||||||
int sW = INT_ARG(1); // strides width
|
int sW = INT_ARG(1); // strides width
|
||||||
int pW = INT_ARG(2); // paddings width
|
int pW = INT_ARG(2); // paddings width
|
||||||
int isSameMode = INT_ARG(3); // 0-VALID, 1-SAME
|
int dW = INT_ARG(3); // dilations width
|
||||||
int isNCW = block.getIArguments()->size() > 4 ? !INT_ARG(4) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME
|
||||||
|
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
||||||
|
|
||||||
int indIOioC, indIiW, indWoC(2);
|
int indIOioC, indIiW, indWoC(2);
|
||||||
if(!isNCW) {
|
if(!isNCW) {
|
||||||
|
@ -251,7 +265,7 @@ DECLARE_SHAPE_FN(conv1d_bp) {
|
||||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||||
|
|
||||||
int trueoH, trueoW; // true output height, width
|
int trueoH, trueoW; // true output height, width
|
||||||
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,1, 1,iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}));
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}));
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
|
||||||
|
|
|
@ -51,20 +51,20 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
||||||
int dD = INT_ARG(9); // dilations depth
|
int dD = INT_ARG(9); // dilations depth
|
||||||
int dH = INT_ARG(10); // dilations height
|
int dH = INT_ARG(10); // dilations height
|
||||||
int dW = INT_ARG(11); // dilations width
|
int dW = INT_ARG(11); // dilations width
|
||||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
|
||||||
|
|
||||||
nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0);
|
nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0);
|
||||||
|
|
||||||
|
@ -116,10 +116,11 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
||||||
int dD = INT_ARG(9); // dilations depth
|
int dD = INT_ARG(9); // dilations depth
|
||||||
int dH = INT_ARG(10); // dilations height
|
int dH = INT_ARG(10); // dilations height
|
||||||
int dW = INT_ARG(11); // dilations width
|
int dW = INT_ARG(11); // dilations width
|
||||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID;
|
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID;
|
||||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
const int rank = 5;
|
const int rank = 5;
|
||||||
|
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo);
|
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo);
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
|
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
|
||||||
|
|
||||||
|
@ -144,7 +145,7 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
||||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||||
|
|
||||||
int oD, oH, oW; // output depth, height, width
|
int oD, oH, oW; // output depth, height, width
|
||||||
ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||||
|
|
||||||
Nd4jLong* outputShapeInfo = nullptr;
|
Nd4jLong* outputShapeInfo = nullptr;
|
||||||
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong);
|
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong);
|
||||||
|
@ -197,7 +198,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
int dD = INT_ARG(9); // dilations depth
|
int dD = INT_ARG(9); // dilations depth
|
||||||
int dH = INT_ARG(10); // dilations height
|
int dH = INT_ARG(10); // dilations height
|
||||||
int dW = INT_ARG(11); // dilations width
|
int dW = INT_ARG(11); // dilations width
|
||||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
int isNDHWC = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
int isNDHWC = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
@ -205,8 +206,9 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
@ -214,8 +216,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
if(bias)
|
if(bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
|
||||||
|
|
||||||
nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0);
|
nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0);
|
||||||
|
|
||||||
|
@ -285,10 +286,11 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
||||||
int dD = INT_ARG(9); // dilations depth
|
int dD = INT_ARG(9); // dilations depth
|
||||||
int dH = INT_ARG(10); // dilations height
|
int dH = INT_ARG(10); // dilations height
|
||||||
int dW = INT_ARG(11); // dilations width
|
int dW = INT_ARG(11); // dilations width
|
||||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
int isNDHWC = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
int isNDHWC = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
const int rank = 5;
|
const int rank = 5;
|
||||||
|
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo);
|
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo);
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
|
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
|
||||||
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo);
|
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo);
|
||||||
|
@ -309,7 +311,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
||||||
int oC = weightsShapeInfo[indWoC+1]; // output channels
|
int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||||
|
|
||||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}));
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}));
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||||
|
|
|
@ -39,8 +39,8 @@ namespace nd4j {
|
||||||
* 2: padding
|
* 2: padding
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_conv1d)
|
#if NOT_EXCLUDED(OP_conv1d)
|
||||||
DECLARE_CUSTOM_OP(conv1d, 2, 1, false, 0, 4);
|
DECLARE_CUSTOM_OP(conv1d, 2, 1, false, 0, 5);
|
||||||
DECLARE_CUSTOM_OP(conv1d_bp, 3, 2, false, 0, 4);
|
DECLARE_CUSTOM_OP(conv1d_bp, 3, 2, false, 0, 5);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -37,79 +37,93 @@ namespace nd4j {
|
||||||
|
|
||||||
class ConvolutionUtils {
|
class ConvolutionUtils {
|
||||||
public:
|
public:
|
||||||
static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int isSameMode) {
|
static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) {
|
||||||
if(isSameMode > 0) {
|
|
||||||
oH = (int) math::nd4j_ceil<double, double>(iH * 1. / sH);
|
if(paddingMode == 0) { // valid
|
||||||
oW = (int) math::nd4j_ceil<double, double>(iW * 1. / sW);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
||||||
oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1;
|
oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1;
|
||||||
}
|
}
|
||||||
|
else if (paddingMode == 1) { // same
|
||||||
|
oH = (int) math::nd4j_ceil<double, double>(iH * 1. / sH);
|
||||||
|
oW = (int) math::nd4j_ceil<double, double>(iW * 1. / sW);
|
||||||
|
}
|
||||||
|
else { // causal
|
||||||
|
oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH
|
||||||
|
oW = (iW - 1) / sW + 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int isSameMode) {
|
static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) {
|
||||||
if(!isSameMode) { // valid
|
|
||||||
|
|
||||||
|
if(paddingMode == 0) { // valid
|
||||||
oD = (iD - (kD + (kD - 1) * (dD - 1)) + 2 * pD) / sD + 1;
|
oD = (iD - (kD + (kD - 1) * (dD - 1)) + 2 * pD) / sD + 1;
|
||||||
oH = (iH - (kH + (kH - 1) * (dH - 1)) + 2 * pH) / sH + 1;
|
oH = (iH - (kH + (kH - 1) * (dH - 1)) + 2 * pH) / sH + 1;
|
||||||
oW = (iW - (kW + (kW - 1) * (dW - 1)) + 2 * pW) / sW + 1;
|
oW = (iW - (kW + (kW - 1) * (dW - 1)) + 2 * pW) / sW + 1;
|
||||||
}
|
}
|
||||||
else { // same
|
else if(paddingMode == 1) { // same
|
||||||
|
|
||||||
oD = (int) nd4j::math::nd4j_ceil<double, double>(iD * 1. / sD);
|
oD = (int) nd4j::math::nd4j_ceil<double, double>(iD * 1. / sD);
|
||||||
oH = (int) nd4j::math::nd4j_ceil<double, double>(iH * 1. / sH);
|
oH = (int) nd4j::math::nd4j_ceil<double, double>(iH * 1. / sH);
|
||||||
oW = (int) nd4j::math::nd4j_ceil<double, double>(iW * 1. / sW);
|
oW = (int) nd4j::math::nd4j_ceil<double, double>(iW * 1. / sW);
|
||||||
|
|
||||||
|
}
|
||||||
|
else { // causal
|
||||||
|
oD = (iD - 1) / sD + 1;
|
||||||
|
oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH
|
||||||
|
oW = (iW - 1) / sW + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline void calcPadding2D(int& pH, int& pW, int oH, int oW, int iH, int iW, int kH, int kW, int sH, int sW, int dH, int dW) {
|
static inline void calcPadding2D(int& pH, int& pW, int oH, int oW, int iH, int iW, int kH, int kW, int sH, int sW, int dH, int dW, const int paddingMode = 1 /* default is same mode*/) {
|
||||||
int eKH, eKW;
|
|
||||||
if (dH == 1 && dW == 1) {
|
if(paddingMode == 0) // valid
|
||||||
eKH = kH;
|
return;
|
||||||
eKW = kW;
|
|
||||||
} else {
|
if(paddingMode == 1) { // same
|
||||||
eKH = (kH - 1) * dH + 1;
|
|
||||||
eKW = (kW - 1) * dW + 1;
|
const int eKH = (kH - 1) * dH + 1;
|
||||||
}
|
const int eKW = (kW - 1) * dW + 1;
|
||||||
|
|
||||||
pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
|
pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
|
||||||
pW = ((oW - 1) * sW + eKW - iW) / 2;
|
pW = ((oW - 1) * sW + eKW - iW) / 2;
|
||||||
}
|
}
|
||||||
|
else { // causal
|
||||||
static inline void calcPadding3D(int& pD, int& pH, int& pW, const int oD, const int oH, const int oW, const int iD, const int iH, const int iW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int dD, const int dH, const int dW) {
|
pH = (kH - 1) * dH;
|
||||||
int eKD, eKH, eKW;
|
pW = (kW - 1) * dW;
|
||||||
if (dD == 1 && dH == 1 && dW == 1) {
|
}
|
||||||
eKD = kD;
|
|
||||||
eKH = kH;
|
|
||||||
eKW = kW;
|
|
||||||
} else {
|
|
||||||
eKD = (kD - 1) * dD + 1;
|
|
||||||
eKH = (kH - 1) * dH + 1;
|
|
||||||
eKW = (kW - 1) * dW + 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pD = ((oD - 1) * sD + eKD - iD) / 2; // Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
|
static inline void calcPadding3D(int& pD, int& pH, int& pW, const int oD, const int oH, const int oW, const int iD, const int iH, const int iW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int dD, const int dH, const int dW, const int paddingMode = 1 /* default is same mode*/) {
|
||||||
pH = ((oH - 1) * sH + eKH - iH) / 2;
|
|
||||||
pW = ((oW - 1) * sW + eKW - iW) / 2;
|
|
||||||
|
|
||||||
|
if(paddingMode == 0) // valid
|
||||||
|
return;
|
||||||
|
|
||||||
|
if(paddingMode == 1) { // same
|
||||||
|
|
||||||
|
const int eKD = (kD - 1) * dD + 1;
|
||||||
|
const int eKH = (kH - 1) * dH + 1;
|
||||||
|
const int eKW = (kW - 1) * dW + 1;
|
||||||
|
|
||||||
|
pD = ((oD - 1) * sD + eKD - iD) / 2;
|
||||||
|
pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
|
||||||
|
pW = ((oW - 1) * sW + eKW - iW) / 2;
|
||||||
|
}
|
||||||
|
else { // causal
|
||||||
|
pD = (kD - 1) * dD;
|
||||||
|
pH = (kH - 1) * dH;
|
||||||
|
pW = (kW - 1) * dW;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculation of output height and width in 2D deconvolution procedure
|
// calculation of output height and width in 2D deconvolution procedure
|
||||||
static inline void calcOutSizeDeconv2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int isSameMode) {
|
static inline void calcOutSizeDeconv2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) {
|
||||||
if (isSameMode) {
|
|
||||||
|
if (paddingMode) {
|
||||||
oH = sH * iH;
|
oH = sH * iH;
|
||||||
oW = sW * iW;
|
oW = sW * iW;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
int ekH, ekW;
|
const int ekH = (kH - 1) * dH + 1;
|
||||||
if (dH == 1 && dW == 1) {
|
const int ekW = (kW - 1) * dW + 1;
|
||||||
ekH = kH;
|
|
||||||
ekW = kW;
|
|
||||||
} else {
|
|
||||||
ekH = (kH - 1) * dH + 1;
|
|
||||||
ekW = (kW - 1) * dW + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
oH = sH * (iH - 1) + ekH - 2 * pH;
|
oH = sH * (iH - 1) + ekH - 2 * pH;
|
||||||
oW = sW * (iW - 1) + ekW - 2 * pW;
|
oW = sW * (iW - 1) + ekW - 2 * pW;
|
||||||
|
@ -117,24 +131,19 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculation of output height and width in 3D deconvolution procedure
|
// calculation of output height and width in 3D deconvolution procedure
|
||||||
static inline void calcOutSizeDeconv3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int isSameMode) {
|
static inline void calcOutSizeDeconv3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) {
|
||||||
if (isSameMode) {
|
|
||||||
|
if (paddingMode) {
|
||||||
oD = sD * iD;
|
oD = sD * iD;
|
||||||
oH = sH * iH;
|
oH = sH * iH;
|
||||||
oW = sW * iW;
|
oW = sW * iW;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
int ekD, ekH, ekW;
|
|
||||||
if (dD == 1 && dH == 1 && dW == 1) {
|
const int ekD = (kD - 1) * dD + 1;
|
||||||
ekD = kD;
|
const int ekH = (kH - 1) * dH + 1;
|
||||||
ekH = kH;
|
const int ekW = (kW - 1) * dW + 1;
|
||||||
ekW = kW;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
ekD = (kD - 1) * dD + 1;
|
|
||||||
ekH = (kH - 1) * dH + 1;
|
|
||||||
ekW = (kW - 1) * dW + 1;
|
|
||||||
}
|
|
||||||
oD = sD * (iD - 1) + ekD - 2 * pD;
|
oD = sD * (iD - 1) + ekD - 2 * pD;
|
||||||
oH = sH * (iH - 1) + ekH - 2 * pH;
|
oH = sH * (iH - 1) + ekH - 2 * pH;
|
||||||
oW = sW * (iW - 1) + ekW - 2 * pW;
|
oW = sW * (iW - 1) + ekW - 2 * pW;
|
||||||
|
@ -194,10 +203,10 @@ namespace nd4j {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) {
|
// static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) {
|
||||||
|
|
||||||
// if(kH != 1) {
|
// if(kH != 1) {
|
||||||
// if(isSameMode) {
|
// if(paddingMode) {
|
||||||
// pH = (oH - 1) * sH - iH + kH - pH;
|
// pH = (oH - 1) * sH - iH + kH - pH;
|
||||||
// dH = dH - 1;
|
// dH = dH - 1;
|
||||||
// }
|
// }
|
||||||
|
@ -205,7 +214,7 @@ namespace nd4j {
|
||||||
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
||||||
// }
|
// }
|
||||||
// if(kW != 1) {
|
// if(kW != 1) {
|
||||||
// if(isSameMode) {
|
// if(paddingMode) {
|
||||||
// pW = (oW - 1) * sW - iW + kW - pW;
|
// pW = (oW - 1) * sW - iW + kW - pW;
|
||||||
// dW = dW - 1;
|
// dW = dW - 1;
|
||||||
// }
|
// }
|
||||||
|
@ -214,10 +223,10 @@ namespace nd4j {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
|
// static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int paddingMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
|
||||||
|
|
||||||
// if(kD != 1) {
|
// if(kD != 1) {
|
||||||
// if(isSameMode) {
|
// if(paddingMode) {
|
||||||
// pD = (oD - 1) * sD - iD + kD - pD;
|
// pD = (oD - 1) * sD - iD + kD - pD;
|
||||||
// dD = dD - 1;
|
// dD = dD - 1;
|
||||||
// }
|
// }
|
||||||
|
@ -225,7 +234,7 @@ namespace nd4j {
|
||||||
// dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
|
// dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
|
||||||
// }
|
// }
|
||||||
// if(kH != 1) {
|
// if(kH != 1) {
|
||||||
// if(isSameMode) {
|
// if(paddingMode) {
|
||||||
// pH = (oH - 1) * sH - iH + kH - pH;
|
// pH = (oH - 1) * sH - iH + kH - pH;
|
||||||
// dH = dH - 1;
|
// dH = dH - 1;
|
||||||
// }
|
// }
|
||||||
|
@ -233,7 +242,7 @@ namespace nd4j {
|
||||||
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
||||||
// }
|
// }
|
||||||
// if(kW != 1) {
|
// if(kW != 1) {
|
||||||
// if(isSameMode) {
|
// if(paddingMode) {
|
||||||
// pW = (oW - 1) * sW - iW + kW - pW;
|
// pW = (oW - 1) * sW - iW + kW - pW;
|
||||||
// dW = dW - 1;
|
// dW = dW - 1;
|
||||||
// }
|
// }
|
||||||
|
@ -242,19 +251,19 @@ namespace nd4j {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||||
|
|
||||||
// static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
// static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
||||||
|
|
||||||
// static void conv2dBP(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
|
// static void conv2dBP(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
|
||||||
|
|
||||||
static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||||
|
|
||||||
static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||||
|
|
||||||
static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||||
|
|
||||||
static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW);
|
||||||
|
|
||||||
static void vol2col(nd4j::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
static void vol2col(nd4j::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
||||||
|
|
||||||
|
|
|
@ -258,7 +258,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
// weights [kH, kW, iC, oC] always
|
// weights [kH, kW, iC, oC] always
|
||||||
|
@ -273,15 +273,14 @@ namespace nd4j {
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 1-NCHW, 0-NHWC
|
// isNCHW 1-NCHW, 0-NHWC
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
nd4j_debug("MKL-DNN is not used for conv2d!\n", 0);
|
nd4j_debug("MKL-DNN is not used for conv2d!\n", 0);
|
||||||
|
|
||||||
|
@ -320,7 +319,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
// weights [kH, kW, iC, oC] always
|
// weights [kH, kW, iC, oC] always
|
||||||
|
@ -339,15 +338,14 @@ namespace nd4j {
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 0-NHWC, 1-NCHW
|
// isNCHW 0-NHWC, 1-NCHW
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0);
|
nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0);
|
||||||
|
|
||||||
|
@ -393,7 +391,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
// weights [kH, kW, iC, mC] always
|
// weights [kH, kW, iC, mC] always
|
||||||
|
@ -408,7 +406,7 @@ namespace nd4j {
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 0-NCHW, 1-NHWC
|
// isNCHW 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||||
|
@ -430,7 +428,7 @@ namespace nd4j {
|
||||||
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(paddingMode == 1) // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||||
|
@ -449,7 +447,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||||
// weights [kH, kW, iC, mC] always
|
// weights [kH, kW, iC, mC] always
|
||||||
|
@ -467,7 +465,7 @@ namespace nd4j {
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 0-NHWC, 1-NCHW
|
// isNCHW 0-NHWC, 1-NCHW
|
||||||
|
|
||||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||||
|
@ -492,7 +490,7 @@ namespace nd4j {
|
||||||
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(paddingMode == 1) // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||||
|
@ -526,7 +524,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
// weightsDepth [kH, kW, iC, mC] always
|
// weightsDepth [kH, kW, iC, mC] always
|
||||||
|
@ -542,7 +540,7 @@ namespace nd4j {
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 1-NCHW, 0-NHWC
|
// isNCHW 1-NCHW, 0-NHWC
|
||||||
|
|
||||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
||||||
|
@ -555,11 +553,11 @@ namespace nd4j {
|
||||||
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
|
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
|
||||||
|
|
||||||
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
|
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
|
||||||
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
|
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
|
||||||
if (weightsPoint) {
|
if (weightsPoint) {
|
||||||
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH, oW=iW
|
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW
|
||||||
delete outputDepth;
|
delete outputDepth;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1774,20 +1772,20 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||||
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
||||||
|
|
|
@ -217,7 +217,7 @@ void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& col,
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
// weights [kH, kW, iC, oC] always
|
// weights [kH, kW, iC, oC] always
|
||||||
|
@ -232,15 +232,14 @@ static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDA
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 1-NCHW, 0-NHWC
|
// isNCHW 1-NCHW, 0-NHWC
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
std::vector<int> permutForOutput;
|
std::vector<int> permutForOutput;
|
||||||
|
|
||||||
|
@ -276,13 +275,13 @@ static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDA
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
// weights [kH, kW, iC, mC] always
|
// weights [kH, kW, iC, mC] always
|
||||||
|
@ -297,7 +296,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input,
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 0-NCHW, 1-NHWC
|
// isNCHW 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||||
|
@ -319,7 +318,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input,
|
||||||
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(paddingMode == 1) // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||||
|
@ -337,13 +336,13 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input,
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
// weightsDepth [kH, kW, iC, mC] always
|
// weightsDepth [kH, kW, iC, mC] always
|
||||||
|
@ -359,7 +358,7 @@ static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const ND
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 1-NCHW, 0-NHWC
|
// isNCHW 1-NCHW, 0-NHWC
|
||||||
|
|
||||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width
|
||||||
|
@ -372,18 +371,18 @@ static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const ND
|
||||||
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
|
outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector<Nd4jLong>({bS, oH, oW, iC*mC}) : std::vector<Nd4jLong>({bS, iC*mC, oH, oW}), input->dataType(), input->getContext());
|
||||||
|
|
||||||
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
|
// ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- //
|
||||||
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
|
// ----- perform pointwise convolution (oH = iH, oW = iW) ----- //
|
||||||
if (weightsPoint) {
|
if (weightsPoint) {
|
||||||
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH, oW=iW
|
ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW); // in this case oH=iH, oW=iW
|
||||||
delete outputDepth;
|
delete outputDepth;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1177,7 +1176,7 @@ void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& i
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
// weights [kH, kW, iC, oC] always
|
// weights [kH, kW, iC, oC] always
|
||||||
|
@ -1196,15 +1195,14 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 0-NHWC, 1-NCHW
|
// isNCHW 0-NHWC, 1-NCHW
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
std::vector<int> gradOaxesForDot;
|
std::vector<int> gradOaxesForDot;
|
||||||
|
|
||||||
|
@ -1247,13 +1245,13 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||||
// weights [kH, kW, iC, mC] always
|
// weights [kH, kW, iC, mC] always
|
||||||
|
@ -1271,7 +1269,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
||||||
// pW paddings width
|
// pW paddings width
|
||||||
// dH dilations height
|
// dH dilations height
|
||||||
// dW dilations width
|
// dW dilations width
|
||||||
// isSameMode 0-VALID, 1-SAME
|
// paddingMode 0-VALID, 1-SAME
|
||||||
// isNCHW 0-NHWC, 1-NCHW
|
// isNCHW 0-NHWC, 1-NCHW
|
||||||
|
|
||||||
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width
|
||||||
|
@ -1296,7 +1294,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
||||||
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(paddingMode == 1) // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||||
|
@ -1328,8 +1326,8 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) {
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
#include <ops/declarable/helpers/col2im.h>
|
#include <ops/declarable/helpers/col2im.h>
|
||||||
#include <PointersManager.h>
|
#include <PointersManager.h>
|
||||||
|
#include <GradCheck.h>
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
#ifdef HAVE_MKLDNN
|
||||||
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
|
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
|
||||||
|
@ -771,7 +772,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
|
||||||
bias.linspace(1);
|
bias.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 0});
|
auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result_FF->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result_FF->status());
|
||||||
|
|
||||||
|
@ -785,7 +786,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
|
||||||
auto epsilonNxt = z->dup();
|
auto epsilonNxt = z->dup();
|
||||||
epsilonNxt->linspace(1);
|
epsilonNxt->linspace(1);
|
||||||
|
|
||||||
auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 0});
|
auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result_BP->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result_BP->status());
|
||||||
|
|
||||||
auto eps = result_BP->at(0);
|
auto eps = result_BP->at(0);
|
||||||
|
@ -813,7 +814,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::conv1d op;
|
nd4j::ops::conv1d op;
|
||||||
auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1});
|
auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1, 1,0});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -822,6 +823,219 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_1) {
|
||||||
|
|
||||||
|
int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iW, iC});
|
||||||
|
NDArray weights('c', {kW, iC, oC});
|
||||||
|
NDArray bias('c', {oC}, {-1,-2,-3});
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oW, oC}, {18. , 18. , 18. , 53. , 55.6, 58.2, 89.8, 95.6, 101.4, 102. , 106.8, 111.6, 163.4, 175.6, 187.8, 200.2, 215.6, 231.});
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_2) {
|
||||||
|
|
||||||
|
int bS=2, iW=16, iC=3,oC=4, kW=2, sW=2, pW=0, dW=1;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iW, iC});
|
||||||
|
NDArray weights('c', {kW, iC, oC});
|
||||||
|
NDArray bias('c', {oC}, {-1,-2,-3,-4});
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oW, oC}, { 10. , 9.6, 9.2, 8.8, 48.9, 51.8, 54.7, 57.6, 88.5, 95. , 101.5, 108. , 128.1, 138.2, 148.3, 158.4,
|
||||||
|
167.7, 181.4, 195.1, 208.8, 207.3, 224.6, 241.9, 259.2, 246.9, 267.8, 288.7, 309.6, 286.5, 311. , 335.5, 360. ,
|
||||||
|
254.8, 268.8, 282.8, 296.8, 365.7, 397.4, 429.1, 460.8, 405.3, 440.6, 475.9, 511.2, 444.9, 483.8, 522.7, 561.6,
|
||||||
|
484.5, 527. , 569.5, 612. , 524.1, 570.2, 616.3, 662.4, 563.7, 613.4, 663.1, 712.8, 603.3, 656.6, 709.9, 763.2});
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_3) {
|
||||||
|
|
||||||
|
int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iW, iC});
|
||||||
|
NDArray weights('c', {kW, iC, oC});
|
||||||
|
NDArray bias('c', {oC}, {-1,-2,-3,-4});
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16.,145.4, 151.6, 157.8, 164.,283.1, 297.4, 311.7, 326., 420.8, 443.2, 465.6, 488.,
|
||||||
|
558.5, 589., 619.5, 650.,696.2001, 734.8, 773.4, 812., 434.8, 448.8, 462.8, 476.8, 879.8, 929.2, 978.6, 1028.,
|
||||||
|
1017.5, 1075., 1132.5, 1190.,1155.2001, 1220.8, 1286.4, 1352.,1292.8999, 1366.6, 1440.3, 1514., 1430.6001, 1512.4, 1594.2, 1676.});
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_4) {
|
||||||
|
|
||||||
|
int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iW, iC});
|
||||||
|
NDArray weights('c', {kW, iC, oC});
|
||||||
|
NDArray bias('c', {oC}, {-1,-2,-3,-4});
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16. ,43.3, 43.8, 44.3, 44.8,69.4, 70.8, 72.2, 73.6,106.5, 109.4, 112.3, 115.2,147.9, 152.6, 157.3, 162. ,189.3, 195.8, 202.3,
|
||||||
|
208.8,234.5, 243.4, 252.3, 261.2,280.4, 292. , 303.6, 315.2, 226. , 232.8, 239.6, 246.4, 252.1, 259.8, 267.5, 275.2,278.2, 286.8, 295.4, 304. ,437.7,
|
||||||
|
455. , 472.3, 489.6,479.1, 498.2, 517.3, 536.4,520.5, 541.4, 562.3, 583.2, 601.7, 632.2, 662.7, 693.2, 647.6, 680.8, 714. , 747.2});
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_5) {
|
||||||
|
|
||||||
|
int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iC, iW});
|
||||||
|
NDArray weights('c', {kW, iC, oC});
|
||||||
|
NDArray bias('c', {oC}, {-1,-2,-3,-4});
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oC, oW}, { 83.7, 92.4, 101.1, 162.1, 175.9, 189.7, 223.4, 238.7,85.4, 94.4, 103.4, 167.4, 181.8, 196.2, 233.2, 249.4,87.1, 96.4, 105.7, 172.7, 187.7, 202.7, 243. , 260.1,
|
||||||
|
88.8, 98.4, 108. , 178. , 193.6, 209.2, 252.8, 270.8, 292.5, 301.2, 309.9, 493.3, 507.1, 520.9, 590.6, 605.9, 301.4, 310.4, 319.4, 513. , 527.4, 541.8, 622. , 638.2,
|
||||||
|
310.3, 319.6, 328.9, 532.7, 547.7, 562.7, 653.4, 670.5, 319.2, 328.8, 338.4, 552.4, 568. , 583.6, 684.8, 702.8});
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_6) {
|
||||||
|
|
||||||
|
int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iC, iW});
|
||||||
|
NDArray weights('c', {kW, iC, oC});
|
||||||
|
NDArray bias('c', {oC}, {-1,-2,-3,-4});
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oC, oW}, {159.7,335.3,381.2,427.1,473. ,518.9,163.8,351.4,400. ,448.6,497.2,545.8,167.9,367.5,418.8,470.1,521.4,572.7,172. ,383.6,437.6,491.6,545.6,599.6,
|
||||||
|
577.3, 1069.7, 1115.6, 1161.5, 1207.4, 1253.3,595.8, 1129. , 1177.6, 1226.2, 1274.8, 1323.4,614.3, 1188.3, 1239.6, 1290.9, 1342.2, 1393.5,
|
||||||
|
632.8, 1247.6, 1301.6, 1355.6, 1409.6, 1463.6});
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {
|
||||||
|
|
||||||
|
int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iW, iC});
|
||||||
|
NDArray weights('c', {kW, iC, oC});
|
||||||
|
NDArray bias('c', {oC}, {-1,-2,-3});
|
||||||
|
NDArray gradO('c', {bS, oW, oC});
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-1.5, 0.1);
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
|
||||||
|
nd4j::ops::conv1d opFF;
|
||||||
|
nd4j::ops::conv1d_bp opBP;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
|
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ConvolutionTests1, Test_Dilation2D_1) {
|
TEST_F(ConvolutionTests1, Test_Dilation2D_1) {
|
||||||
auto input = NDArrayFactory::create<double>('c', {2, 6, 6, 3});
|
auto input = NDArrayFactory::create<double>('c', {2, 6, 6, 3});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {3, 2, 3});
|
auto weights = NDArrayFactory::create<double>('c', {3, 2, 3});
|
||||||
|
|
|
@ -908,7 +908,7 @@ TEST_F(ParityOpsTests, scatterMax_test4) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8});
|
||||||
|
|
||||||
nd4j::ops::scatter_max op;
|
nd4j::ops::scatter_max op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {});
|
auto result = op.execute({&matrix, &idc, &updates}, {}, {true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
|
@ -29,12 +29,11 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.util.Collections;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -79,7 +78,8 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
addIArgument(config.getK(),
|
addIArgument(config.getK(),
|
||||||
config.getS(),
|
config.getS(),
|
||||||
config.getP(),
|
config.getP(),
|
||||||
ArrayUtil.fromBoolean(config.isSameMode()),
|
config.getD(),
|
||||||
|
config.getPaddingMode().ordinal(),
|
||||||
ArrayUtil.fromBoolean(config.isNWC()));
|
ArrayUtil.fromBoolean(config.isNWC()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,10 +95,12 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
public Object getValue(Field property) {
|
public Object getValue(Field property) {
|
||||||
if (config == null && !iArguments.isEmpty()) {
|
if (config == null && !iArguments.isEmpty()) {
|
||||||
config = Conv1DConfig.builder()
|
config = Conv1DConfig.builder()
|
||||||
.s(iArguments.get(0))
|
.k(iArguments.get(0))
|
||||||
.p(iArguments.get(1))
|
.s(iArguments.get(1))
|
||||||
.isSameMode(iArguments.get(2) == 1)
|
.p(iArguments.get(2))
|
||||||
.dataFormat(iArguments.get(3) == 1 ? Conv1DConfig.NCW : Conv1DConfig.NWC)
|
.d(iArguments.get(3))
|
||||||
|
.paddingMode(PaddingMode.values()[iArguments.get(4).intValue()])
|
||||||
|
.dataFormat(iArguments.get(5) == 1 ? Conv1DConfig.NCW : Conv1DConfig.NWC)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,16 +127,20 @@ public class Conv1D extends DynamicCustomOp {
|
||||||
return "conv1d";
|
return "conv1d";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
int n = args().length;
|
int n = args().length;
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
return Collections.singletonList(inputDataTypes.get(0));
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> grads){
|
||||||
|
List<SDVariable> args = new ArrayList<>();
|
||||||
|
Collections.addAll(args, args());
|
||||||
|
args.add(grads.get(0));
|
||||||
|
|
||||||
|
Conv1DDerivative gradFn = new Conv1DDerivative(sameDiff, args.toArray(new SDVariable[0]), config);
|
||||||
|
return Arrays.asList(gradFn.outputVariables());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Conv1D Backprop operation
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Getter
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class Conv1DDerivative extends DynamicCustomOp {
|
||||||
|
|
||||||
|
protected Conv1DConfig config;
|
||||||
|
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
|
||||||
|
|
||||||
|
public Conv1DDerivative(@NonNull SameDiff sameDiff,
|
||||||
|
@NonNull SDVariable[] inputs,
|
||||||
|
@NonNull Conv1DConfig config) {
|
||||||
|
super(sameDiff, inputs);
|
||||||
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv1DDerivative(@NonNull SameDiff sd, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, SDVariable gradOut, @NonNull Conv1DConfig config){
|
||||||
|
this(sd, wrapFilterNull(input, weights, bias, gradOut), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv1DDerivative(INDArray[] inputs, INDArray[] outputs, Conv1DConfig config){
|
||||||
|
super(inputs, outputs);
|
||||||
|
|
||||||
|
initConfig(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Conv1DDerivative(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull INDArray gradOut, INDArray output, @NonNull Conv1DConfig config){
|
||||||
|
this(wrapFilterNull(input, weights, bias, gradOut), wrapOrNull(output), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initConfig(Conv1DConfig config){
|
||||||
|
this.config = config;
|
||||||
|
Preconditions.checkState(config.getS() >= 1 && config.getP() >= 0, INVALID_CONFIGURATION, config.getS(), config.getP());
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void addArgs() {
|
||||||
|
if (config == null)
|
||||||
|
config = Conv1DConfig.builder().build();
|
||||||
|
|
||||||
|
addIArgument(config.getK(),
|
||||||
|
config.getS(),
|
||||||
|
config.getP(),
|
||||||
|
config.getD(),
|
||||||
|
config.getPaddingMode().ordinal(),
|
||||||
|
ArrayUtil.fromBoolean(config.isNWC()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long[] iArgs() {
|
||||||
|
if (iArguments.size() == 0)
|
||||||
|
addArgs();
|
||||||
|
|
||||||
|
return super.iArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getValue(Field property) {
|
||||||
|
if (config == null && !iArguments.isEmpty()) {
|
||||||
|
config = Conv1DConfig.builder()
|
||||||
|
.k(iArguments.get(0))
|
||||||
|
.s(iArguments.get(1))
|
||||||
|
.p(iArguments.get(2))
|
||||||
|
.d(iArguments.get(3))
|
||||||
|
.paddingMode(PaddingMode.values()[iArguments.get(4).intValue()])
|
||||||
|
.dataFormat(iArguments.get(5) == 1 ? Conv1DConfig.NCW : Conv1DConfig.NWC)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
return config.getValue(property);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Object> propertiesForFunction() {
|
||||||
|
return config.toProperties();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isConfigProperties() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String configFieldName() {
|
||||||
|
return "config";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "conv1d_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumOutputs(){
|
||||||
|
if(args().length == 4){
|
||||||
|
return 3; //Includes bias
|
||||||
|
} else {
|
||||||
|
return 2; //No bias - only input + weight grads
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return new ArrayList<>(inputDataTypes.subList(0, inputDataTypes.size()-1)); //All except gradient input variable
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,6 +21,7 @@ import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
|
@ -38,15 +39,28 @@ public class Conv1DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private long p = 0; // padding
|
private long p = 0; // padding
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
|
private long d = 1; // dilation
|
||||||
|
@Builder.Default
|
||||||
private String dataFormat = NCW;
|
private String dataFormat = NCW;
|
||||||
private boolean isSameMode;
|
private PaddingMode paddingMode;
|
||||||
|
|
||||||
|
public Conv1DConfig(long k, long s, long p, long d, String dataFormat, @NonNull PaddingMode paddingMode) {
|
||||||
|
this.k = k;
|
||||||
|
this.s = s;
|
||||||
|
this.p = p;
|
||||||
|
this.d = d;
|
||||||
|
this.dataFormat = dataFormat;
|
||||||
|
this.paddingMode = paddingMode;
|
||||||
|
|
||||||
|
validate();
|
||||||
|
}
|
||||||
|
|
||||||
public Conv1DConfig(long k, long s, long p, String dataFormat, boolean isSameMode) {
|
public Conv1DConfig(long k, long s, long p, String dataFormat, boolean isSameMode) {
|
||||||
this.k = k;
|
this.k = k;
|
||||||
this.s = s;
|
this.s = s;
|
||||||
this.p = p;
|
this.p = p;
|
||||||
this.dataFormat = dataFormat;
|
this.dataFormat = dataFormat;
|
||||||
this.isSameMode = isSameMode;
|
this.paddingMode = isSameMode ? PaddingMode.SAME : PaddingMode.VALID;
|
||||||
|
|
||||||
validate();
|
validate();
|
||||||
}
|
}
|
||||||
|
@ -71,14 +85,15 @@ public class Conv1DConfig extends BaseConvolutionConfig {
|
||||||
ret.put("k", k);
|
ret.put("k", k);
|
||||||
ret.put("s", s);
|
ret.put("s", s);
|
||||||
ret.put("p", p);
|
ret.put("p", p);
|
||||||
ret.put("isSameMode", isSameMode);
|
ret.put("d", d);
|
||||||
|
ret.put("isSameMode", paddingMode);
|
||||||
ret.put("dataFormat", dataFormat);
|
ret.put("dataFormat", dataFormat);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void validate() {
|
protected void validate() {
|
||||||
ConvConfigUtil.validate1D(k, s, p);
|
ConvConfigUtil.validate1D(k, s, p, d);
|
||||||
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
|
Preconditions.checkArgument(dataFormat != null, "Data format can't be null");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
|
|
||||||
|
public enum PaddingMode {
|
||||||
|
VALID,
|
||||||
|
SAME,
|
||||||
|
CAUSAL
|
||||||
|
}
|
|
@ -76,11 +76,13 @@ public class ConvConfigUtil {
|
||||||
/**
|
/**
|
||||||
* Validate a 1D convolution's Kernel, Stride, and Padding
|
* Validate a 1D convolution's Kernel, Stride, and Padding
|
||||||
*/
|
*/
|
||||||
public static void validate1D(long k, long s, long p){
|
public static void validate1D(long k, long s, long p, long d){
|
||||||
Preconditions.checkArgument(k != 0, "Kernel can not be 0");
|
Preconditions.checkArgument(k != 0, "Kernel can not be 0");
|
||||||
|
|
||||||
Preconditions.checkArgument(s > 0, "Stride can not be negative or 0, got: %s", s);
|
Preconditions.checkArgument(s > 0, "Stride can not be negative or 0, got: %s", s);
|
||||||
|
|
||||||
|
Preconditions.checkArgument(d > 0, "Dilation can not be negative or 0, got: %s", s);
|
||||||
|
|
||||||
Preconditions.checkArgument(p >= 0, "Padding can not be negative, got: %s", p);
|
Preconditions.checkArgument(p >= 0, "Padding can not be negative, got: %s", p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,14 +39,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -944,7 +937,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
|
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
|
||||||
.k(k).p(0).s(1)
|
.k(k).p(0).s(1)
|
||||||
.isSameMode(false)
|
.paddingMode(PaddingMode.VALID)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
||||||
|
@ -960,6 +953,55 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConv1dCausal() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
int nIn = 3;
|
||||||
|
int nOut = 4;
|
||||||
|
int mb = 2;
|
||||||
|
|
||||||
|
for( int k : new int[]{2, 3}) {
|
||||||
|
for (int sz : new int[]{3, 4, 5}) {
|
||||||
|
for (int s : new int[]{1, 2}) {
|
||||||
|
for (int d : new int[]{1, 2}) {
|
||||||
|
for (boolean ncw : new boolean[]{true, false}) {
|
||||||
|
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
INDArray wArr = Nd4j.rand(DataType.DOUBLE, k, nIn, nOut);
|
||||||
|
INDArray inArr = Nd4j.rand(DataType.DOUBLE, (ncw ? new long[]{mb, nIn, sz} : new long[]{mb, sz, nIn}));
|
||||||
|
INDArray bArr = Nd4j.rand(DataType.DOUBLE, nOut);
|
||||||
|
|
||||||
|
SDVariable in = sd.var("in", inArr);
|
||||||
|
SDVariable w = sd.var("W", wArr);
|
||||||
|
SDVariable b = sd.var("b", bArr);
|
||||||
|
|
||||||
|
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
|
||||||
|
.dataFormat(ncw ? Conv1DConfig.NCW : Conv1DConfig.NWC)
|
||||||
|
.k(k).p(0).s(s).d(d)
|
||||||
|
.paddingMode(PaddingMode.CAUSAL)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig);
|
||||||
|
SDVariable loss = sd.nn().tanh(out).std(true).rename("loss");
|
||||||
|
|
||||||
|
sd.setLossVariables("loss");
|
||||||
|
|
||||||
|
String name = "k=" + k + ", sz=" + sz + ", ncw=" + ncw;
|
||||||
|
|
||||||
|
System.out.println(name);
|
||||||
|
|
||||||
|
TestCase tc = new TestCase(sd).testName(name).gradientCheck(true);
|
||||||
|
String err = OpValidation
|
||||||
|
.validate(tc);
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConv1dForward(){
|
public void testConv1dForward(){
|
||||||
int nIn = 2;
|
int nIn = 2;
|
||||||
|
@ -1254,7 +1296,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
|
Conv1DConfig conv1DConfig = Conv1DConfig.builder()
|
||||||
.k(k).p(-1).s(0)
|
.k(k).p(-1).s(0)
|
||||||
.isSameMode(false)
|
.paddingMode(PaddingMode.VALID)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig);
|
||||||
|
|
Loading…
Reference in New Issue