|
|
|
void bgemm(const std::vector<NDArray*>& vA, const std::vector<NDArray*>& vB, std::vector<NDArray*>& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc) {
|
|
|
|
status = cublasDgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const double**)aBuffers, lda, (const double**)bBuffers, ldb, &beta, (double**)cBuffers, ldc, bS);
|
|
|
|
status = cublasSgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const float**)aBuffers, lda, (const float**)bBuffers, ldb, &beta, (float**)cBuffers, ldc, bS);
|
|
|
|
status = cublasHgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const __half**)aBuffers, lda, (const __half**)bBuffers, ldb, &beta, (__half**)cBuffers, ldc, bS);
|
|
|
|
status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &alpha, aBuffers, CUDA_R_8I, lda, bBuffers, CUDA_R_8I, ldb, &beta, cBuffers, CUDA_R_32F, ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
|
|
|
|
status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &alpha, aBuffers, CUDA_R_16F, lda, bBuffers, CUDA_R_16F, ldb, &beta, cBuffers, CUDA_R_32F, ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
|