[WIP] HGemm (#181)
* skip string arrays for device validation Signed-off-by: raver119 <raver119@gmail.com> * confusion_matrix fix Signed-off-by: raver119 <raver119@gmail.com> * exclude cublasHGemm from archs < 530 Signed-off-by: raver119 <raver119@gmail.com>master
parent
0e523490e9
commit
7f0c660d8b
|
@ -228,6 +228,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
float alphaF(alpha), betaF(beta);
|
float alphaF(alpha), betaF(beta);
|
||||||
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
|
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
|
||||||
}
|
}
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
else if(ABC && aType == DataType::HALF) {
|
else if(ABC && aType == DataType::HALF) {
|
||||||
float16 alphaH(alpha), betaH(beta);
|
float16 alphaH(alpha), betaH(beta);
|
||||||
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
|
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
|
||||||
|
@ -240,6 +241,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
float alphaF(alpha), betaF(beta);
|
float alphaF(alpha), betaF(beta);
|
||||||
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
|
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
else {
|
else {
|
||||||
dim3 threadsPerBlock(N, M);
|
dim3 threadsPerBlock(N, M);
|
||||||
dim3 blocksPerGrid(1, 1);
|
dim3 blocksPerGrid(1, 1);
|
||||||
|
|
Loading…
Reference in New Issue