Shyrma deconv3 (#69)

* - profiling cuda kernels for vol2col and im2col

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct addBias helper

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct mkl dilation formula and switch off mkl api for dilation deconvolutions

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2019-11-21 21:17:30 +02:00 committed by GitHub
parent ff73e6da3f
commit 7a90a31cfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 654 additions and 386 deletions

View File

@ -258,10 +258,10 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1);
if(typeFloat) {
BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, reinterpret_cast<float *>(pA->getBuffer()), lda, reinterpret_cast<float *>(pB->getBuffer()), ldb, (float) beta, reinterpret_cast<float *>(pC->getBuffer()), ldc);
BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, pA->bufferAsT<float>(), lda, pB->bufferAsT<float>(), ldb, (float) beta, pC->bufferAsT<float>(), ldc);
}
else if(typeDouble) {
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast<double *>(pA->getBuffer()), lda, reinterpret_cast<double *>(pB->getBuffer()), ldb, (double) beta, reinterpret_cast<double *>(pC->getBuffer()), ldc);
BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, pA->bufferAsT<double>(), lda, pB->bufferAsT<double>(), ldb, (double) beta, pC->bufferAsT<double>(), ldc);
}
if(pC != C) {

View File

@ -263,9 +263,13 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 6 + 128; // 6 = aRank + bRank + cRank
NDArray::prepareSpecialUse({C}, {A, B});
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
// BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({C}, {A, B});
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
}
else {
@ -334,6 +338,10 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
NDArray::registerSpecialUse({pC}, {pA, pB});
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
if(C != pC)
C->assign(pC);
@ -341,10 +349,6 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
delete toDelete[i];
}
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult);
return C;
}
@ -397,10 +401,14 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({Y}, {A, X});
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
// BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES)
NDArray::registerSpecialUse({Y}, {A, X});
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
}
else {
@ -434,16 +442,16 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
if (status != CUBLAS_STATUS_SUCCESS)
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
NDArray::registerSpecialUse({Y}, {pA, X});
if(pA != A)
delete pA;
}
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0)
throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult);
return Y;
}
@ -624,7 +632,7 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !");
}
else
C = new NDArray(outOrder, cExpectedShape, B->dataType());
C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
const int cRank = C->rankOf();

View File

