From 0eca33ad94030d5bbc6779e03107a60910125a0f Mon Sep 17 00:00:00 2001 From: shugeo Date: Fri, 17 Apr 2020 16:52:08 +0300 Subject: [PATCH] Shugeo cuda solver fix (#383) * Refactored cuSolver handle usage to handle LaunchContext instance properly. Signed-off-by: shugeo * Refactored svd solver usage with LaunchContext instance singleton. Signed-off-by: shugeo * add device locks for cuSolver uses Signed-off-by: raver119 Co-authored-by: raver119 --- .../ops/declarable/helpers/cuda/lup.cu | 28 ++++++----- .../ops/declarable/helpers/cuda/svd.cu | 48 ++++++++++--------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 2ca731912..c986260e8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -341,14 +341,16 @@ namespace helpers { static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { auto stream = context->getCudaStream(); auto n = input->rows(); - cusolverDnHandle_t cusolverH = nullptr; + std::lock_guard lock(*LaunchContext::deviceMutex()); + + cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; // create solver handle - cusolverStatus_t status = cusolverDnCreate(&cusolverH); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("Cannot create cuSolver handle", status); - } + cusolverStatus_t status; //cusolverDnCreate(&cusolverH); +// if (CUSOLVER_STATUS_SUCCESS != status) { +// throw cuda_exception::build("Cannot create cuSolver handle", status); +// } // set solver stream - status = cusolverDnSetStream(cusolverH, *stream); + status = cusolverDnSetStream(*cusolverH, *stream); if (CUSOLVER_STATUS_SUCCESS != status) { throw cuda_exception::build("Cannot set up stream for cuda solver", status); } @@ -368,7 +370,7 @@ namespace helpers { // compute internal buffer size double *matrix = reinterpret_cast(input->specialBuffer()); status = cusolverDnDgetrf_bufferSize( - cusolverH, + *cusolverH, n, n, matrix, @@ -386,7 +388,7 @@ namespace helpers { if (permutation == nullptr) { status = cusolverDnDgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -404,7 +406,7 @@ namespace helpers { NDArray permutVector('c', {n}, sd::DataType::INT32, context); int* permutationBuf = permutVector.dataBuffer()->specialAsT(); status = cusolverDnDgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -440,7 +442,7 @@ namespace helpers { float *d_work = nullptr; status = cusolverDnSgetrf_bufferSize( - cusolverH, + *cusolverH, n, n, matrix, @@ -458,7 +460,7 @@ namespace helpers { if (permutation == nullptr) status = cusolverDnSgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -470,7 +472,7 @@ namespace helpers { NDArray permutVector('c', {n}, DataType::INT32, context); int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); status = cusolverDnSgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -504,7 +506,7 @@ namespace helpers { if (err) { throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err); } - cusolverDnDestroy(cusolverH); +// cusolverDnDestroy(cusolverH); // NDArray::registerSpecialUse({input}, {input}); input->tickWriteDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index 44f924bf0..5c3d2811c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -170,23 +170,25 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr } } + std::lock_guard lock(*LaunchContext::deviceMutex()); + // create cusolverDn handle - cusolverDnHandle_t handle = nullptr; - cusolverStatus_t status = cusolverDnCreate(&handle); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); + cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; + //cusolverStatus_t status = cusolverDnCreate(&handle); + if(handle == nullptr) + throw cuda_exception::build("svdQR: cuda failed !", -1); // stream - status = cusolverDnSetStream(handle, *context->getCudaStream()); + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); if(status != CUSOLVER_STATUS_SUCCESS) throw cuda_exception::build("svdQR: cuda failed !", status); // query working space of SVD int lwork = 0; if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork); + status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork); else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork); + status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork); else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -227,10 +229,10 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr // choose appropriate cuda gemm api depending on data types if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -259,8 +261,8 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr if (rWork) cudaFree(rWork); - if(handle) - cusolverDnDestroy(handle); +// if(handle) +// cusolverDnDestroy(handle); // cudaDeviceReset(); } @@ -346,14 +348,16 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA ldv = pV->strideAt(1); } + std::lock_guard lock(*LaunchContext::deviceMutex()); + // create cusolverDn handle - cusolverDnHandle_t handle = nullptr; - cusolverStatus_t status = cusolverDnCreate(&handle); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); + cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); + //cusolverStatus_t status = cusolverDnCreate(&handle); + if(handle == nullptr) + throw cuda_exception::build("svdJcb: cuda failed !", -1); // stream - status = cusolverDnSetStream(handle, *context->getCudaStream()); + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); if(status != CUSOLVER_STATUS_SUCCESS) throw cuda_exception::build("svdJcb: cuda failed !", status); @@ -391,9 +395,9 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA // query working space of SVD int lwork = 0; if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); + status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); + status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -410,10 +414,10 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA // choose appropriate cuda gemm api depending on data types if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -446,8 +450,8 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA cudaFree(devInfo); if (dWork ) cudaFree(dWork); - if(handle) - cusolverDnDestroy(handle); +// if(handle) +// cusolverDnDestroy(handle); if(gesvdjParams) cusolverDnDestroyGesvdjInfo(gesvdjParams);