commit
7583ccfa15
|
@ -17,20 +17,20 @@ endif()
|
||||||
# -fsanitize=address
|
# -fsanitize=address
|
||||||
# -fsanitize=leak
|
# -fsanitize=leak
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true")
|
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true")
|
||||||
elseif(WIN32)
|
elseif(WIN32)
|
||||||
set(X86_BUILD true)
|
set(X86_BUILD true)
|
||||||
if (NOT CUDA_BLAS)
|
if (NOT CUDA_BLAS)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
|
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fmax-errors=2")
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
|
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
|
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -g -O0 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
|
set(CMAKE_CXX_FLAGS_DEBUG " -g -O0 -fPIC -std=c++11 -fmax-errors=2")
|
||||||
|
|
||||||
if (CPU_BLAS)
|
if (CPU_BLAS)
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
|
||||||
|
|
|
@ -97,6 +97,8 @@ namespace nd4j {
|
||||||
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
|
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
|
||||||
static std::string strideAsString(const NDArray* array);
|
static std::string strideAsString(const NDArray* array);
|
||||||
|
|
||||||
|
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
|
||||||
|
|
||||||
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
|
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
|
||||||
static Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, nd4j::memory::Workspace* workspace);
|
static Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, nd4j::memory::Workspace* workspace);
|
||||||
|
|
||||||
|
|
|
@ -469,7 +469,7 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
|
||||||
useBisection = true;
|
useBisection = true;
|
||||||
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
|
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
|
||||||
useBisection = true;
|
useBisection = true;
|
||||||
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev))
|
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev) && math::nd4j_abs<T>(fCur - fPrev) > (T)16. * DataTypeUtils::eps<T>())
|
||||||
useBisection = true;
|
useBisection = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -903,12 +903,8 @@ void SVD<T>::evalData(const NDArray& matrix) {
|
||||||
scale = 1.;
|
scale = 1.;
|
||||||
|
|
||||||
NDArray copy;
|
NDArray copy;
|
||||||
if(_transp) {
|
if(_transp)
|
||||||
copy = NDArrayFactory::create<T>(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(0)}, matrix.getContext());
|
copy = matrix.transpose();
|
||||||
for(int i = 0; i < copy.sizeAt(0); ++i)
|
|
||||||
for(int j = 0; j < copy.sizeAt(1); ++j)
|
|
||||||
copy.p<T>(i, j, matrix.e<T>(j,i) / scale);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
copy = matrix / scale;
|
copy = matrix / scale;
|
||||||
|
|
||||||
|
|
|
@ -671,6 +671,20 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
std::vector<Nd4jLong> ShapeUtils::shapeAsVector(const Nd4jLong* shapeInfo) {
|
||||||
|
|
||||||
|
if(!shapeInfo)
|
||||||
|
throw std::runtime_error("ShapeUtils::shapeAsVector method: input shapeInfo must not be nullptr !");
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> vector(shapeInfo[0]);
|
||||||
|
|
||||||
|
for (uint e = 0; e < shapeInfo[0]; e++)
|
||||||
|
vector[e] = shapeInfo[e + 1];
|
||||||
|
|
||||||
|
return vector;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
|
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
|
||||||
Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, nd4j::memory::Workspace* workspace){
|
Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, nd4j::memory::Workspace* workspace){
|
||||||
|
|
|
@ -323,7 +323,9 @@
|
||||||
(11, TruncatedNormalDistribution) ,\
|
(11, TruncatedNormalDistribution) ,\
|
||||||
(12, AlphaDropOut),\
|
(12, AlphaDropOut),\
|
||||||
(13, ExponentialDistribution),\
|
(13, ExponentialDistribution),\
|
||||||
(14, ExponentialDistributionInv)
|
(14, ExponentialDistributionInv), \
|
||||||
|
(15, PoissonDistribution), \
|
||||||
|
(16, GammaDistribution)
|
||||||
|
|
||||||
#define PAIRWISE_INT_OPS \
|
#define PAIRWISE_INT_OPS \
|
||||||
(0, ShiftLeft), \
|
(0, ShiftLeft), \
|
||||||
|
|
|
@ -58,8 +58,8 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, oC, iC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
||||||
|
|
||||||
//----- calculation of output -----//
|
//----- calculation of output -----//
|
||||||
// NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
|
// NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
|
||||||
// NCHW: [iC, oC, kH, kW] x [bS, iC, iH, iW] = [oC, kH, kW, bS, iH, iW]
|
// NCHW: [kH, kW, oC, iC] x [bS, iC, iH, iW] = [kH, kW, oC, bS, iH, iW]
|
||||||
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5});
|
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5});
|
||||||
LaunchContext* ctx = block.launchContext();
|
LaunchContext* ctx = block.launchContext();
|
||||||
helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
|
helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
|
||||||
|
@ -103,8 +103,8 @@ DECLARE_SHAPE_FN(deconv2d) {
|
||||||
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
||||||
|
|
||||||
const int rank = 4;
|
const int rank = 4;
|
||||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo));
|
||||||
|
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
||||||
|
@ -131,10 +131,10 @@ DECLARE_SHAPE_FN(deconv2d) {
|
||||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||||
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, oC, iC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||||
if (biasShapeInfo)
|
if (biasShapeInfo)
|
||||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||||
|
|
||||||
int oH, oW; // output height, width
|
int oH, oW; // output height, width
|
||||||
ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
@ -196,15 +196,18 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
int trueoH, trueoW; // true output height, width
|
int trueoH, trueoW; // true output height, width
|
||||||
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, oC, iC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if(bias)
|
if(bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode){ // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// ----- calculation of gradI -> pass it through conv2d_ff ----- //
|
// ----- calculation of gradI -> pass it through conv2d_ff ----- //
|
||||||
nd4j::ops::conv2d conv2d;
|
nd4j::ops::conv2d conv2d;
|
||||||
|
@ -252,9 +255,9 @@ DECLARE_SHAPE_FN(deconv2d_bp) {
|
||||||
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
const int rank = 4;
|
const int rank = 4;
|
||||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, weightsShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, shape::rank(weightsShapeInfo));
|
||||||
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo));
|
||||||
|
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
||||||
|
@ -284,10 +287,10 @@ DECLARE_SHAPE_FN(deconv2d_bp) {
|
||||||
int trueoH, trueoW; // true output height, width
|
int trueoH, trueoW; // true output height, width
|
||||||
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, oC, iC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||||
if(biasShapeInfo)
|
if(biasShapeInfo)
|
||||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||||
|
|
||||||
|
|
|
@ -65,10 +65,10 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
|
||||||
int trueoH, trueoW; // true output height, width
|
int trueoH, trueoW; // true output height, width
|
||||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
|
||||||
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||||
|
|
||||||
|
@ -89,10 +89,9 @@ DECLARE_SHAPE_FN(deconv2d_tf) {
|
||||||
|
|
||||||
const int rank = 4;
|
const int rank = 4;
|
||||||
|
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo));
|
||||||
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_TF OP: rank of input array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV2D_TF OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo));
|
||||||
REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, gradIShapeShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(gradIShapeShapeInfo) == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, shape::rank(gradIShapeShapeInfo));
|
||||||
|
|
||||||
|
|
||||||
const int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
const int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
|
||||||
const int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
const int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
|
||||||
|
@ -126,10 +125,10 @@ DECLARE_SHAPE_FN(deconv2d_tf) {
|
||||||
int trueiH, trueiW; // output height, width
|
int trueiH, trueiW; // output height, width
|
||||||
ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode);
|
ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode);
|
||||||
|
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1}));
|
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradIShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradIShape).c_str());
|
REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||||
|
|
||||||
Nd4jLong shape[4];
|
Nd4jLong shape[4];
|
||||||
shape[0] = bS;
|
shape[0] = bS;
|
||||||
|
|
|
@ -59,22 +59,22 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(!isNCDHW)
|
if(!isNCDHW)
|
||||||
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
||||||
|
|
||||||
//----- calculation of output -----//
|
//----- calculation of output -----//
|
||||||
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||||
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
|
// NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||||
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
|
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||||
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
||||||
|
|
||||||
|
@ -105,8 +105,8 @@ DECLARE_SHAPE_FN(deconv3d) {
|
||||||
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
||||||
|
|
||||||
const int rank = 5;
|
const int rank = 5;
|
||||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo));
|
||||||
|
|
||||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
|
||||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height
|
||||||
|
@ -138,10 +138,10 @@ DECLARE_SHAPE_FN(deconv3d) {
|
||||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||||
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||||
if (biasShapeInfo)
|
if (biasShapeInfo)
|
||||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo));
|
||||||
|
|
||||||
int oD, oH, oW; // output depth, height, width
|
int oD, oH, oW; // output depth, height, width
|
||||||
ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||||
|
@ -209,15 +209,15 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
||||||
int trueoD, trueoH, trueoW; // true output height, width
|
int trueoD, trueoH, trueoW; // true output height, width
|
||||||
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if(bias)
|
if(bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
// ----- calculation of gradI -> pass it through conv3d_ff ----- //
|
// ----- calculation of gradI -> pass it through conv3d_ff ----- //
|
||||||
nd4j::ops::conv3dnew conv3d;
|
nd4j::ops::conv3dnew conv3d;
|
||||||
|
@ -252,7 +252,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
||||||
if(!isNCDHW)
|
if(!isNCDHW)
|
||||||
delete gradO;
|
delete gradO;
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(deconv3d_bp) {
|
DECLARE_TYPES(deconv3d_bp) {
|
||||||
|
@ -272,9 +272,9 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
|
||||||
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
const int rank = 5;
|
const int rank = 5;
|
||||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
|
||||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, weightsShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, shape::rank(weightsShapeInfo));
|
||||||
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]);
|
REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo));
|
||||||
|
|
||||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
|
||||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height
|
||||||
|
@ -309,10 +309,10 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
|
||||||
int trueoD, trueoH, trueoW; // true output depth, height, width
|
int trueoD, trueoH, trueoW; // true output depth, height, width
|
||||||
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}));
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||||
if(biasShapeInfo)
|
if(biasShapeInfo)
|
||||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||||
|
|
||||||
|
|
|
@ -69,36 +69,26 @@ DECLARE_TYPES(biasadd) {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) {
|
CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto bias = INPUT_VARIABLE(1);
|
auto bias = INPUT_VARIABLE(1);
|
||||||
auto epsilonNext = INPUT_VARIABLE(2);
|
auto gradO = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
auto epsilon = OUTPUT_VARIABLE(0);
|
auto gradI = OUTPUT_VARIABLE(0);
|
||||||
auto gradB = OUTPUT_VARIABLE(1);
|
auto gradB = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
epsilon->assign(epsilonNext);
|
const bool isNCHW = !block.getBArguments()->empty() ? B_ARG(0) : false;
|
||||||
|
const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last
|
||||||
|
|
||||||
// cnn case
|
gradI->assign(gradO);
|
||||||
if (input->rankOf() == 4) {
|
|
||||||
auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3});
|
|
||||||
epsilonNext2d.reshapei('c', {(int) bias->lengthOf(), -1});
|
|
||||||
|
|
||||||
auto sum = epsilonNext2d.reduceAlongDimension(reduce::Sum, {1});
|
gradO->reduceAlongDimension(nd4j::reduce::Sum, gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim}));
|
||||||
gradB->assign(sum);
|
|
||||||
|
|
||||||
delete sum;
|
|
||||||
} else if (input->rankOf() == 2) {
|
|
||||||
// regular fully-connected case
|
|
||||||
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});
|
|
||||||
gradB->assign(sum);
|
|
||||||
|
|
||||||
delete sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
DECLARE_SYN(BiasAddGrad, biasadd_bp);
|
DECLARE_SYN(BiasAddGrad, biasadd_bp);
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_SHAPE_FN(biasadd_bp) {
|
DECLARE_SHAPE_FN(biasadd_bp) {
|
||||||
auto input = inputShape->at(0);
|
auto input = inputShape->at(0);
|
||||||
auto bias = inputShape->at(1);
|
auto bias = inputShape->at(1);
|
||||||
|
|
|
@ -623,7 +623,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//Zero output array, so unused elements have 0 gradient
|
//Zero output array, so unused elements have 0 gradient
|
||||||
output->nullify();
|
output->nullify();
|
||||||
|
std::sort(indices.begin(), indices.end());
|
||||||
if(indices.size() == 3 && (indices[1] - indices[0]) == 1) {
|
if(indices.size() == 3 && (indices[1] - indices[0]) == 1) {
|
||||||
output->p(indices[0], *epsNext);
|
output->p(indices[0], *epsNext);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_random_gamma)
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/random.h>
|
||||||
|
#include <ops/declarable/helpers/random.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(random_gamma, 2, 1, false, 0, 0) {
|
||||||
|
// gamma distribution
|
||||||
|
auto rng = block.randomGenerator();
|
||||||
|
auto shape = INPUT_VARIABLE(0);
|
||||||
|
auto alpha = INPUT_VARIABLE(1);
|
||||||
|
NDArray* beta = nullptr;
|
||||||
|
|
||||||
|
if (block.width() > 2) {
|
||||||
|
beta = INPUT_VARIABLE(2);
|
||||||
|
REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(*alpha, *beta), 0, "random_gamma: alpha and beta shapes should be broadcastable.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
auto seed = 0;
|
||||||
|
|
||||||
|
if (block.getIArguments()->size()) {
|
||||||
|
seed = INT_ARG(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
rng.setSeed(seed);
|
||||||
|
|
||||||
|
helpers::fillRandomGamma(block.launchContext(), rng, alpha, beta, output);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(random_gamma) {
|
||||||
|
auto in = INPUT_VARIABLE(0);
|
||||||
|
auto shape = in->template asVectorT<Nd4jLong>();
|
||||||
|
auto alphaShape = inputShape->at(1);
|
||||||
|
auto additionalShape = alphaShape;
|
||||||
|
if (inputShape->size() > 2) {
|
||||||
|
auto rest = inputShape->at(2); additionalShape = nullptr;
|
||||||
|
REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(alphaShape, rest), 0, "random_gamma: alpha and beta shapes should be broadcastable.");
|
||||||
|
ShapeUtils::evalBroadcastShapeInfo(alphaShape, rest, true, additionalShape, block.workspace());
|
||||||
|
}
|
||||||
|
auto lastDim = shape::sizeAt(alphaShape, 0);
|
||||||
|
auto dtype = ArrayOptions::dataType(alphaShape);
|
||||||
|
for (auto i = 0; i < shape::rank(additionalShape); i++)
|
||||||
|
shape.push_back(shape::sizeAt(additionalShape, i));
|
||||||
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape);
|
||||||
|
return SHAPELIST(newShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(random_gamma) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_INTS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(2, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,67 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_random_poisson)
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/random.h>
|
||||||
|
#include <ops/declarable/helpers/random.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(random_poisson, 2, 1, false, 0, 0) {
|
||||||
|
// gamma distribution
|
||||||
|
auto rng = block.randomGenerator();
|
||||||
|
auto shape = INPUT_VARIABLE(0);
|
||||||
|
auto lambda = INPUT_VARIABLE(1);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
auto seed = 0;
|
||||||
|
if (block.getIArguments()->size()) {
|
||||||
|
seed = INT_ARG(0);
|
||||||
|
}
|
||||||
|
rng.setSeed(seed);
|
||||||
|
helpers::fillRandomPoisson(block.launchContext(), rng, lambda, output);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(random_poisson) {
|
||||||
|
auto in = INPUT_VARIABLE(0);
|
||||||
|
auto shape = in->template asVectorT<Nd4jLong>();
|
||||||
|
auto lambdaShape = inputShape->at(1);
|
||||||
|
auto dtype = ArrayOptions::dataType(lambdaShape);
|
||||||
|
for (auto d = 0; d < shape::rank(lambdaShape); ++d ) {
|
||||||
|
shape.emplace_back(shape::sizeAt(lambdaShape, d));
|
||||||
|
}
|
||||||
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape);
|
||||||
|
return SHAPELIST(newShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(random_poisson) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_INTS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -185,42 +185,42 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||||
|
|
||||||
// Wx validation
|
// Wx validation
|
||||||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||||
// Wr validation
|
// Wr validation
|
||||||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||||
// biases validation
|
// biases validation
|
||||||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||||
// initial output validation
|
// initial output validation
|
||||||
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
|
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||||
// initial cell validation
|
// initial cell validation
|
||||||
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
|
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||||
// peephole weights validation
|
// peephole weights validation
|
||||||
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||||
}
|
}
|
||||||
else { // bidirectional
|
else { // bidirectional
|
||||||
// Wx validation
|
// Wx validation
|
||||||
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
|
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||||
// Wr validation
|
// Wr validation
|
||||||
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
|
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||||
// biases validation
|
// biases validation
|
||||||
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
|
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||||
// initial output validation
|
// initial output validation
|
||||||
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
|
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||||
// initial cell validation
|
// initial cell validation
|
||||||
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
|
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||||
// peephole weights validation
|
// peephole weights validation
|
||||||
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
|
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
|
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
|
||||||
|
|
|
@ -28,10 +28,14 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
|
||||||
|
|
||||||
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
|
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
|
||||||
|
|
||||||
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
||||||
|
|
||||||
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
||||||
|
|
||||||
// first of all take into account possible presence of empty arrays
|
// first of all take into account possible presence of empty arrays
|
||||||
// also if scalar is present -> copy its value to vector with length=1
|
// also if scalar is present -> copy its value to vector with length=1
|
||||||
std::vector<NDArray*> nonEmptyArrs;
|
std::vector<NDArray*> nonEmptyArrs;
|
||||||
|
@ -40,7 +44,8 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
||||||
bool allOfSameType = true;
|
bool allOfSameType = true;
|
||||||
auto theFirstRank = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
|
auto theFirstRank = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
|
||||||
auto theFirstDatatype = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();
|
auto theFirstDatatype = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();
|
||||||
for(int i = 0; i < block.width(); ++i) {
|
|
||||||
|
for(int i = 0; i < numOfInArrs; ++i) {
|
||||||
auto input = INPUT_VARIABLE(i);
|
auto input = INPUT_VARIABLE(i);
|
||||||
auto currentRank = input->rankOf();
|
auto currentRank = input->rankOf();
|
||||||
|
|
||||||
|
@ -50,6 +55,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
||||||
if(!input->isEmpty()) {
|
if(!input->isEmpty()) {
|
||||||
|
|
||||||
allOfSameType &= (theFirstDatatype == input->dataType());
|
allOfSameType &= (theFirstDatatype == input->dataType());
|
||||||
|
|
||||||
if(input->rankOf() == 0) {
|
if(input->rankOf() == 0) {
|
||||||
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
|
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
|
||||||
vec->assign(input);
|
vec->assign(input);
|
||||||
|
@ -63,25 +69,28 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int numOfArrs = nonEmptyArrs.size();
|
const int numOfNonEmptyArrs = nonEmptyArrs.size();
|
||||||
|
|
||||||
if(numOfArrs == 0){
|
if(numOfNonEmptyArrs == 0){
|
||||||
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
|
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
|
||||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT op: If all input variables are empty, output must be empty");
|
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT op: If all input variables are empty, output must be empty");
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
|
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
|
||||||
int axis = INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank;
|
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
||||||
|
if(axis < 0){
|
||||||
|
axis += rank;
|
||||||
|
}
|
||||||
|
|
||||||
// ******** input validation ******** //
|
// ******** input validation ******** //
|
||||||
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
||||||
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i)
|
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
||||||
REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i) {
|
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
|
||||||
for(int dim = 0; dim < rank; ++dim)
|
for(int dim = 0; dim < rank; ++dim)
|
||||||
if(dim != axis)
|
if(dim != axis)
|
||||||
REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
||||||
|
@ -90,7 +99,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
if(numOfArrs == 1)
|
if(numOfNonEmptyArrs == 1)
|
||||||
output->assign(nonEmptyArrs[0]);
|
output->assign(nonEmptyArrs[0]);
|
||||||
else
|
else
|
||||||
helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis);
|
helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis);
|
||||||
|
@ -108,20 +117,25 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
||||||
|
|
||||||
DECLARE_TYPES(concat) {
|
DECLARE_TYPES(concat) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY);
|
||||||
->setSameMode(true);
|
// ->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_SHAPE_FN(concat) {
|
DECLARE_SHAPE_FN(concat) {
|
||||||
|
|
||||||
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
|
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
|
||||||
|
|
||||||
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
||||||
|
|
||||||
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
||||||
|
|
||||||
// first of all take into account possible presence of empty arrays
|
// first of all take into account possible presence of empty arrays
|
||||||
// also if scalar is present -> use the shape of vector with length=1 instead
|
// also if scalar is present -> use the shape of vector with length=1 instead
|
||||||
std::vector<Nd4jLong*> arrShapes;
|
std::vector<Nd4jLong*> arrShapes;
|
||||||
std::vector<int> shapesToDelete;
|
std::vector<int> shapesToDelete;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for(int i = 0; i < block.width(); ++i) {
|
for(int i = 0; i < numOfInArrs; ++i) {
|
||||||
|
|
||||||
if(inputShape->at(i)[0] == 0) {
|
if(inputShape->at(i)[0] == 0) {
|
||||||
if (shape::isEmpty(inputShape->at(i)))
|
if (shape::isEmpty(inputShape->at(i)))
|
||||||
|
@ -135,21 +149,22 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
++index;
|
++index;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int numOfArrs = arrShapes.size();
|
const int numOfNonEmptyArrs = arrShapes.size();
|
||||||
|
|
||||||
const int rank = arrShapes[0][0];
|
const int rank = arrShapes[0][0];
|
||||||
|
|
||||||
int axis = INT_ARG(0);
|
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
||||||
if(axis < 0)
|
if(axis < 0){
|
||||||
axis += rank;
|
axis += rank;
|
||||||
|
}
|
||||||
|
|
||||||
// ******** input validation ******** //
|
// ******** input validation ******** //
|
||||||
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i)
|
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
||||||
REQUIRE_TRUE(arrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
REQUIRE_TRUE(arrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i) {
|
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
|
||||||
for(int dim = 0; dim < rank; ++dim)
|
for(int dim = 0; dim < rank; ++dim)
|
||||||
if(dim != axis)
|
if(dim != axis)
|
||||||
REQUIRE_TRUE(arrShapes[i][dim+1] == arrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
REQUIRE_TRUE(arrShapes[i][dim+1] == arrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
||||||
|
@ -161,12 +176,12 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
COPY_SHAPE(arrShapes[0], outShapeInfo);
|
COPY_SHAPE(arrShapes[0], outShapeInfo);
|
||||||
|
|
||||||
// case when we have only one input array
|
// case when we have only one input array
|
||||||
if(numOfArrs == 1) {
|
if(numOfNonEmptyArrs == 1) {
|
||||||
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||||
return SHAPELIST(CONSTANT(outShapeInfo));
|
return SHAPELIST(CONSTANT(outShapeInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i)
|
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
||||||
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||||
|
@ -358,24 +373,22 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
// return SHAPELIST(newShape);
|
// return SHAPELIST(newShape);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
DECLARE_TYPES(concat_bp) {
|
//////////////////////////////////////////////////////////////////////////
|
||||||
getOpDescriptor()
|
CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 0) {
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
|
||||||
}
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 1) {
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
||||||
auto epsilonNext = INPUT_VARIABLE(block.width() - 1);
|
|
||||||
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
||||||
|
|
||||||
|
auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1);
|
||||||
|
|
||||||
auto first = INPUT_VARIABLE(0);
|
auto first = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
int axis = INT_ARG(0);
|
const int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf());
|
||||||
|
|
||||||
if (axis < 0)
|
|
||||||
axis += first->rankOf();
|
|
||||||
|
|
||||||
int startPos = 0;
|
int startPos = 0;
|
||||||
for (int e = 0; e < block.width() - 1; e++) {
|
|
||||||
|
for (int e = 0; e < numOfInArrs - 1; e++) {
|
||||||
auto originalChunk = INPUT_VARIABLE(e);
|
auto originalChunk = INPUT_VARIABLE(e);
|
||||||
auto epsilonChunk = OUTPUT_VARIABLE(e);
|
auto epsilonChunk = OUTPUT_VARIABLE(e);
|
||||||
std::vector<Nd4jLong> indices(2 * epsilonNext->rankOf());
|
std::vector<Nd4jLong> indices(2 * epsilonNext->rankOf());
|
||||||
|
@ -398,15 +411,28 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(concat_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(concat_bp) {
|
DECLARE_SHAPE_FN(concat_bp) {
|
||||||
|
|
||||||
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
||||||
|
|
||||||
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
||||||
|
|
||||||
auto shapeList = SHAPELIST();
|
auto shapeList = SHAPELIST();
|
||||||
|
|
||||||
for (int e = 0; e < inputShape->size() - 1; e++) {
|
for (int e = 0; e < numOfInArrs - 1; e++) {
|
||||||
auto inShape = inputShape->at(e);
|
auto inShape = inputShape->at(e);
|
||||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
|
||||||
}
|
}
|
||||||
|
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,23 @@ namespace nd4j {
|
||||||
DECLARE_CUSTOM_OP(random_exponential, 1, 1, true, 1, 0);
|
DECLARE_CUSTOM_OP(random_exponential, 1, 1, true, 1, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if NOT_EXCLUDED(OP_random_crop)
|
||||||
DECLARE_CUSTOM_OP(random_crop, 2, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(random_crop, 2, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* random_gamma op.
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_random_gamma)
|
||||||
|
DECLARE_CUSTOM_OP(random_gamma, 2, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* random_poisson op.
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_random_poisson)
|
||||||
|
DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,8 +59,8 @@ namespace nd4j {
|
||||||
DECLARE_CONFIGURABLE_OP(invert_permutation, 1, 1, false, 0, 0);
|
DECLARE_CONFIGURABLE_OP(invert_permutation, 1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DECLARE_CUSTOM_OP(concat, -1, 1, false, 0, 1);
|
DECLARE_CUSTOM_OP(concat, -1, 1, false, 0, 0);
|
||||||
DECLARE_CUSTOM_OP(concat_bp, -1, -1, false, 0, 1);
|
DECLARE_CUSTOM_OP(concat_bp, -1, -1, false, 0, 0);
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_mergemax)
|
#if NOT_EXCLUDED(OP_mergemax)
|
||||||
DECLARE_OP(mergemax, -1, 1, false);
|
DECLARE_OP(mergemax, -1, 1, false);
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/random.h>
|
||||||
|
//#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
//#include <graph/Context.h>
|
||||||
|
#include <ShapeUtils.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
|
||||||
|
|
||||||
|
Nd4jLong* broadcasted = nullptr;
|
||||||
|
if (beta != nullptr)
|
||||||
|
ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, context->getWorkspace());
|
||||||
|
else
|
||||||
|
broadcasted = alpha->shapeInfo();
|
||||||
|
auto step = shape::length(broadcasted);
|
||||||
|
auto shift = output->lengthOf() / step;
|
||||||
|
|
||||||
|
auto copyAlpha = alpha;
|
||||||
|
auto copyBeta = beta;
|
||||||
|
if (beta != nullptr) {
|
||||||
|
NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context);
|
||||||
|
NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context);
|
||||||
|
|
||||||
|
copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha));
|
||||||
|
copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta));
|
||||||
|
|
||||||
|
}
|
||||||
|
// bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c';
|
||||||
|
bool directOutput = output->ews() == 1 && output->ordering() == 'c';
|
||||||
|
T* outputBuf = output->dataBuffer()->primaryAsT<T>();
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for (auto k = 0; k < shift; k++) {
|
||||||
|
auto pos = k * step;
|
||||||
|
auto u = rng.relativeT<T>(k, 0., 1.);
|
||||||
|
for (auto e = 0; e < step; e++)
|
||||||
|
if (directOutput) {
|
||||||
|
outputBuf[pos + e] = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
|
||||||
|
beta != nullptr ? copyBeta->t<T>(e) * u : u);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output->t<T>(pos + e) = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
|
||||||
|
beta != nullptr ? copyBeta->t<T>(e) * u : u);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (beta != nullptr) {
|
||||||
|
delete copyAlpha;
|
||||||
|
delete copyBeta;
|
||||||
|
//delete broadcasted;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context,
|
||||||
|
graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* algorithm Poisson generator based upon the inversion by sequential search:[48]:505
|
||||||
|
init:
|
||||||
|
Let x ← 0, p ← e−λ, s ← p.
|
||||||
|
Generate uniform random number u in [0,1].
|
||||||
|
while u > s do:
|
||||||
|
x ← x + 1.
|
||||||
|
p ← p * λ / x.
|
||||||
|
s ← s + p.
|
||||||
|
return x.
|
||||||
|
* */
|
||||||
|
template <typename T>
|
||||||
|
void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
|
||||||
|
auto shift = output->lengthOf() / lambda->lengthOf();
|
||||||
|
auto step = lambda->lengthOf();
|
||||||
|
T* lambdaBuf = lambda->dataBuffer()->primaryAsT<T>();
|
||||||
|
T* outputBuf = output->dataBuffer()->primaryAsT<T>();
|
||||||
|
bool directLa = lambda->ews() == 1 && lambda->ordering() == 'c';
|
||||||
|
bool directOut = output->ews() == 1 && output->ordering() == 'c';
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for (auto k = 0; k < shift; k++) {
|
||||||
|
auto pos = k * step;
|
||||||
|
auto u = rng.relativeT<T>(k, 0., 1.);
|
||||||
|
for (auto e = 0; e < step; e++) {
|
||||||
|
auto p = math::nd4j_exp<T, T>(-lambda->t<T>(e));
|
||||||
|
auto s = p;
|
||||||
|
auto x = T(0.f);
|
||||||
|
while (u > s) {
|
||||||
|
x += 1.f;
|
||||||
|
p *= directLa?lambdaBuf[e]/x:lambda->t<T>(e) / x;
|
||||||
|
s += p;
|
||||||
|
}
|
||||||
|
if (directOut)
|
||||||
|
outputBuf[pos + e] = x;
|
||||||
|
else
|
||||||
|
output->t<T>(pos + e) = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context,
|
||||||
|
graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -466,7 +466,7 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
|
||||||
useBisection = true;
|
useBisection = true;
|
||||||
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
|
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
|
||||||
useBisection = true;
|
useBisection = true;
|
||||||
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev))
|
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev) && math::nd4j_abs<T>(fCur - fPrev) > (T)16. * DataTypeUtils::eps<T>())
|
||||||
useBisection = true;
|
useBisection = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -900,12 +900,8 @@ void SVD<T>::evalData(const NDArray& matrix) {
|
||||||
scale = 1.;
|
scale = 1.;
|
||||||
|
|
||||||
NDArray copy;
|
NDArray copy;
|
||||||
if(_transp) {
|
if(_transp)
|
||||||
copy = NDArrayFactory::create<T>(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(0)}, matrix.getContext());
|
copy = matrix.transpose();
|
||||||
for(int i = 0; i < copy.sizeAt(0); ++i)
|
|
||||||
for(int j = 0; j < copy.sizeAt(1); ++j)
|
|
||||||
copy.p<T>(i, j, matrix.e<T>(j,i) / scale);
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
copy = matrix / scale;
|
copy = matrix / scale;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,186 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/random.h>
|
||||||
|
//#include <NativeOps.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
#include <ShapeUtils.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* fillGammaKernel - fill up output with gamma distributed values
|
||||||
|
*
|
||||||
|
* uList - uniformly distributed values set
|
||||||
|
* uLength - length of uList
|
||||||
|
* alpha - alpha param
|
||||||
|
* beta - beta param
|
||||||
|
* output - distributed output.
|
||||||
|
* */
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void fillGammaKernel(T* uList, Nd4jLong uLength, T* alpha, Nd4jLong* alphaShape,
|
||||||
|
T* beta, Nd4jLong* betaShape, T* output, Nd4jLong* outputShape) {
|
||||||
|
// fill up
|
||||||
|
__shared__ Nd4jLong aLength;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
aLength = shape::length(alphaShape);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) {
|
||||||
|
auto pos = k * aLength;
|
||||||
|
auto u = uList[k]; // this is a vector
|
||||||
|
for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) {
|
||||||
|
auto aIndex = shape::getIndexOffset(e, alphaShape);
|
||||||
|
auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL;
|
||||||
|
auto betaV = T(beta != nullptr ? beta[bIndex] * u : u);
|
||||||
|
auto zIndex = shape::getIndexOffset(e + pos, outputShape);
|
||||||
|
|
||||||
|
output[zIndex] = math::nd4j_igamma<T, T, T>(alpha[aIndex], betaV);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
|
||||||
|
// To fill up output need to broadcast alpha and beta to the same shape and in
|
||||||
|
Nd4jLong* broadcasted = nullptr;
|
||||||
|
if (beta != nullptr)
|
||||||
|
ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, context->getWorkspace());
|
||||||
|
else
|
||||||
|
broadcasted = alpha->shapeInfo();
|
||||||
|
auto step = shape::length(broadcasted);
|
||||||
|
auto shift = output->lengthOf() / step;
|
||||||
|
|
||||||
|
auto copyAlpha = alpha;
|
||||||
|
auto copyBeta = beta;
|
||||||
|
if (beta != nullptr) {
|
||||||
|
NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context);
|
||||||
|
NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context);
|
||||||
|
|
||||||
|
copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha));
|
||||||
|
copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta));
|
||||||
|
copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray uniform = NDArrayFactory::create<T>('c', {shift}, context);
|
||||||
|
uniform.syncToDevice();
|
||||||
|
// fill up uniform with given length
|
||||||
|
RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.);
|
||||||
|
|
||||||
|
fillGammaKernel<T><<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT<T>(), shift,
|
||||||
|
copyAlpha->dataBuffer()->specialAsT<T>(), copyAlpha->specialShapeInfo(),
|
||||||
|
beta?copyBeta->dataBuffer()->specialAsT<T>():(T*)nullptr,
|
||||||
|
beta?copyBeta->specialShapeInfo():(Nd4jLong*)nullptr,
|
||||||
|
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
|
||||||
|
|
||||||
|
if (beta != nullptr) {
|
||||||
|
delete copyAlpha;
|
||||||
|
delete copyBeta;
|
||||||
|
//delete broadcasted;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
|
||||||
|
if (beta)
|
||||||
|
NDArray::prepareSpecialUse({output}, {alpha, beta});
|
||||||
|
else
|
||||||
|
NDArray::prepareSpecialUse({output}, {alpha});
|
||||||
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE);
|
||||||
|
if (beta)
|
||||||
|
NDArray::registerSpecialUse({output}, {alpha, beta});
|
||||||
|
else
|
||||||
|
NDArray::prepareSpecialUse({output}, {alpha});
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE);
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* algorithm Poisson generator based upon the inversion by sequential search
|
||||||
|
*
|
||||||
|
init:
|
||||||
|
Let x ← 0, p ← e−λ, s ← p.
|
||||||
|
using uniformly random sequence U (u in U) distributed at [0, 1].
|
||||||
|
while u > s do:
|
||||||
|
x ← x + 1.
|
||||||
|
p ← p * λ / x.
|
||||||
|
s ← s + p.
|
||||||
|
return x.
|
||||||
|
* */
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void fillPoissonKernel(T* uList, Nd4jLong uLength, T* lambda, Nd4jLong* lambdaShape, T* output,
|
||||||
|
Nd4jLong* outputShape) {
|
||||||
|
|
||||||
|
__shared__ Nd4jLong step;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
step = shape::length(lambdaShape);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) {
|
||||||
|
auto pos = k * step;
|
||||||
|
auto u = uList[k];
|
||||||
|
for (auto e = threadIdx.x; e < step; e += blockDim.x) {
|
||||||
|
auto p = math::nd4j_exp<T,T>(-lambda[e]);
|
||||||
|
auto s = p;
|
||||||
|
auto x = T(0.f);
|
||||||
|
auto lIndex = shape::getIndexOffset(e, lambdaShape);
|
||||||
|
auto zIndex = shape::getIndexOffset(e + pos, outputShape);
|
||||||
|
while (u > s) {
|
||||||
|
x += T(1.);
|
||||||
|
p *= lambda[lIndex] / x;
|
||||||
|
s += p;
|
||||||
|
}
|
||||||
|
output[zIndex] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
|
||||||
|
auto shift = output->lengthOf() / lambda->lengthOf();
|
||||||
|
NDArray uniform('c', {shift}, output->dataType());
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
// fill up uniform with given length
|
||||||
|
RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.);
|
||||||
|
fillPoissonKernel<T><<<128, 256, 128, *stream>>>(uniform.dataBuffer()->specialAsT<T>(), uniform.lengthOf(),
|
||||||
|
lambda->dataBuffer()->specialAsT<T>(), lambda->specialShapeInfo(),
|
||||||
|
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {lambda});
|
||||||
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE);
|
||||||
|
NDArray::registerSpecialUse({output}, {lambda});
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// Declaration of distribution helpers
|
||||||
|
//
|
||||||
|
#ifndef __RANDOM_HELPERS__
|
||||||
|
#define __RANDOM_HELPERS__
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <helpers/helper_random.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
|
||||||
|
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -268,8 +268,8 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
|
|
||||||
// dLdO
|
// dLdO
|
||||||
auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer());
|
auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer());
|
||||||
const bool dLdOReorder = op_bp_prim_desc.diff_src_desc() != dLdO_user_mem.get_desc();
|
const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc();
|
||||||
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdO_user_mem;
|
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
|
||||||
if (dLdOReorder)
|
if (dLdOReorder)
|
||||||
mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
|
mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
|
||||||
args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem;
|
args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem;
|
||||||
|
@ -284,8 +284,8 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
|
|
||||||
// dLdI
|
// dLdI
|
||||||
auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer());
|
auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer());
|
||||||
const bool dLdIReorder = op_bp_prim_desc.diff_dst_desc() != dLdI_user_mem.get_desc();
|
const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc();
|
||||||
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdI_user_mem;
|
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
|
||||||
args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem;
|
args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem;
|
||||||
|
|
||||||
// gamma and beta (and their gradients) if they are present
|
// gamma and beta (and their gradients) if they are present
|
||||||
|
|
|
@ -32,6 +32,8 @@ using namespace mkldnn;
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||||
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
||||||
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode,
|
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode,
|
||||||
|
@ -50,7 +52,7 @@ namespace nd4j {
|
||||||
empty);
|
empty);
|
||||||
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||||
empty);
|
empty);
|
||||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||||
bias, output,
|
bias, output,
|
||||||
|
@ -58,17 +60,18 @@ namespace nd4j {
|
||||||
&conv_bias_md, &conv_dst_md,
|
&conv_bias_md, &conv_dst_md,
|
||||||
&user_src_md, nullptr, &user_weights_md, nullptr,
|
&user_src_md, nullptr, &user_weights_md, nullptr,
|
||||||
&user_bias_md, &user_dst_md,
|
&user_bias_md, &user_dst_md,
|
||||||
conv_strides, conv_padding, conv_padding_r);
|
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||||
|
|
||||||
auto conv_desc = bias != nullptr
|
auto conv_desc = bias != nullptr
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
? convolution_forward::desc(prop_kind::forward,
|
||||||
algorithm::convolution_auto, conv_src_md,
|
algorithm::convolution_auto, conv_src_md,
|
||||||
conv_weights_md, conv_bias_md,
|
conv_weights_md, conv_bias_md,
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
conv_padding_r)
|
conv_padding_r)
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
algorithm::convolution_auto, conv_src_md,
|
algorithm::convolution_auto, conv_src_md,
|
||||||
conv_weights_md,
|
conv_weights_md,
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
conv_padding_r);
|
conv_padding_r);
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
mkldnn::stream stream(engine);
|
mkldnn::stream stream(engine);
|
||||||
|
@ -110,6 +113,7 @@ namespace nd4j {
|
||||||
stream.wait();
|
stream.wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv2d) {
|
PLATFORM_IMPL(conv2d) {
|
||||||
auto input = INPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
@ -148,6 +152,217 @@ namespace nd4j {
|
||||||
return block.isUseMKLDNN() && input->dataType() == nd4j::DataType::FLOAT32 &&
|
return block.isUseMKLDNN() && input->dataType() == nd4j::DataType::FLOAT32 &&
|
||||||
weights->dataType() == nd4j::DataType::FLOAT32;
|
weights->dataType() == nd4j::DataType::FLOAT32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(conv2d_bp) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto weights = INPUT_VARIABLE(
|
||||||
|
1); // [kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||||
|
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(
|
||||||
|
1); // [kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||||
|
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
|
||||||
|
input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 4, 0,
|
||||||
|
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
|
||||||
|
weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
|
||||||
|
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
|
||||||
|
gradO->rankOf());
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||||
|
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
if (isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||||
|
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||||
|
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
||||||
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
||||||
|
gradB, gradO,
|
||||||
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||||
|
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
||||||
|
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||||
|
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||||
|
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||||
|
auto conv_desc = gradB != nullptr
|
||||||
|
? convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
|
conv_padding_r)
|
||||||
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
|
conv_padding_r);
|
||||||
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
||||||
|
LaunchContext::defaultContext()->engine()));
|
||||||
|
if (gradW != nullptr) {
|
||||||
|
auto convW_desc = gradB != nullptr
|
||||||
|
? convolution_backward_weights::desc(
|
||||||
|
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||||
|
: convolution_backward_weights::desc(
|
||||||
|
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
||||||
|
conv_prim_desc);
|
||||||
|
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
|
||||||
|
const_cast<NDArray *>(input)->buffer());
|
||||||
|
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||||
|
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||||
|
const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convW_src_memory = userW_src_memory;
|
||||||
|
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||||
|
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
|
||||||
|
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
||||||
|
convW_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_weights_memory = userW_weights_memory;
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_dst_memory = userW_dst_memory;
|
||||||
|
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||||
|
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
||||||
|
convW_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gradB != nullptr) {
|
||||||
|
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||||
|
gradB->buffer());
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||||
|
} else {
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
||||||
|
userW_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gradI != nullptr) {
|
||||||
|
auto convI_desc =
|
||||||
|
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
|
||||||
|
conv_weights_md, conv_dst_md, conv_strides, conv_dilation,
|
||||||
|
conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
||||||
|
conv_prim_desc);
|
||||||
|
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
|
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||||
|
const_cast<NDArray *>(weights)->buffer());
|
||||||
|
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||||
|
const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convI_src_memory = userI_src_memory;
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_weights_memory = userI_weights_memory;
|
||||||
|
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||||
|
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
|
||||||
|
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
||||||
|
convI_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_dst_memory = userI_dst_memory;
|
||||||
|
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||||
|
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
||||||
|
convI_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
|
||||||
|
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
||||||
|
userI_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
};
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(conv2d_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto weights = INPUT_VARIABLE(
|
||||||
|
1); // [kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||||
|
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(
|
||||||
|
1); // [kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() &&
|
||||||
|
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,243 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author saudet
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
|
||||||
#include <ops/declarable/OpRegistrator.h>
|
|
||||||
#include <platform_boilerplate.h>
|
|
||||||
|
|
||||||
#include <helpers/MKLDNNStream.h>
|
|
||||||
#include "mkldnnUtils.h"
|
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
|
||||||
|
|
||||||
using namespace mkldnn;
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace platforms {
|
|
||||||
PLATFORM_IMPL(conv2d_bp) {
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
|
||||||
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
|
||||||
int sH = INT_ARG(2); // strides height
|
|
||||||
int sW = INT_ARG(3); // strides width
|
|
||||||
int pH = INT_ARG(4); // paddings height
|
|
||||||
int pW = INT_ARG(5); // paddings width
|
|
||||||
int dH = INT_ARG(6); // dilations height
|
|
||||||
int dW = INT_ARG(7); // dilations width
|
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
|
||||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
|
||||||
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
|
|
||||||
input->rankOf());
|
|
||||||
REQUIRE_TRUE(weights->rankOf() == 4, 0,
|
|
||||||
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
|
|
||||||
weights->rankOf());
|
|
||||||
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
|
|
||||||
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
|
|
||||||
gradO->rankOf());
|
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
|
||||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
|
||||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
|
||||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
|
||||||
gradB, gradO,
|
|
||||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
|
||||||
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
|
||||||
&user_src_md, &user_diff_src_md, &user_weights_md,
|
|
||||||
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
|
||||||
conv_strides, conv_padding, conv_padding_r);
|
|
||||||
auto conv_desc = gradB != nullptr
|
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
|
||||||
conv_padding_r)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
|
||||||
conv_padding_r);
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
|
||||||
LaunchContext::defaultContext()->engine()));
|
|
||||||
if (gradW != nullptr) {
|
|
||||||
auto convW_desc = gradB != nullptr
|
|
||||||
? convolution_backward_weights::desc(
|
|
||||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r)
|
|
||||||
: convolution_backward_weights::desc(
|
|
||||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
mkldnn::stream stream(engine);
|
|
||||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
|
||||||
conv_prim_desc);
|
|
||||||
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
|
|
||||||
const_cast<NDArray *>(input)->buffer());
|
|
||||||
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
|
|
||||||
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
|
|
||||||
const_cast<NDArray *>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convW_src_memory = userW_src_memory;
|
|
||||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
|
||||||
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
|
|
||||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
|
||||||
convW_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_weights_memory = userW_weights_memory;
|
|
||||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
|
||||||
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_dst_memory = userW_dst_memory;
|
|
||||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
|
||||||
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
|
||||||
convW_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradB != nullptr) {
|
|
||||||
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
|
|
||||||
gradB->buffer());
|
|
||||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
|
||||||
{{MKLDNN_ARG_SRC, convW_src_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
|
|
||||||
} else {
|
|
||||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
|
||||||
{{MKLDNN_ARG_SRC, convW_src_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
|
||||||
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
|
||||||
userW_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradI != nullptr) {
|
|
||||||
auto convI_desc =
|
|
||||||
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
|
|
||||||
conv_weights_md, conv_dst_md, conv_strides,
|
|
||||||
conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
mkldnn::stream stream(engine);
|
|
||||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
|
||||||
conv_prim_desc);
|
|
||||||
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
|
||||||
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
|
|
||||||
const_cast<NDArray *>(weights)->buffer());
|
|
||||||
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
|
|
||||||
const_cast<NDArray *>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convI_src_memory = userI_src_memory;
|
|
||||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
|
||||||
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_weights_memory = userI_weights_memory;
|
|
||||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
|
||||||
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
|
|
||||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
|
||||||
convI_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_dst_memory = userI_dst_memory;
|
|
||||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
|
||||||
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
|
||||||
convI_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
|
||||||
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
|
|
||||||
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
|
|
||||||
|
|
||||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
|
||||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
|
||||||
userI_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
};
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(conv2d_bp) {
|
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
if (::optimalLevel() < 2)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
|
||||||
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() &&
|
|
||||||
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -32,6 +32,8 @@ using namespace mkldnn;
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv3dnew) {
|
PLATFORM_IMPL(conv3dnew) {
|
||||||
auto input = INPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
@ -86,7 +88,7 @@ namespace nd4j {
|
||||||
empty);
|
empty);
|
||||||
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||||
empty);
|
empty);
|
||||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
||||||
isNCDHW,
|
isNCDHW,
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
|
||||||
|
@ -95,17 +97,17 @@ namespace nd4j {
|
||||||
&conv_bias_md, &conv_dst_md,
|
&conv_bias_md, &conv_dst_md,
|
||||||
&user_src_md, nullptr, &user_weights_md, nullptr,
|
&user_src_md, nullptr, &user_weights_md, nullptr,
|
||||||
&user_bias_md, &user_dst_md,
|
&user_bias_md, &user_dst_md,
|
||||||
conv_strides, conv_padding, conv_padding_r);
|
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||||
auto conv_desc = bias != nullptr
|
auto conv_desc = bias != nullptr
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
? convolution_forward::desc(prop_kind::forward,
|
||||||
algorithm::convolution_auto, conv_src_md,
|
algorithm::convolution_auto, conv_src_md,
|
||||||
conv_weights_md, conv_bias_md,
|
conv_weights_md, conv_bias_md,
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
conv_padding_r)
|
conv_padding_r)
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
algorithm::convolution_auto, conv_src_md,
|
algorithm::convolution_auto, conv_src_md,
|
||||||
conv_weights_md,
|
conv_weights_md,
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
conv_padding_r);
|
conv_padding_r);
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
mkldnn::stream stream(engine);
|
mkldnn::stream stream(engine);
|
||||||
|
@ -162,6 +164,238 @@ namespace nd4j {
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(conv3dnew_bp) {
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(
|
||||||
|
1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||||
|
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(
|
||||||
|
1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
|
||||||
|
input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 5, 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
|
||||||
|
weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
|
||||||
|
gradO->rankOf());
|
||||||
|
|
||||||
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
int isNDHWC =
|
||||||
|
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
||||||
|
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||||
|
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
|
||||||
|
dW, iD, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
||||||
|
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
||||||
|
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||||
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
|
||||||
|
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
|
||||||
|
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
|
||||||
|
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
|
||||||
|
oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
|
||||||
|
mkldnn_memory_desc_t empty;
|
||||||
|
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||||
|
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
||||||
|
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||||
|
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
|
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
|
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
||||||
|
isNDHWC,
|
||||||
|
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
|
||||||
|
gradW, gradB, gradO,
|
||||||
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||||
|
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
||||||
|
&user_src_md, &user_diff_src_md, &user_weights_md,
|
||||||
|
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
||||||
|
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
||||||
|
auto conv_desc = gradB != nullptr
|
||||||
|
? convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
|
conv_padding_r)
|
||||||
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
|
algorithm::convolution_auto, conv_src_md,
|
||||||
|
conv_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
|
conv_padding_r);
|
||||||
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
||||||
|
LaunchContext::defaultContext()->engine()));
|
||||||
|
if (gradW != nullptr) {
|
||||||
|
auto convW_desc = gradB != nullptr
|
||||||
|
? convolution_backward_weights::desc(
|
||||||
|
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
|
||||||
|
: convolution_backward_weights::desc(
|
||||||
|
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
||||||
|
conv_prim_desc);
|
||||||
|
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
|
||||||
|
const_cast<NDArray *>(input)->buffer());
|
||||||
|
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
|
||||||
|
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||||
|
const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convW_src_memory = userW_src_memory;
|
||||||
|
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
||||||
|
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
|
||||||
|
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
||||||
|
convW_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_weights_memory = userW_weights_memory;
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convW_dst_memory = userW_dst_memory;
|
||||||
|
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
||||||
|
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
||||||
|
convW_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gradB != nullptr) {
|
||||||
|
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
|
||||||
|
gradB->buffer());
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
|
||||||
|
} else {
|
||||||
|
convolution_backward_weights(convW_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_SRC, convW_src_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
||||||
|
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
||||||
|
userW_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
if (gradI != nullptr) {
|
||||||
|
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
|
||||||
|
conv_diff_src_md, conv_weights_md,
|
||||||
|
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
||||||
|
conv_padding_r);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
||||||
|
conv_prim_desc);
|
||||||
|
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
||||||
|
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
|
||||||
|
const_cast<NDArray *>(weights)->buffer());
|
||||||
|
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
|
||||||
|
const_cast<NDArray *>(gradO)->buffer());
|
||||||
|
|
||||||
|
auto convI_src_memory = userI_src_memory;
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_weights_memory = userI_weights_memory;
|
||||||
|
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
||||||
|
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
|
||||||
|
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
||||||
|
convI_weights_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convI_dst_memory = userI_dst_memory;
|
||||||
|
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
||||||
|
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
|
||||||
|
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
||||||
|
convI_dst_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
convolution_backward_data(convI_prim_desc).execute(stream,
|
||||||
|
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
|
||||||
|
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
|
||||||
|
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
|
||||||
|
|
||||||
|
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
||||||
|
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
||||||
|
userI_src_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(conv3dnew_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
if (::optimalLevel() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(
|
||||||
|
1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
||||||
|
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(
|
||||||
|
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
auto gradW = OUTPUT_VARIABLE(
|
||||||
|
1); // [kD, kH, kW, iC, oC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() &&
|
||||||
|
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,263 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author saudet
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
|
||||||
#include <ops/declarable/OpRegistrator.h>
|
|
||||||
#include <platform_boilerplate.h>
|
|
||||||
|
|
||||||
#include <helpers/MKLDNNStream.h>
|
|
||||||
#include "mkldnnUtils.h"
|
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
|
||||||
|
|
||||||
using namespace mkldnn;
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace platforms {
|
|
||||||
PLATFORM_IMPL(conv3dnew_bp) {
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kD, kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
|
||||||
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kD, kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
|
|
||||||
input->rankOf());
|
|
||||||
REQUIRE_TRUE(weights->rankOf() == 5, 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
|
|
||||||
weights->rankOf());
|
|
||||||
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
|
|
||||||
gradO->rankOf());
|
|
||||||
|
|
||||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
|
||||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
|
||||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
|
||||||
int sD = INT_ARG(3); // strides depth
|
|
||||||
int sH = INT_ARG(4); // strides height
|
|
||||||
int sW = INT_ARG(5); // strides width
|
|
||||||
int pD = INT_ARG(6); // paddings depth
|
|
||||||
int pH = INT_ARG(7); // paddings height
|
|
||||||
int pW = INT_ARG(8); // paddings width
|
|
||||||
int dD = INT_ARG(9); // dilations depth
|
|
||||||
int dH = INT_ARG(10); // dilations height
|
|
||||||
int dW = INT_ARG(11); // dilations width
|
|
||||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
|
||||||
int isNDHWC =
|
|
||||||
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
|
||||||
|
|
||||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
|
||||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
|
|
||||||
dW, iD, iH, iW, isSameMode);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
|
||||||
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
|
|
||||||
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
|
||||||
if (bias)
|
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
|
|
||||||
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
|
|
||||||
oC, bias->rankOf(), bias->lengthOf());
|
|
||||||
|
|
||||||
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
|
||||||
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
|
||||||
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
|
||||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
|
||||||
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
|
|
||||||
isNDHWC,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
|
|
||||||
gradW, gradB, gradO,
|
|
||||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
|
||||||
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
|
|
||||||
&user_src_md, &user_diff_src_md, &user_weights_md,
|
|
||||||
&user_diff_weights_md, &user_bias_md, &user_dst_md,
|
|
||||||
conv_strides, conv_padding, conv_padding_r);
|
|
||||||
auto conv_desc = gradB != nullptr
|
|
||||||
? convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
|
||||||
conv_padding_r)
|
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
|
||||||
algorithm::convolution_auto, conv_src_md,
|
|
||||||
conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
|
||||||
conv_padding_r);
|
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
|
|
||||||
LaunchContext::defaultContext()->engine()));
|
|
||||||
if (gradW != nullptr) {
|
|
||||||
auto convW_desc = gradB != nullptr
|
|
||||||
? convolution_backward_weights::desc(
|
|
||||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r)
|
|
||||||
: convolution_backward_weights::desc(
|
|
||||||
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding, conv_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
mkldnn::stream stream(engine);
|
|
||||||
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
|
|
||||||
conv_prim_desc);
|
|
||||||
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
|
|
||||||
const_cast<NDArray *>(input)->buffer());
|
|
||||||
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
|
|
||||||
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
|
|
||||||
const_cast<NDArray *>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convW_src_memory = userW_src_memory;
|
|
||||||
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
|
|
||||||
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
|
|
||||||
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
|
|
||||||
convW_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_weights_memory = userW_weights_memory;
|
|
||||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
|
||||||
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convW_dst_memory = userW_dst_memory;
|
|
||||||
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
|
|
||||||
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
|
|
||||||
convW_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradB != nullptr) {
|
|
||||||
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
|
|
||||||
gradB->buffer());
|
|
||||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
|
||||||
{{MKLDNN_ARG_SRC, convW_src_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
|
|
||||||
} else {
|
|
||||||
convolution_backward_weights(convW_prim_desc).execute(stream,
|
|
||||||
{{MKLDNN_ARG_SRC, convW_src_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
|
|
||||||
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
|
|
||||||
userW_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
}
|
|
||||||
if (gradI != nullptr) {
|
|
||||||
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
|
|
||||||
conv_diff_src_md, conv_weights_md,
|
|
||||||
conv_dst_md, conv_strides, conv_padding,
|
|
||||||
conv_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
mkldnn::stream stream(engine);
|
|
||||||
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
|
|
||||||
conv_prim_desc);
|
|
||||||
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
|
|
||||||
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
|
|
||||||
const_cast<NDArray *>(weights)->buffer());
|
|
||||||
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
|
|
||||||
const_cast<NDArray *>(gradO)->buffer());
|
|
||||||
|
|
||||||
auto convI_src_memory = userI_src_memory;
|
|
||||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
|
||||||
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_weights_memory = userI_weights_memory;
|
|
||||||
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
|
|
||||||
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
|
|
||||||
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
|
|
||||||
convI_weights_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto convI_dst_memory = userI_dst_memory;
|
|
||||||
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
|
|
||||||
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
|
|
||||||
convI_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
convolution_backward_data(convI_prim_desc).execute(stream,
|
|
||||||
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
|
|
||||||
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
|
|
||||||
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
|
|
||||||
|
|
||||||
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
|
|
||||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
|
|
||||||
userI_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(conv3dnew_bp) {
|
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
if (::optimalLevel() < 2)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kD, kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
|
||||||
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kD, kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() &&
|
|
||||||
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,535 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||||
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
|
const int isSameMode) {
|
||||||
|
|
||||||
|
// input [bS, iH, iW, iC] nchw, mkl doesn't support format nhwc
|
||||||
|
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
||||||
|
// bias [oC], may be nullptr
|
||||||
|
|
||||||
|
// output [bS, oH, oW, oC] nchw, mkl doesn't support format nhwc
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
|
mkldnn::memory::dims strides = { sH, sW };
|
||||||
|
mkldnn::memory::dims dilation = { dH - 1, dW - 1};
|
||||||
|
mkldnn::memory::dims padding = { pH, pW };
|
||||||
|
mkldnn::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
|
|
||||||
|
// input type
|
||||||
|
mkldnn::memory::data_type xType;
|
||||||
|
if(input->dataType() == DataType::FLOAT32)
|
||||||
|
xType = mkldnn::memory::data_type::f32;
|
||||||
|
else if(input->dataType() == DataType::HALF)
|
||||||
|
xType = mkldnn::memory::data_type::f16;
|
||||||
|
else if(input->dataType() == DataType::UINT8)
|
||||||
|
xType = mkldnn::memory::data_type::u8;
|
||||||
|
else
|
||||||
|
xType = mkldnn::memory::data_type::s8;
|
||||||
|
|
||||||
|
// weights type
|
||||||
|
mkldnn::memory::data_type wType = xType;
|
||||||
|
if(xType == mkldnn::memory::data_type::u8)
|
||||||
|
wType = mkldnn::memory::data_type::s8;
|
||||||
|
|
||||||
|
// output and bias type (have the same types)
|
||||||
|
mkldnn::memory::data_type zType;
|
||||||
|
if(output->dataType() == DataType::FLOAT32)
|
||||||
|
zType = mkldnn::memory::data_type::f32;
|
||||||
|
else if(output->dataType() == DataType::HALF)
|
||||||
|
zType = mkldnn::memory::data_type::f16;
|
||||||
|
else if(output->dataType() == DataType::UINT8)
|
||||||
|
zType = mkldnn::memory::data_type::u8;
|
||||||
|
else if(output->dataType() == DataType::INT8)
|
||||||
|
zType = mkldnn::memory::data_type::s8;
|
||||||
|
else
|
||||||
|
zType = mkldnn::memory::data_type::s32;
|
||||||
|
|
||||||
|
|
||||||
|
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||||
|
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
|
||||||
|
|
||||||
|
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
|
||||||
|
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
|
||||||
|
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||||
|
|
||||||
|
// weights
|
||||||
|
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||||
|
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||||
|
|
||||||
|
// bias
|
||||||
|
mkldnn::memory::desc b_mkl_md;
|
||||||
|
if(bias != nullptr)
|
||||||
|
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x);
|
||||||
|
|
||||||
|
// output
|
||||||
|
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat);
|
||||||
|
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct,
|
||||||
|
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, mkldnn::memory> args;
|
||||||
|
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// bias
|
||||||
|
if(bias != nullptr) {
|
||||||
|
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer());
|
||||||
|
args[MKLDNN_ARG_BIAS] = b_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// output
|
||||||
|
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer());
|
||||||
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
args[MKLDNN_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||||
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
|
const int isSameMode) {
|
||||||
|
|
||||||
|
// input and gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
|
||||||
|
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
||||||
|
// gradB [oC], may be nullptr
|
||||||
|
// gradO [bS, oH, oW, oC]
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
|
mkldnn::memory::dims strides = { sH, sW };
|
||||||
|
mkldnn::memory::dims dilation = { dH - 1, dW - 1 };
|
||||||
|
mkldnn::memory::dims padding = { pH, pW };
|
||||||
|
mkldnn::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
|
|
||||||
|
// input type
|
||||||
|
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// weights type
|
||||||
|
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradO type
|
||||||
|
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradI type
|
||||||
|
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradW type
|
||||||
|
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradB type
|
||||||
|
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32;
|
||||||
|
|
||||||
|
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||||
|
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
|
||||||
|
|
||||||
|
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
|
||||||
|
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
|
||||||
|
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||||
|
|
||||||
|
// weights
|
||||||
|
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||||
|
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
|
||||||
|
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
|
||||||
|
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
|
||||||
|
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
mkldnn::memory::desc gradB_mkl_md;
|
||||||
|
if(gradB != nullptr)
|
||||||
|
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x);
|
||||||
|
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// forward primitive description
|
||||||
|
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backward data primitive description
|
||||||
|
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// backward weights primitive description
|
||||||
|
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, mkldnn::memory> args;
|
||||||
|
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
if (gradOReorder)
|
||||||
|
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||||
|
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||||
|
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||||
|
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||||
|
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||||
|
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
if(gradB != nullptr) {
|
||||||
|
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||||
|
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run backward data calculations
|
||||||
|
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// run backward weights calculations
|
||||||
|
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder gradI if necessary
|
||||||
|
if (gradIReorder)
|
||||||
|
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||||
|
if (gradWReorder)
|
||||||
|
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(deconv2d) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||||
|
|
||||||
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
if(isSameMode){ // SAME
|
||||||
|
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||||
|
}
|
||||||
|
|
||||||
|
// mkl supports only [oC, iC, kH, kW] format for weights
|
||||||
|
weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||||
|
|
||||||
|
// mkl supports only NCHW
|
||||||
|
if(!isNCHW) {
|
||||||
|
input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode);
|
||||||
|
|
||||||
|
delete weights;
|
||||||
|
|
||||||
|
if(!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(deconv2d) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
// if (::optimalLevel() < 2)
|
||||||
|
// return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto weights = INPUT_VARIABLE(1);
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||||
|
|
||||||
|
auto output = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const DataType xType = input->dataType();
|
||||||
|
const DataType wType = weights->dataType();
|
||||||
|
const DataType zType = output->dataType();
|
||||||
|
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && (
|
||||||
|
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
|
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF ) ||
|
||||||
|
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(deconv2d_bp) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
|
||||||
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of weights array must be equal to 4 , but got %i instead !", weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf());
|
||||||
|
|
||||||
|
|
||||||
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
|
int trueoH, trueoW; // true output height, width
|
||||||
|
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if(bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
if(isSameMode){ // SAME
|
||||||
|
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||||
|
}
|
||||||
|
|
||||||
|
// mkl supports only [oC, iC, kH, kW] for weights
|
||||||
|
weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||||
|
gradW = new NDArray(gradW->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||||
|
|
||||||
|
// mkl supports NCHW format only
|
||||||
|
if(!isNCHW) {
|
||||||
|
input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode);
|
||||||
|
|
||||||
|
delete weights;
|
||||||
|
delete gradW;
|
||||||
|
|
||||||
|
if(!isNCHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(deconv2d_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
// if (::optimalLevel() < 2)
|
||||||
|
// return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
|
||||||
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
|
||||||
|
const DataType xType = input->dataType();
|
||||||
|
const DataType wType = weights->dataType();
|
||||||
|
const DataType gradOType = gradO->dataType();
|
||||||
|
|
||||||
|
const DataType gradIType = gradI->dataType();
|
||||||
|
const DataType gradWType = gradW->dataType();
|
||||||
|
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,244 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
||||||
|
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
|
||||||
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||||
|
|
||||||
|
// gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
|
||||||
|
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
|
||||||
|
// gradO [bS, oH, oW, oC]
|
||||||
|
|
||||||
|
mkldnn::memory::dims strides = { sH, sW };
|
||||||
|
mkldnn::memory::dims dilation = { dH - 1, dW - 1 };
|
||||||
|
mkldnn::memory::dims padding = { pH, pW };
|
||||||
|
mkldnn::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||||
|
|
||||||
|
// weights type
|
||||||
|
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradO type
|
||||||
|
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradI type
|
||||||
|
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
|
||||||
|
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
|
||||||
|
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
|
||||||
|
|
||||||
|
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
|
||||||
|
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, gradOType, mkldnn::memory::format_tag::any);
|
||||||
|
|
||||||
|
// weights
|
||||||
|
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||||
|
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
|
||||||
|
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
|
||||||
|
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
||||||
|
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// forward primitive description
|
||||||
|
mkldnn::convolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backward data primitive description
|
||||||
|
mkldnn::convolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, mkldnn::memory> args;
|
||||||
|
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
if (gradOReorder)
|
||||||
|
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||||
|
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||||
|
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||||
|
|
||||||
|
// run backward data calculations
|
||||||
|
mkldnn::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder gradI if necessary
|
||||||
|
if (gradIReorder)
|
||||||
|
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(deconv2d_tf) {
|
||||||
|
|
||||||
|
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
|
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
||||||
|
int sH = INT_ARG(2); // strides height
|
||||||
|
int sW = INT_ARG(3); // strides width
|
||||||
|
int pH = INT_ARG(4); // paddings height
|
||||||
|
int pW = INT_ARG(5); // paddings width
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
const int rank = gradO->rankOf();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
|
||||||
|
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
|
||||||
|
|
||||||
|
int indIOioC, indIiH, indWoC(3), indOoH;
|
||||||
|
if(!isNCHW) {
|
||||||
|
indIOioC = 3; indIiH = 1; indOoH = 1;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
indIOioC = 1; indIiH = 2; indOoH = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> gradIShapeVector = gradIShape->template asVectorT<Nd4jLong>();
|
||||||
|
|
||||||
|
const int bS = gradIShapeVector[0]; // batch size
|
||||||
|
const int iH = gradIShapeVector[indIiH]; // input height
|
||||||
|
const int iW = gradIShapeVector[indIiH+1]; // input width
|
||||||
|
const int iC = gradIShapeVector[indIOioC]; // input channels
|
||||||
|
const int oC = weights->sizeAt(indWoC); // output channels
|
||||||
|
const int oH = gradO->sizeAt(indOoH); // input height
|
||||||
|
const int oW = gradO->sizeAt(indOoH); // input width
|
||||||
|
|
||||||
|
int trueoH, trueoW; // true output height, width
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
|
||||||
|
if(isSameMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
|
// mkl supports only [oC, iC, kH, kW] for weights
|
||||||
|
weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
|
|
||||||
|
// mkl supports NCHW format only
|
||||||
|
if(!isNCHW) {
|
||||||
|
gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW);
|
||||||
|
|
||||||
|
delete weights;
|
||||||
|
|
||||||
|
if(!isNCHW) {
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(deconv2d_tf) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
// if (::optimalLevel() < 2)
|
||||||
|
// return false;
|
||||||
|
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
|
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
|
||||||
|
|
||||||
|
|
||||||
|
const DataType wType = weights->dataType();
|
||||||
|
const DataType gradOType = gradO->dataType();
|
||||||
|
const DataType gradIType = gradI->dataType();
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && ((wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,549 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <platform_boilerplate.h>
|
||||||
|
|
||||||
|
#include <helpers/MKLDNNStream.h>
|
||||||
|
#include "mkldnnUtils.h"
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||||
|
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
|
||||||
|
const int isSameMode) {
|
||||||
|
|
||||||
|
// input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc
|
||||||
|
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
||||||
|
// bias [oC], may be nullptr
|
||||||
|
|
||||||
|
// output [bS, oD, oH, oW, oC] ncdhw, mkl doesn't support format ndhwc
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
|
mkldnn::memory::dims strides = { sD, sH, sW };
|
||||||
|
mkldnn::memory::dims dilation = { dD - 1, dH - 1, dW - 1};
|
||||||
|
mkldnn::memory::dims padding = { pD, pH, pW };
|
||||||
|
mkldnn::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
|
|
||||||
|
// input type
|
||||||
|
mkldnn::memory::data_type xType;
|
||||||
|
if(input->dataType() == DataType::FLOAT32)
|
||||||
|
xType = mkldnn::memory::data_type::f32;
|
||||||
|
else if(input->dataType() == DataType::HALF)
|
||||||
|
xType = mkldnn::memory::data_type::f16;
|
||||||
|
else if(input->dataType() == DataType::UINT8)
|
||||||
|
xType = mkldnn::memory::data_type::u8;
|
||||||
|
else
|
||||||
|
xType = mkldnn::memory::data_type::s8;
|
||||||
|
|
||||||
|
// weights type
|
||||||
|
mkldnn::memory::data_type wType = xType;
|
||||||
|
if(xType == mkldnn::memory::data_type::u8)
|
||||||
|
wType = mkldnn::memory::data_type::s8;
|
||||||
|
|
||||||
|
// output and bias type (have the same types)
|
||||||
|
mkldnn::memory::data_type zType;
|
||||||
|
if(output->dataType() == DataType::FLOAT32)
|
||||||
|
zType = mkldnn::memory::data_type::f32;
|
||||||
|
else if(output->dataType() == DataType::HALF)
|
||||||
|
zType = mkldnn::memory::data_type::f16;
|
||||||
|
else if(output->dataType() == DataType::UINT8)
|
||||||
|
zType = mkldnn::memory::data_type::u8;
|
||||||
|
else if(output->dataType() == DataType::INT8)
|
||||||
|
zType = mkldnn::memory::data_type::s8;
|
||||||
|
else
|
||||||
|
zType = mkldnn::memory::data_type::s32;
|
||||||
|
|
||||||
|
|
||||||
|
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw;
|
||||||
|
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw;
|
||||||
|
|
||||||
|
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||||
|
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||||
|
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
|
||||||
|
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
|
||||||
|
|
||||||
|
// weights
|
||||||
|
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||||
|
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
|
||||||
|
|
||||||
|
// bias
|
||||||
|
mkldnn::memory::desc b_mkl_md;
|
||||||
|
if(bias != nullptr)
|
||||||
|
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x);
|
||||||
|
|
||||||
|
// output
|
||||||
|
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat);
|
||||||
|
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[4] = output->stridesOf()[4];
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct,
|
||||||
|
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, mkldnn::memory> args;
|
||||||
|
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// bias
|
||||||
|
if(bias != nullptr) {
|
||||||
|
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer());
|
||||||
|
args[MKLDNN_ARG_BIAS] = b_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// output
|
||||||
|
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer());
|
||||||
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
args[MKLDNN_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||||
|
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
|
||||||
|
const int isSameMode) {
|
||||||
|
|
||||||
|
// input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format
|
||||||
|
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
||||||
|
// gradB [oC], may be nullptr
|
||||||
|
// gradO [bS, oD, oH, oW, oC]
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
|
mkldnn::memory::dims strides = { sD, sH, sW };
|
||||||
|
mkldnn::memory::dims dilation = { dD - 1, dH - 1, dW - 1 };
|
||||||
|
mkldnn::memory::dims padding = { pD, pH, pW };
|
||||||
|
mkldnn::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
|
|
||||||
|
// input type
|
||||||
|
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// weights type
|
||||||
|
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradO type
|
||||||
|
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradI type
|
||||||
|
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradW type
|
||||||
|
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
|
||||||
|
// gradB type
|
||||||
|
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32;
|
||||||
|
|
||||||
|
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw; // isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
|
||||||
|
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw;
|
||||||
|
|
||||||
|
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||||
|
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||||
|
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
|
||||||
|
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
|
||||||
|
|
||||||
|
// weights
|
||||||
|
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
|
||||||
|
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
||||||
|
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
|
||||||
|
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4];
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
|
||||||
|
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
|
||||||
|
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4];
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
|
||||||
|
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
|
||||||
|
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4];
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
mkldnn::memory::desc gradB_mkl_md;
|
||||||
|
if(gradB != nullptr)
|
||||||
|
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x);
|
||||||
|
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// forward primitive description
|
||||||
|
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backward data primitive description
|
||||||
|
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// backward weights primitive description
|
||||||
|
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, mkldnn::memory> args;
|
||||||
|
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
if (gradOReorder)
|
||||||
|
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||||
|
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||||
|
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||||
|
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||||
|
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||||
|
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
if(gradB != nullptr) {
|
||||||
|
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||||
|
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run backward data calculations
|
||||||
|
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// run backward weights calculations
|
||||||
|
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder gradI if necessary
|
||||||
|
if (gradIReorder)
|
||||||
|
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||||
|
if (gradWReorder)
|
||||||
|
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(deconv3d) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||||
|
|
||||||
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) height
|
||||||
|
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2)); // filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
if(isSameMode){ // SAME
|
||||||
|
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
}
|
||||||
|
|
||||||
|
// mkl supports only [oC, iC, kD, kH, kW] format for weights
|
||||||
|
weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||||
|
|
||||||
|
// mkl supports only NCDHW
|
||||||
|
if(!isNCDHW) {
|
||||||
|
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode);
|
||||||
|
|
||||||
|
delete weights;
|
||||||
|
|
||||||
|
if(!isNCDHW) {
|
||||||
|
delete input;
|
||||||
|
delete output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_CHECK(deconv3d) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
// if (::optimalLevel() < 2)
|
||||||
|
// return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto weights = INPUT_VARIABLE(1);
|
||||||
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||||
|
|
||||||
|
auto output = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const DataType xType = input->dataType();
|
||||||
|
const DataType wType = weights->dataType();
|
||||||
|
const DataType zType = output->dataType();
|
||||||
|
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && (
|
||||||
|
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
|
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF ) ||
|
||||||
|
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(deconv3d_bp) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
|
||||||
|
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
|
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of weights array must be equal to 5 , but got %i instead !", weights->rankOf());
|
||||||
|
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf());
|
||||||
|
|
||||||
|
|
||||||
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||||
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||||
|
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||||
|
int sD = INT_ARG(3); // strides depth
|
||||||
|
int sH = INT_ARG(4); // strides height
|
||||||
|
int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||||
|
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
|
int trueoD, trueoH, trueoW; // true output height, width
|
||||||
|
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if(bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
// mkl supports only [oC, iC, kD, kH, kW] for weights
|
||||||
|
weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||||
|
gradW = new NDArray(gradW->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||||
|
|
||||||
|
// mkl supports NCDHW format only
|
||||||
|
if(!isNCDHW) {
|
||||||
|
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||||
|
}
|
||||||
|
|
||||||
|
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode);
|
||||||
|
|
||||||
|
delete weights;
|
||||||
|
delete gradW;
|
||||||
|
|
||||||
|
if(!isNCDHW) {
|
||||||
|
delete input;
|
||||||
|
delete gradI;
|
||||||
|
delete gradO;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PLATFORM_CHECK(deconv3d_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
// if (::optimalLevel() < 2)
|
||||||
|
// return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW)
|
||||||
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||||
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
|
||||||
|
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||||
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
const DataType xType = input->dataType();
|
||||||
|
const DataType wType = weights->dataType();
|
||||||
|
const DataType gradOType = gradO->dataType();
|
||||||
|
|
||||||
|
const DataType gradIType = gradI->dataType();
|
||||||
|
const DataType gradWType = gradW->dataType();
|
||||||
|
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -416,36 +416,36 @@ PLATFORM_IMPL(lstmLayer) {
|
||||||
|
|
||||||
// Wx validation
|
// Wx validation
|
||||||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||||
// Wr validation
|
// Wr validation
|
||||||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||||
// biases validation
|
// biases validation
|
||||||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||||
// initial output validation
|
// initial output validation
|
||||||
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
|
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||||
// initial cell validation
|
// initial cell validation
|
||||||
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
|
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||||
}
|
}
|
||||||
else { // bidirectional
|
else { // bidirectional
|
||||||
// Wx validation
|
// Wx validation
|
||||||
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
|
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||||
// Wr validation
|
// Wr validation
|
||||||
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
|
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||||
// biases validation
|
// biases validation
|
||||||
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
|
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||||
// initial output validation
|
// initial output validation
|
||||||
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
|
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||||
// initial cell validation
|
// initial cell validation
|
||||||
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
|
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
|
||||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
|
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip)};
|
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip)};
|
||||||
|
|
|
@ -148,7 +148,7 @@ namespace nd4j {
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) {
|
||||||
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
|
||||||
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
|
||||||
mkldnn::memory::dims conv_bias_tz = { oC };
|
mkldnn::memory::dims conv_bias_tz = { oC };
|
||||||
|
@ -156,6 +156,7 @@ namespace nd4j {
|
||||||
|
|
||||||
conv_strides = { sH, sW };
|
conv_strides = { sH, sW };
|
||||||
conv_padding = { pH, pW };
|
conv_padding = { pH, pW };
|
||||||
|
conv_dilation = { dH-1, dW-1};
|
||||||
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||||
(oW - 1) * sW - iW + kW - pW };
|
(oW - 1) * sW - iW + kW - pW };
|
||||||
|
|
||||||
|
@ -227,13 +228,14 @@ namespace nd4j {
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) {
|
||||||
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
|
||||||
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
|
||||||
mkldnn::memory::dims conv_bias_tz = { oC };
|
mkldnn::memory::dims conv_bias_tz = { oC };
|
||||||
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||||
|
|
||||||
conv_strides = { sD, sH, sW };
|
conv_strides = { sD, sH, sW };
|
||||||
|
conv_dilation = { dD-1, dH-1, dW-1};
|
||||||
conv_padding = { pD, pH, pW };
|
conv_padding = { pD, pH, pW };
|
||||||
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||||
(oH - 1) * sH - iH + kH - pH,
|
(oH - 1) * sH - iH + kH - pH,
|
||||||
|
|
|
@ -67,6 +67,16 @@ namespace nd4j{
|
||||||
DECLARE_PLATFORM(batchnorm_bp);
|
DECLARE_PLATFORM(batchnorm_bp);
|
||||||
|
|
||||||
DECLARE_PLATFORM(lstmLayer);
|
DECLARE_PLATFORM(lstmLayer);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(deconv2d);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(deconv2d_tf);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(deconv3d);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(deconv2d_bp);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(deconv3d_bp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,7 +93,7 @@ namespace nd4j{
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation);
|
||||||
|
|
||||||
void getMKLDNNMemoryDescConv3d(
|
void getMKLDNNMemoryDescConv3d(
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
||||||
|
@ -93,7 +103,7 @@ namespace nd4j{
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation);
|
||||||
|
|
||||||
void getMKLDNNMemoryDescPool2d(
|
void getMKLDNNMemoryDescPool2d(
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||||
|
|
|
@ -129,6 +129,47 @@ namespace randomOps {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class PoissonDistribution {
|
||||||
|
public:
|
||||||
|
no_exec_special
|
||||||
|
no_exec_special_cuda
|
||||||
|
|
||||||
|
method_XY
|
||||||
|
|
||||||
|
random_def T op(Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) {
|
||||||
|
T lambda = extraParams[0];
|
||||||
|
T x = helper->relativeT(idx, -nd4j::DataTypeUtils::template max<T>() / 10 , nd4j::DataTypeUtils::template max<T>() / 10);
|
||||||
|
return x <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igammac<T,T,T>(nd4j::math::nd4j_floor<T,T>(x), lambda);
|
||||||
|
}
|
||||||
|
|
||||||
|
random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) {
|
||||||
|
T lambda = extraParams[0];
|
||||||
|
return valueX <= (T)0.f ? (T)0.f : (T)nd4j::math::nd4j_igammac<T,T,T>(nd4j::math::nd4j_floor<T,T>(valueX), lambda);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class GammaDistribution {
|
||||||
|
public:
|
||||||
|
no_exec_special
|
||||||
|
no_exec_special_cuda
|
||||||
|
|
||||||
|
method_XY
|
||||||
|
|
||||||
|
random_def T op(Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) {
|
||||||
|
T alpha = extraParams[0];
|
||||||
|
T beta = extraParams[1];
|
||||||
|
T x = helper->relativeT(idx, -nd4j::DataTypeUtils::template max<T>() / 10 , nd4j::DataTypeUtils::template max<T>() / 10);
|
||||||
|
return x <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igamma<T,T,T>(alpha, x * beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) {
|
||||||
|
T alpha = extraParams[0];
|
||||||
|
T beta = extraParams[1];
|
||||||
|
return valueX <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igamma<T,T,T>(alpha, beta * valueX);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Basic DropOut/DropConnect Op
|
* Basic DropOut/DropConnect Op
|
||||||
|
|
|
@ -894,6 +894,10 @@ namespace nd4j {
|
||||||
Z aim = nd4j_pow<X, X, Z>(x, a) / (nd4j_exp<X, Z>(x) * nd4j_gamma<Y, Z>(a));
|
Z aim = nd4j_pow<X, X, Z>(x, a) / (nd4j_exp<X, Z>(x) * nd4j_gamma<Y, Z>(a));
|
||||||
auto sum = Z(0.);
|
auto sum = Z(0.);
|
||||||
auto denom = Z(1.);
|
auto denom = Z(1.);
|
||||||
|
if (a <= X(0.000001))
|
||||||
|
//throw std::runtime_error("Cannot calculate gamma for a zero val.");
|
||||||
|
return Z(0);
|
||||||
|
|
||||||
for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) {
|
for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) {
|
||||||
denom *= (a + i);
|
denom *= (a + i);
|
||||||
sum += nd4j_pow<X, int, Z>(x, i) / denom;
|
sum += nd4j_pow<X, int, Z>(x, i) / denom;
|
||||||
|
|
|
@ -30,7 +30,7 @@ endif()
|
||||||
|
|
||||||
|
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Release")
|
if (CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||||
set(CMAKE_CXX_FLAGS "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
|
set(CMAKE_CXX_FLAGS "-O3 -fPIC -std=c++11 -fmax-errors=2")
|
||||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
||||||
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
|
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
|
||||||
else()
|
else()
|
||||||
|
@ -38,13 +38,13 @@ if (CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true")
|
set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true")
|
||||||
elseif(WIN32)
|
elseif(WIN32)
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
set(CMAKE_CXX_FLAGS " -O0 -g --fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
|
set(CMAKE_CXX_FLAGS " -O0 -g --fPIC -std=c++11 -fmax-errors=2")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
|
set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 -fmax-errors=2")
|
||||||
if (CPU_BLAS)
|
if (CPU_BLAS)
|
||||||
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address")
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -437,58 +437,38 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST(TypedConvolutionTests1, deconv2D_FF_NoBias_1) {
|
TYPED_TEST(TypedConvolutionTests1, deconv2D_FF_NoBias_1) {
|
||||||
Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
|
||||||
TypeParam _expB[] = {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0,};
|
|
||||||
NDArray exp(_expB, _expS);
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create_<TypeParam>('c', {2, 3, 4, 4});
|
int bS=2, iH=4,iW=4, iC=3,oC=3, kH=5,kW=5, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||||
auto weights = NDArrayFactory::create_<TypeParam>('c', {3, 3, 5, 5});
|
int oH=8,oW=8;
|
||||||
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||||
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
input->linspace(1);
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
||||||
weights->linspace(1);
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, oC, iC}, {1., 76., 151., 26., 101., 176., 51., 126., 201., 2., 77., 152., 27., 102., 177., 52., 127., 202., 3., 78., 153., 28., 103., 178., 53., 128., 203.,
|
||||||
weights->permutei({2,3,1,0});
|
4., 79., 154., 29., 104., 179., 54., 129., 204., 5., 80., 155., 30., 105., 180., 55., 130., 205., 6., 81., 156., 31., 106., 181., 56., 131., 206.,
|
||||||
|
7., 82., 157., 32., 107., 182., 57., 132., 207., 8., 83., 158., 33., 108., 183., 58., 133., 208., 9., 84., 159., 34., 109., 184., 59., 134., 209.,
|
||||||
|
10., 85., 160., 35., 110., 185., 60., 135., 210., 11., 86., 161., 36., 111., 186., 61., 136., 211., 12., 87., 162., 37., 112., 187., 62., 137., 212.,
|
||||||
|
13., 88., 163., 38., 113., 188., 63., 138., 213., 14., 89., 164., 39., 114., 189., 64., 139., 214., 15., 90., 165., 40., 115., 190., 65., 140., 215.,
|
||||||
|
16., 91., 166., 41., 116., 191., 66., 141., 216., 17., 92., 167., 42., 117., 192., 67., 142., 217., 18., 93., 168., 43., 118., 193., 68., 143., 218.,
|
||||||
|
19., 94., 169., 44., 119., 194., 69., 144., 219., 20., 95., 170., 45., 120., 195., 70., 145., 220., 21., 96., 171., 46., 121., 196., 71., 146., 221.,
|
||||||
|
22., 97., 172., 47., 122., 197., 72., 147., 222., 23., 98., 173., 48., 123., 198., 73., 148., 223., 24., 99., 174., 49., 124., 199., 74., 149., 224.,
|
||||||
|
25., 100., 175.,50., 125., 200.,75., 150., 225.});
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}, {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0});
|
||||||
variableSpace->putVariable(-1, input);
|
|
||||||
variableSpace->putVariable(-2, weights);
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, false);
|
input.linspace(1);
|
||||||
block->fillInputs({-1, -2});
|
|
||||||
|
|
||||||
block->getIArguments()->push_back(5);
|
|
||||||
block->getIArguments()->push_back(5);
|
|
||||||
|
|
||||||
block->getIArguments()->push_back(1);
|
|
||||||
block->getIArguments()->push_back(1);
|
|
||||||
|
|
||||||
block->getIArguments()->push_back(0);
|
|
||||||
block->getIArguments()->push_back(0);
|
|
||||||
|
|
||||||
// dilation
|
|
||||||
block->getIArguments()->push_back(1);
|
|
||||||
block->getIArguments()->push_back(1);
|
|
||||||
|
|
||||||
// NOT same mode
|
|
||||||
block->getIArguments()->push_back(0);
|
|
||||||
|
|
||||||
block->getIArguments()->push_back(0);
|
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
|
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
|
||||||
Nd4jStatus status = op.execute(block);
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
auto output = results->at(0);
|
||||||
|
|
||||||
auto output = variableSpace->getVariable(1)->getNDArray();
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
|
||||||
// exp.printBuffer("Expctd buffer");
|
|
||||||
//output->printBuffer("Result buffer");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
delete variableSpace;
|
delete results;
|
||||||
delete block;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) {
|
TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) {
|
||||||
|
@ -812,61 +792,54 @@ TEST_F(ConvolutionTests1, Test_im2col_col2im_3) {
|
||||||
|
|
||||||
TEST_F(ConvolutionTests1, TestDeconv_bp_1) {
|
TEST_F(ConvolutionTests1, TestDeconv_bp_1) {
|
||||||
|
|
||||||
|
int bS=3, iH=4,iW=4, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||||
|
int oH=4,oW=4;
|
||||||
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||||
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
double _expb[] = { 35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, 65.f, 68.f, 71.f, 74.f, 77.f, 80.f, 71.f, 78.f, 85.f, 92.f, 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, 169.f, 176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f, 131.f, 134.f, 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, 167.f, 170.f, 173.f, 176.f, 295.f, 302.f, 309.f, 316.f, 323.f, 330.f, 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, 459.f, 470.f, 481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f, 236.f, 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, 269.f, 272.f, 519.f, 526.f, 533.f, 540.f, 547.f, 554.f, 561.f, 568.f, 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, 833.f, 844.f, 855.f, 866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, 943.f, 954.f, 965.f, 976.f};
|
|
||||||
std::shared_ptr<DataBuffer> pBuffer1 = std::make_shared<DataBuffer>(_expb, sizeof(_expb), nd4j::DataType::DOUBLE, false);
|
|
||||||
NDArray expEpsilon(pBuffer1, 'c', {3, 3, 4, 4});
|
|
||||||
|
|
||||||
double _expwb[] = { 160008.f, 203400.f, 191112.f, 246792.f, 222216.f, 290184.f};
|
NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
|
||||||
std::shared_ptr<DataBuffer> pBuffer2 = std::make_shared<DataBuffer>(_expwb, sizeof(_expwb), nd4j::DataType::DOUBLE, false);
|
NDArray bias('c', {oC}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expGradW(pBuffer2, 'c', {3, 2, 1, 1});
|
NDArray weights('c',{kH,kW,oC,iC}, {1,3,5,2,4,6}, nd4j::DataType::FLOAT32);
|
||||||
expGradW.permutei({2,3,1,0});
|
NDArray gradO('c', {bS, oC, oH, oW},nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
double _expbb[] = {1944.f, 2712.f};
|
NDArray expGradI('c', {bS, iC, iH, iW}, {35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, 65.f, 68.f, 71.f, 74.f,
|
||||||
std::shared_ptr<DataBuffer> pBuffer3 = std::make_shared<DataBuffer>(_expbb, sizeof(_expbb), nd4j::DataType::DOUBLE, false);
|
77.f, 80.f, 71.f, 78.f, 85.f, 92.f, 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, 169.f,
|
||||||
NDArray expGradB(pBuffer3, 'c', {1, 2});
|
176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f,
|
||||||
|
131.f, 134.f, 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, 167.f, 170.f, 173.f, 176.f, 295.f,
|
||||||
auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
|
302.f, 309.f, 316.f, 323.f, 330.f, 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, 459.f, 470.f,
|
||||||
auto bias = NDArrayFactory::create<double>('c', {1, 2});
|
481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f,
|
||||||
auto weights = NDArrayFactory::create<double>('c',{3, 2, 1, 1});
|
236.f, 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, 269.f, 272.f, 519.f, 526.f, 533.f, 540.f,
|
||||||
auto epsilon = NDArrayFactory::create<double>('c', {3, 2, 4, 4});
|
547.f, 554.f, 561.f, 568.f, 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, 833.f, 844.f, 855.f,
|
||||||
|
866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, 943.f, 954.f, 965.f, 976.f}, nd4j::DataType::FLOAT32);
|
||||||
/*
|
NDArray expGradW('c', {kH, kW, oC, iC}, {160008., 191112., 222216., 203400., 246792., 290184.f}, nd4j::DataType::FLOAT32);
|
||||||
Input shape (3, 3, 4, 4)
|
NDArray expGradB('c', {oC}, {1944.f, 2712.f}, nd4j::DataType::FLOAT32);
|
||||||
Weights shape (3, 2, 1, 1)
|
|
||||||
Epsilon shape (3, 2, 4, 4)
|
|
||||||
*/
|
|
||||||
|
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
weights.linspace(1);
|
|
||||||
bias.linspace(1);
|
bias.linspace(1);
|
||||||
epsilon.linspace(1);
|
gradO.linspace(1);
|
||||||
weights.permutei({2,3,1,0});
|
|
||||||
|
|
||||||
nd4j::ops::deconv2d_bp op;
|
nd4j::ops::deconv2d_bp op;
|
||||||
|
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
|
||||||
auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
auto gradI = results->at(0);
|
||||||
|
auto gradW = results->at(1);
|
||||||
|
auto gradB = results->at(2);
|
||||||
|
|
||||||
auto expNext = result->at(0);
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||||
ASSERT_TRUE(expEpsilon.isSameShape(expNext));
|
|
||||||
ASSERT_TRUE(expEpsilon.equalsTo(expNext));
|
|
||||||
|
|
||||||
auto gradW = result->at(1);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
||||||
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
||||||
|
|
||||||
auto gradB = result->at(2);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
||||||
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
||||||
|
|
||||||
delete result;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ConvolutionTests1, TestDeconv_bp_2) {
|
TEST_F(ConvolutionTests1, TestDeconv_bp_2) {
|
||||||
/*
|
/*
|
||||||
Input shape:
|
Input shape:
|
||||||
|
@ -914,13 +887,11 @@ TEST_F(ConvolutionTests1, TestDeconv_ff_2) {
|
||||||
NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.});
|
NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.});
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
|
auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
|
||||||
auto weights = NDArrayFactory::create<double>('c',{3, 2, 1, 1});
|
auto weights = NDArrayFactory::create<double>('c',{1, 1, 2, 3}, {1,3,5,2,4,6});
|
||||||
auto bias = NDArrayFactory::create<double>('c', {2});
|
auto bias = NDArrayFactory::create<double>('c', {2});
|
||||||
|
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
weights.linspace(1);
|
|
||||||
bias.linspace(1);
|
bias.linspace(1);
|
||||||
weights.permutei({2,3,1,0});
|
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
nd4j::ops::deconv2d op;
|
||||||
|
|
||||||
|
@ -2337,14 +2308,14 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ConvolutionTests1, deconv2d_test1) {
|
TEST_F(ConvolutionTests1, deconv2d_test1) {
|
||||||
|
|
||||||
int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||||
int oH=3,oW=3;
|
int iH=3,iW=3;
|
||||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
|
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<double>('c', {kH, kW, oC, iC});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
|
auto exp = NDArrayFactory::create<double>('c', {bS, oH, oW, oC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
|
||||||
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
|
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
|
||||||
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
|
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
|
||||||
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75,
|
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75,
|
||||||
|
|
|
@ -575,24 +575,38 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test1) {
|
||||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||||
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, oD, oH, oW, oC});
|
auto input = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {kD, kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC});
|
||||||
auto bias = NDArrayFactory::create<double>('c', {iC});
|
auto bias = NDArrayFactory::create<float>('c', {iC});
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, iD, iH, iW, iC});
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
|
||||||
|
|
||||||
|
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expGradB('c', {iC}, {364.5}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input = 0.5;
|
input = 0.5;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
gradO.linspace(0.5);
|
gradO.linspace(0.5);
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
nd4j::ops::deconv3d_bp op;
|
||||||
const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
||||||
|
|
||||||
nd4j::ops::deconv3d opFF;
|
auto gradI = results->at(0);
|
||||||
nd4j::ops::deconv3d_bp opBP;
|
auto gradW = results->at(1);
|
||||||
|
auto gradB = results->at(2);
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
||||||
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
||||||
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -603,23 +617,32 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test2) {
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||||
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, oD, oH, oW, oC});
|
auto input = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {kD, kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC});
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, iD, iH, iW, iC});
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
|
||||||
|
|
||||||
|
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {34, 37.2, 16.6, 18.4, 15.4, 17.4, 7.1, 8.2, 10.6, 13., 4.3, 5.6, 2.9, 4.3, 0.75, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16, 16, 9, 9, 10, 10, 5.5, 5.5, 12, 12, 6.5, 6.5, 7, 7, 3.75, 3.75}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input = 0.5;
|
input = 0.5;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
gradO.linspace(0.5);
|
gradO.linspace(0.5);
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
nd4j::ops::deconv3d_bp op;
|
||||||
const OpArgsHolder argsHolderBP({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
||||||
|
|
||||||
nd4j::ops::deconv3d opFF;
|
auto gradI = results->at(0);
|
||||||
nd4j::ops::deconv3d_bp opBP;
|
auto gradW = results->at(1);
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
||||||
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
||||||
|
|
||||||
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -630,24 +653,31 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test3) {
|
||||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||||
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, oC, oD, oH, oW});
|
auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {oC, iC, kD, kH, kW});
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1,0.9,0.2,0.1,0.3,1.1,0.4,1.2,0.5,1.3,0.6,1.4,0.7,1.5,0.8,1.6});
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, iC, iD, iH, iW});
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
|
||||||
|
|
||||||
|
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28, 32, 32, 40, 40, 44, 44, 64, 64, 68, 68, 76, 76, 80, 80.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input = 0.5;
|
input = 0.5;
|
||||||
weights.linspace(0.1, 0.1);
|
|
||||||
gradO.linspace(0.5);
|
gradO.linspace(0.5);
|
||||||
weights.permutei({2, 3, 4, 1, 0});
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
nd4j::ops::deconv3d_bp op;
|
||||||
const OpArgsHolder argsHolderBP({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
||||||
|
|
||||||
nd4j::ops::deconv3d opFF;
|
auto gradI = results->at(0);
|
||||||
nd4j::ops::deconv3d_bp opBP;
|
auto gradW = results->at(1);
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
||||||
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
||||||
|
|
||||||
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -658,24 +688,31 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test4) {
|
||||||
int paddingMode = 0; // 1-SAME, 0-VALID;
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
||||||
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {bS, oC, oD, oH, oW});
|
auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {oC, iC, kD, kH, kW});
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1,0.9,0.2,0.1,0.3,1.1,0.4,1.2,0.5,1.3,0.6,1.4,0.7,1.5,0.8,1.6});
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {bS, iC, iD, iH, iW});
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
|
||||||
|
|
||||||
|
NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input = 0.5;
|
input = 0.5;
|
||||||
weights.linspace(0.1, 0.1);
|
|
||||||
gradO.linspace(0.5);
|
gradO.linspace(0.5);
|
||||||
weights.permutei({2, 3, 4, 1, 0});
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
nd4j::ops::deconv3d_bp op;
|
||||||
const OpArgsHolder argsHolderBP({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
||||||
|
|
||||||
nd4j::ops::deconv3d opFF;
|
auto gradI = results->at(0);
|
||||||
nd4j::ops::deconv3d_bp opBP;
|
auto gradW = results->at(1);
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
||||||
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
||||||
|
|
||||||
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -37,21 +37,6 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests11, test_mixed_biasadd_1) {
|
|
||||||
if (!Environment::getInstance()->isExperimentalBuild())
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
|
||||||
auto y = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f});
|
|
||||||
|
|
||||||
nd4j::ops::biasadd op;
|
|
||||||
auto status = op.execute({&x, &y}, {&z}, {}, {}, {true});
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
|
||||||
|
|
||||||
ASSERT_EQ(exp, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests11, test_listdiff_1) {
|
TEST_F(DeclarableOpsTests11, test_listdiff_1) {
|
||||||
auto x = NDArrayFactory::create<int>('c', {4}, {0, 1, 2, 3});
|
auto x = NDArrayFactory::create<int>('c', {4}, {0, 1, 2, 3});
|
||||||
|
|
|
@ -243,7 +243,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_12) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_BiasAdd_NHWC_1) {
|
TEST_F(DeclarableOpsTests4, biasadd_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 3, 2});
|
auto x = NDArrayFactory::create<double>('c', {2, 3, 3, 2});
|
||||||
auto bias = NDArrayFactory::create<double>('c', {2}, {1, 2});
|
auto bias = NDArrayFactory::create<double>('c', {2}, {1, 2});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
|
||||||
|
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests4, Test_BiasAdd_NHWC_1) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_BiasAdd_NCHW_1) {
|
TEST_F(DeclarableOpsTests4, biasadd_2) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 3, 3});
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 3, 3});
|
||||||
auto bias = NDArrayFactory::create<double>('c', {2}, {1, 2});
|
auto bias = NDArrayFactory::create<double>('c', {2}, {1, 2});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2});
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2});
|
||||||
|
@ -279,6 +279,95 @@ TEST_F(DeclarableOpsTests4, Test_BiasAdd_NCHW_1) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests4, biasadd_3) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||||||
|
auto row = NDArrayFactory::create<double>('c', {3}, {1, 2, 3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 1, 2, 3});
|
||||||
|
|
||||||
|
nd4j::ops::biasadd op;
|
||||||
|
auto result = op.execute({&x, &row}, {}, {}, {true}, false, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, biasadd_bp_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {2,2,2,3}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO('c', {2,2,2,3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray bias('c', {3}, {-1., -2, -3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expGradB('c', {3}, {9.2, 10. , 10.8}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
gradO.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::biasadd_bp op;
|
||||||
|
auto result = op.execute({&x, &bias, &gradO}, {}, {}, {false}); // NHWC
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto gradI = result->at(0);
|
||||||
|
auto gradB = result->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(gradI->isSameShape(gradO));
|
||||||
|
ASSERT_TRUE(gradI->equalsTo(gradO));
|
||||||
|
|
||||||
|
ASSERT_TRUE(gradB->isSameShape(expGradB));
|
||||||
|
ASSERT_TRUE(gradB->equalsTo(expGradB));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, biasadd_bp_2) {
|
||||||
|
|
||||||
|
NDArray x('c', {2,3,2,2}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO('c', {2,3,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray bias('c', {3}, {-1., -2, -3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expGradB('c', {3}, {6.8, 10., 13.2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
gradO.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::biasadd_bp op;
|
||||||
|
auto result = op.execute({&x, &bias, &gradO}, {}, {}, {true}); // NCHW
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto gradI = result->at(0);
|
||||||
|
auto gradB = result->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(gradI->isSameShape(gradO));
|
||||||
|
ASSERT_TRUE(gradI->equalsTo(gradO));
|
||||||
|
|
||||||
|
ASSERT_TRUE(gradB->isSameShape(expGradB));
|
||||||
|
ASSERT_TRUE(gradB->equalsTo(expGradB));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests4, biasadd_4) {
|
||||||
|
if (!Environment::getInstance()->isExperimentalBuild())
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f});
|
||||||
|
|
||||||
|
nd4j::ops::biasadd op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {true});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, z);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Fill_1) {
|
TEST_F(DeclarableOpsTests4, Test_Fill_1) {
|
||||||
auto x = NDArrayFactory::create<int>('c', {1, 3}, {3, 2, 4});
|
auto x = NDArrayFactory::create<int>('c', {1, 3}, {3, 2, 4});
|
||||||
auto v = NDArrayFactory::create<double>(2.);
|
auto v = NDArrayFactory::create<double>(2.);
|
||||||
|
@ -639,24 +728,6 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_BiasAdd_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
|
||||||
auto row = NDArrayFactory::create<double>('c', {3}, {1, 2, 3});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 1, 2, 3});
|
|
||||||
|
|
||||||
nd4j::ops::biasadd op;
|
|
||||||
auto result = op.execute({&x, &row}, {}, {}, {true}, false, nd4j::DataType::DOUBLE);
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_1D_1) {
|
TEST_F(DeclarableOpsTests4, Test_1D_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||||||
|
|
||||||
|
|
|
@ -241,6 +241,52 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
||||||
|
int zero = 0;
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {5, 4});
|
||||||
|
// auto b = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||||
|
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||||
|
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
auto grad = NDArrayFactory::create<double>('c', {5,4});
|
||||||
|
|
||||||
|
matrix.linspace(1);
|
||||||
|
grad.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::strided_slice_bp op;
|
||||||
|
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printShapeInfo("Output shape");
|
||||||
|
z->printIndexedBuffer("Output");
|
||||||
|
//ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
||||||
|
int zero = 0;
|
||||||
|
auto matrix = NDArrayFactory::create<double>('c', {1, 2});
|
||||||
|
// auto b = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||||
|
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
|
||||||
|
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
auto grad = NDArrayFactory::create<double>('c', {1}, {1.});
|
||||||
|
|
||||||
|
matrix.linspace(1);
|
||||||
|
//grad.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::strided_slice_bp op;
|
||||||
|
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printShapeInfo("Output shape");
|
||||||
|
z->printIndexedBuffer("Output");
|
||||||
|
//ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
|
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
|
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
|
||||||
|
|
|
@ -756,6 +756,27 @@ TEST_F(DeclarableOpsTests9, concat_test24) {
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, concat_test25) {
|
||||||
|
|
||||||
|
auto x0 = NDArrayFactory::create<double>('c', {1,4}, {1,2,3,4});
|
||||||
|
auto x1 = NDArrayFactory::create<double>('c', {1,4}, {5,6,7,8});
|
||||||
|
auto axis = NDArrayFactory::create<double>('c', {1}, {0.});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2,4}, {1,2,3,4,5,6,7,8});
|
||||||
|
|
||||||
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
|
auto result = op.execute({&x0, &x1, &axis}, {}, {}, {true});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
auto output = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
||||||
|
|
||||||
|
|
|
@ -773,6 +773,88 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, Test_PoissonDistribution_1) {
|
||||||
|
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
|
||||||
|
auto la = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {10, 2, 3});
|
||||||
|
|
||||||
|
la.linspace(1.0);
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ops::random_poisson op;
|
||||||
|
auto result = op.execute({&x, &la}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Poisson distribution");
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, Test_GammaDistribution_1) {
|
||||||
|
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
|
||||||
|
auto al = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {10, 2, 3});
|
||||||
|
|
||||||
|
al.linspace(1.0);
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ops::random_gamma op;
|
||||||
|
auto result = op.execute({&x, &al}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Gamma distribution");
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, Test_GammaDistribution_2) {
|
||||||
|
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
|
||||||
|
auto al = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto be = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {10, 2, 3});
|
||||||
|
|
||||||
|
al.linspace(1.0);
|
||||||
|
be.assign(1.0);
|
||||||
|
|
||||||
|
nd4j::ops::random_gamma op;
|
||||||
|
auto result = op.execute({&x, &al, &be}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Gamma distribution");
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, Test_GammaDistribution_3) {
|
||||||
|
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
|
||||||
|
auto al = NDArrayFactory::create<float>('c', {3, 1});
|
||||||
|
auto be = NDArrayFactory::create<float>('c', {1, 2});
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {10, 3, 2});
|
||||||
|
|
||||||
|
al.linspace(1.0);
|
||||||
|
be.assign(2.0);
|
||||||
|
|
||||||
|
nd4j::ops::random_gamma op;
|
||||||
|
auto result = op.execute({&x, &al, &be}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Gamma distribution");
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace tests {
|
namespace tests {
|
||||||
static void fillList(Nd4jLong seed, int numberOfArrays, std::vector<Nd4jLong> &shape, std::vector<NDArray*> &list, nd4j::graph::RandomGenerator *rng) {
|
static void fillList(Nd4jLong seed, int numberOfArrays, std::vector<Nd4jLong> &shape, std::vector<NDArray*> &list, nd4j::graph::RandomGenerator *rng) {
|
||||||
|
|
|
@ -109,22 +109,22 @@ endif()
|
||||||
# -fsanitize=address
|
# -fsanitize=address
|
||||||
# -fsanitize=leak
|
# -fsanitize=leak
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -D__APPLE_OS__=true")
|
set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -D__APPLE_OS__=true")
|
||||||
elseif(WIN32)
|
elseif(WIN32)
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
set(CMAKE_CXX_FLAGS " -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -Wa,-mbig-obj")
|
set(CMAKE_CXX_FLAGS " -g -fPIC -std=c++11 -Wa,-mbig-obj")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
if ("${_RELEASE}" OR CMAKE_BUILD_TYPE STREQUAL "Release")
|
if ("${_RELEASE}" OR CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||||
message("Release build for tests")
|
message("Release build for tests")
|
||||||
set(CMAKE_CXX_FLAGS "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations")
|
set(CMAKE_CXX_FLAGS "-O3 -fPIC -std=c++11")
|
||||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
||||||
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
|
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations")
|
set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 ")
|
||||||
if (NOT CUDA_BLAS)
|
if (NOT CUDA_BLAS)
|
||||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address")
|
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -975,8 +975,8 @@ public class DifferentialFunctionFactory {
|
||||||
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
|
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
|
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) {
|
||||||
return new BiasAddGrad(sameDiff(), input, bias, grad).outputVariables();
|
return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables();
|
||||||
}
|
}
|
||||||
|
|
||||||
public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) {
|
public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) {
|
||||||
|
|
|
@ -109,6 +109,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DDerivative.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DTF.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DTF.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DTF.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,
|
||||||
|
|
|
@ -45,12 +45,14 @@ public class BiasAdd extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
super(null, sameDiff, new SDVariable[] {input, bias}, false);
|
||||||
bArguments.clear();
|
bArguments.clear();
|
||||||
bArguments.add(nchw);
|
bArguments.add(nchw);
|
||||||
|
this.nchw = nchw;
|
||||||
}
|
}
|
||||||
|
|
||||||
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
|
||||||
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
super(new INDArray[]{input, bias}, wrapOrNull(output));
|
||||||
bArguments.clear();
|
bArguments.clear();
|
||||||
bArguments.add(nchw);
|
bArguments.add(nchw);
|
||||||
|
this.nchw = nchw;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -80,7 +82,7 @@ public class BiasAdd extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradient){
|
public List<SDVariable> doDiff(List<SDVariable> gradient){
|
||||||
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0)));
|
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -31,9 +31,12 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class BiasAddGrad extends DynamicCustomOp {
|
public class BiasAddGrad extends DynamicCustomOp {
|
||||||
|
protected boolean nchw = true;
|
||||||
|
|
||||||
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient) {
|
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient, boolean nchw) {
|
||||||
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
|
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
|
||||||
|
this.nchw = nchw;
|
||||||
|
addBArgument(nchw);
|
||||||
}
|
}
|
||||||
|
|
||||||
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
|
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
|
||||||
|
@ -52,8 +55,6 @@ public class BiasAddGrad extends DynamicCustomOp {
|
||||||
return "biasadd_bp";
|
return "biasadd_bp";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
throw new UnsupportedOperationException("Differentiation not supported for op " + getClass().getSimpleName());
|
throw new UnsupportedOperationException("Differentiation not supported for op " + getClass().getSimpleName());
|
||||||
|
|
|
@ -147,69 +147,12 @@ public class DeConv3D extends DynamicCustomOp {
|
||||||
return config.getValue(property);
|
return config.getValue(property);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
int sD, sH, sW, dD=1, dH=1, dW=1;
|
|
||||||
|
|
||||||
val aStrides = nodeDef.getAttrOrThrow("strides");
|
|
||||||
List<Long> tfStrides = aStrides.getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
|
|
||||||
|
|
||||||
List<Long> tfDilation = null;
|
|
||||||
if (attributesForNode.containsKey("dilations")) {
|
|
||||||
tfDilation = attributesForNode.get("dilations").getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
|
|
||||||
}
|
|
||||||
|
|
||||||
val aPadding = nodeDef.getAttrOrDefault("padding", null);
|
|
||||||
String paddingMode = aPadding.getS().toStringUtf8();
|
|
||||||
|
|
||||||
String dataFormat = "NDHWC";
|
|
||||||
if (nodeDef.containsAttr("data_format")) {
|
|
||||||
val attr = nodeDef.getAttrOrThrow("data_format");
|
|
||||||
dataFormat = attr.getS().toStringUtf8().toLowerCase();
|
|
||||||
}
|
|
||||||
|
|
||||||
if(dataFormat.equalsIgnoreCase("NCDHW")){
|
|
||||||
sD = tfStrides.get(2).intValue();
|
|
||||||
sH = tfStrides.get(3).intValue();
|
|
||||||
sW = tfStrides.get(4).intValue();
|
|
||||||
if(tfDilation != null){
|
|
||||||
dD = tfDilation.get(2).intValue();
|
|
||||||
dH = tfDilation.get(3).intValue();
|
|
||||||
dW = tfDilation.get(4).intValue();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sD = tfStrides.get(1).intValue();
|
|
||||||
sH = tfStrides.get(2).intValue();
|
|
||||||
sW = tfStrides.get(3).intValue();
|
|
||||||
if(tfDilation != null){
|
|
||||||
dD = tfDilation.get(1).intValue();
|
|
||||||
dH = tfDilation.get(2).intValue();
|
|
||||||
dW = tfDilation.get(3).intValue();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
|
|
||||||
this.config = DeConv3DConfig.builder()
|
|
||||||
.kD(-1).kH(-1).kW(-1) //Infer from kernel
|
|
||||||
.sD(sD).sH(sW).sW(sH)
|
|
||||||
.dD(dD).dH(dH).dW(dW)
|
|
||||||
.isSameMode(isSameMode)
|
|
||||||
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
addArgs();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "deconv3d";
|
return "deconv3d";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Conv3DBackpropInputV2";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
|
|
@ -0,0 +1,208 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DeConv3D operation, TF-wrapper
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Getter
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class DeConv3DTF extends DynamicCustomOp {
|
||||||
|
|
||||||
|
protected DeConv3DConfig config;
|
||||||
|
|
||||||
|
public DeConv3DTF(@NonNull SameDiff sameDiff, @NonNull SDVariable shape, @NonNull SDVariable weights, @NonNull SDVariable input, @NonNull DeConv3DConfig config) {
|
||||||
|
super(sameDiff, new SDVariable[]{shape, weights, input});
|
||||||
|
|
||||||
|
this.config = config;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long[] iArgs() {
|
||||||
|
if (iArguments.size() == 0)
|
||||||
|
addArgs();
|
||||||
|
|
||||||
|
return super.iArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Object> propertiesForFunction() {
|
||||||
|
if(config == null && !iArguments.isEmpty()){
|
||||||
|
config = DeConv3DConfig.builder()
|
||||||
|
.kD(iArguments.get(0))
|
||||||
|
.kH(iArguments.get(1))
|
||||||
|
.kW(iArguments.get(2))
|
||||||
|
.sD(iArguments.get(3))
|
||||||
|
.sH(iArguments.get(4))
|
||||||
|
.sW(iArguments.get(5))
|
||||||
|
.pD(iArguments.get(6))
|
||||||
|
.pH(iArguments.get(7))
|
||||||
|
.pW(iArguments.get(8))
|
||||||
|
.dD(iArguments.get(9))
|
||||||
|
.dH(iArguments.get(10))
|
||||||
|
.dW(iArguments.get(11))
|
||||||
|
.isSameMode(iArguments.get(12) == 1)
|
||||||
|
.dataFormat(iArguments.get(13) == 1 ? DeConv3DConfig.NDHWC : DeConv3DConfig.NCDHW)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
return config.toProperties();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addArgs() {
|
||||||
|
addIArgument(config.getKD());
|
||||||
|
addIArgument(config.getKH());
|
||||||
|
addIArgument(config.getKW());
|
||||||
|
addIArgument(config.getSD());
|
||||||
|
addIArgument(config.getSH());
|
||||||
|
addIArgument(config.getSW());
|
||||||
|
addIArgument(config.getPD());
|
||||||
|
addIArgument(config.getPH());
|
||||||
|
addIArgument(config.getPW());
|
||||||
|
addIArgument(config.getDD());
|
||||||
|
addIArgument(config.getDH());
|
||||||
|
addIArgument(config.getDW());
|
||||||
|
addIArgument(ArrayUtil.fromBoolean(config.isSameMode()));
|
||||||
|
addIArgument(config.getDataFormat().equalsIgnoreCase(DeConv3DConfig.NCDHW) ? 0 : 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isConfigProperties() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String configFieldName() {
|
||||||
|
return "config";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object getValue(Field property) {
|
||||||
|
if (config == null) {
|
||||||
|
config = DeConv3DConfig.builder().build();
|
||||||
|
}
|
||||||
|
|
||||||
|
return config.getValue(property);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
|
||||||
|
val aStrides = nodeDef.getAttrOrThrow("strides");
|
||||||
|
val aDilations = nodeDef.getAttrOrDefault("dilations", null);
|
||||||
|
val tfStrides = aStrides.getList().getIList();
|
||||||
|
val tfDilation = aDilations == null ? null : aDilations.getList().getIList();
|
||||||
|
int sD, sH, sW, dD, dH, dW;
|
||||||
|
|
||||||
|
val aPadding = nodeDef.getAttrOrDefault("padding", null);
|
||||||
|
String paddingMode = aPadding.getS().toStringUtf8();
|
||||||
|
|
||||||
|
String dataFormat = DeConv3DConfig.NDHWC;
|
||||||
|
if (nodeDef.containsAttr("data_format")) {
|
||||||
|
val attr = nodeDef.getAttrOrThrow("data_format");
|
||||||
|
dataFormat = attr.getS().toStringUtf8().toLowerCase();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) {
|
||||||
|
sD = tfStrides.get(2).intValue();
|
||||||
|
sH = tfStrides.get(3).intValue();
|
||||||
|
sW = tfStrides.get(4).intValue();
|
||||||
|
|
||||||
|
|
||||||
|
dD = tfDilation == null ? 1 : tfDilation.get(2).intValue();
|
||||||
|
dH = tfDilation == null ? 1 : tfDilation.get(3).intValue();
|
||||||
|
dW = tfDilation == null ? 1 : tfDilation.get(4).intValue();
|
||||||
|
} else {
|
||||||
|
sD = tfStrides.get(1).intValue();
|
||||||
|
sH = tfStrides.get(2).intValue();
|
||||||
|
sW = tfStrides.get(3).intValue();
|
||||||
|
|
||||||
|
dD = tfDilation == null ? 1 : tfDilation.get(1).intValue();
|
||||||
|
dH = tfDilation == null ? 1 : tfDilation.get(2).intValue();
|
||||||
|
dW = tfDilation == null ? 1 : tfDilation.get(3).intValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
|
||||||
|
DeConv3DConfig conv3DConfig = DeConv3DConfig.builder()
|
||||||
|
.kD(-1)
|
||||||
|
.kH(-1)
|
||||||
|
.kW(-1)
|
||||||
|
.sD(sD)
|
||||||
|
.sH(sW)
|
||||||
|
.sW(sH)
|
||||||
|
.dD(dD)
|
||||||
|
.dH(dH)
|
||||||
|
.dW(dW)
|
||||||
|
.isSameMode(isSameMode)
|
||||||
|
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
|
||||||
|
.build();
|
||||||
|
this.config = conv3DConfig;
|
||||||
|
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "deconv3d_tf";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] tensorflowNames() {
|
||||||
|
return new String[]{"Conv3DBackpropInput", "Conv3DBackpropInputV2"};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Backprop not yet implemented for " + getClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ //inShape, weights, input
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(2));
|
||||||
|
}
|
||||||
|
}
|
|
@ -39,6 +39,7 @@ import java.util.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class Concat extends DynamicCustomOp {
|
public class Concat extends DynamicCustomOp {
|
||||||
private int concatDimension = -1;
|
private int concatDimension = -1;
|
||||||
|
private boolean isDynamicAxis = false;
|
||||||
|
|
||||||
public Concat(){
|
public Concat(){
|
||||||
|
|
||||||
|
@ -83,73 +84,11 @@ public class Concat extends DynamicCustomOp {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
|
||||||
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
|
|
||||||
|
|
||||||
Map<String,PropertyMapping> concatMap = new HashMap<>();
|
|
||||||
val concatDimProps = PropertyMapping.builder()
|
|
||||||
.tfInputPosition(0)
|
|
||||||
.onnxAttrName("axis")
|
|
||||||
.build();
|
|
||||||
concatMap.put("concatDimension",concatDimProps);
|
|
||||||
|
|
||||||
|
|
||||||
Map<String,PropertyMapping> concatV2Map = new HashMap<>();
|
|
||||||
val concat2DimProps = PropertyMapping.builder()
|
|
||||||
//lalst position
|
|
||||||
.tfInputPosition(-1)
|
|
||||||
.onnxAttrName("axis")
|
|
||||||
.build();
|
|
||||||
concatV2Map.put("concatDimension",concat2DimProps);
|
|
||||||
|
|
||||||
//note that onnx is already covered here
|
|
||||||
ret.put(tensorflowNames()[0],concatMap);
|
|
||||||
ret.put(tensorflowNames()[1],concatV2Map);
|
|
||||||
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
int concatDimension = -1;
|
//TF uses dynamic axis - last argument is a scalar integer array for axis
|
||||||
String input = null;
|
addBArgument(true);
|
||||||
val inputCount = nodeDef.getInputCount();
|
isDynamicAxis = true;
|
||||||
for(int i = 0; i < inputCount; i++) {
|
|
||||||
if(nodeDef.getInput(i).contains("/concat_dim")) {
|
|
||||||
input = nodeDef.getInput(i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//older versions may specify a concat_dim, usually it's the last argument
|
|
||||||
if(input == null) {
|
|
||||||
input = nodeDef.getInput(nodeDef.getInputCount() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
val variable = initWith.getVariable(input);
|
|
||||||
// concat dimension is only possible
|
|
||||||
if (variable != null) {
|
|
||||||
val arr = variable.getArr();
|
|
||||||
if (arr.length() == 1) {
|
|
||||||
concatDimension = arr.getInt(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.concatDimension = concatDimension;
|
|
||||||
addIArgument(this.concatDimension);
|
|
||||||
log.trace("Concat dimension: {}", concatDimension);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//don't pass both iArg and last axis down to libnd4j
|
|
||||||
if(inputArguments().length == nodeDef.getInputCount()) {
|
|
||||||
val inputArgs = inputArguments();
|
|
||||||
removeInputArgument(inputArgs[inputArguments().length - 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO Fix this: https://github.com/eclipse/deeplearning4j/issues/8285
|
|
||||||
sameDiff.removeArgFromOp(input,this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -159,12 +98,6 @@ public class Concat extends DynamicCustomOp {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
return "Concat";
|
return "Concat";
|
||||||
|
@ -175,7 +108,6 @@ public class Concat extends DynamicCustomOp {
|
||||||
return "Concat";
|
return "Concat";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String[] tensorflowNames() {
|
||||||
return new String[] {"Concat","ConcatV2"};
|
return new String[] {"Concat","ConcatV2"};
|
||||||
|
@ -189,18 +121,32 @@ public class Concat extends DynamicCustomOp {
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
SDVariable[] args = args();
|
SDVariable[] args = args();
|
||||||
SDVariable[] bpArgs = Arrays.copyOf(args, args.length + 1);
|
SDVariable[] bpArgs;
|
||||||
|
if(isDynamicAxis){
|
||||||
|
bpArgs = Arrays.copyOf(args, args.length + 2);
|
||||||
|
bpArgs[bpArgs.length - 1] = bpArgs[bpArgs.length - 3]; //Last input is axis -> move to end of bp args too
|
||||||
|
bpArgs[bpArgs.length - 2] = i_v.get(0);
|
||||||
|
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
||||||
|
} else {
|
||||||
|
bpArgs = Arrays.copyOf(args, args.length + 1);
|
||||||
bpArgs[bpArgs.length - 1] = i_v.get(0);
|
bpArgs[bpArgs.length - 1] = i_v.get(0);
|
||||||
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
DataType first = dataTypes.get(0);
|
DataType first = dataTypes.get(0);
|
||||||
for( int i=1; i<dataTypes.size(); i++ ){
|
|
||||||
|
for( int i=1; i<dataTypes.size() - (isDynamicAxis ? 1 : 0); i++ ){
|
||||||
DataType dt = dataTypes.get(i);
|
DataType dt = dataTypes.get(i);
|
||||||
Preconditions.checkState(first == dt, "All inputs must have same datatype - got %s and %s for inputs 0 and %s respectively", first, dt, i);
|
Preconditions.checkState(first == dt, "All inputs must have same datatype - got %s and %s for inputs 0 and %s respectively", first, dt, i);
|
||||||
}
|
}
|
||||||
|
if(isDynamicAxis) {
|
||||||
|
Preconditions.checkState(dataTypes.get(dataTypes.size() - 1).isIntType(),
|
||||||
|
"For dynamic axis case, last datatype must be an integer type, got input types %s");
|
||||||
|
}
|
||||||
|
|
||||||
//Output type is same as input types
|
//Output type is same as input types
|
||||||
return Collections.singletonList(first);
|
return Collections.singletonList(first);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape.bp;
|
package org.nd4j.linalg.api.ops.impl.shape.bp;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -42,6 +43,7 @@ import java.util.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ConcatBp extends DynamicCustomOp {
|
public class ConcatBp extends DynamicCustomOp {
|
||||||
private int concatDimension;
|
private int concatDimension;
|
||||||
|
private boolean dynamicAxis;
|
||||||
|
|
||||||
public ConcatBp(){
|
public ConcatBp(){
|
||||||
|
|
||||||
|
@ -53,38 +55,30 @@ public class ConcatBp extends DynamicCustomOp {
|
||||||
* @param concatDimension
|
* @param concatDimension
|
||||||
* @param inputsAndGrad Original inputs, followed by output gradient
|
* @param inputsAndGrad Original inputs, followed by output gradient
|
||||||
*/
|
*/
|
||||||
public ConcatBp(SameDiff sameDiff, int concatDimension, SDVariable... inputsAndGrad){
|
public ConcatBp(@NonNull SameDiff sameDiff, int concatDimension, @NonNull SDVariable... inputsAndGrad){
|
||||||
super(null, sameDiff, inputsAndGrad);
|
super(null, sameDiff, inputsAndGrad);
|
||||||
addIArgument(concatDimension);
|
addIArgument(concatDimension);
|
||||||
this.concatDimension = concatDimension;
|
this.concatDimension = concatDimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param sameDiff SameDiff instance
|
||||||
|
* @param inputsGradAxis Inputs, gradient array, and axis
|
||||||
|
*/
|
||||||
|
public ConcatBp(@NonNull SameDiff sameDiff, @NonNull SDVariable... inputsGradAxis){
|
||||||
|
super(null, sameDiff, inputsGradAxis);
|
||||||
|
Preconditions.checkState(inputsGradAxis[inputsGradAxis.length-1].dataType().isIntType(),
|
||||||
|
"When using this constructor, the last input must be an integer array (for the axis)");
|
||||||
|
addBArgument(true); //Last argument
|
||||||
|
this.dynamicAxis = true;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "concat_bp";
|
return "concat_bp";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Op.Type opType() {
|
public Op.Type opType() {
|
||||||
return Op.Type.CUSTOM;
|
return Op.Type.CUSTOM;
|
||||||
|
@ -92,7 +86,7 @@ public class ConcatBp extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getNumOutputs(){
|
public int getNumOutputs(){
|
||||||
return args().length - 1;
|
return args().length - 1 - (dynamicAxis ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1358,4 +1358,35 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.build());
|
.build());
|
||||||
assertEquals(outCC, outFC); //Fails here
|
assertEquals(outCC, outFC); //Fails here
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBiasAdd_nchw_nhwc() {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
for(boolean nchw : new boolean[]{true, false}) {
|
||||||
|
log.info("Starting test: {}", nchw ? "nchw" : "nhwc");
|
||||||
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
|
||||||
|
SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4}));
|
||||||
|
SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4}));
|
||||||
|
|
||||||
|
SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw);
|
||||||
|
SDVariable loss = bAdd.std(true);
|
||||||
|
|
||||||
|
|
||||||
|
INDArray exp = in.getArr().dup();
|
||||||
|
if(nchw){
|
||||||
|
exp.addi(b.getArr().reshape(1,4,1,1));
|
||||||
|
} else {
|
||||||
|
exp.addi(b.getArr().reshape(1,1,1,4));
|
||||||
|
}
|
||||||
|
|
||||||
|
TestCase tc = new TestCase(sameDiff)
|
||||||
|
.gradientCheck(true)
|
||||||
|
.expectedOutput(bAdd.name(), exp);
|
||||||
|
|
||||||
|
String err = OpValidation.validate(tc);
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -99,7 +99,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||||
"multinomial/.*",
|
"multinomial/.*",
|
||||||
|
|
||||||
//2019/11/02 AB - need deconv3d changes (for handling shape)
|
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
|
||||||
"conv3d_transpose.*"
|
"conv3d_transpose.*"
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue