Shyrma deconv3 (#69)

* - profiling cuda kernels for vol2col and im2col

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct addBias helper

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct mkl dilation formula and switch off mkl api for dilation deconvolutions

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2019-11-21 21:17:30 +02:00 committed by GitHub
parent ff73e6da3f
commit 7a90a31cfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 654 additions and 386 deletions

View File

@ -186,9 +186,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
if(B->rankOf() != 2) if(B->rankOf() != 2)
throw std::runtime_error("MmulHelper::mmulMxM: rank of B array is not equal 2 !"); throw std::runtime_error("MmulHelper::mmulMxM: rank of B array is not equal 2 !");
const auto M = A->sizeAt(0); const auto M = A->sizeAt(0);
const auto K = A->sizeAt(1); const auto K = A->sizeAt(1);
const auto N = B->sizeAt(1); const auto N = B->sizeAt(1);
if(C != nullptr && C->rankOf() != 2) if(C != nullptr && C->rankOf() != 2)
throw std::runtime_error("MmulHelper::mmulMxM: rank of C array is not equal 2 !"); throw std::runtime_error("MmulHelper::mmulMxM: rank of C array is not equal 2 !");
@ -235,7 +235,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
aMcont = true; aMcont = true;
} }
if(!bKcont && !bNcont) { if(!bKcont && !bNcont) {
pB = B->dup('f'); pB = B->dup('f');
toDelete.push_back(pB); toDelete.push_back(pB);
bKcont = true; bKcont = true;
} }
@ -258,10 +258,10 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1); const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1);
if(typeFloat) { if(typeFloat) {
BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, reinterpret_cast<float *>(pA->getBuffer()), lda, reinterpret_cast<float *>(pB->getBuffer()), ldb, (float) beta, reinterpret_cast<float *>(pC->getBuffer()), ldc); BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, pA->bufferAsT<float>(), lda, pB->bufferAsT<float>(), ldb, (float) beta, pC->bufferAsT<float>(), ldc);
} }
else if(typeDouble) { else if(typeDouble) {
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast<double *>(pA->getBuffer()), lda, reinterpret_cast<double *>(pB->getBuffer()), ldb, (double) beta, reinterpret_cast<double *>(pC->getBuffer()), ldc); BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, pA->bufferAsT<double>(), lda, pB->bufferAsT<double>(), ldb, (double) beta, pC->bufferAsT<double>(), ldc);
} }
if(pC != C) { if(pC != C) {

View File

@ -263,9 +263,13 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 6 + 128; // 6 = aRank + bRank + cRank const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 6 + 128; // 6 = aRank + bRank + cRank
NDArray::prepareSpecialUse({C}, {A, B}); NDArray::prepareSpecialUse({C}, {A, B});
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES) BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({C}, {A, B}); NDArray::registerSpecialUse({C}, {A, B});
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
} }
else { else {
@ -334,6 +338,10 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
NDArray::registerSpecialUse({pC}, {pA, pB}); NDArray::registerSpecialUse({pC}, {pA, pB});
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
if(C != pC) if(C != pC)
C->assign(pC); C->assign(pC);
@ -341,10 +349,6 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
delete toDelete[i]; delete toDelete[i];
} }
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
return C; return C;
} }
@ -397,10 +401,14 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({Y}, {A, X}); NDArray::prepareSpecialUse({Y}, {A, X});
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES) BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({Y}, {A, X}); NDArray::registerSpecialUse({Y}, {A, X});
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
} }
else { else {
@ -434,16 +442,16 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
if (status != CUBLAS_STATUS_SUCCESS) if (status != CUBLAS_STATUS_SUCCESS)
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
NDArray::registerSpecialUse({Y}, {pA, X}); NDArray::registerSpecialUse({Y}, {pA, X});
if(pA != A) if(pA != A)
delete pA; delete pA;
} }
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
return Y; return Y;
} }
@ -624,7 +632,7 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !");
} }
else else
C = new NDArray(outOrder, cExpectedShape, B->dataType()); C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
const int cRank = C->rankOf(); const int cRank = C->rankOf();

View File

