Skip to content

Commit e17f492

Browse files
committed
Check that we don't overflow int arguments in cublas wrapper.
1 parent 9f6b6df commit e17f492

3 files changed

Lines changed: 43 additions & 1 deletion

File tree

src/gpuarray/error.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ enum ga_error {
3434
GA_NODEV_ERROR,
3535
GA_MISC_ERROR,
3636
GA_COMM_ERROR,
37+
GA_XLARGE_ERROR,
3738
/* Add more error types if needed, but at the end */
3839
/* Don't forget to sync with Gpu_error() */
3940
};

src/gpuarray_blas_cuda_cublas.c

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
#include "gpuarray/kernel.h"
66
#include "gpuarray/error.h"
77

8-
#include "cublas_v2.h"
8+
#include <limits.h>
9+
10+
#include <cublas_v2.h>
911

1012
extern const gpuarray_buffer_ops cuda_ops;
1113

@@ -33,6 +35,8 @@ typedef struct _blas_handle {
3335
cublasStatus_t err;
3436
} blas_handle;
3537

38+
#define LARGE_VAL(v) (v >= INT_MAX)
39+
3640
static const char *code_sgemvBH_N_a1_b1_small = \
3741
"extern \"C\"__global__ void sgemv(const float *A[], size_t lda, " \
3842
" const float *x[], size_t incx, " \
@@ -326,6 +330,10 @@ static int sgemm(cb_order order, cb_transpose transA, cb_transpose transB,
326330
ASSERT_BUF(B);
327331
ASSERT_BUF(C);
328332

333+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
334+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
335+
return GA_XLARGE_ERROR;
336+
329337
if (order == cb_c) {
330338
/* swap A and B */
331339
t = N;
@@ -386,6 +394,10 @@ static int dgemm(cb_order order, cb_transpose transA, cb_transpose transB,
386394
ASSERT_BUF(B);
387395
ASSERT_BUF(C);
388396

397+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
398+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
399+
return GA_XLARGE_ERROR;
400+
389401
if (order == cb_c) {
390402
/* swap A and B */
391403
t = N;
@@ -450,6 +462,10 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
450462
ASSERT_BUF(B);
451463
ASSERT_BUF(C);
452464

465+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
466+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
467+
return GA_XLARGE_ERROR;
468+
453469
if (order == cb_c) {
454470
/* swap A and B */
455471
t = N;
@@ -539,6 +555,10 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
539555

540556
if (batchCount == 0) return GA_NO_ERROR;
541557

558+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
559+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
560+
return GA_XLARGE_ERROR;
561+
542562
ASSERT_BUF(A[0]);
543563
ctx = A[0]->ctx;
544564
h = (blas_handle *)ctx->blas_handle;
@@ -659,6 +679,10 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
659679

660680
if (batchCount == 0) return GA_NO_ERROR;
661681

682+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
683+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
684+
return GA_XLARGE_ERROR;
685+
662686
ASSERT_BUF(A[0]);
663687
ctx = A[0]->ctx;
664688
h = (blas_handle *)ctx->blas_handle;
@@ -782,6 +806,10 @@ static int sgemv(cb_order order, cb_transpose transA, size_t M, size_t N,
782806
ASSERT_BUF(X);
783807
ASSERT_BUF(Y);
784808

809+
if (LARGE_VAL(M) || LARGE_VAL(N) ||
810+
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
811+
return GA_XLARGE_ERROR;
812+
785813
if (order == cb_c) {
786814
t = N;
787815
N = M;
@@ -833,6 +861,10 @@ static int dgemv(cb_order order, cb_transpose transA, size_t M, size_t N,
833861
ASSERT_BUF(X);
834862
ASSERT_BUF(Y);
835863

864+
if (LARGE_VAL(M) || LARGE_VAL(N) ||
865+
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
866+
return GA_XLARGE_ERROR;
867+
836868
if (order == cb_c) {
837869
t = N;
838870
N = M;
@@ -1149,6 +1181,10 @@ static int sger(cb_order order, size_t M, size_t N, float alpha, gpudata *X,
11491181
ASSERT_BUF(Y);
11501182
ASSERT_BUF(A);
11511183

1184+
if (LARGE_VAL(M) || LARGE_VAL(N) ||
1185+
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
1186+
return GA_XLARGE_ERROR;
1187+
11521188
if (order == cb_c) {
11531189
t = M;
11541190
M = N;
@@ -1202,6 +1238,10 @@ static int dger(cb_order order, size_t M, size_t N, double alpha, gpudata *X,
12021238
ASSERT_BUF(Y);
12031239
ASSERT_BUF(A);
12041240

1241+
if (LARGE_VAL(M) || LARGE_VAL(N) ||
1242+
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
1243+
return GA_XLARGE_ERROR;
1244+
12051245
if (order == cb_c) {
12061246
t = M;
12071247
M = N;

src/gpuarray_error.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const char *gpuarray_error_str(int err) {
2323
case GA_NODEV_ERROR: return "No devices are available";
2424
case GA_MISC_ERROR: return "Undeterminate error";
2525
case GA_COMM_ERROR: return "Error in collectives call";
26+
case GA_XLARGE_ERROR: return "Input size too large for operation";
2627
default: return "Unknown GA error";
2728
}
2829
}

0 commit comments

Comments
 (0)