Skip to content

Commit cb1219a

Browse files
committed
Actually check for the compute capability of the current device and only ask for tensor core when it is available.
1 parent c193b87 commit cb1219a

3 files changed

Lines changed: 18 additions & 4 deletions

File tree

src/gpuarray_blas_cuda_cublas.c

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,10 @@ static const char *code_dgerBH_gen_small = \
200200
static int setup(gpucontext *c) {
201201
cuda_context *ctx = (cuda_context *)c;
202202
blas_handle *handle;
203+
CUdevice dev;
203204
cublasStatus_t err;
204205
int types[10];
206+
int major, minor;
205207
int e;
206208

207209
if (ctx->blas_handle != NULL)
@@ -211,14 +213,24 @@ static int setup(gpucontext *c) {
211213
if (handle == NULL)
212214
return error_sys(ctx->err, "calloc");
213215

216+
cuda_enter(ctx);
217+
{
218+
CUresult err;
219+
err = cuCtxGetDevice(&dev);
220+
if (err != CUDA_SUCCESS) {
221+
cuda_exit(ctx);
222+
return error_cuda(ctx->err, "cuCtxGetDevice", err);
223+
}
224+
}
225+
GA_CUDA_EXIT_ON_ERROR(ctx, get_cc(dev, &major, &minor, ctx->err));
226+
214227
/* Only try to use tensor core on cuda 9 and up */
215-
if (ctx->major >= 9) {
228+
if (ctx->major >= 9 && major >= 7 && minor >= 0) {
216229
handle->tensorCore = 1;
217230
} else {
218231
handle->tensorCore = 0;
219232
}
220233

221-
cuda_enter(ctx);
222234
err = cublasCreate(&handle->h);
223235
if (err != CUBLAS_STATUS_SUCCESS) {
224236
cuda_exit(ctx);
@@ -507,7 +519,7 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
507519
CUDA_R_16F,
508520
ldc));
509521
}
510-
522+
511523
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
512524
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B, CUDA_WAIT_READ));
513525
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C, CUDA_WAIT_ALL));

src/gpuarray_buffer_cuda.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ static int cuda_memset(gpudata *dst, size_t dstoff, int data) {
10481048
return GA_NO_ERROR;
10491049
}
10501050

1051-
static int get_cc(CUdevice dev, int *maj, int *min, error *e) {
1051+
int get_cc(CUdevice dev, int *maj, int *min, error *e) {
10521052
CUresult err;
10531053
err = cuDeviceGetAttribute(maj,
10541054
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,

src/private_cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,6 @@ struct _gpukernel {
157157
#endif
158158
};
159159

160+
int get_cc(CUdevice dev, int *maj, int *min, error *e);
161+
160162
#endif

0 commit comments

Comments
 (0)