Skip to content

Commit 1c1e068

Browse files
authored
Merge pull request #538 from abergeron/test
Add support for Tensor Cores in the BLAS bindings.
2 parents 2438e7a + cb1219a commit 1c1e068

5 files changed

Lines changed: 78 additions & 12 deletions

File tree

src/gpuarray_blas_cuda_cublas.c

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ typedef struct _blas_handle {
7272
GpuKernel dgemvBH_T_a1_b1_small;
7373
GpuKernel sgerBH_gen_small;
7474
GpuKernel dgerBH_gen_small;
75+
uint8_t tensorCore;
7576
} blas_handle;
7677

7778
#define LARGE_VAL(v) (v >= INT_MAX)
@@ -199,8 +200,10 @@ static const char *code_dgerBH_gen_small = \
199200
static int setup(gpucontext *c) {
200201
cuda_context *ctx = (cuda_context *)c;
201202
blas_handle *handle;
203+
CUdevice dev;
202204
cublasStatus_t err;
203205
int types[10];
206+
int major, minor;
204207
int e;
205208

206209
if (ctx->blas_handle != NULL)
@@ -211,6 +214,23 @@ static int setup(gpucontext *c) {
211214
return error_sys(ctx->err, "calloc");
212215

213216
cuda_enter(ctx);
217+
{
218+
CUresult err;
219+
err = cuCtxGetDevice(&dev);
220+
if (err != CUDA_SUCCESS) {
221+
cuda_exit(ctx);
222+
return error_cuda(ctx->err, "cuCtxGetDevice", err);
223+
}
224+
}
225+
GA_CUDA_EXIT_ON_ERROR(ctx, get_cc(dev, &major, &minor, ctx->err));
226+
227+
/* Only try to use tensor core on cuda 9 and up */
228+
if (ctx->major >= 9 && major >= 7 && minor >= 0) {
229+
handle->tensorCore = 1;
230+
} else {
231+
handle->tensorCore = 0;
232+
}
233+
214234
err = cublasCreate(&handle->h);
215235
if (err != CUBLAS_STATUS_SUCCESS) {
216236
cuda_exit(ctx);
@@ -443,8 +463,8 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
443463
ASSERT_BUF(B);
444464
ASSERT_BUF(C);
445465

446-
if (cublasSgemmEx == NULL)
447-
return error_set(ctx->err, GA_DEVSUP_ERROR, "cublasSgemmEx unavailable");
466+
if (cublasSgemmEx == NULL && (cublasGemmEx == NULL || h->tensorCore == 0))
467+
return error_set(ctx->err, GA_DEVSUP_ERROR, "cublasSgemmEx|cublasGemmEx unavailable");
448468

449469
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
450470
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
@@ -476,15 +496,29 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
476496
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
477497
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));
478498

479-
CUBLAS_EXIT_ON_ERROR(ctx, cublasSgemmEx(h->h, convT(transA), convT(transB),
480-
M, N, K,
481-
&alpha, ((uint16_t *)A->ptr) + offA,
482-
CUDA_R_16F,
483-
lda, ((uint16_t *)B->ptr) + offB,
484-
CUDA_R_16F,
485-
ldb, &beta, ((uint16_t *)C->ptr) + offC,
486-
CUDA_R_16F,
487-
ldc));
499+
if (cublasGemmEx != NULL && h->tensorCore) {
500+
CUBLAS_EXIT_ON_ERROR(ctx, cublasGemmEx(h->h, convT(transA), convT(transB),
501+
M, N, K,
502+
&alpha, ((uint16_t *)A->ptr) + offA,
503+
CUDA_R_16F,
504+
lda, ((uint16_t *)B->ptr) + offB,
505+
CUDA_R_16F,
506+
ldb, &beta, ((uint16_t *)C->ptr) + offC,
507+
CUDA_R_16F,
508+
ldc,
509+
CUDA_R_32F,
510+
CUBLAS_GEMM_DFALT_TENSOR_OP));
511+
} else {
512+
CUBLAS_EXIT_ON_ERROR(ctx, cublasSgemmEx(h->h, convT(transA), convT(transB),
513+
M, N, K,
514+
&alpha, ((uint16_t *)A->ptr) + offA,
515+
CUDA_R_16F,
516+
lda, ((uint16_t *)B->ptr) + offB,
517+
CUDA_R_16F,
518+
ldb, &beta, ((uint16_t *)C->ptr) + offC,
519+
CUDA_R_16F,
520+
ldc));
521+
}
488522

489523
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
490524
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B, CUDA_WAIT_READ));

