cavis/libnd4j/include/helpers/BlasHelper.h

445 lines
19 KiB
C++

/* ******************************************************************************
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_BLAS_HELPER_H
#define LIBND4J_BLAS_HELPER_H
#include <system/pointercast.h>
#include <types/float16.h>
#include <cblas.h>
#include <helpers/logger.h>
#ifdef _WIN32
#define CUBLASWINAPI __stdcall
#define CUSOLVERAPI __stdcall
#else
#define CUBLASWINAPI
#define CUSOLVERAPI
#endif
namespace sd {
typedef enum{
CUBLAS_STATUS_SUCCESS =0,
CUBLAS_STATUS_NOT_INITIALIZED =1,
CUBLAS_STATUS_ALLOC_FAILED =3,
CUBLAS_STATUS_INVALID_VALUE =7,
CUBLAS_STATUS_ARCH_MISMATCH =8,
CUBLAS_STATUS_MAPPING_ERROR =11,
CUBLAS_STATUS_EXECUTION_FAILED=13,
CUBLAS_STATUS_INTERNAL_ERROR =14,
CUBLAS_STATUS_NOT_SUPPORTED =15,
CUBLAS_STATUS_LICENSE_ERROR =16
} cublasStatus_t;
typedef enum {
CUBLAS_OP_N=0,
CUBLAS_OP_T=1,
CUBLAS_OP_C=2
} cublasOperation_t;
struct cublasContext;
typedef struct cublasContext *cublasHandle_t;
typedef enum
{
CUDA_R_16F= 2, /* real as a half */
CUDA_C_16F= 6, /* complex as a pair of half numbers */
CUDA_R_32F= 0, /* real as a float */
CUDA_C_32F= 4, /* complex as a pair of float numbers */
CUDA_R_64F= 1, /* real as a double */
CUDA_C_64F= 5, /* complex as a pair of double numbers */
CUDA_R_8I = 3, /* real as a signed char */
CUDA_C_8I = 7, /* complex as a pair of signed char numbers */
CUDA_R_8U = 8, /* real as a unsigned char */
CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */
CUDA_R_32I= 10, /* real as a signed int */
CUDA_C_32I= 11, /* complex as a pair of signed int numbers */
CUDA_R_32U= 12, /* real as a unsigned int */
CUDA_C_32U= 13 /* complex as a pair of unsigned int numbers */
} cublasDataType_t;
typedef void (*CblasSgemv)(CBLAS_ORDER Layout,
CBLAS_TRANSPOSE TransA, int M, int N,
float alpha, float *A, int lda,
float *X, int incX, float beta,
float *Y, int incY);
typedef void (*CblasDgemv)(CBLAS_ORDER Layout,
CBLAS_TRANSPOSE TransA, int M, int N,
double alpha, double *A, int lda,
double *X, int incX, double beta,
double *Y, int incY);
typedef void (*CblasSgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB, int M, int N,
int K, float alpha, float *A,
int lda, float *B, int ldb,
float beta, float *C, int ldc);
typedef void (*CblasDgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB, int M, int N,
int K, double alpha, double *A,
int lda, double *B, int ldb,
double beta, double *C, int ldc);
typedef void (*CblasSgemmBatch)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array,
CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array,
int *K_Array, float *alpha_Array, float **A_Array,
int *lda_Array, float **B_Array, int *ldb_Array,
float *beta_Array, float **C_Array, int *ldc_Array,
int group_count, int *group_size);
typedef void (*CblasDgemmBatch)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array,
CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array,
int *K_Array, double *alpha_Array, double **A_Array,
int *lda_Array, double **B_Array, int* ldb_Array,
double *beta_Array, double **C_Array, int *ldc_Array,
int group_count, int *group_size);
#ifdef LAPACK_ROW_MAJOR
#undef LAPACK_ROW_MAJOR
#endif
#ifdef LAPACK_COL_MAJOR
#undef LAPACK_COL_MAJOR
#endif
enum LAPACK_LAYOUT { LAPACK_ROW_MAJOR=101, LAPACK_COL_MAJOR=102 };
typedef int (*LapackeSgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt,
int m, int n, float* a, int lda,
float* s, float* u, int ldu, float* vt,
int ldvt, float* superb);
typedef int (*LapackeDgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt,
int m, int n, double* a,
int lda, double* s, double* u, int ldu,
double* vt, int ldvt, double* superb);
typedef int (*LapackeSgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m,
int n, float* a, int lda, float* s,
float* u, int ldu, float* vt,
int ldvt);
typedef int (*LapackeDgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m,
int n, double* a, int lda, double* s,
double* u, int ldu, double* vt,
int ldvt);
typedef cublasStatus_t (CUBLASWINAPI *CublasSgemv)(cublasHandle_t handle,
cublasOperation_t trans,
int m,
int n,
float *alpha, /* host or device pointer */
float *A,
int lda,
float *x,
int incx,
float *beta, /* host or device pointer */
float *y,
int incy);
typedef cublasStatus_t (CUBLASWINAPI *CublasDgemv)(cublasHandle_t handle,
cublasOperation_t trans,
int m,
int n,
double *alpha, /* host or device pointer */
double *A,
int lda,
double *x,
int incx,
double *beta, /* host or device pointer */
double *y,
int incy);
typedef cublasStatus_t (CUBLASWINAPI *CublasHgemm)(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
__half *alpha, /* host or device pointer */
__half *A,
int lda,
__half *B,
int ldb,
__half *beta, /* host or device pointer */
__half *C,
int ldc);
typedef cublasStatus_t (CUBLASWINAPI *CublasSgemm)(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host or device pointer */
float *A,
int lda,
float *B,
int ldb,
float *beta, /* host or device pointer */
float *C,
int ldc);
typedef cublasStatus_t (CUBLASWINAPI *CublasDgemm)(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
double *alpha, /* host or device pointer */
double *A,
int lda,
double *B,
int ldb,
double *beta, /* host or device pointer */
double *C,
int ldc);
typedef cublasStatus_t (CUBLASWINAPI *CublasSgemmEx)(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host or device pointer */
void *A,
cublasDataType_t Atype,
int lda,
void *B,
cublasDataType_t Btype,
int ldb,
float *beta, /* host or device pointer */
void *C,
cublasDataType_t Ctype,
int ldc);
typedef cublasStatus_t (CUBLASWINAPI *CublasHgemmBatched)(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
__half *alpha, /* host or device pointer */
__half *Aarray[],
int lda,
__half *Barray[],
int ldb,
__half *beta, /* host or device pointer */
__half *Carray[],
int ldc,
int batchCount);
typedef cublasStatus_t (CUBLASWINAPI *CublasSgemmBatched)(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host or device pointer */
float *Aarray[],
int lda,
float *Barray[],
int ldb,
float *beta, /* host or device pointer */
float *Carray[],
int ldc,
int batchCount);
typedef cublasStatus_t (CUBLASWINAPI *CublasDgemmBatched)(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
double *alpha, /* host or device pointer */
double *Aarray[],
int lda,
double *Barray[],
int ldb,
double *beta, /* host or device pointer */
double *Carray[],
int ldc,
int batchCount);
typedef enum{
CUSOLVER_STATUS_SUCCESS=0,
CUSOLVER_STATUS_NOT_INITIALIZED=1,
CUSOLVER_STATUS_ALLOC_FAILED=2,
CUSOLVER_STATUS_INVALID_VALUE=3,
CUSOLVER_STATUS_ARCH_MISMATCH=4,
CUSOLVER_STATUS_MAPPING_ERROR=5,
CUSOLVER_STATUS_EXECUTION_FAILED=6,
CUSOLVER_STATUS_INTERNAL_ERROR=7,
CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED=8,
CUSOLVER_STATUS_NOT_SUPPORTED = 9,
CUSOLVER_STATUS_ZERO_PIVOT=10,
CUSOLVER_STATUS_INVALID_LICENSE=11
} cusolverStatus_t;
typedef enum {
CUSOLVER_EIG_TYPE_1=1,
CUSOLVER_EIG_TYPE_2=2,
CUSOLVER_EIG_TYPE_3=3
} cusolverEigType_t ;
typedef enum {
CUSOLVER_EIG_MODE_NOVECTOR=0,
CUSOLVER_EIG_MODE_VECTOR=1
} cusolverEigMode_t ;
struct cusolverDnContext;
typedef struct cusolverDnContext *cusolverDnHandle_t;
typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnSgesvdBufferSize)(
cusolverDnHandle_t handle,
int m,
int n,
int *lwork);
typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnDgesvdBufferSize)(
cusolverDnHandle_t handle,
int m,
int n,
int *lwork);
typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnSgesvd)(
cusolverDnHandle_t handle,
signed char jobu,
signed char jobvt,
int m,
int n,
float *A,
int lda,
float *S,
float *U,
int ldu,
float *VT,
int ldvt,
float *work,
int lwork,
float *rwork,
int *info);
typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnDgesvd)(
cusolverDnHandle_t handle,
signed char jobu,
signed char jobvt,
int m,
int n,
double *A,
int lda,
double *S,
double *U,
int ldu,
double *VT,
int ldvt,
double *work,
int lwork,
double *rwork,
int *info);
enum BlasFunctions {
GEMV = 0,
GEMM = 1,
};
class BlasHelper {
private:
bool _hasHgemv = false;
bool _hasHgemm = false;
bool _hasHgemmBatch = false;
bool _hasSgemv = false;
bool _hasSgemm = false;
bool _hasSgemmBatch = false;
bool _hasDgemv = false;
bool _hasDgemm = false;
bool _hasDgemmBatch = false;
CblasSgemv cblasSgemv;
CblasDgemv cblasDgemv;
CblasSgemm cblasSgemm;
CblasDgemm cblasDgemm;
CblasSgemmBatch cblasSgemmBatch;
CblasDgemmBatch cblasDgemmBatch;
LapackeSgesvd lapackeSgesvd;
LapackeDgesvd lapackeDgesvd;
LapackeSgesdd lapackeSgesdd;
LapackeDgesdd lapackeDgesdd;
CublasSgemv cublasSgemv;
CublasDgemv cublasDgemv;
CublasHgemm cublasHgemm;
CublasSgemm cublasSgemm;
CublasDgemm cublasDgemm;
CublasSgemmEx cublasSgemmEx;
CublasHgemmBatched cublasHgemmBatched;
CublasSgemmBatched cublasSgemmBatched;
CublasDgemmBatched cublasDgemmBatched;
CusolverDnSgesvdBufferSize cusolverDnSgesvdBufferSize;
CusolverDnDgesvdBufferSize cusolverDnDgesvdBufferSize;
CusolverDnSgesvd cusolverDnSgesvd;
CusolverDnDgesvd cusolverDnDgesvd;
public:
static BlasHelper& getInstance();
void initializeFunctions(Nd4jPointer *functions);
void initializeDeviceFunctions(Nd4jPointer *functions);
template <typename T>
bool hasGEMV();
template <typename T>
bool hasGEMM();
bool hasGEMM(const sd::DataType dtype);
bool hasGEMV(const sd::DataType dtype);
template <typename T>
bool hasBatchedGEMM();
CblasSgemv sgemv();
CblasDgemv dgemv();
CblasSgemm sgemm();
CblasDgemm dgemm();
CblasSgemmBatch sgemmBatched();
CblasDgemmBatch dgemmBatched();
LapackeSgesvd sgesvd();
LapackeDgesvd dgesvd();
LapackeSgesdd sgesdd();
LapackeDgesdd dgesdd();
// destructor
~BlasHelper() noexcept;
};
}
#endif