@ -889,7 +889,8 @@ namespace shape {
*/
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0);
ND4J_EXPORT Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices);
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *indices, Nd4jLong baseOffset = 0);
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *indices, Nd4jLong baseOffset = 0);
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
@ -900,6 +901,8 @@ namespace shape {
* for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1]
*/
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords);
/**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
@ -913,6 +916,8 @@ namespace shape {
* for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned
*/
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords);
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords);
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords);
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords);
/**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
@ -1756,6 +1761,34 @@ INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLo
return index;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords) {
Nd4jLong index, shift = 1;;
index = coords[shapeInfo[0] - 1];
for(uint i = shapeInfo[0]; i > 1; --i) {
shift *= shapeInfo[i];
index += shift * coords[i - 2];
}
return index;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords) {
Nd4jLong index, shift = 1;;
index = coords[shapeInfo[0] - 1];
for(uint i = shapeInfo[0]; i > 1; --i) {
shift *= shapeInfo[i];
index += shift * coords[i - 2];
}
return index;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *indices) {
@ -3223,18 +3256,28 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong
}
//////////////////////////////////////////////////////////////////////////
INLINEDEF Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices) {
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) {
Nd4jLong offset = 0;
Nd4jLong offset = baseOffset;
for(uint i = 1; i <= shapeInfo[0]; ++i)
if(shapeInfo[i] != 1)
offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i];
offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
return offset;
}
//////////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) {
Nd4jLong offset = baseOffset;
for(uint i = 1; i <= shapeInfo[0]; ++i)
if(shapeInfo[i] != 1)
offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
return offset;
}
/**
@ -4720,6 +4763,26 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo,
coords[0] = index; // last iteration
}
//////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords) {
for(uint i = shapeInfo[0]; i > 1; --i) {
coords[i - 1] = static_cast<int>(index) % static_cast<int>(shapeInfo[i]);
index /= static_cast<int>(shapeInfo[i]);
}
coords[0] = static_cast<int>(index); // last iteration
}
//////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords) {
for(uint i = shapeInfo[0]; i > 1; --i) {
coords[i - 1] = static_cast<uint>(index) % static_cast<uint>(shapeInfo[i]);
index /= static_cast<uint>(shapeInfo[i]);
}
coords[0] = static_cast<uint>(index); // last iteration
}
//////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) {

View File

@ -66,10 +66,8 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
if(!isNCHW)
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
if(isSameMode){ // SAME
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
}
NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext());

View File

@ -67,10 +67,10 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
if(!isNCDHW)
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
if(isSameMode) //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
//----- calculation of output -----//
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]

View File

@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC
NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)

View File

@ -69,8 +69,8 @@ namespace nd4j {
eKH = kH;
eKW = kW;
} else {
eKH = kH + (kH - 1) * (dH - 1);
eKW = kW + (kW - 1) * (dW - 1);
eKH = (kH - 1) * dH + 1;
eKW = (kW - 1) * dW + 1;
}
pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
@ -84,9 +84,9 @@ namespace nd4j {
eKH = kH;
eKW = kW;
} else {
eKD = kD + (kD - 1) * (dD - 1);
eKH = kH + (kH - 1) * (dH - 1);
eKW = kW + (kW - 1) * (dW - 1);
eKD = (kD - 1) * dD + 1;
eKH = (kH - 1) * dH + 1;
eKW = (kW - 1) * dW + 1;
}
pD = ((oD - 1) * sD + eKD - iD) / 2; // Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
@ -107,8 +107,8 @@ namespace nd4j {
ekH = kH;
ekW = kW;
} else {
ekH = kH + (kH - 1) * (dH - 1);
ekW = kW + (kW - 1) * (dW - 1);
ekH = (kH - 1) * dH + 1;
ekW = (kW - 1) * dW + 1;
}
oH = sH * (iH - 1) + ekH - 2 * pH;
@ -131,9 +131,9 @@ namespace nd4j {
ekW = kW;
}
else {
ekD = kD + (kD - 1) * (dD - 1);
ekH = kH + (kH - 1) * (dH - 1);
ekW = kW + (kW - 1) * (dW - 1);
ekD = (kD - 1) * dD + 1;
ekH = (kH - 1) * dH + 1;
ekW = (kW - 1) * dW + 1;
}
oD = sD * (iD - 1) + ekD - 2 * pD;
oH = sH * (iH - 1) + ekH - 2 * pH;
@ -194,53 +194,53 @@ namespace nd4j {
}
static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) {
// static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int isSameMode, int& pH, int& pW, int& dH, int& dW) {
if(kH != 1) {
if(isSameMode) {
pH = (oH - 1) * sH - iH + kH - pH;
dH = dH - 1;
}
else
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
}
if(kW != 1) {
if(isSameMode) {
pW = (oW - 1) * sW - iW + kW - pW;
dW = dW - 1;
}
else
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
}
}
// if(kH != 1) {
// if(isSameMode) {
// pH = (oH - 1) * sH - iH + kH - pH;
// dH = dH - 1;
// }
// else
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
// }
// if(kW != 1) {
// if(isSameMode) {
// pW = (oW - 1) * sW - iW + kW - pW;
// dW = dW - 1;
// }
// else
// dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
// }
// }
static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
// static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int isSameMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
if(kD != 1) {
if(isSameMode) {
pD = (oD - 1) * sD - iD + kD - pD;
dD = dD - 1;
}
else
dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
}
if(kH != 1) {
if(isSameMode) {
pH = (oH - 1) * sH - iH + kH - pH;
dH = dH - 1;
}
else
dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
}
if(kW != 1) {
if(isSameMode) {
pW = (oW - 1) * sW - iW + kW - pW;
dW = dW - 1;
}
else
dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
}
}
// if(kD != 1) {
// if(isSameMode) {
// pD = (oD - 1) * sD - iD + kD - pD;
// dD = dD - 1;
// }
// else
// dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
// }
// if(kH != 1) {
// if(isSameMode) {
// pH = (oH - 1) * sH - iH + kH - pH;
// dH = dH - 1;
// }
// else
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
// }
// if(kW != 1) {
// if(isSameMode) {
// pW = (oW - 1) * sW - iW + kW - pW;
// dW = dW - 1;
// }
// else
// dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
// }
// }
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);

View File

@ -47,9 +47,12 @@ static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output,
const bool inOutAreSame = x == z;
int posOfNonUnityDim;
bias.isCommonVector(posOfNonUnityDim);
const uint bS = output.sizeAt(0); // batch size
const Nd4jLong yStrideC = bias.stridesOf()[0];
const Nd4jLong zStrideB = output.stridesOf()[0];
const Nd4jLong yStrideC = bias.strideAt(posOfNonUnityDim);
const Nd4jLong zStrideB = output.strideAt(0);
if(output.rankOf() == 4) {

View File

@ -54,6 +54,7 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
const uint oW = output->sizeAt(2);
auto func = PRAGMA_THREADS_FOR_2D {
for (uint b = start_x; b < stop_x; b += inc_x) {
for (uint oh = start_y; oh < stop_y; oh += inc_y) {
for (uint ow = 0; ow < oW; ++ow) {
@ -69,13 +70,17 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const
const int iw = ow * sW - pW + kw * dW;
if (iw < 0 || iw >= iW) continue;
const X val = x[shape::getOffset(xShapeInfo, {b, (uint) ih, (uint) iw, c})] + y[shape::getOffset(yShapeInfo, {kh, kw, c})];
uint xCoords[4] = {b, (uint)ih, (uint)iw, c};
uint yCoords[3] = {kh, kw, c};
const X val = x[shape::getOffset(xShapeInfo, xCoords)] + y[shape::getOffset(yShapeInfo, yCoords)];
if (val > max)
max = val;
}
}
z[shape::getOffset(zShapeInfo, {b, oh, ow, c})] = static_cast<Z>(max);
uint zCoords[4] = {b, oh, ow, c};
z[shape::getOffset(zShapeInfo, zCoords)] = static_cast<Z>(max);
}
}
}

View File

@ -44,7 +44,7 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
const Y* y = reinterpret_cast<const Y*>(vy);
X* z = reinterpret_cast<X*>(vz);
__shared__ int rank, channelPosition;
__shared__ int rank, channelPosition, posOfNonUnityDim;
__shared__ Nd4jLong *sharedMem, len;
__shared__ bool xzSameOffsets, xzAreSame;
@ -58,6 +58,8 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
len = shape::length(xShapeInfo);
channelPosition = isNCHW ? 1 : rank - 1; // second or last
xzAreSame = x == z;
shape::isCommonVector(yShapeInfo, posOfNonUnityDim);
}
__syncthreads();
@ -69,7 +71,7 @@ __global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
const auto xOffsets = shape::getOffset(xShapeInfo, coords);
const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords);
const auto yOffsets = shape::getOffset(yShapeInfo, coords + channelPosition);
const auto yOffsets = coords[channelPosition] * shape::stride(yShapeInfo)[posOfNonUnityDim];
if(xzAreSame)
z[zOffsets] += static_cast<X>(y[yOffsets]);
@ -94,7 +96,7 @@ void addBias(nd4j::graph::Context& block, const NDArray& input, const NDArray& b
PointersManager manager(block.launchContext(), "addBias");
const int threadsPerBlock = MAX_NUM_THREADS;
const int threadsPerBlock = MAX_NUM_THREADS/2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;

View File

@ -34,64 +34,59 @@ static __global__ void col2imCuda(const void* columns, const Nd4jLong* colShapeI
const T* col = reinterpret_cast<const T*>(columns);
T* im = reinterpret_cast<T*>(image);
__shared__ int colRank, imRank, kHeff, kWeff, oH, oW;
__shared__ Nd4jLong *sharedMem, imLen;
__shared__ uint kH, kW, oH, oW, *sharedMem;
__shared__ Nd4jLong imLen;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
sharedMem = reinterpret_cast<uint*>(shmem);
kH = dH * (colShapeInfo[3] - 1) + 1;
kW = dW * (colShapeInfo[4] - 1) + 1;
oH = colShapeInfo[5];
oW = colShapeInfo[6];
kHeff = colShapeInfo[3] + (colShapeInfo[3] - 1) * (dH - 1);
kWeff = colShapeInfo[4] + (colShapeInfo[4] - 1) * (dW - 1);
imRank = 4;
colRank = 6;
imLen = shape::length(imShapeInfo);
}
__syncthreads();
const auto imInd = threadIdx.x + blockIdx.x * blockDim.x;
auto coords = sharedMem + threadIdx.x * 6;
if(imInd >= imLen)
return;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto coords = sharedMem + threadIdx.x * colRank;
for (Nd4jLong i = tid; i < imLen; i += gridDim.x * blockDim.x) {
shape::index2coords(imInd, imShapeInfo, coords);
shape::index2coords(i, imShapeInfo, coords);
const auto imOffset = shape::getOffset(imShapeInfo, coords);
const int imH = coords[2] + pH;
const int imW = coords[3] + pW;
const auto bSiCoffset = coords[0] * colShapeInfo[7] + coords[1] * colShapeInfo[8];
const int colHstart = (imH < kHeff) ? 0 : (imH - kHeff) / sH + 1;
const int colWstart = (imW < kWeff) ? 0 : (imW - kWeff) / sW + 1;
const uint imH = coords[2] + pH;
const uint imW = coords[3] + pW;
const int colHend = nd4j::math::nd4j_min<int>(imH / sH + 1, oH);
const int colWend = nd4j::math::nd4j_min<int>(imW / sW + 1, oW);
const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1;
const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1;
const uint colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH);
const uint colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW);
T val = 0;
for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) {
coords[2] = imH - coords[4] * sH;
if(coords[2] % dH != 0) continue;
for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) {
coords[3] = imW - coords[5] * sW;
if(coords[3] % dW != 0) continue;
if(coords[2] % dH == 0 && coords[3] % dW == 0) {
coords[2] /= dH;
coords[3] /= dW;
val += col[shape::getOffset(colShapeInfo, coords)];
val += col[bSiCoffset + (coords[2]/dH)*colShapeInfo[9] + (coords[3]/dW)*colShapeInfo[10] + coords[4]*colShapeInfo[11] + coords[5]*colShapeInfo[12]];
}
}
}
im[imOffset] = val;
}
}
////////////////////////////////////////////////////////////////////////
@ -184,8 +179,8 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc
void* image, const Nd4jLong* imShapeInfo,
const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
//col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
// col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
}
//////////////////////////////////////////////////////////////////////////
@ -195,7 +190,7 @@ void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256;
NDArray::prepareSpecialUse({&im}, {&col});
BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES);

View File

@ -122,74 +122,71 @@ static __global__ void col2volCuda(const void* columns, const Nd4jLong* colShape
const T* col = reinterpret_cast<const T*>(columns);
T* vol = reinterpret_cast<T*>(volume);
__shared__ int colRank, volRank, kDeff, kHeff, kWeff, oD, oH, oW;
__shared__ Nd4jLong *sharedMem, volLen;
__shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem;
__shared__ Nd4jLong volLen;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
sharedMem = reinterpret_cast<uint*>(shmem);
oD = colShapeInfo[6];
oH = colShapeInfo[7];
oW = colShapeInfo[8];
kDeff = colShapeInfo[3] + (colShapeInfo[3] - 1) * (dD - 1);
kHeff = colShapeInfo[4] + (colShapeInfo[4] - 1) * (dH - 1);
kWeff = colShapeInfo[5] + (colShapeInfo[5] - 1) * (dW - 1);
volRank = 5;
colRank = 8;
kD = dD * (colShapeInfo[3] - 1) + 1;
kH = dH * (colShapeInfo[4] - 1) + 1;
kW = dW * (colShapeInfo[5] - 1) + 1;
volLen = shape::length(volShapeInfo);
}
__syncthreads();
const auto volInd = threadIdx.x + blockIdx.x * blockDim.x;
auto coords = sharedMem + threadIdx.x * 8;
if(volInd >= volLen)
return;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto coords = sharedMem + threadIdx.x * colRank;
for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) {
shape::index2coords(volInd, volShapeInfo, coords);
shape::index2coords(i, volShapeInfo, coords);
const auto volOffset = shape::getOffset(volShapeInfo, coords);
const int imD = coords[2] + pD;
const int imH = coords[3] + pH;
const int imW = coords[4] + pW;
const auto bSiCoffset = coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10];
const int colDstart = (imD < kDeff) ? 0 : (imD - kDeff) / sD + 1;
const int colHstart = (imH < kHeff) ? 0 : (imH - kHeff) / sH + 1;
const int colWstart = (imW < kWeff) ? 0 : (imW - kWeff) / sW + 1;
const uint imD = coords[2] + pD;
const uint imH = coords[3] + pH;
const uint imW = coords[4] + pW;
const int colDend = nd4j::math::nd4j_min<uint>(imD / sD + 1, oD);
const int colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH);
const int colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW);
const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1;
const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1;
const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1;
const uint colDend = nd4j::math::nd4j_min<uint>(imD / sD + 1, oD);
const uint colHend = nd4j::math::nd4j_min<uint>(imH / sH + 1, oH);
const uint colWend = nd4j::math::nd4j_min<uint>(imW / sW + 1, oW);
T val = 0;
for(coords[5] = colDstart; coords[5] < colDend; ++coords[5]) {
coords[2] = imD - coords[5] * sD;
for(uint colD = colDstart; colD < colDend; ++colD) {
coords[2] = imD - colD * sD;
if(coords[2] % dD != 0) continue;
for(coords[6] = colHstart; coords[6] < colHend; ++coords[6]) {
coords[3] = imH - coords[6] * sH;
for(uint colH = colHstart; colH < colHend; ++colH) {
coords[3] = imH - colH * sH;
if(coords[3] % dH != 0) continue;
for(coords[7] = colWstart; coords[7] < colWend; ++coords[7]) {
coords[4] = imW - coords[7] * sW;
for(uint colW = colWstart; colW < colWend; ++colW) {
coords[4] = imW - colW * sW;
if(coords[4] % dW != 0) continue;
if(coords[2] % dD == 0 && coords[3] % dH == 0 && coords[4] % dW == 0) {
coords[2] /= dD;
coords[3] /= dH;
coords[4] /= dW;
val += col[bSiCoffset + (coords[2]/dD)*colShapeInfo[11] + (coords[3]/dH)*colShapeInfo[12] + (coords[4]/dW)*colShapeInfo[13] + colD*colShapeInfo[14] + colH*colShapeInfo[15] + colW*colShapeInfo[16]];
val += col[shape::getOffset(colShapeInfo, coords)];
}
}
}
}
vol[volOffset] = val;
}
}
//////////////////////////////////////////////////////////////////////////
@ -209,7 +206,7 @@ void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& col,
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256;
NDArray::prepareSpecialUse({&vol}, {&col});
BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES);

View File

@ -46,13 +46,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { pHmkl, pWmkl };
dnnl::memory::dims dilation = { dHmkl, dWmkl };
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dH-1, dW-1 };
// input type
dnnl::memory::data_type xType;
@ -193,13 +190,10 @@ static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(oH, oW, iH, iW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
dnnl::memory::dims strides = { sH, sW };
dnnl::memory::dims padding = { pH, pW };
dnnl::memory::dims padding_r = { pHmkl, pWmkl };
dnnl::memory::dims dilation = { dHmkl, dWmkl };
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dH-1, dW-1 };
// input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
// weights type
@ -423,12 +417,17 @@ PLATFORM_CHECK(deconv2d) {
auto output = INPUT_VARIABLE(0);
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
const DataType xType = input->dataType();
const DataType wType = weights->dataType();
const DataType zType = output->dataType();
const DataType bType = bias != nullptr ? bias->dataType() : zType;
return block.isUseMKLDNN() && (
return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) &&
(
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
);
@ -521,6 +520,9 @@ PLATFORM_CHECK(deconv2d_bp) {
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
const DataType xType = input->dataType();
const DataType wType = weights->dataType();
@ -530,7 +532,7 @@ PLATFORM_CHECK(deconv2d_bp) {
const DataType gradWType = gradW->dataType();
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
}

View File

@ -47,13 +47,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW };
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
// input type
dnnl::memory::data_type xType;
@ -197,13 +194,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
dnnl::memory::dims strides = { sD, sH, sW };
dnnl::memory::dims padding = { pD, pH, pW };
dnnl::memory::dims padding_r = { pDmkl, pHmkl, pWmkl };
dnnl::memory::dims dilation = { dDmkl, dHmkl, dWmkl };
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
// input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
@ -437,12 +431,18 @@ PLATFORM_CHECK(deconv3d) {
auto output = INPUT_VARIABLE(0);
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
const DataType xType = input->dataType();
const DataType wType = weights->dataType();
const DataType zType = output->dataType();
const DataType bType = bias != nullptr ? bias->dataType() : zType;
return block.isUseMKLDNN() && (
return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) &&
(
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
);
@ -538,6 +538,11 @@ PLATFORM_CHECK(deconv3d_bp) {
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
const DataType xType = input->dataType();
const DataType wType = weights->dataType();
const DataType gradOType = gradO->dataType();
@ -546,7 +551,7 @@ PLATFORM_CHECK(deconv3d_bp) {
const DataType gradWType = gradW->dataType();
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
}
}

View File

@ -20,7 +20,6 @@
#include <dnnl_types.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace dnnl;
@ -155,19 +154,10 @@ namespace nd4j {
dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
int dHmkl(dH), dWmkl(dW), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv2DMKL(iH, iW, oH, oW, kH, kW, sH, sW, isSameMode, pHmkl, pWmkl, dHmkl, dWmkl);
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_padding_r = { pHmkl, pWmkl };
conv_dilation = { dHmkl, dWmkl };
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
conv_dilation = { dH-1, dW-1};
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
auto type = dnnl::memory::data_type::f32;
auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
@ -243,13 +233,10 @@ namespace nd4j {
dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
int dDmkl(dD), dHmkl(dH), dWmkl(dW), pDmkl(pD), pHmkl(pH), pWmkl(pW);
nd4j::ops::ConvolutionUtils::calcPaddingAndDilationForConv3DMKL(iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, isSameMode, pDmkl, pHmkl, pWmkl, dDmkl, dHmkl, dWmkl);
conv_strides = { sD, sH, sW };
conv_padding = { pD, pH, pW };
conv_padding_r = { pDmkl, pHmkl, pWmkl };
conv_dilation = { dDmkl, dHmkl, dWmkl };
conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
conv_dilation = { dD-1, dH-1, dW-1};
auto type = dnnl::memory::data_type::f32;
auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;

File diff suppressed because one or more lines are too long

View File

@ -605,7 +605,6 @@ TEST_F(ConvolutionTests2, deconv3d_test5) {
ASSERT_EQ(Status::OK(), results->status());
auto output = results->at(0);
// output->printBuffer();
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));

View File

@ -2285,66 +2285,6 @@ TEST_F(DeclarableOpsTests1, IsMax4) {
ASSERT_EQ(e, z);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, CompactLaunchTests1) {
NDArray input('c', {2, 3, 4, 4}, nd4j::DataType::FLOAT32);
NDArray weights('c', {3, 3, 5, 5}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3,8,8}, {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,
100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,
84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,
54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,
90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,
8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,
144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,
118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,
115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,
268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,
52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,
78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,
89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0}, nd4j::DataType::FLOAT32);
input.linspace(1);
weights.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::deconv2d op;
auto result = op.execute({&input, &weights}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0});
auto z = result->at(0);
// z->printShapeInfo();
// z->printBuffer();
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, CompactLaunchTests2) {
Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99};
double _expB[] = {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0,};
NDArray exp(_expB, _expS);
auto input = NDArrayFactory::create<double>('c', {2, 3, 4, 4});
auto weights = NDArrayFactory::create<double>('c', {3, 3, 5, 5});
auto z = NDArrayFactory::create<double>('c', {2, 3, 8, 8});
input.linspace(1);
weights.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::deconv2d op;
auto result = op.execute({&input, &weights}, {&z}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0},{});
ASSERT_EQ(ND4J_STATUS_OK, result);
ASSERT_TRUE(exp.isSameShape(&z));
ASSERT_TRUE(exp.equalsTo(&z));
}
////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests1, sru_old_test1) {

View File

@ -34,6 +34,7 @@
#include <ops/declarable/helpers/im2col.h>
#include <Loops.h>
#include <RandomLauncher.h>
#include <ops/declarable/helpers/convolutions.h>
#include <helpers/BenchmarkHelper.h>
#include <ops/declarable/helpers/scatter.h>
@ -279,24 +280,65 @@ TEST_F(PlaygroundTests, test_relubp_1) {
//////////////////////////////////////////////////////////////////////
TEST_F(PlaygroundTests, my) {
int bS=1, iH=56,iW=56, iC=144,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oC=iC*mC;
int oH=56,oW=56;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
int oD,oH,oW;
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
nd4j::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0);
input = 2.;
weights.linspace(0.1, 0.1);
printf("!!%i, %i, %i\n", oD,oH,oW);
nd4j::ops::depthwise_conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, nd4j::DataType::DOUBLE);
NDArray vol('c', {bS, iC, oD, oH, oW}, nd4j::DataType::DOUBLE);
delete results;
col = 3.77;
vol = -10.33;
auto variableSpace = new VariableSpace();
auto block = new Context(1, variableSpace, false); // not-in-place
auto timeStart = std::chrono::system_clock::now();
nd4j::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW);
auto timeEnd = std::chrono::system_clock::now();
auto time = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
printf("time: %i \n", time);
delete block;
delete variableSpace;
}
TEST_F(PlaygroundTests, my) {
int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
int oD,oH,oW;
// nd4j::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0);
nd4j::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW,dH, dW, iH, iW, 0);
printf("!!%i, %i, %i\n", oD,oH,oW);
// NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, nd4j::DataType::DOUBLE);
// NDArray vol('c', {bS, iC, oD, oH, oW}, nd4j::DataType::DOUBLE);
NDArray col('c', {bS, iC, kH, kW, iH, iW}, nd4j::DataType::DOUBLE);
NDArray im('c', {bS, iC, oH, oW}, nd4j::DataType::DOUBLE);
col = 3.77;
// vol = -10.33;
im = -10.33;
auto variableSpace = new VariableSpace();
auto block = new Context(1, variableSpace, false); // not-in-place
auto timeStart = std::chrono::system_clock::now();
// nd4j::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW);
nd4j::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, pH, pW, iH, iW, dH, dW);
auto timeEnd = std::chrono::system_clock::now();
auto time = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
printf("time: %i \n", time);
delete block;
delete variableSpace;
}
*/