Shugeo cuda solver fix (#383)

* Refactored cuSolver handle usage to handle LaunchContext instance properly.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored svd solver usage with LaunchContext instance singleton.

Signed-off-by: shugeo <sgazeos@gmail.com>

* add device locks for cuSolver uses

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
shugeo 2020-04-17 16:52:08 +03:00 committed by GitHub
parent 18d4eaa68d
commit 0eca33ad94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 35 deletions

View File

@ -341,14 +341,16 @@ namespace helpers {
static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
auto n = input->rows(); auto n = input->rows();
cusolverDnHandle_t cusolverH = nullptr; std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr;
// create solver handle // create solver handle
cusolverStatus_t status = cusolverDnCreate(&cusolverH); cusolverStatus_t status; //cusolverDnCreate(&cusolverH);
if (CUSOLVER_STATUS_SUCCESS != status) { // if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("Cannot create cuSolver handle", status); // throw cuda_exception::build("Cannot create cuSolver handle", status);
} // }
// set solver stream // set solver stream
status = cusolverDnSetStream(cusolverH, *stream); status = cusolverDnSetStream(*cusolverH, *stream);
if (CUSOLVER_STATUS_SUCCESS != status) { if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("Cannot set up stream for cuda solver", status); throw cuda_exception::build("Cannot set up stream for cuda solver", status);
} }
@ -368,7 +370,7 @@ namespace helpers {
// compute internal buffer size // compute internal buffer size
double *matrix = reinterpret_cast<double *>(input->specialBuffer()); double *matrix = reinterpret_cast<double *>(input->specialBuffer());
status = cusolverDnDgetrf_bufferSize( status = cusolverDnDgetrf_bufferSize(
cusolverH, *cusolverH,
n, n,
n, n,
matrix, matrix,
@ -386,7 +388,7 @@ namespace helpers {
if (permutation == nullptr) { if (permutation == nullptr) {
status = cusolverDnDgetrf( status = cusolverDnDgetrf(
cusolverH, *cusolverH,
n, n,
n, n,
matrix, matrix,
@ -404,7 +406,7 @@ namespace helpers {
NDArray permutVector('c', {n}, sd::DataType::INT32, context); NDArray permutVector('c', {n}, sd::DataType::INT32, context);
int* permutationBuf = permutVector.dataBuffer()->specialAsT<int>(); int* permutationBuf = permutVector.dataBuffer()->specialAsT<int>();
status = cusolverDnDgetrf( status = cusolverDnDgetrf(
cusolverH, *cusolverH,
n, n,
n, n,
matrix, matrix,
@ -440,7 +442,7 @@ namespace helpers {
float *d_work = nullptr; float *d_work = nullptr;
status = cusolverDnSgetrf_bufferSize( status = cusolverDnSgetrf_bufferSize(
cusolverH, *cusolverH,
n, n,
n, n,
matrix, matrix,
@ -458,7 +460,7 @@ namespace helpers {
if (permutation == nullptr) if (permutation == nullptr)
status = cusolverDnSgetrf( status = cusolverDnSgetrf(
cusolverH, *cusolverH,
n, n,
n, n,
matrix, matrix,
@ -470,7 +472,7 @@ namespace helpers {
NDArray permutVector('c', {n}, DataType::INT32, context); NDArray permutVector('c', {n}, DataType::INT32, context);
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer()); int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
status = cusolverDnSgetrf( status = cusolverDnSgetrf(
cusolverH, *cusolverH,
n, n,
n, n,
matrix, matrix,
@ -504,7 +506,7 @@ namespace helpers {
if (err) { if (err) {
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err); throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err);
} }
cusolverDnDestroy(cusolverH); // cusolverDnDestroy(cusolverH);
// NDArray::registerSpecialUse({input}, {input}); // NDArray::registerSpecialUse({input}, {input});
input->tickWriteDevice(); input->tickWriteDevice();
} }

View File

@ -170,23 +170,25 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
} }
} }
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
// create cusolverDn handle // create cusolverDn handle
cusolverDnHandle_t handle = nullptr; cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr;
cusolverStatus_t status = cusolverDnCreate(&handle); //cusolverStatus_t status = cusolverDnCreate(&handle);
if(status != CUSOLVER_STATUS_SUCCESS) if(handle == nullptr)
throw cuda_exception::build("svdQR: cuda failed !", status); throw cuda_exception::build("svdQR: cuda failed !", -1);
// stream // stream
status = cusolverDnSetStream(handle, *context->getCudaStream()); auto status = cusolverDnSetStream(*handle, *context->getCudaStream());
if(status != CUSOLVER_STATUS_SUCCESS) if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdQR: cuda failed !", status); throw cuda_exception::build("svdQR: cuda failed !", status);
// query working space of SVD // query working space of SVD
int lwork = 0; int lwork = 0;
if(A->dataType() == DataType::DOUBLE) if(A->dataType() == DataType::DOUBLE)
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork); status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork);
else if(A->dataType() == DataType::FLOAT32) else if(A->dataType() == DataType::FLOAT32)
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork); status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork);
else else
throw std::invalid_argument("svdQR: given data type is unsupported !"); throw std::invalid_argument("svdQR: given data type is unsupported !");
@ -227,10 +229,10 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
if(A->dataType() == DataType::DOUBLE) { if(A->dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo); status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
} }
else if(A->dataType() == DataType::FLOAT32) { else if(A->dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo); status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
} }
else else
throw std::invalid_argument("svdQR: given data type is unsupported !"); throw std::invalid_argument("svdQR: given data type is unsupported !");
@ -259,8 +261,8 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
if (rWork) if (rWork)
cudaFree(rWork); cudaFree(rWork);
if(handle) // if(handle)
cusolverDnDestroy(handle); // cusolverDnDestroy(handle);
// cudaDeviceReset(); // cudaDeviceReset();
} }
@ -346,14 +348,16 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
ldv = pV->strideAt(1); ldv = pV->strideAt(1);
} }
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
// create cusolverDn handle // create cusolverDn handle
cusolverDnHandle_t handle = nullptr; cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle();
cusolverStatus_t status = cusolverDnCreate(&handle); //cusolverStatus_t status = cusolverDnCreate(&handle);
if(status != CUSOLVER_STATUS_SUCCESS) if(handle == nullptr)
throw cuda_exception::build("svdJcb: cuda failed !", status); throw cuda_exception::build("svdJcb: cuda failed !", -1);
// stream // stream
status = cusolverDnSetStream(handle, *context->getCudaStream()); auto status = cusolverDnSetStream(*handle, *context->getCudaStream());
if(status != CUSOLVER_STATUS_SUCCESS) if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status); throw cuda_exception::build("svdJcb: cuda failed !", status);
@ -391,9 +395,9 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
// query working space of SVD // query working space of SVD
int lwork = 0; int lwork = 0;
if(A->dataType() == DataType::DOUBLE) if(A->dataType() == DataType::DOUBLE)
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, &lwork, gesvdjParams); status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, &lwork, gesvdjParams);
else if(A->dataType() == DataType::FLOAT32) else if(A->dataType() == DataType::FLOAT32)
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, &lwork, gesvdjParams); status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, &lwork, gesvdjParams);
else else
throw std::invalid_argument("svdJcb: given data type is unsupported !"); throw std::invalid_argument("svdJcb: given data type is unsupported !");
@ -410,10 +414,10 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
if(A->dataType() == DataType::DOUBLE) { if(A->dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams); status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
} }
else if(A->dataType() == DataType::FLOAT32) { else if(A->dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams); status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
} }
else else
throw std::invalid_argument("svdJcb: given data type is unsupported !"); throw std::invalid_argument("svdJcb: given data type is unsupported !");
@ -446,8 +450,8 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
cudaFree(devInfo); cudaFree(devInfo);
if (dWork ) if (dWork )
cudaFree(dWork); cudaFree(dWork);
if(handle) // if(handle)
cusolverDnDestroy(handle); // cusolverDnDestroy(handle);
if(gesvdjParams) if(gesvdjParams)
cusolverDnDestroyGesvdjInfo(gesvdjParams); cusolverDnDestroyGesvdjInfo(gesvdjParams);