@@ -72,6 +72,7 @@ typedef struct _blas_handle {
7272 GpuKernel dgemvBH_T_a1_b1_small ;
7373 GpuKernel sgerBH_gen_small ;
7474 GpuKernel dgerBH_gen_small ;
75+ uint8_t tensorCore ;
7576} blas_handle ;
7677
7778#define LARGE_VAL (v ) (v >= INT_MAX)
@@ -210,6 +211,13 @@ static int setup(gpucontext *c) {
210211 if (handle == NULL )
211212 return error_sys (ctx -> err , "calloc" );
212213
214+ /* Only try to use tensor core on cuda 9 and up */
215+ if (ctx -> major >= 9 ) {
216+ handle -> tensorCore = 1 ;
217+ } else {
218+ handle -> tensorCore = 0 ;
219+ }
220+
213221 cuda_enter (ctx );
214222 err = cublasCreate (& handle -> h );
215223 if (err != CUBLAS_STATUS_SUCCESS ) {
@@ -443,8 +451,8 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
443451 ASSERT_BUF (B );
444452 ASSERT_BUF (C );
445453
446- if (cublasGemmEx == NULL && cublasSgemmEx == NULL )
447- return error_set (ctx -> err , GA_DEVSUP_ERROR , "cublasSgemmEx unavailable" );
454+ if (cublasSgemmEx == NULL && ( cublasGemmEx == NULL || h -> tensorCore == 0 ) )
455+ return error_set (ctx -> err , GA_DEVSUP_ERROR , "cublasSgemmEx|cublasGemmEx unavailable" );
448456
449457 if (LARGE_VAL (M ) || LARGE_VAL (N ) || LARGE_VAL (K ) ||
450458 LARGE_VAL (lda ) || LARGE_VAL (ldb ) || LARGE_VAL (ldc ) ||
@@ -476,7 +484,7 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
476484 GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (B , CUDA_WAIT_READ ));
477485 GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (C , CUDA_WAIT_ALL ));
478486
479- if (cublasGemmEx ) {
487+ if (cublasGemmEx != NULL && h -> tensorCore ) {
480488 CUBLAS_EXIT_ON_ERROR (ctx , cublasGemmEx (h -> h , convT (transA ), convT (transB ),
481489 M , N , K ,
482490 & alpha , ((uint16_t * )A -> ptr ) + offA ,
0 commit comments