Skip to content

Commit ce8c98d

Browse files
committed
Also check products to protect cublas which doesn't do this check.
1 parent e17f492 commit ce8c98d

1 file changed

Lines changed: 14 additions & 9 deletions

File tree

src/gpuarray_blas_cuda_cublas.c

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ static int sgemm(cb_order order, cb_transpose transA, cb_transpose transB,
331331
ASSERT_BUF(C);
332332

333333
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
334-
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
334+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
335+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
335336
return GA_XLARGE_ERROR;
336337

337338
if (order == cb_c) {
@@ -395,7 +396,8 @@ static int dgemm(cb_order order, cb_transpose transA, cb_transpose transB,
395396
ASSERT_BUF(C);
396397

397398
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
398-
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
399+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
400+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
399401
return GA_XLARGE_ERROR;
400402

401403
if (order == cb_c) {
@@ -463,7 +465,8 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
463465
ASSERT_BUF(C);
464466

465467
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
466-
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
468+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
469+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
467470
return GA_XLARGE_ERROR;
468471

469472
if (order == cb_c) {
@@ -556,7 +559,8 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
556559
if (batchCount == 0) return GA_NO_ERROR;
557560

558561
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
559-
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
562+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
563+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
560564
return GA_XLARGE_ERROR;
561565

562566
ASSERT_BUF(A[0]);
@@ -680,7 +684,8 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
680684
if (batchCount == 0) return GA_NO_ERROR;
681685

682686
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
683-
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc))
687+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
688+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
684689
return GA_XLARGE_ERROR;
685690

686691
ASSERT_BUF(A[0]);
@@ -806,7 +811,7 @@ static int sgemv(cb_order order, cb_transpose transA, size_t M, size_t N,
806811
ASSERT_BUF(X);
807812
ASSERT_BUF(Y);
808813

809-
if (LARGE_VAL(M) || LARGE_VAL(N) ||
814+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
810815
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
811816
return GA_XLARGE_ERROR;
812817

@@ -861,7 +866,7 @@ static int dgemv(cb_order order, cb_transpose transA, size_t M, size_t N,
861866
ASSERT_BUF(X);
862867
ASSERT_BUF(Y);
863868

864-
if (LARGE_VAL(M) || LARGE_VAL(N) ||
869+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
865870
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
866871
return GA_XLARGE_ERROR;
867872

@@ -1181,7 +1186,7 @@ static int sger(cb_order order, size_t M, size_t N, float alpha, gpudata *X,
11811186
ASSERT_BUF(Y);
11821187
ASSERT_BUF(A);
11831188

1184-
if (LARGE_VAL(M) || LARGE_VAL(N) ||
1189+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
11851190
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
11861191
return GA_XLARGE_ERROR;
11871192

@@ -1238,7 +1243,7 @@ static int dger(cb_order order, size_t M, size_t N, double alpha, gpudata *X,
12381243
ASSERT_BUF(Y);
12391244
ASSERT_BUF(A);
12401245

1241-
if (LARGE_VAL(M) || LARGE_VAL(N) ||
1246+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
12421247
LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
12431248
return GA_XLARGE_ERROR;
12441249

0 commit comments

Comments
 (0)