Shugeo cuda solver fix (#383)

* Refactored cuSolver handle usage to handle LaunchContext instance properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored svd solver usage with LaunchContext instance singleton.

Signed-off-by: shugeo <sgazeos@gmail.com>

* add device locks for cuSolver uses

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
shugeo 2020-04-17 16:52:08 +03:00 committed by GitHub
parent 18d4eaa68d
commit 0eca33ad94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 35 deletions

View File

@ -341,14 +341,16 @@ namespace helpers {
static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
auto stream = context->getCudaStream();
auto n = input->rows();
cusolverDnHandle_t cusolverH = nullptr;
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr;
// create solver handle
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("Cannot create cuSolver handle", status);
}
cusolverStatus_t status; //cusolverDnCreate(&cusolverH);
// if (CUSOLVER_STATUS_SUCCESS != status) {
// throw cuda_exception::build("Cannot create cuSolver handle", status);
// }
// set solver stream
status = cusolverDnSetStream(cusolverH, *stream);
status = cusolverDnSetStream(*cusolverH, *stream);
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("Cannot set up stream for cuda solver", status);
}
@ -368,7 +370,7 @@ namespace helpers {
// compute internal buffer size
double *matrix = reinterpret_cast<double *>(input->specialBuffer());
status = cusolverDnDgetrf_bufferSize(
cusolverH,
*cusolverH,
n,
n,
matrix,
@ -386,7 +388,7 @@ namespace helpers {
if (permutation == nullptr) {
status = cusolverDnDgetrf(
cusolverH,
*cusolverH,
n,
n,
matrix,
@ -404,7 +406,7 @@ namespace helpers {
NDArray permutVector('c', {n}, sd::DataType::INT32, context);
int* permutationBuf = permutVector.dataBuffer()->specialAsT<int>();
status = cusolverDnDgetrf(
cusolverH,
*cusolverH,
n,
n,
matrix,
@ -440,7 +442,7 @@ namespace helpers {
float *d_work = nullptr;
status = cusolverDnSgetrf_bufferSize(
cusolverH,
*cusolverH,
n,
n,
matrix,
@ -458,7 +460,7 @@ namespace helpers {
if (permutation == nullptr)
status = cusolverDnSgetrf(
cusolverH,
*cusolverH,
n,
n,
matrix,
@ -470,7 +472,7 @@ namespace helpers {
NDArray permutVector('c', {n}, DataType::INT32, context);
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
status = cusolverDnSgetrf(
cusolverH,
*cusolverH,
n,
n,
matrix,
@ -504,7 +506,7 @@ namespace helpers {
if (err) {
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err);
}
cusolverDnDestroy(cusolverH);
// cusolverDnDestroy(cusolverH);
// NDArray::registerSpecialUse({input}, {input});
input->tickWriteDevice();
}

View File

@ -170,23 +170,25 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
}
}
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
// create cusolverDn handle
cusolverDnHandle_t handle = nullptr;
cusolverStatus_t status = cusolverDnCreate(&handle);
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdQR: cuda failed !", status);
cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr;
//cusolverStatus_t status = cusolverDnCreate(&handle);
if(handle == nullptr)
throw cuda_exception::build("svdQR: cuda failed !", -1);
// stream
status = cusolverDnSetStream(handle, *context->getCudaStream());
auto status = cusolverDnSetStream(*handle, *context->getCudaStream());
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdQR: cuda failed !", status);
// query working space of SVD
int lwork = 0;
if(A->dataType() == DataType::DOUBLE)
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork);
status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork);
else if(A->dataType() == DataType::FLOAT32)
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork);
status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork);
else
throw std::invalid_argument("svdQR: given data type is unsupported !");
@ -227,10 +229,10 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
// choose appropriate cuda gemm api depending on data types
if(A->dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
}
else if(A->dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
}
else
throw std::invalid_argument("svdQR: given data type is unsupported !");
@ -259,8 +261,8 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
if (rWork)
cudaFree(rWork);
if(handle)
cusolverDnDestroy(handle);
// if(handle)
// cusolverDnDestroy(handle);
// cudaDeviceReset();
}
@ -346,14 +348,16 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
ldv = pV->strideAt(1);
}
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
// create cusolverDn handle
cusolverDnHandle_t handle = nullptr;
cusolverStatus_t status = cusolverDnCreate(&handle);
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle();
//cusolverStatus_t status = cusolverDnCreate(&handle);
if(handle == nullptr)
throw cuda_exception::build("svdJcb: cuda failed !", -1);
// stream
status = cusolverDnSetStream(handle, *context->getCudaStream());
auto status = cusolverDnSetStream(*handle, *context->getCudaStream());
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
@ -391,9 +395,9 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
// query working space of SVD
int lwork = 0;
if(A->dataType() == DataType::DOUBLE)
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, &lwork, gesvdjParams);
status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, &lwork, gesvdjParams);
else if(A->dataType() == DataType::FLOAT32)
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, &lwork, gesvdjParams);
status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, &lwork, gesvdjParams);
else
throw std::invalid_argument("svdJcb: given data type is unsupported !");
@ -410,10 +414,10 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
// choose appropriate cuda gemm api depending on data types
if(A->dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
}
else if(A->dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
}
else
throw std::invalid_argument("svdJcb: given data type is unsupported !");
@ -446,8 +450,8 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
cudaFree(devInfo);
if (dWork )
cudaFree(dWork);
if(handle)
cusolverDnDestroy(handle);
// if(handle)
// cusolverDnDestroy(handle);
if(gesvdjParams)
cusolverDnDestroyGesvdjInfo(gesvdjParams);