@@ -200,8 +200,10 @@ static const char *code_dgerBH_gen_small = \
200200static int setup (gpucontext * c ) {
201201 cuda_context * ctx = (cuda_context * )c ;
202202 blas_handle * handle ;
203+ CUdevice dev ;
203204 cublasStatus_t err ;
204205 int types [10 ];
206+ int major , minor ;
205207 int e ;
206208
207209 if (ctx -> blas_handle != NULL )
@@ -211,14 +213,24 @@ static int setup(gpucontext *c) {
211213 if (handle == NULL )
212214 return error_sys (ctx -> err , "calloc" );
213215
216+ cuda_enter (ctx );
217+ {
218+ CUresult err ;
219+ err = cuCtxGetDevice (& dev );
220+ if (err != CUDA_SUCCESS ) {
221+ cuda_exit (ctx );
222+ return error_cuda (ctx -> err , "cuCtxGetDevice" , err );
223+ }
224+ }
225+ GA_CUDA_EXIT_ON_ERROR (ctx , get_cc (dev , & major , & minor , ctx -> err ));
226+
214227 /* Only try to use tensor core on cuda 9 and up */
215- if (ctx -> major >= 9 ) {
228+ if (ctx -> major >= 9 && major >= 7 && minor >= 0 ) {
216229 handle -> tensorCore = 1 ;
217230 } else {
218231 handle -> tensorCore = 0 ;
219232 }
220233
221- cuda_enter (ctx );
222234 err = cublasCreate (& handle -> h );
223235 if (err != CUBLAS_STATUS_SUCCESS ) {
224236 cuda_exit (ctx );
@@ -507,7 +519,7 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
507519 CUDA_R_16F ,
508520 ldc ));
509521 }
510-
522+
511523 GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (A , CUDA_WAIT_READ ));
512524 GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (B , CUDA_WAIT_READ ));
513525 GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (C , CUDA_WAIT_ALL ));
0 commit comments