@ -889,7 +889,8 @@ namespace shape {
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0); ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0);
ND4J_EXPORT Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices); ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *indices, Nd4jLong baseOffset = 0);
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *indices, Nd4jLong baseOffset = 0);
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank); ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
@ -900,6 +901,8 @@ namespace shape {
* for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1]
*/ */
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords);
/** /**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
@ -913,6 +916,8 @@ namespace shape {
* for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords); ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords);
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords);
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords);
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords); ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords);
/** /**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
@ -1756,6 +1761,34 @@ INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLo
return index; return index;
} }
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords) {
Nd4jLong index, shift = 1;;
index = coords[shapeInfo[0] - 1];
for(uint i = shapeInfo[0]; i > 1; --i) {
shift *= shapeInfo[i];
index += shift * coords[i - 2];
}
return index;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords) {
Nd4jLong index, shift = 1;;
index = coords[shapeInfo[0] - 1];
for(uint i = shapeInfo[0]; i > 1; --i) {
shift *= shapeInfo[i];
index += shift * coords[i - 2];
}
return index;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *indices) { INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *indices) {
@ -3223,18 +3256,28 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
INLINEDEF Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices) { INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) {
Nd4jLong offset = 0; Nd4jLong offset = baseOffset;
for(uint i = 1; i <= shapeInfo[0]; ++i) for(uint i = 1; i <= shapeInfo[0]; ++i)
if(shapeInfo[i] != 1) if(shapeInfo[i] != 1)
offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i]; offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
return offset; return offset;
} }
//////////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) {
Nd4jLong offset = baseOffset;
for(uint i = 1; i <= shapeInfo[0]; ++i)
if(shapeInfo[i] != 1)
offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
return offset;
}
/** /**
@ -4720,6 +4763,26 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo,
coords[0] = index; // last iteration coords[0] = index; // last iteration
} }
//////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords) {
for(uint i = shapeInfo[0]; i > 1; --i) {
coords[i - 1] = static_cast<int>(index) % static_cast<int>(shapeInfo[i]);
index /= static_cast<int>(shapeInfo[i]);
}
coords[0] = static_cast<int>(index); // last iteration
}
//////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords) {
for(uint i = shapeInfo[0]; i > 1; --i) {
coords[i - 1] = static_cast<uint>(index) % static_cast<uint>(shapeInfo[i]);
index /= static_cast<uint>(shapeInfo[i]);
}
coords[0] = static_cast<uint>(index); // last iteration
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) { INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) {

View File

@ -35,13 +35,13 @@ namespace ops {
CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
int sH = INT_ARG(2); // strides height int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height int pH = INT_ARG(4); // paddings height
@ -60,8 +60,8 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC}); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
@ -99,21 +99,21 @@ DECLARE_SHAPE_FN(conv2d) {
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indIOioC = 3; indIiH = 1;
} }
else { else {
indIOioC = 1; indIiH = 2; indIOioC = 1; indIiH = 2;
} }
const int bS = inputShapeInfo[1]; // batch size const int bS = inputShapeInfo[1]; // batch size
const int iH = inputShapeInfo[indIiH+1]; // input height const int iH = inputShapeInfo[indIiH+1]; // input height
const int iW = inputShapeInfo[indIiH+2]; // input width const int iW = inputShapeInfo[indIiH+2]; // input width
const int iC = inputShapeInfo[indIOioC+1]; // input channels const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC}); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo) if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
Nd4jLong* outputShapeInfo = nullptr; Nd4jLong* outputShapeInfo = nullptr;
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
@ -132,9 +132,9 @@ DECLARE_SHAPE_FN(conv2d) {
outputShapeInfo[3] = oW; outputShapeInfo[3] = oW;
outputShapeInfo[4] = oC; outputShapeInfo[4] = oC;
} }
ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo));
return SHAPELIST(CONSTANT(outputShapeInfo)); return SHAPELIST(CONSTANT(outputShapeInfo));
} }
@ -153,18 +153,18 @@ DECLARE_SHAPE_FN(conv2d) {
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height int sH = INT_ARG(2); // strides height
@ -195,7 +195,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK(); return Status::OK();
} }
@ -229,14 +229,14 @@ DECLARE_SHAPE_FN(conv2d_bp) {
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1; indIOioC = 3; indIiH = 1; indOoH = 1;
} }
else { else {
indIOioC = 1; indIiH = 2; indOoH = 2; indIOioC = 1; indIiH = 2; indOoH = 2;
} }
const int bS = inputShapeInfo[1]; // batch size const int bS = inputShapeInfo[1]; // batch size
const int iH = inputShapeInfo[indIiH+1]; // input height const int iH = inputShapeInfo[indIiH+1]; // input height
const int iW = inputShapeInfo[indIiH+2]; // input width const int iW = inputShapeInfo[indIiH+2]; // input width
const int iC = inputShapeInfo[indIOioC+1]; // input channels const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels const int oC = weightsShapeInfo[indWoC+1]; // output channels
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
@ -248,27 +248,27 @@ DECLARE_SHAPE_FN(conv2d_bp) {
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo) if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace());
auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace());
if(biasShapeInfo) { if(biasShapeInfo) {
auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace());
return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo));
} }
return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo));
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
auto gradIShape = INPUT_VARIABLE(0); // [4] auto gradIShape = INPUT_VARIABLE(0); // [4]
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height int sH = INT_ARG(2); // strides height
@ -284,9 +284,9 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
// create empty conv2d input array // create empty conv2d input array
std::vector<Nd4jLong> gradIShapeAsVector(rank); std::vector<Nd4jLong> gradIShapeAsVector(rank);
for(int i = 0; i < rank; ++i) for(int i = 0; i < rank; ++i)
gradIShapeAsVector[i] = gradIShape->e<Nd4jLong>(i); gradIShapeAsVector[i] = gradIShape->e<Nd4jLong>(i);
@ -306,7 +306,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK(); return Status::OK();
} }
@ -325,8 +325,8 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
const int rank = 4; const int rank = 4;
REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, gradIShapeShapeInfo[0]); REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, gradIShapeShapeInfo[0]);
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]);
@ -345,18 +345,18 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
if(!isNCHW) { if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1; indIOioC = 3; indIiH = 1; indOoH = 1;
} }
else { else {
indIOioC = 1; indIiH = 2; indOoH = 2; indIOioC = 1; indIiH = 2; indOoH = 2;
} }
std::vector<Nd4jLong> gradIShape = INPUT_VARIABLE(0)->template asVectorT<Nd4jLong>(); std::vector<Nd4jLong> gradIShape = INPUT_VARIABLE(0)->template asVectorT<Nd4jLong>();
const int bS = gradIShape[0]; // batch size const int bS = gradIShape[0]; // batch size
const int iH = gradIShape[indIiH]; // input height const int iH = gradIShape[indIiH]; // input height
const int iW = gradIShape[indIiH+1]; // input width const int iW = gradIShape[indIiH+1]; // input width
const int iC = gradIShape[indIOioC]; // input channels const int iC = gradIShape[indIOioC]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels const int oC = weightsShapeInfo[indWoC+1]; // output channels
int trueoH, trueoW; // true output height, width int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
@ -364,8 +364,8 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC}); std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
Nd4jLong* gradIshapeInfo(nullptr); Nd4jLong* gradIshapeInfo(nullptr);
ALLOCATE(gradIshapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); ALLOCATE(gradIshapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
gradIshapeInfo[0] = rank; gradIshapeInfo[0] = rank;
@ -380,7 +380,7 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
gradIshapeInfo[3] = iW; gradIshapeInfo[3] = iW;
gradIshapeInfo[4] = iC; gradIshapeInfo[4] = iC;
} }
ShapeUtils::updateStridesAndType(gradIshapeInfo, gradOShapeInfo, shape::order(gradOShapeInfo)); ShapeUtils::updateStridesAndType(gradIshapeInfo, gradOShapeInfo, shape::order(gradOShapeInfo));
return SHAPELIST(CONSTANT(gradIshapeInfo)); return SHAPELIST(CONSTANT(gradIshapeInfo));

View File

@ -66,10 +66,8 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
if(!isNCHW) if(!isNCHW)
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
if(isSameMode){ // SAME if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
}
NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext());

View File

@ -67,16 +67,16 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
if(!isNCDHW) if(!isNCDHW)
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
if(isSameMode) //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
//----- calculation of output -----// //----- calculation of output -----//
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] // NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
// NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW] // NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW]
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
//----- add biases if required -----// //----- add biases if required -----//
if(bias) if(bias)

View File

@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC
NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
@ -300,7 +300,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
DECLARE_SHAPE_FN(sconv2d_bp) { DECLARE_SHAPE_FN(sconv2d_bp) {
auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC] always auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC] always
Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always Nd4jLong* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC] always
Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr Nd4jLong* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr

View File

@ -69,8 +69,8 @@ namespace nd4j {
eKH = kH; eKH = kH;
eKW = kW; eKW = kW;
} else { } else {
eKH = kH + (kH - 1) * (dH - 1); eKH = (kH - 1) * dH + 1;
eKW = kW + (kW - 1) * (dW - 1); eKW = (kW - 1) * dW + 1;
} }
pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
@ -84,9 +84,9 @@ namespace nd4j {
eKH = kH; eKH = kH;
eKW = kW; eKW = kW;
} else { } else {
eKD = kD + (kD - 1) * (dD - 1); eKD = (kD - 1) * dD + 1;
eKH = kH + (kH - 1) * (dH - 1); eKH = (kH - 1) * dH + 1;
eKW = kW + (kW - 1) * (dW - 1); eKW = (kW - 1) * dW + 1;
} }
pD = ((oD - 1) * sD + eKD - iD) / 2; // Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 pD = ((oD - 1) * sD + eKD - iD) / 2; // Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
@ -107,8 +107,8 @@ namespace nd4j {
ekH = kH; ekH = kH;
ekW = kW; ekW = kW;
} else { } else {
ekH = kH + (kH - 1) * (dH - 1); ekH = (kH - 1) * dH + 1;
ekW = kW + (kW - 1) * (dW - 1); ekW = (kW - 1) * dW + 1;
} }
oH = sH * (iH - 1) + ekH - 2 * pH; oH = sH * (iH - 1) + ekH - 2 * pH;
@ -131,9 +131,9 @@ namespace nd4j {
ekW = kW; ekW = kW;
} }
else { else {
ekD = kD + (kD - 1) * (dD - 1); ekD = (kD - 1) * dD + 1;
ekH = kH + (kH - 1) * (dH - 1); ekH = (kH - 1) * dH + 1;
ekW = kW + (kW - 1) * (dW - 1); ekW = (kW - 1) * dW + 1;
} }
oD = sD * (iD - 1) + ekD - 2 * pD; oD = sD * (iD - 1) + ekD - 2 * pD;
oH = sH * (iH - 1) + ekH - 2 * pH; oH = sH * (iH - 1) + ekH - 2 * pH;
@ -194,53 +194,53 @@ namespace nd4j {
} }
static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) { // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) {
if(kH != 1) { // if(kH != 1) {
if(isSameMode) { // if(isSameMode) {
pH = (oH - 1) * sH - iH + kH - pH; // pH = (oH - 1) * sH - iH + kH - pH;
dH = dH - 1; // dH = dH - 1;
} // }
else // else
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
} // }
if(kW != 1) { // if(kW != 1) {
if(isSameMode) { // if(isSameMode) {
pW = (oW - 1) * sW - iW + kW - pW; // pW = (oW - 1) * sW - iW + kW - pW;
dW = dW - 1; // dW = dW - 1;
} // }
else // else
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
} // }
} // }
static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) { // static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
if(kD != 1) { // if(kD != 1) {
if(isSameMode) { // if(isSameMode) {
pD = (oD - 1) * sD - iD + kD - pD; // pD = (oD - 1) * sD - iD + kD - pD;
dD = dD - 1; // dD = dD - 1;
} // }
else // else
dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1); // dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
} // }
if(kH != 1) { // if(kH != 1) {
if(isSameMode) { // if(isSameMode) {
pH = (oH - 1) * sH - iH + kH - pH; // pH = (oH - 1) * sH - iH + kH - pH;
dH = dH - 1; // dH = dH - 1;
} // }
else // else
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
} // }
if(kW != 1) { // if(kW != 1) {
if(isSameMode) { // if(isSameMode) {
pW = (oW - 1) * sW - iW + kW - pW; // pW = (oW - 1) * sW - iW + kW - pW;
dW = dW - 1; // dW = dW - 1;
} // }
else // else
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
} // }
} // }
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);

View File

@ -47,9 +47,12 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output,
const bool inOutAreSame = x == z; const bool inOutAreSame = x == z;
int posOfNonUnityDim;
bias.isCommonVector(posOfNonUnityDim);
const uint bS = output.sizeAt(0); // batch size const uint bS = output.sizeAt(0); // batch size
const Nd4jLong yStrideC = bias.stridesOf()[0]; const Nd4jLong yStrideC = bias.strideAt(posOfNonUnityDim);
const Nd4jLong zStrideB = output.stridesOf()[0]; const Nd4jLong zStrideB = output.strideAt(0);
if(output.rankOf() == 4) { if(output.rankOf() == 4) {

View File

@ -54,6 +54,7 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
const uint oW = output->sizeAt(2); const uint oW = output->sizeAt(2);
auto func = PRAGMA_THREADS_FOR_2D { auto func = PRAGMA_THREADS_FOR_2D {
for (uint b = start_x; b < stop_x; b += inc_x) { for (uint b = start_x; b < stop_x; b += inc_x) {
for (uint oh = start_y; oh < stop_y; oh += inc_y) { for (uint oh = start_y; oh < stop_y; oh += inc_y) {
for (uint ow = 0; ow < oW; ++ow) { for (uint ow = 0; ow < oW; ++ow) {
@ -69,13 +70,17 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
const int iw = ow * sW - pW + kw * dW; const int iw = ow * sW - pW + kw * dW;
if (iw < 0 || iw >= iW) continue; if (iw < 0 || iw >= iW) continue;
const X val = x[shape::getOffset(xShapeInfo, {b, (uint) ih, (uint) iw, c})] + y[shape::getOffset(yShapeInfo, {kh, kw, c})]; uint xCoords[4] = {b, (uint)ih, (uint)iw, c};
uint yCoords[3] = {kh, kw, c};
const X val = x[shape::getOffset(xShapeInfo, xCoords)] + y[shape::getOffset(yShapeInfo, yCoords)];
if (val > max) if (val > max)
max = val; max = val;
} }
} }
z[shape::getOffset(zShapeInfo, {b, oh, ow, c})] = static_cast<Z>(max); uint zCoords[4] = {b, oh, ow, c};
z[shape::getOffset(zShapeInfo, zCoords)] = static_cast<Z>(max);
} }
} }
} }

View File

@ -44,7 +44,7 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
const Y* y = reinterpret_cast<const Y*>(vy); const Y* y = reinterpret_cast<const Y*>(vy);
X* z = reinterpret_cast<X*>(vz); X* z = reinterpret_cast<X*>(vz);
__shared__ int rank, channelPosition; __shared__ int rank, channelPosition, posOfNonUnityDim;
__shared__ Nd4jLong *sharedMem, len; __shared__ Nd4jLong *sharedMem, len;
__shared__ bool xzSameOffsets, xzAreSame; __shared__ bool xzSameOffsets, xzAreSame;
@ -58,6 +58,8 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
len = shape::length(xShapeInfo); len = shape::length(xShapeInfo);
channelPosition = isNCHW ? 1 : rank - 1; // second or last channelPosition = isNCHW ? 1 : rank - 1; // second or last
xzAreSame = x == z; xzAreSame = x == z;
shape::isCommonVector(yShapeInfo, posOfNonUnityDim);
} }
__syncthreads(); __syncthreads();
@ -69,7 +71,7 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
const auto xOffsets = shape::getOffset(xShapeInfo, coords); const auto xOffsets = shape::getOffset(xShapeInfo, coords);
const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords); const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords);
const auto yOffsets = shape::getOffset(yShapeInfo, coords + channelPosition); const auto yOffsets = coords[channelPosition] * shape::stride(yShapeInfo)[posOfNonUnityDim];
if(xzAreSame) if(xzAreSame)
z[zOffsets] += static_cast<X>(y[yOffsets]); z[zOffsets] += static_cast<X>(y[yOffsets]);
@ -94,7 +96,7 @@ void addBias(nd4j::graph::Context& block, const NDArray& input, const NDArray& b
PointersManager manager(block.launchContext(), "addBias"); PointersManager manager(block.launchContext(), "addBias");
const int threadsPerBlock = MAX_NUM_THREADS; const int threadsPerBlock = MAX_NUM_THREADS/2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;

View File

@ -34,64 +34,59 @@ static __global__ void col2imCuda(const void* columns, const Nd4jLong* colShapeI
const T* col = reinterpret_cast<const T*>(columns); const T* col = reinterpret_cast<const T*>(columns);
T* im = reinterpret_cast<T*>(image); T* im = reinterpret_cast<T*>(image);
__shared__ int colRank, imRank, kHeff, kWeff, oH, oW; __shared__ uint kH, kW, oH, oW, *sharedMem;
__shared__ Nd4jLong *sharedMem, imLen; __shared__ Nd4jLong imLen;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[]; extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem); sharedMem = reinterpret_cast<uint*>(shmem);
kH = dH * (colShapeInfo[3] - 1) + 1;
kW = dW * (colShapeInfo[4] - 1) + 1;
oH = colShapeInfo[5]; oH = colShapeInfo[5];
oW = colShapeInfo[6]; oW = colShapeInfo[6];
kHeff = colShapeInfo[3] + (colShapeInfo[3] - 1) * (dH - 1); imLen = shape::length(imShapeInfo);
kWeff = colShapeInfo[4] + (colShapeInfo[4] - 1) * (dW - 1);
imRank = 4;
colRank = 6;
imLen = shape::length(imShapeInfo);
} }
__syncthreads(); __syncthreads();
const auto imInd = threadIdx.x + blockIdx.x * blockDim.x; auto coords = sharedMem + threadIdx.x * 6;
if(imInd >= imLen) const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
return;
auto coords = sharedMem + threadIdx.x * colRank; for (Nd4jLong i = tid; i < imLen; i += gridDim.x * blockDim.x) {
shape::index2coords(imInd, imShapeInfo, coords); shape::index2coords(i, imShapeInfo, coords);
const auto imOffset = shape::getOffset(imShapeInfo, coords); const auto imOffset = shape::getOffset(imShapeInfo, coords);
const int imH = coords[2] + pH; const auto bSiCoffset = coords[0] * colShapeInfo[7] + coords[1] * colShapeInfo[8];
const int imW = coords[3] + pW;
const int colHstart = (imH < kHeff) ? 0 : (imH - kHeff) / sH + 1; const uint imH = coords[2] + pH;
const int colWstart = (imW < kWeff) ? 0 : (imW - kWeff) / sW + 1; const uint imW = coords[3] + pW;
const int colHend = nd4j::math::nd4j_min<int>(imH / sH + 1, oH); const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1;
const int colWend = nd4j::math::nd4j_min<int>(imW / sW + 1, oW); const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1;
T val = 0; const uint colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH);
const uint colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW);
for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) { T val = 0;
coords[2] = imH - coords[4] * sH;
for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) { for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) {
coords[3] = imW - coords[5] * sW; coords[2] = imH - coords[4] * sH;
if(coords[2] % dH != 0) continue;
if(coords[2] % dH == 0 && coords[3] % dW == 0) { for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) {
coords[2] /= dH; coords[3] = imW - coords[5] * sW;
coords[3] /= dW; if(coords[3] % dW != 0) continue;
val += col[shape::getOffset(colShapeInfo, coords)]; val += col[bSiCoffset + (coords[2]/dH)*colShapeInfo[9] + (coords[3]/dW)*colShapeInfo[10] + coords[4]*colShapeInfo[11] + coords[5]*colShapeInfo[12]];
} }
} }
im[imOffset] = val;
} }
im[imOffset] = val;
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -184,8 +179,8 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc
void* image, const Nd4jLong* imShapeInfo, void* image, const Nd4jLong* imShapeInfo,
const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW); // col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
//col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -195,7 +190,7 @@ void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const
const int threadsPerBlock = MAX_NUM_THREADS / 2; const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256;
NDArray::prepareSpecialUse({&im}, {&col}); NDArray::prepareSpecialUse({&im}, {&col});
BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES);

View File

@ -122,74 +122,71 @@ static __global__ void col2volCuda(const void* columns, const Nd4jLong* colShape
const T* col = reinterpret_cast<const T*>(columns); const T* col = reinterpret_cast<const T*>(columns);
T* vol = reinterpret_cast<T*>(volume); T* vol = reinterpret_cast<T*>(volume);
__shared__ int colRank, volRank, kDeff, kHeff, kWeff, oD, oH, oW; __shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem;
__shared__ Nd4jLong *sharedMem, volLen; __shared__ Nd4jLong volLen;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[]; extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem); sharedMem = reinterpret_cast<uint*>(shmem);
oD = colShapeInfo[6]; oD = colShapeInfo[6];
oH = colShapeInfo[7]; oH = colShapeInfo[7];
oW = colShapeInfo[8]; oW = colShapeInfo[8];
kDeff = colShapeInfo[3] + (colShapeInfo[3] - 1) * (dD - 1); kD = dD * (colShapeInfo[3] - 1) + 1;
kHeff = colShapeInfo[4] + (colShapeInfo[4] - 1) * (dH - 1); kH = dH * (colShapeInfo[4] - 1) + 1;
kWeff = colShapeInfo[5] + (colShapeInfo[5] - 1) * (dW - 1); kW = dW * (colShapeInfo[5] - 1) + 1;
volRank = 5; volLen = shape::length(volShapeInfo);
colRank = 8;
volLen = shape::length(volShapeInfo);
} }
__syncthreads(); __syncthreads();
const auto volInd = threadIdx.x + blockIdx.x * blockDim.x; auto coords = sharedMem + threadIdx.x * 8;
if(volInd >= volLen) const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
return;
auto coords = sharedMem + threadIdx.x * colRank; for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) {
shape::index2coords(volInd, volShapeInfo, coords); shape::index2coords(i, volShapeInfo, coords);
const auto volOffset = shape::getOffset(volShapeInfo, coords); const auto volOffset = shape::getOffset(volShapeInfo, coords);
const int imD = coords[2] + pD; const auto bSiCoffset = coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10];
const int imH = coords[3] + pH;
const int imW = coords[4] + pW;
const int colDstart = (imD < kDeff) ? 0 : (imD - kDeff) / sD + 1; const uint imD = coords[2] + pD;
const int colHstart = (imH < kHeff) ? 0 : (imH - kHeff) / sH + 1; const uint imH = coords[3] + pH;
const int colWstart = (imW < kWeff) ? 0 : (imW - kWeff) / sW + 1; const uint imW = coords[4] + pW;
const int colDend = nd4j::math::nd4j_min<uint>(imD / sD + 1, oD); const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1;
const int colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH); const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1;
const int colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW); const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1;
T val = 0; const uint colDend = nd4j::math::nd4j_min<uint>(imD / sD + 1, oD);
const uint colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH);
const uint colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW);
for(coords[5] = colDstart; coords[5] < colDend; ++coords[5]) { T val = 0;
coords[2] = imD - coords[5] * sD;
for(coords[6] = colHstart; coords[6] < colHend; ++coords[6]) { for(uint colD = colDstart; colD < colDend; ++colD) {
coords[3] = imH - coords[6] * sH; coords[2] = imD - colD * sD;
if(coords[2] % dD != 0) continue;
for(coords[7] = colWstart; coords[7] < colWend; ++coords[7]) { for(uint colH = colHstart; colH < colHend; ++colH) {
coords[4] = imW - coords[7] * sW; coords[3] = imH - colH * sH;
if(coords[3] % dH != 0) continue;
if(coords[2] % dD == 0 && coords[3] % dH == 0 && coords[4] % dW == 0) { for(uint colW = colWstart; colW < colWend; ++colW) {
coords[2] /= dD; coords[4] = imW - colW * sW;
coords[3] /= dH; if(coords[4] % dW != 0) continue;
coords[4] /= dW;
val += col[bSiCoffset + (coords[2]/dD)*colShapeInfo[11] + (coords[3]/dH)*colShapeInfo[12] + (coords[4]/dW)*colShapeInfo[13] + colD*colShapeInfo[14] + colH*colShapeInfo[15] + colW*colShapeInfo[16]];
val += col[shape::getOffset(colShapeInfo, coords)];
} }
} }
} }
}
vol[volOffset] = val; vol[volOffset] = val;
}
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -209,7 +206,7 @@ void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& col,
const int threadsPerBlock = MAX_NUM_THREADS / 4; const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256;
NDArray::prepareSpecialUse({&vol}, {&col}); NDArray::prepareSpecialUse({&vol}, {&col});
BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);

View File

@ -46,13 +46,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
dnnl::memory::dims strides = { sH, sW }; dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW }; dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { pHmkl, pWmkl }; dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dHmkl, dWmkl }; dnnl::memory::dims dilation = { dH-1, dW-1 };
// input type // input type
dnnl::memory::data_type xType; dnnl::memory::data_type xType;
@ -193,13 +190,10 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
dnnl::memory::dims strides = { sH, sW }; dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW }; dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { pHmkl, pWmkl }; dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dHmkl, dWmkl }; dnnl::memory::dims dilation = { dH-1, dW-1 };
// input type // input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type // weights type
@ -423,12 +417,17 @@ PLATFORM_CHECK(deconv2d) {
auto output = INPUT_VARIABLE(0); auto output = INPUT_VARIABLE(0);
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
const DataType xType = input->dataType(); const DataType xType = input->dataType();
const DataType wType = weights->dataType(); const DataType wType = weights->dataType();
const DataType zType = output->dataType(); const DataType zType = output->dataType();
const DataType bType = bias != nullptr ? bias->dataType() : zType; const DataType bType = bias != nullptr ? bias->dataType() : zType;
return block.isUseMKLDNN() && ( return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) &&
(
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
); );
@ -521,6 +520,9 @@ PLATFORM_CHECK(deconv2d_bp) {
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
const DataType xType = input->dataType(); const DataType xType = input->dataType();
const DataType wType = weights->dataType(); const DataType wType = weights->dataType();
@ -530,7 +532,7 @@ PLATFORM_CHECK(deconv2d_bp) {
const DataType gradWType = gradW->dataType(); const DataType gradWType = gradW->dataType();
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
} }

View File

@ -47,13 +47,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
dnnl::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW }; dnnl::memory::dims padding = { pD, pH, pW };
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl }; dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl }; dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
// input type // input type
dnnl::memory::data_type xType; dnnl::memory::data_type xType;
@ -197,13 +194,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
dnnl::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW }; dnnl::memory::dims padding = { pD, pH, pW };
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl }; dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl }; dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
// input type // input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
@ -437,12 +431,18 @@ PLATFORM_CHECK(deconv3d) {
auto output = INPUT_VARIABLE(0); auto output = INPUT_VARIABLE(0);
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
const DataType xType = input->dataType(); const DataType xType = input->dataType();
const DataType wType = weights->dataType(); const DataType wType = weights->dataType();
const DataType zType = output->dataType(); const DataType zType = output->dataType();
const DataType bType = bias != nullptr ? bias->dataType() : zType; const DataType bType = bias != nullptr ? bias->dataType() : zType;
return block.isUseMKLDNN() && ( return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) &&
(
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
); );
@ -538,6 +538,11 @@ PLATFORM_CHECK(deconv3d_bp) {
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
const DataType xType = input->dataType(); const DataType xType = input->dataType();
const DataType wType = weights->dataType(); const DataType wType = weights->dataType();
const DataType gradOType = gradO->dataType(); const DataType gradOType = gradO->dataType();
@ -546,7 +551,7 @@ PLATFORM_CHECK(deconv3d_bp) {
const DataType gradWType = gradW->dataType(); const DataType gradWType = gradW->dataType();
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
} }
} }

View File

@ -20,7 +20,6 @@
#include <dnnl_types.h> #include <dnnl_types.h>
#include "mkldnnUtils.h" #include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace dnnl; using namespace dnnl;
@ -155,19 +154,10 @@ namespace nd4j {
dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_padding_r = { pHmkl, pWmkl };
conv_dilation = { dHmkl, dWmkl };
conv_strides = { sH, sW }; conv_strides = { sH, sW };
conv_padding = { pH, pW }; conv_padding = { pH, pW };
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
conv_dilation = { dH-1, dW-1}; conv_dilation = { dH-1, dW-1};
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
auto type = dnnl::memory::data_type::f32; auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
@ -243,13 +233,10 @@ namespace nd4j {
dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
conv_strides = { sD, sH, sW }; conv_strides = { sD, sH, sW };
conv_padding = { pD, pH, pW }; conv_padding = { pD, pH, pW };
conv_padding_r = { pDmkl, pHmkl, pWmkl }; conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
conv_dilation = { dDmkl, dHmkl, dWmkl }; conv_dilation = { dD-1, dH-1, dW-1};
auto type = dnnl::memory::data_type::f32; auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;

File diff suppressed because one or more lines are too long

View File

@ -605,7 +605,6 @@ TEST_F(ConvolutionTests2, deconv3d_test5) {
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
auto output = results->at(0); auto output = results->at(0);
// output->printBuffer();
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));

View File

@ -2285,66 +2285,6 @@ TEST_F(DeclarableOpsTests1, IsMax4) {
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, CompactLaunchTests1) {
NDArray input('c', {2, 3, 4, 4}, nd4j::DataType::FLOAT32);
NDArray weights('c', {3, 3, 5, 5}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3,8,8}, {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,
100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,
84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,
54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,
90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,
8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,
144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,
118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,
115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,
268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,
52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,
78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,
89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0}, nd4j::DataType::FLOAT32);
input.linspace(1);
weights.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::deconv2d op;
auto result = op.execute({&input, &weights}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
auto z = result->at(0);
// z->printShapeInfo();
// z->printBuffer();
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, CompactLaunchTests2) {
Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99};
double _expB[] = {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0,};
NDArray exp(_expB, _expS);
auto input = NDArrayFactory::create<double>('c', {2, 3, 4, 4});
auto weights = NDArrayFactory::create<double>('c', {3, 3, 5, 5});
auto z = NDArrayFactory::create<double>('c', {2, 3, 8, 8});
input.linspace(1);
weights.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::deconv2d op;
auto result = op.execute({&input, &weights}, {&z}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0},{});
ASSERT_EQ(ND4J_STATUS_OK, result);
ASSERT_TRUE(exp.isSameShape(&z));
ASSERT_TRUE(exp.equalsTo(&z));
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests1, sru_old_test1) { // TEST_F(DeclarableOpsTests1, sru_old_test1) {

View File

@ -34,6 +34,7 @@
#include <ops/declarable/helpers/im2col.h> #include <ops/declarable/helpers/im2col.h>
#include <Loops.h> #include <Loops.h>
#include <RandomLauncher.h> #include <RandomLauncher.h>
#include <ops/declarable/helpers/convolutions.h>
#include <helpers/BenchmarkHelper.h> #include <helpers/BenchmarkHelper.h>
#include <ops/declarable/helpers/scatter.h> #include <ops/declarable/helpers/scatter.h>
@ -279,24 +280,65 @@ TEST_F(PlaygroundTests, test_relubp_1) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(PlaygroundTests, my) { TEST_F(PlaygroundTests, my) {
int bS=1, iH=56,iW=56, iC=144,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
int oC=iC*mC; int oD,oH,oW;
int oH=56,oW=56;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}); nd4j::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0);
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
input = 2.; printf("!!%i, %i, %i\n", oD,oH,oW);
weights.linspace(0.1, 0.1);
nd4j::ops::depthwise_conv2d op; NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, nd4j::DataType::DOUBLE);
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); NDArray vol('c', {bS, iC, oD, oH, oW}, nd4j::DataType::DOUBLE);
delete results; col = 3.77;
vol = -10.33;
auto variableSpace = new VariableSpace();
auto block = new Context(1, variableSpace, false); // not-in-place
auto timeStart = std::chrono::system_clock::now();
nd4j::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW);
auto timeEnd = std::chrono::system_clock::now();
auto time = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
printf("time: %i \n", time);
delete block;
delete variableSpace;
} }
*/ TEST_F(PlaygroundTests, my) {
int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
int oD,oH,oW;
// nd4j::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0);
nd4j::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW,dH, dW, iH, iW, 0);
printf("!!%i, %i, %i\n", oD,oH,oW);
// NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, nd4j::DataType::DOUBLE);
// NDArray vol('c', {bS, iC, oD, oH, oW}, nd4j::DataType::DOUBLE);
NDArray col('c', {bS, iC, kH, kW, iH, iW}, nd4j::DataType::DOUBLE);
NDArray im('c', {bS, iC, oH, oW}, nd4j::DataType::DOUBLE);
col = 3.77;
// vol = -10.33;
im = -10.33;
auto variableSpace = new VariableSpace();
auto block = new Context(1, variableSpace, false); // not-in-place
auto timeStart = std::chrono::system_clock::now();
// nd4j::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW);
nd4j::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, pH, pW, iH, iW, dH, dW);
auto timeEnd = std::chrono::system_clock::now();
auto time = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
printf("time: %i \n", time);
delete block;
delete variableSpace;
}
*/