[WIP] HGemm (#181)

* skip string arrays for device validation

Signed-off-by: raver119 <raver119@gmail.com>

* confusion_matrix fix

Signed-off-by: raver119 <raver119@gmail.com>

* exclude cublasHGemm from archs < 530

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-27 15:05:43 +03:00 committed by GitHub
parent 0e523490e9
commit 7f0c660d8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -228,6 +228,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
float alphaF(alpha), betaF(beta);
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
}
#if __CUDA_ARCH__ >= 530
else if(ABC && aType == DataType::HALF) {
float16 alphaH(alpha), betaH(beta);
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
@ -240,6 +241,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
float alphaF(alpha), betaF(beta);
status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc);
}
#endif
else {
dim3 threadsPerBlock(N, M);
dim3 blocksPerGrid(1, 1);