Skip to content

Commit b4a5b84

Browse files
authored
Merge pull request #281 from abergeron/err_blas
Check for overflow in cublas wrapper
2 parents 8e5f40f + ce8c98d commit b4a5b84

4 files changed

Lines changed: 52 additions & 2 deletions

File tree

pygpu/gpuarray.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1978,7 +1978,10 @@ cdef class GpuArray:
19781978
return str(numpy.asarray(self))
19791979

19801980
def __repr__(self):
1981-
return 'gpuarray.' + repr(numpy.asarray(self))
1981+
try:
1982+
return 'gpuarray.' + repr(numpy.asarray(self))
1983+
except Exception:
1984+
return 'gpuarray.array(<content not available>)'
19821985

19831986

19841987

src/gpuarray/error.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ enum ga_error {
3434
GA_NODEV_ERROR,
3535
GA_MISC_ERROR,
3636
GA_COMM_ERROR,
37+
GA_XLARGE_ERROR,
3738
/* Add more error types if needed, but at the end */
3839
/* Don't forget to sync with Gpu_error() */
3940
};

src/gpuarray_blas_cuda_cublas.c

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
#include "gpuarray/kernel.h"
66
#include "gpuarray/error.h"
77

8-
#include "cublas_v2.h"
8+
#include <limits.h>
9+
10+
#include <cublas_v2.h>
911

1012
extern const gpuarray_buffer_ops cuda_ops;
1113

@@ -33,6 +35,8 @@ typedef struct _blas_handle {
3335
cublasStatus_t err;
3436
} blas_handle;
3537

38+
#define LARGE_VAL(v) (v >= INT_MAX)
39+
3640
static const char *code_sgemvBH_N_a1_b1_small = \
3741
"extern \"C\"__global__ void sgemv(const float *A[], size_t lda, " \
3842
" const float *x[], size_t incx, " \
@@ -326,6 +330,11 @@ static int sgemm(cb_order order, cb_transpose transA, cb_transpose transB,
326330
ASSERT_BUF(B);
327331
ASSERT_BUF(C);
328332

333+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
334+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
335+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
336+
return GA_XLARGE_ERROR;
337+
329338
if (order == cb_c) {
330339
/* swap A and B */
331340
t = N;
@@ -386,6 +395,11 @@ static int dgemm(cb_order order, cb_transpose transA, cb_transpose transB,
386395
ASSERT_BUF(B);
387396
ASSERT_BUF(C);
388397

398+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
399+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
400+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
401+
return GA_XLARGE_ERROR;
402+
389403
if (order == cb_c) {
390404
/* swap A and B */
391405
t = N;
@@ -450,6 +464,11 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
450464
ASSERT_BUF(B);
451465
ASSERT_BUF(C);
452466

467+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
468+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
469+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
470+
return GA_XLARGE_ERROR;
471+
453472
if (order == cb_c) {
454473
/* swap A and B */
455474
t = N;
@@ -539,6 +558,11 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
539558

540559
if (batchCount == 0) return GA_NO_ERROR;
541560

561+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
562+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
563+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
564+
return GA_XLARGE_ERROR;
565+
542566
ASSERT_BUF(A[0]);
543567
ctx = A[0]->ctx;
544568
h = (blas_handle *)ctx->blas_handle;
@@ -659,6 +683,11 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
659683

660684
if (batchCount == 0) return GA_NO_ERROR;
661685

686+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
687+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
688+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
689+
return GA_XLARGE_ERROR;
690+
662691
ASSERT_BUF(A[0]);
663692
ctx = A[0]->ctx;
664693
h = (blas_handle *)ctx->blas_handle;
@@ -782,6 +811,10 @@ static int sgemv(cb_order order, cb_transpose transA, size_t M, size_t N,
782811
ASSERT_BUF(X);
783812
ASSERT_BUF(Y);
784813

814+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
815+
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
816+
return GA_XLARGE_ERROR;
817+
785818
if (order == cb_c) {
786819
t = N;
787820
N = M;
@@ -833,6 +866,10 @@ static int dgemv(cb_order order, cb_transpose transA, size_t M, size_t N,
833866
ASSERT_BUF(X);
834867
ASSERT_BUF(Y);
835868

869+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
870+
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
871+
return GA_XLARGE_ERROR;
872+
836873
if (order == cb_c) {
837874
t = N;
838875
N = M;
@@ -1149,6 +1186,10 @@ static int sger(cb_order order, size_t M, size_t N, float alpha, gpudata *X,
11491186
ASSERT_BUF(Y);
11501187
ASSERT_BUF(A);
11511188

1189+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
1190+
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
1191+
return GA_XLARGE_ERROR;
1192+
11521193
if (order == cb_c) {
11531194
t = M;
11541195
M = N;
@@ -1202,6 +1243,10 @@ static int dger(cb_order order, size_t M, size_t N, double alpha, gpudata *X,
12021243
ASSERT_BUF(Y);
12031244
ASSERT_BUF(A);
12041245

1246+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
1247+
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
1248+
return GA_XLARGE_ERROR;
1249+
12051250
if (order == cb_c) {
12061251
t = M;
12071252
M = N;

src/gpuarray_error.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const char *gpuarray_error_str(int err) {
2323
case GA_NODEV_ERROR: return "No devices are available";
2424
case GA_MISC_ERROR: return "Undeterminate error";
2525
case GA_COMM_ERROR: return "Error in collectives call";
26+
case GA_XLARGE_ERROR: return "Input size too large for operation";
2627
default: return "Unknown GA error";
2728
}
2829
}

0 commit comments

Comments
 (0)