src/gpuarray_buffer_cuda.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ static int cuda_memset(gpudata *dst, size_t dstoff, int data) {
10481048
return GA_NO_ERROR;
10491049
}
10501050

1051-
static int get_cc(CUdevice dev, int *maj, int *min, error *e) {
1051+
int get_cc(CUdevice dev, int *maj, int *min, error *e) {
10521052
CUresult err;
10531053
err = cuDeviceGetAttribute(maj,
10541054
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,

src/loaders/libcublas.fn

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ DEF_PROC_V2(cublasDger, (cublasHandle_t handle, int m, int n, const double *alph
2121

2222
DEF_PROC_OPT(cublasSgemmEx, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const void *A, cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, const float *beta, void *C, cudaDataType Ctype, int ldc));
2323

24+
DEF_PROC_OPT(cublasGemmEx, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void *alpha, const void *A, cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo));
25+
2426
DEF_PROC(cublasSgemmBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *Aarray[], int lda, const float *Barray[], int ldb, const float *beta, float *Carray[], int ldc, int batchCount));
2527
DEF_PROC(cublasDgemmBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *Aarray[], int lda, const double *Barray[], int ldb, const double *beta, double *Carray[], int ldc, int batchCount));
2628

src/loaders/libcublas.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,34 @@ typedef enum cudaDataType_t
3434
CUDA_C_32U= 13 // complex as a pair of unsigned int numbers
3535
} cudaDataType;
3636

37+
typedef cudaDataType cudaDataType_t;
38+
39+
typedef enum {
40+
CUBLAS_GEMM_DFALT = -1,
41+
CUBLAS_GEMM_ALGO0 = 0,
42+
CUBLAS_GEMM_ALGO1 = 1,
43+
CUBLAS_GEMM_ALGO2 = 2,
44+
CUBLAS_GEMM_ALGO3 = 3,
45+
CUBLAS_GEMM_ALGO4 = 4,
46+
CUBLAS_GEMM_ALGO5 = 5,
47+
CUBLAS_GEMM_ALGO6 = 6,
48+
CUBLAS_GEMM_ALGO7 = 7,
49+
CUBLAS_GEMM_ALGO8 = 8,
50+
CUBLAS_GEMM_ALGO9 = 9,
51+
CUBLAS_GEMM_ALGO10 = 10,
52+
CUBLAS_GEMM_ALGO11 = 11,
53+
CUBLAS_GEMM_ALGO12 = 12,
54+
CUBLAS_GEMM_ALGO13 = 13,
55+
CUBLAS_GEMM_ALGO14 = 14,
56+
CUBLAS_GEMM_ALGO15 = 15,
57+
CUBLAS_GEMM_ALGO16 = 16,
58+
CUBLAS_GEMM_ALGO17 = 17,
59+
CUBLAS_GEMM_DFALT_TENSOR_OP = 99,
60+
CUBLAS_GEMM_ALGO0_TENSOR_OP = 100,
61+
CUBLAS_GEMM_ALGO1_TENSOR_OP = 101,
62+
CUBLAS_GEMM_ALGO2_TENSOR_OP = 102
63+
} cublasGemmAlgo_t;
64+
3765
typedef struct CUstream_st *cudaStream_t;
3866

3967
typedef enum {

src/private_cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,6 @@ struct _gpukernel {
157157
#endif
158158
};
159159

160+
int get_cc(CUdevice dev, int *maj, int *min, error *e);
161+
160162
#endif

0 commit comments

Comments
 (0)