|
5 | 5 | #include "gpuarray/kernel.h" |
6 | 6 | #include "gpuarray/error.h" |
7 | 7 |
|
8 | | -#include "cublas_v2.h" |
| 8 | +#include <limits.h> |
| 9 | + |
| 10 | +#include <cublas_v2.h> |
9 | 11 |
|
10 | 12 | extern const gpuarray_buffer_ops cuda_ops; |
11 | 13 |
|
@@ -33,6 +35,8 @@ typedef struct _blas_handle { |
33 | 35 | cublasStatus_t err; |
34 | 36 | } blas_handle; |
35 | 37 |
|
| 38 | +#define LARGE_VAL(v) (v >= INT_MAX) |
| 39 | + |
36 | 40 | static const char *code_sgemvBH_N_a1_b1_small = \ |
37 | 41 | "extern \"C\"__global__ void sgemv(const float *A[], size_t lda, " \ |
38 | 42 | " const float *x[], size_t incx, " \ |
@@ -326,6 +330,11 @@ static int sgemm(cb_order order, cb_transpose transA, cb_transpose transB, |
326 | 330 | ASSERT_BUF(B); |
327 | 331 | ASSERT_BUF(C); |
328 | 332 |
|
| 333 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 334 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) || |
| 335 | + LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N)) |
| 336 | + return GA_XLARGE_ERROR; |
| 337 | + |
329 | 338 | if (order == cb_c) { |
330 | 339 | /* swap A and B */ |
331 | 340 | t = N; |
@@ -386,6 +395,11 @@ static int dgemm(cb_order order, cb_transpose transA, cb_transpose transB, |
386 | 395 | ASSERT_BUF(B); |
387 | 396 | ASSERT_BUF(C); |
388 | 397 |
|
| 398 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 399 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) || |
| 400 | + LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N)) |
| 401 | + return GA_XLARGE_ERROR; |
| 402 | + |
389 | 403 | if (order == cb_c) { |
390 | 404 | /* swap A and B */ |
391 | 405 | t = N; |
@@ -450,6 +464,11 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB, |
450 | 464 | ASSERT_BUF(B); |
451 | 465 | ASSERT_BUF(C); |
452 | 466 |
|
| 467 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 468 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) || |
| 469 | + LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N)) |
| 470 | + return GA_XLARGE_ERROR; |
| 471 | + |
453 | 472 | if (order == cb_c) { |
454 | 473 | /* swap A and B */ |
455 | 474 | t = N; |
@@ -539,6 +558,11 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB, |
539 | 558 |
|
540 | 559 | if (batchCount == 0) return GA_NO_ERROR; |
541 | 560 |
|
| 561 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 562 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) || |
| 563 | + LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N)) |
| 564 | + return GA_XLARGE_ERROR; |
| 565 | + |
542 | 566 | ASSERT_BUF(A[0]); |
543 | 567 | ctx = A[0]->ctx; |
544 | 568 | h = (blas_handle *)ctx->blas_handle; |
@@ -659,6 +683,11 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB, |
659 | 683 |
|
660 | 684 | if (batchCount == 0) return GA_NO_ERROR; |
661 | 685 |
|
| 686 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 687 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) || |
| 688 | + LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N)) |
| 689 | + return GA_XLARGE_ERROR; |
| 690 | + |
662 | 691 | ASSERT_BUF(A[0]); |
663 | 692 | ctx = A[0]->ctx; |
664 | 693 | h = (blas_handle *)ctx->blas_handle; |
@@ -782,6 +811,10 @@ static int sgemv(cb_order order, cb_transpose transA, size_t M, size_t N, |
782 | 811 | ASSERT_BUF(X); |
783 | 812 | ASSERT_BUF(Y); |
784 | 813 |
|
| 814 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) || |
| 815 | + LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY)) |
| 816 | + return GA_XLARGE_ERROR; |
| 817 | + |
785 | 818 | if (order == cb_c) { |
786 | 819 | t = N; |
787 | 820 | N = M; |
@@ -833,6 +866,10 @@ static int dgemv(cb_order order, cb_transpose transA, size_t M, size_t N, |
833 | 866 | ASSERT_BUF(X); |
834 | 867 | ASSERT_BUF(Y); |
835 | 868 |
|
| 869 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) || |
| 870 | + LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY)) |
| 871 | + return GA_XLARGE_ERROR; |
| 872 | + |
836 | 873 | if (order == cb_c) { |
837 | 874 | t = N; |
838 | 875 | N = M; |
@@ -1149,6 +1186,10 @@ static int sger(cb_order order, size_t M, size_t N, float alpha, gpudata *X, |
1149 | 1186 | ASSERT_BUF(Y); |
1150 | 1187 | ASSERT_BUF(A); |
1151 | 1188 |
|
| 1189 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) || |
| 1190 | + LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY)) |
| 1191 | + return GA_XLARGE_ERROR; |
| 1192 | + |
1152 | 1193 | if (order == cb_c) { |
1153 | 1194 | t = M; |
1154 | 1195 | M = N; |
@@ -1202,6 +1243,10 @@ static int dger(cb_order order, size_t M, size_t N, double alpha, gpudata *X, |
1202 | 1243 | ASSERT_BUF(Y); |
1203 | 1244 | ASSERT_BUF(A); |
1204 | 1245 |
|
| 1246 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) || |
| 1247 | + LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY)) |
| 1248 | + return GA_XLARGE_ERROR; |
| 1249 | + |
1205 | 1250 | if (order == cb_c) { |
1206 | 1251 | t = M; |
1207 | 1252 | M = N; |
|
0 commit comments