Shyrma deconv3 (#69)
* - profiling cuda kernels for vol2col and im2col Signed-off-by: Yurii <iuriish@yahoo.com> * - correct addBias helper Signed-off-by: Yurii <iuriish@yahoo.com> * - correct mkl dilation formula and switch off mkl api for dilation deconvolutions Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
ff73e6da3f
commit
7a90a31cfb
|
@ -258,10 +258,10 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
|
||||||
const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1);
|
const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1);
|
||||||
|
|
||||||
if(typeFloat) {
|
if(typeFloat) {
|
||||||
BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, reinterpret_cast<float *>(pA->getBuffer()), lda, reinterpret_cast<float *>(pB->getBuffer()), ldb, (float) beta, reinterpret_cast<float *>(pC->getBuffer()), ldc);
|
BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, pA->bufferAsT<float>(), lda, pB->bufferAsT<float>(), ldb, (float) beta, pC->bufferAsT<float>(), ldc);
|
||||||
}
|
}
|
||||||
else if(typeDouble) {
|
else if(typeDouble) {
|
||||||
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast<double *>(pA->getBuffer()), lda, reinterpret_cast<double *>(pB->getBuffer()), ldb, (double) beta, reinterpret_cast<double *>(pC->getBuffer()), ldc);
|
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, pA->bufferAsT<double>(), lda, pB->bufferAsT<double>(), ldb, (double) beta, pC->bufferAsT<double>(), ldc);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(pC != C) {
|
if(pC != C) {
|
||||||
|
|
|
@ -263,9 +263,13 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 6 + 128; // 6 = aRank + bRank + cRank
|
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 6 + 128; // 6 = aRank + bRank + cRank
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({C}, {A, B});
|
NDArray::prepareSpecialUse({C}, {A, B});
|
||||||
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
|
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
|
||||||
NDArray::registerSpecialUse({C}, {A, B});
|
NDArray::registerSpecialUse({C}, {A, B});
|
||||||
|
|
||||||
|
auto cudaResult = cudaStreamSynchronize(*stream);
|
||||||
|
if (cudaResult != 0)
|
||||||
|
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -334,6 +338,10 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
|
|
||||||
NDArray::registerSpecialUse({pC}, {pA, pB});
|
NDArray::registerSpecialUse({pC}, {pA, pB});
|
||||||
|
|
||||||
|
auto cudaResult = cudaStreamSynchronize(*stream);
|
||||||
|
if (cudaResult != 0)
|
||||||
|
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
|
||||||
|
|
||||||
if(C != pC)
|
if(C != pC)
|
||||||
C->assign(pC);
|
C->assign(pC);
|
||||||
|
|
||||||
|
@ -341,10 +349,6 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
delete toDelete[i];
|
delete toDelete[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cudaResult = cudaStreamSynchronize(*stream);
|
|
||||||
if (cudaResult != 0)
|
|
||||||
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
|
|
||||||
|
|
||||||
return C;
|
return C;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -397,10 +401,14 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
||||||
const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({Y}, {A, X});
|
NDArray::prepareSpecialUse({Y}, {A, X});
|
||||||
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
|
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
|
||||||
NDArray::registerSpecialUse({Y}, {A, X});
|
NDArray::registerSpecialUse({Y}, {A, X});
|
||||||
|
|
||||||
|
auto cudaResult = cudaStreamSynchronize(*stream);
|
||||||
|
if (cudaResult != 0)
|
||||||
|
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
|
||||||
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -434,16 +442,16 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
||||||
if (status != CUBLAS_STATUS_SUCCESS)
|
if (status != CUBLAS_STATUS_SUCCESS)
|
||||||
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
|
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
|
||||||
|
|
||||||
|
auto cudaResult = cudaStreamSynchronize(*stream);
|
||||||
|
if (cudaResult != 0)
|
||||||
|
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
|
||||||
|
|
||||||
NDArray::registerSpecialUse({Y}, {pA, X});
|
NDArray::registerSpecialUse({Y}, {pA, X});
|
||||||
|
|
||||||
if(pA != A)
|
if(pA != A)
|
||||||
delete pA;
|
delete pA;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cudaResult = cudaStreamSynchronize(*stream);
|
|
||||||
if (cudaResult != 0)
|
|
||||||
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
|
|
||||||
|
|
||||||
return Y;
|
return Y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -624,7 +632,7 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
|
||||||
throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !");
|
throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !");
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
C = new NDArray(outOrder, cExpectedShape, B->dataType());
|
C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
|
||||||
|
|
||||||
const int cRank = C->rankOf();
|
const int cRank = C->rankOf();
|
||||||
|
|
||||||
|
|
|
@ -889,7 +889,8 @@ namespace shape {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0);
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0);
|
||||||
ND4J_EXPORT Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices);
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *indices, Nd4jLong baseOffset = 0);
|
||||||
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *indices, Nd4jLong baseOffset = 0);
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
|
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
|
||||||
|
|
||||||
|
@ -900,6 +901,8 @@ namespace shape {
|
||||||
* for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1]
|
* for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1]
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords);
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords);
|
||||||
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords);
|
||||||
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords);
|
||||||
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords);
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords);
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
|
@ -913,6 +916,8 @@ namespace shape {
|
||||||
* for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned
|
* for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords);
|
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords);
|
||||||
|
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords);
|
||||||
|
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords);
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords);
|
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords);
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
|
@ -1756,6 +1761,34 @@ INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLo
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords) {
|
||||||
|
|
||||||
|
Nd4jLong index, shift = 1;;
|
||||||
|
|
||||||
|
index = coords[shapeInfo[0] - 1];
|
||||||
|
for(uint i = shapeInfo[0]; i > 1; --i) {
|
||||||
|
shift *= shapeInfo[i];
|
||||||
|
index += shift * coords[i - 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords) {
|
||||||
|
|
||||||
|
Nd4jLong index, shift = 1;;
|
||||||
|
|
||||||
|
index = coords[shapeInfo[0] - 1];
|
||||||
|
for(uint i = shapeInfo[0]; i > 1; --i) {
|
||||||
|
shift *= shapeInfo[i];
|
||||||
|
index += shift * coords[i - 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *indices) {
|
INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *indices) {
|
||||||
|
|
||||||
|
@ -3223,18 +3256,28 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices) {
|
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) {
|
||||||
|
|
||||||
Nd4jLong offset = 0;
|
Nd4jLong offset = baseOffset;
|
||||||
|
|
||||||
for(uint i = 1; i <= shapeInfo[0]; ++i)
|
for(uint i = 1; i <= shapeInfo[0]; ++i)
|
||||||
if(shapeInfo[i] != 1)
|
if(shapeInfo[i] != 1)
|
||||||
offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i];
|
offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
|
||||||
|
|
||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) {
|
||||||
|
|
||||||
|
Nd4jLong offset = baseOffset;
|
||||||
|
|
||||||
|
for(uint i = 1; i <= shapeInfo[0]; ++i)
|
||||||
|
if(shapeInfo[i] != 1)
|
||||||
|
offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -4720,6 +4763,26 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo,
|
||||||
coords[0] = index; // last iteration
|
coords[0] = index; // last iteration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords) {
|
||||||
|
|
||||||
|
for(uint i = shapeInfo[0]; i > 1; --i) {
|
||||||
|
coords[i - 1] = static_cast<int>(index) % static_cast<int>(shapeInfo[i]);
|
||||||
|
index /= static_cast<int>(shapeInfo[i]);
|
||||||
|
}
|
||||||
|
coords[0] = static_cast<int>(index); // last iteration
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords) {
|
||||||
|
|
||||||
|
for(uint i = shapeInfo[0]; i > 1; --i) {
|
||||||
|
coords[i - 1] = static_cast<uint>(index) % static_cast<uint>(shapeInfo[i]);
|
||||||
|
index /= static_cast<uint>(shapeInfo[i]);
|
||||||
|
}
|
||||||
|
coords[0] = static_cast<uint>(index); // last iteration
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) {
|
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) {
|
||||||
|
|
||||||
|
|
|
@ -66,10 +66,8 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
||||||
if(!isNCHW)
|
if(!isNCHW)
|
||||||
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
|
|
||||||
if(isSameMode){ // SAME
|
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||||
}
|
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext());
|
NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext());
|
||||||
|
|
||||||
|
|
|
@ -67,10 +67,10 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
||||||
if(!isNCDHW)
|
if(!isNCDHW)
|
||||||
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||||
|
|
||||||
if(isSameMode) //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
||||||
|
|
||||||
//----- calculation of output -----//
|
//----- calculation of output -----//
|
||||||
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||||
|
|
|
@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
||||||
NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||||
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
|
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
|
||||||
NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
|
NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC
|
||||||
|
|
||||||
NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||||
|
|
||||||
|
|
|
@ -69,8 +69,8 @@ namespace nd4j {
|
||||||
eKH = kH;
|
eKH = kH;
|
||||||
eKW = kW;
|
eKW = kW;
|
||||||
} else {
|
} else {
|
||||||
eKH = kH + (kH - 1) * (dH - 1);
|
eKH = (kH - 1) * dH + 1;
|
||||||
eKW = kW + (kW - 1) * (dW - 1);
|
eKW = (kW - 1) * dW + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
|
pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
|
||||||
|
@ -84,9 +84,9 @@ namespace nd4j {
|
||||||
eKH = kH;
|
eKH = kH;
|
||||||
eKW = kW;
|
eKW = kW;
|
||||||
} else {
|
} else {
|
||||||
eKD = kD + (kD - 1) * (dD - 1);
|
eKD = (kD - 1) * dD + 1;
|
||||||
eKH = kH + (kH - 1) * (dH - 1);
|
eKH = (kH - 1) * dH + 1;
|
||||||
eKW = kW + (kW - 1) * (dW - 1);
|
eKW = (kW - 1) * dW + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
pD = ((oD - 1) * sD + eKD - iD) / 2; // Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
|
pD = ((oD - 1) * sD + eKD - iD) / 2; // Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
|
||||||
|
@ -107,8 +107,8 @@ namespace nd4j {
|
||||||
ekH = kH;
|
ekH = kH;
|
||||||
ekW = kW;
|
ekW = kW;
|
||||||
} else {
|
} else {
|
||||||
ekH = kH + (kH - 1) * (dH - 1);
|
ekH = (kH - 1) * dH + 1;
|
||||||
ekW = kW + (kW - 1) * (dW - 1);
|
ekW = (kW - 1) * dW + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
oH = sH * (iH - 1) + ekH - 2 * pH;
|
oH = sH * (iH - 1) + ekH - 2 * pH;
|
||||||
|
@ -131,9 +131,9 @@ namespace nd4j {
|
||||||
ekW = kW;
|
ekW = kW;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ekD = kD + (kD - 1) * (dD - 1);
|
ekD = (kD - 1) * dD + 1;
|
||||||
ekH = kH + (kH - 1) * (dH - 1);
|
ekH = (kH - 1) * dH + 1;
|
||||||
ekW = kW + (kW - 1) * (dW - 1);
|
ekW = (kW - 1) * dW + 1;
|
||||||
}
|
}
|
||||||
oD = sD * (iD - 1) + ekD - 2 * pD;
|
oD = sD * (iD - 1) + ekD - 2 * pD;
|
||||||
oH = sH * (iH - 1) + ekH - 2 * pH;
|
oH = sH * (iH - 1) + ekH - 2 * pH;
|
||||||
|
@ -194,53 +194,53 @@ namespace nd4j {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) {
|
// static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) {
|
||||||
|
|
||||||
if(kH != 1) {
|
// if(kH != 1) {
|
||||||
if(isSameMode) {
|
// if(isSameMode) {
|
||||||
pH = (oH - 1) * sH - iH + kH - pH;
|
// pH = (oH - 1) * sH - iH + kH - pH;
|
||||||
dH = dH - 1;
|
// dH = dH - 1;
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
||||||
}
|
// }
|
||||||
if(kW != 1) {
|
// if(kW != 1) {
|
||||||
if(isSameMode) {
|
// if(isSameMode) {
|
||||||
pW = (oW - 1) * sW - iW + kW - pW;
|
// pW = (oW - 1) * sW - iW + kW - pW;
|
||||||
dW = dW - 1;
|
// dW = dW - 1;
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
|
// dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
|
// static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
|
||||||
|
|
||||||
if(kD != 1) {
|
// if(kD != 1) {
|
||||||
if(isSameMode) {
|
// if(isSameMode) {
|
||||||
pD = (oD - 1) * sD - iD + kD - pD;
|
// pD = (oD - 1) * sD - iD + kD - pD;
|
||||||
dD = dD - 1;
|
// dD = dD - 1;
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
|
// dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
|
||||||
}
|
// }
|
||||||
if(kH != 1) {
|
// if(kH != 1) {
|
||||||
if(isSameMode) {
|
// if(isSameMode) {
|
||||||
pH = (oH - 1) * sH - iH + kH - pH;
|
// pH = (oH - 1) * sH - iH + kH - pH;
|
||||||
dH = dH - 1;
|
// dH = dH - 1;
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
|
||||||
}
|
// }
|
||||||
if(kW != 1) {
|
// if(kW != 1) {
|
||||||
if(isSameMode) {
|
// if(isSameMode) {
|
||||||
pW = (oW - 1) * sW - iW + kW - pW;
|
// pW = (oW - 1) * sW - iW + kW - pW;
|
||||||
dW = dW - 1;
|
// dW = dW - 1;
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
|
// dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||||
|
|
||||||
|
|
|
@ -47,9 +47,12 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output,
|
||||||
|
|
||||||
const bool inOutAreSame = x == z;
|
const bool inOutAreSame = x == z;
|
||||||
|
|
||||||
|
int posOfNonUnityDim;
|
||||||
|
bias.isCommonVector(posOfNonUnityDim);
|
||||||
|
|
||||||
const uint bS = output.sizeAt(0); // batch size
|
const uint bS = output.sizeAt(0); // batch size
|
||||||
const Nd4jLong yStrideC = bias.stridesOf()[0];
|
const Nd4jLong yStrideC = bias.strideAt(posOfNonUnityDim);
|
||||||
const Nd4jLong zStrideB = output.stridesOf()[0];
|
const Nd4jLong zStrideB = output.strideAt(0);
|
||||||
|
|
||||||
if(output.rankOf() == 4) {
|
if(output.rankOf() == 4) {
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
|
||||||
const uint oW = output->sizeAt(2);
|
const uint oW = output->sizeAt(2);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_2D {
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
|
|
||||||
for (uint b = start_x; b < stop_x; b += inc_x) {
|
for (uint b = start_x; b < stop_x; b += inc_x) {
|
||||||
for (uint oh = start_y; oh < stop_y; oh += inc_y) {
|
for (uint oh = start_y; oh < stop_y; oh += inc_y) {
|
||||||
for (uint ow = 0; ow < oW; ++ow) {
|
for (uint ow = 0; ow < oW; ++ow) {
|
||||||
|
@ -69,13 +70,17 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
|
||||||
const int iw = ow * sW - pW + kw * dW;
|
const int iw = ow * sW - pW + kw * dW;
|
||||||
if (iw < 0 || iw >= iW) continue;
|
if (iw < 0 || iw >= iW) continue;
|
||||||
|
|
||||||
const X val = x[shape::getOffset(xShapeInfo, {b, (uint) ih, (uint) iw, c})] + y[shape::getOffset(yShapeInfo, {kh, kw, c})];
|
uint xCoords[4] = {b, (uint)ih, (uint)iw, c};
|
||||||
|
uint yCoords[3] = {kh, kw, c};
|
||||||
|
|
||||||
|
const X val = x[shape::getOffset(xShapeInfo, xCoords)] + y[shape::getOffset(yShapeInfo, yCoords)];
|
||||||
if (val > max)
|
if (val > max)
|
||||||
max = val;
|
max = val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
z[shape::getOffset(zShapeInfo, {b, oh, ow, c})] = static_cast<Z>(max);
|
uint zCoords[4] = {b, oh, ow, c};
|
||||||
|
z[shape::getOffset(zShapeInfo, zCoords)] = static_cast<Z>(max);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
const Y* y = reinterpret_cast<const Y*>(vy);
|
const Y* y = reinterpret_cast<const Y*>(vy);
|
||||||
X* z = reinterpret_cast<X*>(vz);
|
X* z = reinterpret_cast<X*>(vz);
|
||||||
|
|
||||||
__shared__ int rank, channelPosition;
|
__shared__ int rank, channelPosition, posOfNonUnityDim;
|
||||||
__shared__ Nd4jLong *sharedMem, len;
|
__shared__ Nd4jLong *sharedMem, len;
|
||||||
__shared__ bool xzSameOffsets, xzAreSame;
|
__shared__ bool xzSameOffsets, xzAreSame;
|
||||||
|
|
||||||
|
@ -58,6 +58,8 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
len = shape::length(xShapeInfo);
|
len = shape::length(xShapeInfo);
|
||||||
channelPosition = isNCHW ? 1 : rank - 1; // second or last
|
channelPosition = isNCHW ? 1 : rank - 1; // second or last
|
||||||
xzAreSame = x == z;
|
xzAreSame = x == z;
|
||||||
|
|
||||||
|
shape::isCommonVector(yShapeInfo, posOfNonUnityDim);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -69,7 +71,7 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
|
||||||
const auto xOffsets = shape::getOffset(xShapeInfo, coords);
|
const auto xOffsets = shape::getOffset(xShapeInfo, coords);
|
||||||
const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords);
|
const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords);
|
||||||
const auto yOffsets = shape::getOffset(yShapeInfo, coords + channelPosition);
|
const auto yOffsets = coords[channelPosition] * shape::stride(yShapeInfo)[posOfNonUnityDim];
|
||||||
|
|
||||||
if(xzAreSame)
|
if(xzAreSame)
|
||||||
z[zOffsets] += static_cast<X>(y[yOffsets]);
|
z[zOffsets] += static_cast<X>(y[yOffsets]);
|
||||||
|
@ -94,7 +96,7 @@ void addBias(nd4j::graph::Context& block, const NDArray& input, const NDArray& b
|
||||||
|
|
||||||
PointersManager manager(block.launchContext(), "addBias");
|
PointersManager manager(block.launchContext(), "addBias");
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS;
|
const int threadsPerBlock = MAX_NUM_THREADS/2;
|
||||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
|
|
|
@ -34,64 +34,59 @@ static __global__ void col2imCuda(const void* columns, const Nd4jLong* colShapeI
|
||||||
const T* col = reinterpret_cast<const T*>(columns);
|
const T* col = reinterpret_cast<const T*>(columns);
|
||||||
T* im = reinterpret_cast<T*>(image);
|
T* im = reinterpret_cast<T*>(image);
|
||||||
|
|
||||||
__shared__ int colRank, imRank, kHeff, kWeff, oH, oW;
|
__shared__ uint kH, kW, oH, oW, *sharedMem;
|
||||||
__shared__ Nd4jLong *sharedMem, imLen;
|
__shared__ Nd4jLong imLen;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
sharedMem = reinterpret_cast<uint*>(shmem);
|
||||||
|
|
||||||
|
kH = dH * (colShapeInfo[3] - 1) + 1;
|
||||||
|
kW = dW * (colShapeInfo[4] - 1) + 1;
|
||||||
|
|
||||||
oH = colShapeInfo[5];
|
oH = colShapeInfo[5];
|
||||||
oW = colShapeInfo[6];
|
oW = colShapeInfo[6];
|
||||||
|
|
||||||
kHeff = colShapeInfo[3] + (colShapeInfo[3] - 1) * (dH - 1);
|
|
||||||
kWeff = colShapeInfo[4] + (colShapeInfo[4] - 1) * (dW - 1);
|
|
||||||
|
|
||||||
imRank = 4;
|
|
||||||
colRank = 6;
|
|
||||||
|
|
||||||
imLen = shape::length(imShapeInfo);
|
imLen = shape::length(imShapeInfo);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto imInd = threadIdx.x + blockIdx.x * blockDim.x;
|
auto coords = sharedMem + threadIdx.x * 6;
|
||||||
|
|
||||||
if(imInd >= imLen)
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
return;
|
|
||||||
|
|
||||||
auto coords = sharedMem + threadIdx.x * colRank;
|
for (Nd4jLong i = tid; i < imLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
shape::index2coords(imInd, imShapeInfo, coords);
|
shape::index2coords(i, imShapeInfo, coords);
|
||||||
|
|
||||||
const auto imOffset = shape::getOffset(imShapeInfo, coords);
|
const auto imOffset = shape::getOffset(imShapeInfo, coords);
|
||||||
|
|
||||||
const int imH = coords[2] + pH;
|
const auto bSiCoffset = coords[0] * colShapeInfo[7] + coords[1] * colShapeInfo[8];
|
||||||
const int imW = coords[3] + pW;
|
|
||||||
|
|
||||||
const int colHstart = (imH < kHeff) ? 0 : (imH - kHeff) / sH + 1;
|
const uint imH = coords[2] + pH;
|
||||||
const int colWstart = (imW < kWeff) ? 0 : (imW - kWeff) / sW + 1;
|
const uint imW = coords[3] + pW;
|
||||||
|
|
||||||
const int colHend = nd4j::math::nd4j_min<int>(imH / sH + 1, oH);
|
const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1;
|
||||||
const int colWend = nd4j::math::nd4j_min<int>(imW / sW + 1, oW);
|
const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1;
|
||||||
|
|
||||||
|
const uint colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH);
|
||||||
|
const uint colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW);
|
||||||
|
|
||||||
T val = 0;
|
T val = 0;
|
||||||
|
|
||||||
for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) {
|
for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) {
|
||||||
coords[2] = imH - coords[4] * sH;
|
coords[2] = imH - coords[4] * sH;
|
||||||
|
if(coords[2] % dH != 0) continue;
|
||||||
|
|
||||||
for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) {
|
for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) {
|
||||||
coords[3] = imW - coords[5] * sW;
|
coords[3] = imW - coords[5] * sW;
|
||||||
|
if(coords[3] % dW != 0) continue;
|
||||||
|
|
||||||
if(coords[2] % dH == 0 && coords[3] % dW == 0) {
|
val += col[bSiCoffset + (coords[2]/dH)*colShapeInfo[9] + (coords[3]/dW)*colShapeInfo[10] + coords[4]*colShapeInfo[11] + coords[5]*colShapeInfo[12]];
|
||||||
coords[2] /= dH;
|
|
||||||
coords[3] /= dW;
|
|
||||||
|
|
||||||
val += col[shape::getOffset(colShapeInfo, coords)];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
im[imOffset] = val;
|
im[imOffset] = val;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -184,8 +179,8 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc
|
||||||
void* image, const Nd4jLong* imShapeInfo,
|
void* image, const Nd4jLong* imShapeInfo,
|
||||||
const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||||
|
|
||||||
col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
// col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||||
//col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -195,7 +190,7 @@ void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
const int blocksPerGrid = (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&im}, {&col});
|
NDArray::prepareSpecialUse({&im}, {&col});
|
||||||
BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES);
|
||||||
|
|
|
@ -122,74 +122,71 @@ static __global__ void col2volCuda(const void* columns, const Nd4jLong* colShape
|
||||||
const T* col = reinterpret_cast<const T*>(columns);
|
const T* col = reinterpret_cast<const T*>(columns);
|
||||||
T* vol = reinterpret_cast<T*>(volume);
|
T* vol = reinterpret_cast<T*>(volume);
|
||||||
|
|
||||||
__shared__ int colRank, volRank, kDeff, kHeff, kWeff, oD, oH, oW;
|
__shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem;
|
||||||
__shared__ Nd4jLong *sharedMem, volLen;
|
__shared__ Nd4jLong volLen;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
sharedMem = reinterpret_cast<uint*>(shmem);
|
||||||
|
|
||||||
oD = colShapeInfo[6];
|
oD = colShapeInfo[6];
|
||||||
oH = colShapeInfo[7];
|
oH = colShapeInfo[7];
|
||||||
oW = colShapeInfo[8];
|
oW = colShapeInfo[8];
|
||||||
|
|
||||||
kDeff = colShapeInfo[3] + (colShapeInfo[3] - 1) * (dD - 1);
|
kD = dD * (colShapeInfo[3] - 1) + 1;
|
||||||
kHeff = colShapeInfo[4] + (colShapeInfo[4] - 1) * (dH - 1);
|
kH = dH * (colShapeInfo[4] - 1) + 1;
|
||||||
kWeff = colShapeInfo[5] + (colShapeInfo[5] - 1) * (dW - 1);
|
kW = dW * (colShapeInfo[5] - 1) + 1;
|
||||||
|
|
||||||
volRank = 5;
|
|
||||||
colRank = 8;
|
|
||||||
|
|
||||||
volLen = shape::length(volShapeInfo);
|
volLen = shape::length(volShapeInfo);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto volInd = threadIdx.x + blockIdx.x * blockDim.x;
|
auto coords = sharedMem + threadIdx.x * 8;
|
||||||
|
|
||||||
if(volInd >= volLen)
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
return;
|
|
||||||
|
|
||||||
auto coords = sharedMem + threadIdx.x * colRank;
|
for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
shape::index2coords(volInd, volShapeInfo, coords);
|
shape::index2coords(i, volShapeInfo, coords);
|
||||||
|
|
||||||
const auto volOffset = shape::getOffset(volShapeInfo, coords);
|
const auto volOffset = shape::getOffset(volShapeInfo, coords);
|
||||||
|
|
||||||
const int imD = coords[2] + pD;
|
const auto bSiCoffset = coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10];
|
||||||
const int imH = coords[3] + pH;
|
|
||||||
const int imW = coords[4] + pW;
|
|
||||||
|
|
||||||
const int colDstart = (imD < kDeff) ? 0 : (imD - kDeff) / sD + 1;
|
const uint imD = coords[2] + pD;
|
||||||
const int colHstart = (imH < kHeff) ? 0 : (imH - kHeff) / sH + 1;
|
const uint imH = coords[3] + pH;
|
||||||
const int colWstart = (imW < kWeff) ? 0 : (imW - kWeff) / sW + 1;
|
const uint imW = coords[4] + pW;
|
||||||
|
|
||||||
const int colDend = nd4j::math::nd4j_min<uint>(imD / sD + 1, oD);
|
const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1;
|
||||||
const int colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH);
|
const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1;
|
||||||
const int colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW);
|
const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1;
|
||||||
|
|
||||||
|
const uint colDend = nd4j::math::nd4j_min<uint>(imD / sD + 1, oD);
|
||||||
|
const uint colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH);
|
||||||
|
const uint colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW);
|
||||||
|
|
||||||
T val = 0;
|
T val = 0;
|
||||||
|
|
||||||
for(coords[5] = colDstart; coords[5] < colDend; ++coords[5]) {
|
for(uint colD = colDstart; colD < colDend; ++colD) {
|
||||||
coords[2] = imD - coords[5] * sD;
|
coords[2] = imD - colD * sD;
|
||||||
|
if(coords[2] % dD != 0) continue;
|
||||||
|
|
||||||
for(coords[6] = colHstart; coords[6] < colHend; ++coords[6]) {
|
for(uint colH = colHstart; colH < colHend; ++colH) {
|
||||||
coords[3] = imH - coords[6] * sH;
|
coords[3] = imH - colH * sH;
|
||||||
|
if(coords[3] % dH != 0) continue;
|
||||||
|
|
||||||
for(coords[7] = colWstart; coords[7] < colWend; ++coords[7]) {
|
for(uint colW = colWstart; colW < colWend; ++colW) {
|
||||||
coords[4] = imW - coords[7] * sW;
|
coords[4] = imW - colW * sW;
|
||||||
|
if(coords[4] % dW != 0) continue;
|
||||||
|
|
||||||
if(coords[2] % dD == 0 && coords[3] % dH == 0 && coords[4] % dW == 0) {
|
val += col[bSiCoffset + (coords[2]/dD)*colShapeInfo[11] + (coords[3]/dH)*colShapeInfo[12] + (coords[4]/dW)*colShapeInfo[13] + colD*colShapeInfo[14] + colH*colShapeInfo[15] + colW*colShapeInfo[16]];
|
||||||
coords[2] /= dD;
|
|
||||||
coords[3] /= dH;
|
|
||||||
coords[4] /= dW;
|
|
||||||
|
|
||||||
val += col[shape::getOffset(colShapeInfo, coords)];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
vol[volOffset] = val;
|
vol[volOffset] = val;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -209,7 +206,7 @@ void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& col,
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&vol}, {&col});
|
NDArray::prepareSpecialUse({&vol}, {&col});
|
||||||
BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);
|
||||||
|
|
|
@ -46,13 +46,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
|
||||||
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
|
||||||
|
|
||||||
dnnl::memory::dims strides = { sH, sW };
|
dnnl::memory::dims strides = { sH, sW };
|
||||||
dnnl::memory::dims padding = { pH, pW };
|
dnnl::memory::dims padding = { pH, pW };
|
||||||
dnnl::memory::dims padding_r = { pHmkl, pWmkl };
|
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dHmkl, dWmkl };
|
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType;
|
dnnl::memory::data_type xType;
|
||||||
|
@ -193,13 +190,10 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
|
||||||
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
|
||||||
|
|
||||||
dnnl::memory::dims strides = { sH, sW };
|
dnnl::memory::dims strides = { sH, sW };
|
||||||
dnnl::memory::dims padding = { pH, pW };
|
dnnl::memory::dims padding = { pH, pW };
|
||||||
dnnl::memory::dims padding_r = { pHmkl, pWmkl };
|
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dHmkl, dWmkl };
|
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
// weights type
|
// weights type
|
||||||
|
@ -423,12 +417,17 @@ PLATFORM_CHECK(deconv2d) {
|
||||||
|
|
||||||
auto output = INPUT_VARIABLE(0);
|
auto output = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
|
||||||
const DataType xType = input->dataType();
|
const DataType xType = input->dataType();
|
||||||
const DataType wType = weights->dataType();
|
const DataType wType = weights->dataType();
|
||||||
const DataType zType = output->dataType();
|
const DataType zType = output->dataType();
|
||||||
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && (
|
return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) &&
|
||||||
|
(
|
||||||
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||||
);
|
);
|
||||||
|
@ -521,6 +520,9 @@ PLATFORM_CHECK(deconv2d_bp) {
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
int dH = INT_ARG(6); // dilations height
|
||||||
|
int dW = INT_ARG(7); // dilations width
|
||||||
|
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
|
|
||||||
const DataType xType = input->dataType();
|
const DataType xType = input->dataType();
|
||||||
const DataType wType = weights->dataType();
|
const DataType wType = weights->dataType();
|
||||||
|
@ -530,7 +532,7 @@ PLATFORM_CHECK(deconv2d_bp) {
|
||||||
const DataType gradWType = gradW->dataType();
|
const DataType gradWType = gradW->dataType();
|
||||||
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,13 +47,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
|
||||||
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
|
||||||
|
|
||||||
dnnl::memory::dims strides = { sD, sH, sW };
|
dnnl::memory::dims strides = { sD, sH, sW };
|
||||||
dnnl::memory::dims padding = { pD, pH, pW };
|
dnnl::memory::dims padding = { pD, pH, pW };
|
||||||
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
|
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
|
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType;
|
dnnl::memory::data_type xType;
|
||||||
|
@ -197,13 +194,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
|
||||||
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
|
||||||
|
|
||||||
dnnl::memory::dims strides = { sD, sH, sW };
|
dnnl::memory::dims strides = { sD, sH, sW };
|
||||||
dnnl::memory::dims padding = { pD, pH, pW };
|
dnnl::memory::dims padding = { pD, pH, pW };
|
||||||
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
|
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
|
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
@ -437,12 +431,18 @@ PLATFORM_CHECK(deconv3d) {
|
||||||
|
|
||||||
auto output = INPUT_VARIABLE(0);
|
auto output = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||||
|
|
||||||
const DataType xType = input->dataType();
|
const DataType xType = input->dataType();
|
||||||
const DataType wType = weights->dataType();
|
const DataType wType = weights->dataType();
|
||||||
const DataType zType = output->dataType();
|
const DataType zType = output->dataType();
|
||||||
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
const DataType bType = bias != nullptr ? bias->dataType() : zType;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && (
|
return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) &&
|
||||||
|
(
|
||||||
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
|
||||||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
|
||||||
);
|
);
|
||||||
|
@ -538,6 +538,11 @@ PLATFORM_CHECK(deconv3d_bp) {
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
int dD = INT_ARG(9); // dilations depth
|
||||||
|
int dH = INT_ARG(10); // dilations height
|
||||||
|
int dW = INT_ARG(11); // dilations width
|
||||||
|
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||||
|
|
||||||
const DataType xType = input->dataType();
|
const DataType xType = input->dataType();
|
||||||
const DataType wType = weights->dataType();
|
const DataType wType = weights->dataType();
|
||||||
const DataType gradOType = gradO->dataType();
|
const DataType gradOType = gradO->dataType();
|
||||||
|
@ -546,7 +551,7 @@ PLATFORM_CHECK(deconv3d_bp) {
|
||||||
const DataType gradWType = gradW->dataType();
|
const DataType gradWType = gradW->dataType();
|
||||||
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
|
||||||
|
|
||||||
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
|
|
||||||
#include <dnnl_types.h>
|
#include <dnnl_types.h>
|
||||||
#include "mkldnnUtils.h"
|
#include "mkldnnUtils.h"
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
|
||||||
|
|
||||||
using namespace dnnl;
|
using namespace dnnl;
|
||||||
|
|
||||||
|
@ -155,19 +154,10 @@ namespace nd4j {
|
||||||
dnnl::memory::dims conv_bias_tz = { oC };
|
dnnl::memory::dims conv_bias_tz = { oC };
|
||||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||||
|
|
||||||
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
|
|
||||||
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
|
|
||||||
|
|
||||||
conv_strides = { sH, sW };
|
|
||||||
conv_padding = { pH, pW };
|
|
||||||
conv_padding_r = { pHmkl, pWmkl };
|
|
||||||
conv_dilation = { dHmkl, dWmkl };
|
|
||||||
|
|
||||||
conv_strides = { sH, sW };
|
conv_strides = { sH, sW };
|
||||||
conv_padding = { pH, pW };
|
conv_padding = { pH, pW };
|
||||||
|
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||||
conv_dilation = { dH-1, dW-1};
|
conv_dilation = { dH-1, dW-1};
|
||||||
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
|
||||||
(oW - 1) * sW - iW + kW - pW };
|
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
auto type = dnnl::memory::data_type::f32;
|
||||||
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
@ -243,13 +233,10 @@ namespace nd4j {
|
||||||
dnnl::memory::dims conv_bias_tz = { oC };
|
dnnl::memory::dims conv_bias_tz = { oC };
|
||||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
|
||||||
|
|
||||||
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
|
|
||||||
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
|
|
||||||
|
|
||||||
conv_strides = { sD, sH, sW };
|
conv_strides = { sD, sH, sW };
|
||||||
conv_padding = { pD, pH, pW };
|
conv_padding = { pD, pH, pW };
|
||||||
conv_padding_r = { pDmkl, pHmkl, pWmkl };
|
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||||
conv_dilation = { dDmkl, dHmkl, dWmkl };
|
conv_dilation = { dD-1, dH-1, dW-1};
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
auto type = dnnl::memory::data_type::f32;
|
||||||
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -605,7 +605,6 @@ TEST_F(ConvolutionTests2, deconv3d_test5) {
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
// output->printBuffer();
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
|
@ -2285,66 +2285,6 @@ TEST_F(DeclarableOpsTests1, IsMax4) {
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, CompactLaunchTests1) {
|
|
||||||
|
|
||||||
NDArray input('c', {2, 3, 4, 4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray weights('c', {3, 3, 5, 5}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray exp('c', {2,3,8,8}, {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,
|
|
||||||
100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,
|
|
||||||
84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,
|
|
||||||
54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,
|
|
||||||
90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,
|
|
||||||
8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,
|
|
||||||
144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,
|
|
||||||
118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,
|
|
||||||
115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,
|
|
||||||
268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,
|
|
||||||
52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,
|
|
||||||
78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,
|
|
||||||
89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input.linspace(1);
|
|
||||||
weights.linspace(1);
|
|
||||||
weights.permutei({2,3,1,0});
|
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
|
||||||
auto result = op.execute({&input, &weights}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
// z->printShapeInfo();
|
|
||||||
// z->printBuffer();
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, CompactLaunchTests2) {
|
|
||||||
Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99};
|
|
||||||
double _expB[] = {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0,};
|
|
||||||
NDArray exp(_expB, _expS);
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2, 3, 4, 4});
|
|
||||||
auto weights = NDArrayFactory::create<double>('c', {3, 3, 5, 5});
|
|
||||||
auto z = NDArrayFactory::create<double>('c', {2, 3, 8, 8});
|
|
||||||
|
|
||||||
input.linspace(1);
|
|
||||||
weights.linspace(1);
|
|
||||||
weights.permutei({2,3,1,0});
|
|
||||||
|
|
||||||
nd4j::ops::deconv2d op;
|
|
||||||
auto result = op.execute({&input, &weights}, {&z}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0},{});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(&z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&z));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
// TEST_F(DeclarableOpsTests1, sru_old_test1) {
|
// TEST_F(DeclarableOpsTests1, sru_old_test1) {
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@
|
||||||
#include <ops/declarable/helpers/im2col.h>
|
#include <ops/declarable/helpers/im2col.h>
|
||||||
#include <Loops.h>
|
#include <Loops.h>
|
||||||
#include <RandomLauncher.h>
|
#include <RandomLauncher.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
#include <helpers/BenchmarkHelper.h>
|
#include <helpers/BenchmarkHelper.h>
|
||||||
#include <ops/declarable/helpers/scatter.h>
|
#include <ops/declarable/helpers/scatter.h>
|
||||||
|
@ -279,24 +280,65 @@ TEST_F(PlaygroundTests, test_relubp_1) {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(PlaygroundTests, my) {
|
TEST_F(PlaygroundTests, my) {
|
||||||
|
|
||||||
int bS=1, iH=56,iW=56, iC=144,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
|
||||||
int oC=iC*mC;
|
int oD,oH,oW;
|
||||||
int oH=56,oW=56;
|
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
nd4j::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0);
|
||||||
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
|
|
||||||
|
|
||||||
input = 2.;
|
printf("!!%i, %i, %i\n", oD,oH,oW);
|
||||||
weights.linspace(0.1, 0.1);
|
|
||||||
|
|
||||||
nd4j::ops::depthwise_conv2d op;
|
NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, nd4j::DataType::DOUBLE);
|
||||||
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
NDArray vol('c', {bS, iC, oD, oH, oW}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
delete results;
|
col = 3.77;
|
||||||
|
vol = -10.33;
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
auto block = new Context(1, variableSpace, false); // not-in-place
|
||||||
|
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
nd4j::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto time = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
|
||||||
|
|
||||||
|
printf("time: %i \n", time);
|
||||||
|
|
||||||
|
delete block;
|
||||||
|
delete variableSpace;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PlaygroundTests, my) {
|
||||||
|
|
||||||
|
int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
|
||||||
|
int oD,oH,oW;
|
||||||
|
|
||||||
|
// nd4j::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0);
|
||||||
|
nd4j::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW,dH, dW, iH, iW, 0);
|
||||||
|
|
||||||
|
printf("!!%i, %i, %i\n", oD,oH,oW);
|
||||||
|
|
||||||
|
// NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, nd4j::DataType::DOUBLE);
|
||||||
|
// NDArray vol('c', {bS, iC, oD, oH, oW}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray col('c', {bS, iC, kH, kW, iH, iW}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray im('c', {bS, iC, oH, oW}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
col = 3.77;
|
||||||
|
// vol = -10.33;
|
||||||
|
im = -10.33;
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
auto block = new Context(1, variableSpace, false); // not-in-place
|
||||||
|
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
// nd4j::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||||
|
nd4j::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, pH, pW, iH, iW, dH, dW);
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto time = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
|
||||||
|
|
||||||
|
printf("time: %i \n", time);
|
||||||
|
|
||||||
|
delete block;
|
||||||
|
delete variableSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue