Shugeo cuda solver fix (#383)
* Refactored cuSolver handle usage to handle LaunchContext instance properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored svd solver usage with LaunchContext instance singleton. Signed-off-by: shugeo <sgazeos@gmail.com> * add device locks for cuSolver uses Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
18d4eaa68d
commit
0eca33ad94
|
@ -341,14 +341,16 @@ namespace helpers {
|
||||||
static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
|
static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
auto n = input->rows();
|
auto n = input->rows();
|
||||||
cusolverDnHandle_t cusolverH = nullptr;
|
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||||
|
|
||||||
|
cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr;
|
||||||
// create solver handle
|
// create solver handle
|
||||||
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
cusolverStatus_t status; //cusolverDnCreate(&cusolverH);
|
||||||
if (CUSOLVER_STATUS_SUCCESS != status) {
|
// if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
// throw cuda_exception::build("Cannot create cuSolver handle", status);
|
||||||
}
|
// }
|
||||||
// set solver stream
|
// set solver stream
|
||||||
status = cusolverDnSetStream(cusolverH, *stream);
|
status = cusolverDnSetStream(*cusolverH, *stream);
|
||||||
if (CUSOLVER_STATUS_SUCCESS != status) {
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
throw cuda_exception::build("Cannot set up stream for cuda solver", status);
|
throw cuda_exception::build("Cannot set up stream for cuda solver", status);
|
||||||
}
|
}
|
||||||
|
@ -368,7 +370,7 @@ namespace helpers {
|
||||||
// compute internal buffer size
|
// compute internal buffer size
|
||||||
double *matrix = reinterpret_cast<double *>(input->specialBuffer());
|
double *matrix = reinterpret_cast<double *>(input->specialBuffer());
|
||||||
status = cusolverDnDgetrf_bufferSize(
|
status = cusolverDnDgetrf_bufferSize(
|
||||||
cusolverH,
|
*cusolverH,
|
||||||
n,
|
n,
|
||||||
n,
|
n,
|
||||||
matrix,
|
matrix,
|
||||||
|
@ -386,7 +388,7 @@ namespace helpers {
|
||||||
|
|
||||||
if (permutation == nullptr) {
|
if (permutation == nullptr) {
|
||||||
status = cusolverDnDgetrf(
|
status = cusolverDnDgetrf(
|
||||||
cusolverH,
|
*cusolverH,
|
||||||
n,
|
n,
|
||||||
n,
|
n,
|
||||||
matrix,
|
matrix,
|
||||||
|
@ -404,7 +406,7 @@ namespace helpers {
|
||||||
NDArray permutVector('c', {n}, sd::DataType::INT32, context);
|
NDArray permutVector('c', {n}, sd::DataType::INT32, context);
|
||||||
int* permutationBuf = permutVector.dataBuffer()->specialAsT<int>();
|
int* permutationBuf = permutVector.dataBuffer()->specialAsT<int>();
|
||||||
status = cusolverDnDgetrf(
|
status = cusolverDnDgetrf(
|
||||||
cusolverH,
|
*cusolverH,
|
||||||
n,
|
n,
|
||||||
n,
|
n,
|
||||||
matrix,
|
matrix,
|
||||||
|
@ -440,7 +442,7 @@ namespace helpers {
|
||||||
float *d_work = nullptr;
|
float *d_work = nullptr;
|
||||||
|
|
||||||
status = cusolverDnSgetrf_bufferSize(
|
status = cusolverDnSgetrf_bufferSize(
|
||||||
cusolverH,
|
*cusolverH,
|
||||||
n,
|
n,
|
||||||
n,
|
n,
|
||||||
matrix,
|
matrix,
|
||||||
|
@ -458,7 +460,7 @@ namespace helpers {
|
||||||
|
|
||||||
if (permutation == nullptr)
|
if (permutation == nullptr)
|
||||||
status = cusolverDnSgetrf(
|
status = cusolverDnSgetrf(
|
||||||
cusolverH,
|
*cusolverH,
|
||||||
n,
|
n,
|
||||||
n,
|
n,
|
||||||
matrix,
|
matrix,
|
||||||
|
@ -470,7 +472,7 @@ namespace helpers {
|
||||||
NDArray permutVector('c', {n}, DataType::INT32, context);
|
NDArray permutVector('c', {n}, DataType::INT32, context);
|
||||||
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
|
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
|
||||||
status = cusolverDnSgetrf(
|
status = cusolverDnSgetrf(
|
||||||
cusolverH,
|
*cusolverH,
|
||||||
n,
|
n,
|
||||||
n,
|
n,
|
||||||
matrix,
|
matrix,
|
||||||
|
@ -504,7 +506,7 @@ namespace helpers {
|
||||||
if (err) {
|
if (err) {
|
||||||
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err);
|
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err);
|
||||||
}
|
}
|
||||||
cusolverDnDestroy(cusolverH);
|
// cusolverDnDestroy(cusolverH);
|
||||||
// NDArray::registerSpecialUse({input}, {input});
|
// NDArray::registerSpecialUse({input}, {input});
|
||||||
input->tickWriteDevice();
|
input->tickWriteDevice();
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,23 +170,25 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||||
|
|
||||||
// create cusolverDn handle
|
// create cusolverDn handle
|
||||||
cusolverDnHandle_t handle = nullptr;
|
cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr;
|
||||||
cusolverStatus_t status = cusolverDnCreate(&handle);
|
//cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||||
if(status != CUSOLVER_STATUS_SUCCESS)
|
if(handle == nullptr)
|
||||||
throw cuda_exception::build("svdQR: cuda failed !", status);
|
throw cuda_exception::build("svdQR: cuda failed !", -1);
|
||||||
|
|
||||||
// stream
|
// stream
|
||||||
status = cusolverDnSetStream(handle, *context->getCudaStream());
|
auto status = cusolverDnSetStream(*handle, *context->getCudaStream());
|
||||||
if(status != CUSOLVER_STATUS_SUCCESS)
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
throw cuda_exception::build("svdQR: cuda failed !", status);
|
throw cuda_exception::build("svdQR: cuda failed !", status);
|
||||||
|
|
||||||
// query working space of SVD
|
// query working space of SVD
|
||||||
int lwork = 0;
|
int lwork = 0;
|
||||||
if(A->dataType() == DataType::DOUBLE)
|
if(A->dataType() == DataType::DOUBLE)
|
||||||
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork);
|
status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork);
|
||||||
else if(A->dataType() == DataType::FLOAT32)
|
else if(A->dataType() == DataType::FLOAT32)
|
||||||
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork);
|
status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork);
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
||||||
|
|
||||||
|
@ -227,10 +229,10 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
|
||||||
|
|
||||||
// choose appropriate cuda gemm api depending on data types
|
// choose appropriate cuda gemm api depending on data types
|
||||||
if(A->dataType() == DataType::DOUBLE) {
|
if(A->dataType() == DataType::DOUBLE) {
|
||||||
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
|
status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
|
||||||
}
|
}
|
||||||
else if(A->dataType() == DataType::FLOAT32) {
|
else if(A->dataType() == DataType::FLOAT32) {
|
||||||
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
|
status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
||||||
|
@ -259,8 +261,8 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
|
||||||
if (rWork)
|
if (rWork)
|
||||||
cudaFree(rWork);
|
cudaFree(rWork);
|
||||||
|
|
||||||
if(handle)
|
// if(handle)
|
||||||
cusolverDnDestroy(handle);
|
// cusolverDnDestroy(handle);
|
||||||
|
|
||||||
// cudaDeviceReset();
|
// cudaDeviceReset();
|
||||||
}
|
}
|
||||||
|
@ -346,14 +348,16 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
||||||
ldv = pV->strideAt(1);
|
ldv = pV->strideAt(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||||
|
|
||||||
// create cusolverDn handle
|
// create cusolverDn handle
|
||||||
cusolverDnHandle_t handle = nullptr;
|
cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle();
|
||||||
cusolverStatus_t status = cusolverDnCreate(&handle);
|
//cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||||
if(status != CUSOLVER_STATUS_SUCCESS)
|
if(handle == nullptr)
|
||||||
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
throw cuda_exception::build("svdJcb: cuda failed !", -1);
|
||||||
|
|
||||||
// stream
|
// stream
|
||||||
status = cusolverDnSetStream(handle, *context->getCudaStream());
|
auto status = cusolverDnSetStream(*handle, *context->getCudaStream());
|
||||||
if(status != CUSOLVER_STATUS_SUCCESS)
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||||
|
|
||||||
|
@ -391,9 +395,9 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
||||||
// query working space of SVD
|
// query working space of SVD
|
||||||
int lwork = 0;
|
int lwork = 0;
|
||||||
if(A->dataType() == DataType::DOUBLE)
|
if(A->dataType() == DataType::DOUBLE)
|
||||||
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, &lwork, gesvdjParams);
|
status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, &lwork, gesvdjParams);
|
||||||
else if(A->dataType() == DataType::FLOAT32)
|
else if(A->dataType() == DataType::FLOAT32)
|
||||||
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, &lwork, gesvdjParams);
|
status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, &lwork, gesvdjParams);
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
||||||
|
|
||||||
|
@ -410,10 +414,10 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
||||||
|
|
||||||
// choose appropriate cuda gemm api depending on data types
|
// choose appropriate cuda gemm api depending on data types
|
||||||
if(A->dataType() == DataType::DOUBLE) {
|
if(A->dataType() == DataType::DOUBLE) {
|
||||||
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
|
status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : reinterpret_cast<double*>(nullPtr), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
|
||||||
}
|
}
|
||||||
else if(A->dataType() == DataType::FLOAT32) {
|
else if(A->dataType() == DataType::FLOAT32) {
|
||||||
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
|
status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : reinterpret_cast<float*>(nullPtr), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
||||||
|
@ -446,8 +450,8 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
||||||
cudaFree(devInfo);
|
cudaFree(devInfo);
|
||||||
if (dWork )
|
if (dWork )
|
||||||
cudaFree(dWork);
|
cudaFree(dWork);
|
||||||
if(handle)
|
// if(handle)
|
||||||
cusolverDnDestroy(handle);
|
// cusolverDnDestroy(handle);
|
||||||
if(gesvdjParams)
|
if(gesvdjParams)
|
||||||
cusolverDnDestroyGesvdjInfo(gesvdjParams);
|
cusolverDnDestroyGesvdjInfo(gesvdjParams);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue