Skip to content

Commit c193b87

Browse files
committed
Don't attempt to use tensor core on CUDA < 9.0.
1 parent 7252551 commit c193b87

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

src/gpuarray_blas_cuda_cublas.c

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ typedef struct _blas_handle {
7272
GpuKernel dgemvBH_T_a1_b1_small;
7373
GpuKernel sgerBH_gen_small;
7474
GpuKernel dgerBH_gen_small;
75+
uint8_t tensorCore;
7576
} blas_handle;
7677

7778
#define LARGE_VAL(v) (v >= INT_MAX)
@@ -210,6 +211,13 @@ static int setup(gpucontext *c) {
210211
if (handle == NULL)
211212
return error_sys(ctx->err, "calloc");
212213

214+
/* Only try to use tensor core on cuda 9 and up */
215+
if (ctx->major >= 9) {
216+
handle->tensorCore = 1;
217+
} else {
218+
handle->tensorCore = 0;
219+
}
220+
213221
cuda_enter(ctx);
214222
err = cublasCreate(&handle->h);
215223
if (err != CUBLAS_STATUS_SUCCESS) {
@@ -443,8 +451,8 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
443451
ASSERT_BUF(B);
444452
ASSERT_BUF(C);
445453

446-
if (cublasGemmEx == NULL && cublasSgemmEx == NULL)
447-
return error_set(ctx->err, GA_DEVSUP_ERROR, "cublasSgemmEx unavailable");
454+
if (cublasSgemmEx == NULL && (cublasGemmEx == NULL || h->tensorCore == 0))
455+
return error_set(ctx->err, GA_DEVSUP_ERROR, "cublasSgemmEx|cublasGemmEx unavailable");
448456

449457
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
450458
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
@@ -476,7 +484,7 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
476484
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
477485
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));
478486

479-
if (cublasGemmEx) {
487+
if (cublasGemmEx != NULL && h->tensorCore) {
480488
CUBLAS_EXIT_ON_ERROR(ctx, cublasGemmEx(h->h, convT(transA), convT(transB),
481489
M, N, K,
482490
&alpha, ((uint16_t *)A->ptr) + offA,

0 commit comments

Comments
 (0)