|
5 | 5 | #include "gpuarray/kernel.h" |
6 | 6 | #include "gpuarray/error.h" |
7 | 7 |
|
8 | | -#include "cublas_v2.h" |
| 8 | +#include <limits.h> |
| 9 | + |
| 10 | +#include <cublas_v2.h> |
9 | 11 |
|
10 | 12 | extern const gpuarray_buffer_ops cuda_ops; |
11 | 13 |
|
@@ -33,6 +35,8 @@ typedef struct _blas_handle { |
33 | 35 | cublasStatus_t err; |
34 | 36 | } blas_handle; |
35 | 37 |
|
| 38 | +#define LARGE_VAL(v) (v >= INT_MAX) |
| 39 | + |
36 | 40 | static const char *code_sgemvBH_N_a1_b1_small = \ |
37 | 41 | "extern \"C\"__global__ void sgemv(const float *A[], size_t lda, " \ |
38 | 42 | " const float *x[], size_t incx, " \ |
@@ -326,6 +330,10 @@ static int sgemm(cb_order order, cb_transpose transA, cb_transpose transB, |
326 | 330 | ASSERT_BUF(B); |
327 | 331 | ASSERT_BUF(C); |
328 | 332 |
|
| 333 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 334 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc)) |
| 335 | + return GA_XLARGE_ERROR; |
| 336 | + |
329 | 337 | if (order == cb_c) { |
330 | 338 | /* swap A and B */ |
331 | 339 | t = N; |
@@ -386,6 +394,10 @@ static int dgemm(cb_order order, cb_transpose transA, cb_transpose transB, |
386 | 394 | ASSERT_BUF(B); |
387 | 395 | ASSERT_BUF(C); |
388 | 396 |
|
| 397 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 398 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc)) |
| 399 | + return GA_XLARGE_ERROR; |
| 400 | + |
389 | 401 | if (order == cb_c) { |
390 | 402 | /* swap A and B */ |
391 | 403 | t = N; |
@@ -450,6 +462,10 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB, |
450 | 462 | ASSERT_BUF(B); |
451 | 463 | ASSERT_BUF(C); |
452 | 464 |
|
| 465 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 466 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc)) |
| 467 | + return GA_XLARGE_ERROR; |
| 468 | + |
453 | 469 | if (order == cb_c) { |
454 | 470 | /* swap A and B */ |
455 | 471 | t = N; |
@@ -539,6 +555,10 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB, |
539 | 555 |
|
540 | 556 | if (batchCount == 0) return GA_NO_ERROR; |
541 | 557 |
|
| 558 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 559 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc)) |
| 560 | + return GA_XLARGE_ERROR; |
| 561 | + |
542 | 562 | ASSERT_BUF(A[0]); |
543 | 563 | ctx = A[0]->ctx; |
544 | 564 | h = (blas_handle *)ctx->blas_handle; |
@@ -659,6 +679,10 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB, |
659 | 679 |
|
660 | 680 | if (batchCount == 0) return GA_NO_ERROR; |
661 | 681 |
|
| 682 | + if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) || |
| 683 | + LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc)) |
| 684 | + return GA_XLARGE_ERROR; |
| 685 | + |
662 | 686 | ASSERT_BUF(A[0]); |
663 | 687 | ctx = A[0]->ctx; |
664 | 688 | h = (blas_handle *)ctx->blas_handle; |
@@ -782,6 +806,10 @@ static int sgemv(cb_order order, cb_transpose transA, size_t M, size_t N, |
782 | 806 | ASSERT_BUF(X); |
783 | 807 | ASSERT_BUF(Y); |
784 | 808 |
|
| 809 | + if (LARGE_VAL(M) || LARGE_VAL(N) || |
| 810 | + LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY)) |
| 811 | + return GA_XLARGE_ERROR; |
| 812 | + |
785 | 813 | if (order == cb_c) { |
786 | 814 | t = N; |
787 | 815 | N = M; |
@@ -833,6 +861,10 @@ static int dgemv(cb_order order, cb_transpose transA, size_t M, size_t N, |
833 | 861 | ASSERT_BUF(X); |
834 | 862 | ASSERT_BUF(Y); |
835 | 863 |
|
| 864 | + if (LARGE_VAL(M) || LARGE_VAL(N) || |
| 865 | + LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY)) |
| 866 | + return GA_XLARGE_ERROR; |
| 867 | + |
836 | 868 | if (order == cb_c) { |
837 | 869 | t = N; |
838 | 870 | N = M; |
@@ -1149,6 +1181,10 @@ static int sger(cb_order order, size_t M, size_t N, float alpha, gpudata *X, |
1149 | 1181 | ASSERT_BUF(Y); |
1150 | 1182 | ASSERT_BUF(A); |
1151 | 1183 |
|
| 1184 | + if (LARGE_VAL(M) || LARGE_VAL(N) || |
| 1185 | + LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY)) |
| 1186 | + return GA_XLARGE_ERROR; |
| 1187 | + |
1152 | 1188 | if (order == cb_c) { |
1153 | 1189 | t = M; |
1154 | 1190 | M = N; |
@@ -1202,6 +1238,10 @@ static int dger(cb_order order, size_t M, size_t N, double alpha, gpudata *X, |
1202 | 1238 | ASSERT_BUF(Y); |
1203 | 1239 | ASSERT_BUF(A); |
1204 | 1240 |
|
| 1241 | + if (LARGE_VAL(M) || LARGE_VAL(N) || |
| 1242 | + LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY)) |
| 1243 | + return GA_XLARGE_ERROR; |
| 1244 | + |
1205 | 1245 | if (order == cb_c) { |
1206 | 1246 | t = M; |
1207 | 1247 | M = N; |
|
0 commit comments