Shugeo cuda solver fix (#383)
* Refactored cuSolver handle usage to handle LaunchContext instance properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored svd solver usage with LaunchContext instance singleton. Signed-off-by: shugeo <sgazeos@gmail.com> * add device locks for cuSolver uses Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
18d4eaa68d
commit
0eca33ad94
|
@ -341,14 +341,16 @@ namespace helpers {
|
|||
static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
|
||||
auto stream = context->getCudaStream();
|
||||
auto n = input->rows();
|
||||
cusolverDnHandle_t cusolverH = nullptr;
|
||||
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||
|
||||
cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr;
|
||||
// create solver handle
|
||||
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
||||
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
||||
}
|
||||
cusolverStatus_t status; //cusolverDnCreate(&cusolverH);
|
||||
// if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||
// throw cuda_exception::build("Cannot create cuSolver handle", status);
|
||||
// }
|
||||
// set solver stream
|
||||
status = cusolverDnSetStream(cusolverH, *stream);
|
||||
status = cusolverDnSetStream(*cusolverH, *stream);
|
||||
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||
throw cuda_exception::build("Cannot set up stream for cuda solver", status);
|
||||
}
|
||||
|
@ -368,7 +370,7 @@ namespace helpers {
|
|||
// compute internal buffer size
|
||||
double *matrix = reinterpret_cast<double *>(input->specialBuffer());
|
||||
status = cusolverDnDgetrf_bufferSize(
|
||||
cusolverH,
|
||||
*cusolverH,
|
||||
n,
|
||||
n,
|
||||
matrix,
|
||||
|
@ -386,7 +388,7 @@ namespace helpers {
|
|||
|
||||
if (permutation == nullptr) {
|
||||
status = cusolverDnDgetrf(
|
||||
cusolverH,
|
||||
*cusolverH,
|
||||
n,
|
||||
n,
|
||||
matrix,
|
||||
|
@ -404,7 +406,7 @@ namespace helpers {
|
|||
NDArray permutVector('c', {n}, sd::DataType::INT32, context);
|
||||
int* permutationBuf = permutVector.dataBuffer()->specialAsT<int>();
|
||||
status = cusolverDnDgetrf(
|
||||
cusolverH,
|
||||
*cusolverH,
|
||||
n,
|
||||
n,
|
||||
matrix,
|
||||
|
@ -440,7 +442,7 @@ namespace helpers {
|
|||
float *d_work = nullptr;
|
||||
|
||||
status = cusolverDnSgetrf_bufferSize(
|
||||
cusolverH,
|
||||
*cusolverH,
|
||||
n,
|
||||
n,
|
||||
matrix,
|
||||
|
@ -458,7 +460,7 @@ namespace helpers {
|
|||
|
||||
if (permutation == nullptr)
|
||||
status = cusolverDnSgetrf(
|
||||
cusolverH,
|
||||
*cusolverH,
|
||||
n,
|
||||
n,
|
||||
matrix,
|
||||
|
@ -470,7 +472,7 @@ namespace helpers {
|
|||
NDArray permutVector('c', {n}, DataType::INT32, context);
|
||||
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
|
||||
status = cusolverDnSgetrf(
|
||||
cusolverH,
|
||||
*cusolverH,
|
||||
n,
|
||||
n,
|
||||
matrix,
|
||||
|
@ -504,7 +506,7 @@ namespace helpers {
|
|||
if (err) {
|
||||
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err);
|
||||
}
|
||||
cusolverDnDestroy(cusolverH);
|
||||
// cusolverDnDestroy(cusolverH);
|
||||
// NDArray::registerSpecialUse({input}, {input});
|
||||
input->tickWriteDevice();
|
||||
}
|
||||
|
|
|
@ -170,23 +170,25 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
|
|||
}
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||
|
||||
// create cusolverDn handle
|
||||
cusolverDnHandle_t handle = nullptr;
|
||||
cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||
throw cuda_exception::build("svdQR: cuda failed !", status);
|
||||
cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr;
|
||||
//cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||
if(handle == nullptr)
|
||||
throw cuda_exception::build("svdQR: cuda failed !", -1);
|
||||
|
||||
// stream
|
||||
status = cusolverDnSetStream(handle, *context->getCudaStream());
|
||||
auto status = cusolverDnSetStream(*handle, *context->getCudaStream());
|
||||
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||
throw cuda_exception::build("svdQR: cuda failed !", status);
|
||||
|
||||
// query working space of SVD
|
||||
int lwork = 0;
|
||||
if(A->dataType() == DataType::DOUBLE)
|
||||
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork);
|
||||
status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork);
|
||||
else if(A->dataType() == DataType::FLOAT32)
|
||||
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork);
|
||||
status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork);
|
||||
else
|
||||
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
||||
|
||||
|
@ -227,10 +229,10 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
|
|||
|
||||
// choose appropriate cuda gemm api depending on data types
|
||||
if(A->dataType() == DataType::DOUBLE) {
|
||||
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
|
||||
status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
|
||||
}
|
||||
else if(A->dataType() == DataType::FLOAT32) {
|
||||
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
|
||||
status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
|
||||
}
|
||||
else
|
||||
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
||||
|
@ -259,8 +261,8 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
|
|||
if (rWork)
|
||||
cudaFree(rWork);
|
||||
|
||||
if(handle)
|
||||
cusolverDnDestroy(handle);
|
||||
// if(handle)
|
||||
// cusolverDnDestroy(handle);
|
||||
|
||||
// cudaDeviceReset();
|
||||
}
|
||||
|
@ -346,14 +348,16 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
|||
ldv = pV->strideAt(1);
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||
|
||||
// create cusolverDn handle
|
||||
cusolverDnHandle_t handle = nullptr;
|
||||
cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||
cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle();
|
||||
//cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||
if(handle == nullptr)
|
||||
throw cuda_exception::build("svdJcb: cuda failed !", -1);
|
||||
|
||||
// stream
|
||||
status = cusolverDnSetStream(handle, *context->getCudaStream());
|
||||
auto status = cusolverDnSetStream(*handle, *context->getCudaStream());
|
||||
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||
|
||||
|
@ -391,9 +395,9 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
|||
// query working space of SVD
|
||||
int lwork = 0;
|
||||
if(A->dataType() == DataType::DOUBLE)
|
||||
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, &lwork, gesvdjParams);
|
||||
status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, &lwork, gesvdjParams);
|
||||
else if(A->dataType() == DataType::FLOAT32)
|
||||
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, &lwork, gesvdjParams);
|
||||
status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, &lwork, gesvdjParams);
|
||||
else
|
||||
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
||||
|
||||
|
@ -410,10 +414,10 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
|||
|
||||
// choose appropriate cuda gemm api depending on data types
|
||||
if(A->dataType() == DataType::DOUBLE) {
|
||||
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
|
||||
status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
|
||||
}
|
||||
else if(A->dataType() == DataType::FLOAT32) {
|
||||
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
|
||||
status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
|
||||
}
|
||||
else
|
||||
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
||||
|
@ -446,8 +450,8 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
|||
cudaFree(devInfo);
|
||||
if (dWork )
|
||||
cudaFree(dWork);
|
||||
if(handle)
|
||||
cusolverDnDestroy(handle);
|
||||
// if(handle)
|
||||
// cusolverDnDestroy(handle);
|
||||
if(gesvdjParams)
|
||||
cusolverDnDestroyGesvdjInfo(gesvdjParams);
|
||||
|
||||
|
|
Loading…
Reference in New Issue