Skip to content

Commit 7252551

Browse files
committed
Add support for tensor cores in float16.
1 parent ff234cf commit 7252551

3 files changed

Lines changed: 55 additions & 11 deletions

File tree

src/gpuarray_blas_cuda_cublas.c

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
443443
ASSERT_BUF(B);
444444
ASSERT_BUF(C);
445445

446-
if (cublasSgemmEx == NULL)
446+
if (cublasGemmEx == NULL && cublasSgemmEx == NULL)
447447
return error_set(ctx->err, GA_DEVSUP_ERROR, "cublasSgemmEx unavailable");
448448

449449
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
@@ -476,16 +476,30 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
476476
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
477477
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));
478478

479-
CUBLAS_EXIT_ON_ERROR(ctx, cublasSgemmEx(h->h, convT(transA), convT(transB),
480-
M, N, K,
481-
&alpha, ((uint16_t *)A->ptr) + offA,
482-
CUDA_R_16F,
483-
lda, ((uint16_t *)B->ptr) + offB,
484-
CUDA_R_16F,
485-
ldb, &beta, ((uint16_t *)C->ptr) + offC,
486-
CUDA_R_16F,
487-
ldc));
488-
479+
if (cublasGemmEx) {
480+
CUBLAS_EXIT_ON_ERROR(ctx, cublasGemmEx(h->h, convT(transA), convT(transB),
481+
M, N, K,
482+
&alpha, ((uint16_t *)A->ptr) + offA,
483+
CUDA_R_16F,
484+
lda, ((uint16_t *)B->ptr) + offB,
485+
CUDA_R_16F,
486+
ldb, &beta, ((uint16_t *)C->ptr) + offC,
487+
CUDA_R_16F,
488+
ldc,
489+
CUDA_R_32F,
490+
CUBLAS_GEMM_DFALT_TENSOR_OP));
491+
} else {
492+
CUBLAS_EXIT_ON_ERROR(ctx, cublasSgemmEx(h->h, convT(transA), convT(transB),
493+
M, N, K,
494+
&alpha, ((uint16_t *)A->ptr) + offA,
495+
CUDA_R_16F,
496+
lda, ((uint16_t *)B->ptr) + offB,
497+
CUDA_R_16F,
498+
ldb, &beta, ((uint16_t *)C->ptr) + offC,
499+
CUDA_R_16F,
500+
ldc));
501+
}
502+
489503
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
490504
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B, CUDA_WAIT_READ));
491505
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C, CUDA_WAIT_ALL));

src/loaders/libcublas.fn

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ DEF_PROC_V2(cublasDger, (cublasHandle_t handle, int m, int n, const double *alph
2121

2222
DEF_PROC_OPT(cublasSgemmEx, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const float *beta, void *C, cudaDataType Ctype, int ldc));
2323

24+
DEF_PROC_OPT(cublasGemmEx, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo));
25+
2426
DEF_PROC(cublasSgemmBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *Aarray[], int lda, const float *Barray[], int ldb, const float *beta, float *Carray[], int ldc, int batchCount));
2527
DEF_PROC(cublasDgemmBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *Aarray[], int lda, const double *Barray[], int ldb, const double *beta, double *Carray[], int ldc, int batchCount));
2628

src/loaders/libcublas.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,34 @@ typedef enum cudaDataType_t
3434
CUDA_C_32U= 13 // complex as a pair of unsigned int numbers
3535
} cudaDataType;
3636

37+
typedef cudaDataType cudaDataType_t;
38+
39+
typedef enum {
40+
CUBLAS_GEMM_DFALT = -1,
41+
CUBLAS_GEMM_ALGO0 = 0,
42+
CUBLAS_GEMM_ALGO1 = 1,
43+
CUBLAS_GEMM_ALGO2 = 2,
44+
CUBLAS_GEMM_ALGO3 = 3,
45+
CUBLAS_GEMM_ALGO4 = 4,
46+
CUBLAS_GEMM_ALGO5 = 5,
47+
CUBLAS_GEMM_ALGO6 = 6,
48+
CUBLAS_GEMM_ALGO7 = 7,
49+
CUBLAS_GEMM_ALGO8 = 8,
50+
CUBLAS_GEMM_ALGO9 = 9,
51+
CUBLAS_GEMM_ALGO10 = 10,
52+
CUBLAS_GEMM_ALGO11 = 11,
53+
CUBLAS_GEMM_ALGO12 = 12,
54+
CUBLAS_GEMM_ALGO13 = 13,
55+
CUBLAS_GEMM_ALGO14 = 14,
56+
CUBLAS_GEMM_ALGO15 = 15,
57+
CUBLAS_GEMM_ALGO16 = 16,
58+
CUBLAS_GEMM_ALGO17 = 17,
59+
CUBLAS_GEMM_DFALT_TENSOR_OP = 99,
60+
CUBLAS_GEMM_ALGO0_TENSOR_OP = 100,
61+
CUBLAS_GEMM_ALGO1_TENSOR_OP = 101,
62+
CUBLAS_GEMM_ALGO2_TENSOR_OP = 102
63+
} cublasGemmAlgo_t;
64+
3765
typedef struct CUstream_st *cudaStream_t;
3866

3967
typedef enum {

0 commit comments

Comments
 (0)