Skip to content

Commit 0767181

Browse files
committed
Add gemm3D for batch gemm of 3d matrices.
1 parent ddb016b commit 0767181

6 files changed

Lines changed: 264 additions & 54 deletions

File tree

src/gpuarray/buffer_blas.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,30 @@ GPUARRAY_PUBLIC int gpublas_hgemmBatch(
115115
float beta, gpudata **C, size_t *offC, size_t ldc,
116116
size_t batchCount, int flags);
117117

118-
//TODO: float should be half
119-
GPUARRAY_PUBLIC int gpublas_hgemmStridedBatch(
118+
GPUARRAY_PUBLIC int gpublas_hgemm3D(
120119
cb_order order, cb_transpose transA, cb_transpose transB,
121120
size_t M, size_t N, size_t K, float alpha,
122121
gpudata *A, size_t lda, ssize_t strideA,
123122
gpudata *B, size_t ldb, ssize_t strideB,
124123
float beta, gpudata *C, size_t ldc, ssize_t strideC,
125124
size_t batchCount, int flags);
126125

126+
GPUARRAY_PUBLIC int gpublas_sgemm3D(
127+
cb_order order, cb_transpose transA, cb_transpose transB,
128+
size_t M, size_t N, size_t K, float alpha,
129+
gpudata *A, size_t lda, ssize_t strideA,
130+
gpudata *B, size_t ldb, ssize_t strideB,
131+
float beta, gpudata *C, size_t ldc, ssize_t strideC,
132+
size_t batchCount, int flags);
133+
134+
GPUARRAY_PUBLIC int gpublas_dgemm3D(
135+
cb_order order, cb_transpose transA, cb_transpose transB,
136+
size_t M, size_t N, size_t K, double alpha,
137+
gpudata *A, size_t lda, ssize_t strideA,
138+
gpudata *B, size_t ldb, ssize_t strideB,
139+
double beta, gpudata *C, size_t ldc, ssize_t strideC,
140+
size_t batchCount, int flags);
141+
127142
GPUARRAY_PUBLIC int gpublas_sgemmBatch(
128143
cb_order order, cb_transpose transA, cb_transpose transB,
129144
size_t M, size_t N, size_t K, float alpha,

src/gpuarray_blas_cuda_cublas.c

Lines changed: 186 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -510,37 +510,33 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
510510
cuda_exit(ctx);
511511
return GA_NO_ERROR;
512512
}
513-
//TODO: change float to half
514-
static int hgemmStridedBatch(cb_order order, cb_transpose transA, cb_transpose transB,
515-
size_t M, size_t N, size_t K, float alpha,
516-
gpudata *A, size_t lda, ssize_t strideA,
517-
gpudata *B, size_t ldb, ssize_t strideB,
518-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
519-
size_t batchCount) {
513+
514+
static int hgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
515+
size_t M, size_t N, size_t K, float alpha,
516+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
517+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
518+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
519+
size_t batchCount) {
520520
cuda_context *ctx;
521521
blas_handle *h;
522522
size_t t;
523-
ssize_t lt;
523+
ssize_t st;
524524
gpudata *T;
525525
cb_transpose transT;
526526
cublasStatus_t err;
527-
__half halpha, hbeta;
528-
529-
//ignore overflow, underflow, denormalized and inf values. Mayve also nan.
530-
uint32_t x = (uint32_t)alpha;
531-
alpha = ((x>>16)&0x8000)|((((x&0x7f800000)-0x38000000)>>13)&0x7c00)|((x>>13)&0x03ff);
532-
x = (uint32_t)beta;
533-
beta = ((x>>16)&0x8000)|((((x&0x7f800000)-0x38000000)>>13)&0x7c00)|((x>>13)&0x03ff);
534-
527+
ga_half_t halpha, hbeta;
528+
535529
ASSERT_BUF(A);
536-
if (cublasHgemmStridedBatched == NULL)
537-
return GA_DEVSUP_ERROR;
530+
ASSERT_BUF(B);
531+
ASSERT_BUF(C);
538532

539533
ctx = A->ctx;
540-
// TODO: stride* are long long int in cuda, LARGE_VAL check for int.
534+
535+
if (cublasHgemmStridedBatched == NULL)
536+
return error_set(ctx->error, GA_DEVSUP_ERROR, "cublasHgemmStridedBatched not available in your version of cuBLAS");
537+
541538
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
542539
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
543-
LARGE_VAL(strideA) || LARGE_VAL(strideB) || LARGE_VAL(strideC) ||
544540
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
545541
return error_set(ctx->err, GA_XLARGE_ERROR, "Passed-in sizes would overflow the ints in the cublas interface");
546542

@@ -558,28 +554,108 @@ static int hgemmStridedBatch(cb_order order, cb_transpose transA, cb_transpose t
558554
t = lda;
559555
lda = ldb;
560556
ldb = t;
557+
t = offA;
558+
offA = offB;
559+
offB = t;
561560
transT = transA;
562561
transA = transB;
563562
transB = transT;
564-
lt = strideA;
563+
st = strideA;
565564
strideA = strideB;
566-
strideB = lt;
565+
strideB = st;
567566
}
568567

569-
ASSERT_BUF(A);
570-
ASSERT_BUF(B);
571-
ASSERT_BUF(C);
568+
halpha = ga_float2half(alpha);
569+
hbeta = ga_float2half(beta);
570+
572571
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A, CUDA_WAIT_READ));
573572
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
574573
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));
575-
raise(SIGINT);
576574
err = cublasHgemmStridedBatched(h->h,
577575
convT(transA), convT(transB),
578576
M, N, K, &halpha,
579-
(__half *)(A->ptr), (int) lda, strideA,
580-
(__half *)(B->ptr), (int) ldb, strideB,
577+
((__half *)A->ptr) + offA, lda, strideA,
578+
((__half *)B->ptr) + offB, ldb, strideB,
581579
&hbeta,
582-
(__half *)(C->ptr), (int) ldc, strideB,
580+
((__half *)C->ptr) + offC, ldc, strideB,
581+
batchCount);
582+
if (err != CUBLAS_STATUS_SUCCESS) {
583+
cuda_exit(ctx);
584+
return error_cublas(ctx->err, "cublasHgemmStridedBatched", err);
585+
}
586+
587+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
588+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B, CUDA_WAIT_READ));
589+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C, CUDA_WAIT_ALL));
590+
591+
cuda_exit(ctx);
592+
return GA_NO_ERROR;
593+
}
594+
595+
static int sgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
596+
size_t M, size_t N, size_t K, float alpha,
597+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
598+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
599+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
600+
size_t batchCount) {
601+
cuda_context *ctx;
602+
blas_handle *h;
603+
size_t t;
604+
ssize_t st;
605+
gpudata *T;
606+
cb_transpose transT;
607+
cublasStatus_t err;
608+
609+
ASSERT_BUF(A);
610+
ASSERT_BUF(B);
611+
ASSERT_BUF(C);
612+
613+
ctx = A->ctx;
614+
615+
if (cublasSgemmStridedBatched == NULL)
616+
return error_set(ctx->error, GA_DEVSUP_ERROR, "cublasSgemmStridedBatched not available in your version of cuBLAS");
617+
618+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
619+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
620+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
621+
return error_set(ctx->err, GA_XLARGE_ERROR, "Passed-in sizes would overflow the ints in the cublas interface");
622+
623+
h = (blas_handle *)ctx->blas_handle;
624+
cuda_enter(ctx);
625+
626+
if (order == cb_c) {
627+
/* swap A and B */
628+
t = N;
629+
N = M;
630+
M = t;
631+
T = A;
632+
A = B;
633+
B = T;
634+
t = lda;
635+
lda = ldb;
636+
ldb = t;
637+
t = offA;
638+
offA = offB;
639+
offB = t;
640+
transT = transA;
641+
transA = transB;
642+
transB = transT;
643+
st = strideA;
644+
strideA = strideB;
645+
strideB = st;
646+
}
647+
648+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A, CUDA_WAIT_READ));
649+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
650+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));
651+
652+
err = cublasSgemmStridedBatched(h->h,
653+
convT(transA), convT(transB),
654+
M, N, K, &alpha,
655+
((float *)A->ptr) + offA, (int)lda, strideA,
656+
((float *)B->ptr) + offB, (int)ldb, strideB,
657+
&beta,
658+
((float *)C->ptr) + offC, (int)ldc, strideB,
583659
batchCount);
584660
if (err != CUBLAS_STATUS_SUCCESS) {
585661
cuda_exit(ctx);
@@ -594,6 +670,84 @@ static int hgemmStridedBatch(cb_order order, cb_transpose transA, cb_transpose t
594670
return GA_NO_ERROR;
595671
}
596672

673+
static int dgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
674+
size_t M, size_t N, size_t K, double alpha,
675+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
676+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
677+
double beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
678+
size_t batchCount) {
679+
cuda_context *ctx;
680+
blas_handle *h;
681+
size_t t;
682+
ssize_t st;
683+
gpudata *T;
684+
cb_transpose transT;
685+
cublasStatus_t err;
686+
687+
ASSERT_BUF(A);
688+
ASSERT_BUF(B);
689+
ASSERT_BUF(C);
690+
691+
ctx = A->ctx;
692+
693+
if (cublasDgemmStridedBatched == NULL)
694+
return error_set(ctx->error, GA_DEVSUP_ERROR, "cublasDgemmStridedBatched not available in your version of cuBLAS");
695+
696+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
697+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
698+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
699+
return error_set(ctx->err, GA_XLARGE_ERROR, "Passed-in sizes would overflow the ints in the cublas interface");
700+
701+
h = (blas_handle *)ctx->blas_handle;
702+
cuda_enter(ctx);
703+
704+
if (order == cb_c) {
705+
/* swap A and B */
706+
t = N;
707+
N = M;
708+
M = t;
709+
T = A;
710+
A = B;
711+
B = T;
712+
t = lda;
713+
lda = ldb;
714+
ldb = t;
715+
t = offA;
716+
offA = offB;
717+
offB = t;
718+
transT = transA;
719+
transA = transB;
720+
transB = transT;
721+
st = strideA;
722+
strideA = strideB;
723+
strideB = st;
724+
}
725+
726+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A, CUDA_WAIT_READ));
727+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
728+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));
729+
730+
err = cublasDgemmStridedBatched(h->h,
731+
convT(transA), convT(transB),
732+
M, N, K, &alpha,
733+
((double *)A->ptr) + offA, (int)lda, strideA,
734+
((double *)B->ptr) + offB, (int)ldb, strideB,
735+
&beta,
736+
((double *)C->ptr) + offC, (int)ldc, strideB,
737+
batchCount);
738+
if (err != CUBLAS_STATUS_SUCCESS) {
739+
cuda_exit(ctx);
740+
return error_cublas(ctx->err, "cublasDgemmStridedBatched", err);
741+
}
742+
743+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
744+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B, CUDA_WAIT_READ));
745+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C, CUDA_WAIT_ALL));
746+
747+
cuda_exit(ctx);
748+
return GA_NO_ERROR;
749+
}
750+
597751
static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
598752
size_t M, size_t N, size_t K, float alpha,
599753
gpudata **A, size_t *offA, size_t lda,
@@ -1662,5 +1816,7 @@ gpuarray_blas_ops cublas_ops = {
16621816
NULL, /* hgerBatch */
16631817
sgerBatch,
16641818
dgerBatch,
1665-
hgemmStridedBatch,
1819+
hgemm3D,
1820+
sgemm3D,
1821+
dgemm3D
16661822
};

src/gpuarray_blas_opencl_clblas.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,5 +449,7 @@ gpuarray_blas_ops clblas_ops = {
449449
NULL, /* hgerBatch */
450450
NULL, /* sgerBatch */
451451
NULL, /* dgerBatch */
452-
NULL, /* hgemmStridedzBatch */
452+
NULL, /* hgemm3D */
453+
NULL, /* sgemm3D */
454+
NULL, /* dgemm3D */
453455
};

src/gpuarray_blas_opencl_clblast.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,5 +524,7 @@ gpuarray_blas_ops clblast_ops = {
524524
NULL, /* hgerBatch */
525525
NULL, /* sgerBatch */
526526
NULL, /* dgerBatch */
527-
NULL, /* hgemmStridedzBatch */
527+
NULL, /* hgemm3D */
528+
NULL, /* sgemm3D */
529+
NULL, /* dgemm3D */
528530
};

src/gpuarray_buffer_blas.c

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,19 +169,6 @@ int gpublas_hgemmBatch(
169169
B, offB, ldb, beta, C, offC, ldc, batchCount));
170170
}
171171

172-
//TODO: use half and not float here.
173-
int gpublas_hgemmStridedBatch(
174-
cb_order order, cb_transpose transA, cb_transpose transB,
175-
size_t M, size_t N, size_t K, float alpha,
176-
gpudata *A, size_t lda, ssize_t strideA,
177-
gpudata *B, size_t ldb, ssize_t strideB,
178-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
179-
size_t batchCount, int flags) {
180-
BLAS_OPF(A, hgemmStridedBatch,
181-
(order, transA, transB, M, N, K, alpha, A, lda, strideA,
182-
B, ldb, strideB, beta, C, ldc, strideC, batchCount));
183-
}
184-
185172
int gpublas_sgemmBatch(
186173
cb_order order, cb_transpose transA, cb_transpose transB,
187174
size_t M, size_t N, size_t K, float alpha,
@@ -271,3 +258,40 @@ int gpublas_dgerBatch(cb_order order, size_t M, size_t N, double alpha,
271258
(order, M, N, alpha, x, offX, incX, y, offY, incY,
272259
A, offA, lda, batchCount, flags));
273260
}
261+
262+
263+
int gpublas_hgemm3d(
264+
cb_order order, cb_transpose transA, cb_transpose transB,
265+
size_t M, size_t N, size_t K, float alpha,
266+
gpudata *A, size_t lda, ssize_t strideA,
267+
gpudata *B, size_t ldb, ssize_t strideB,
268+
float beta, gpudata *C, size_t ldc, ssize_t strideC,
269+
size_t batchCount, int flags) {
270+
BLAS_OPBF(A, hgemm3d,
271+
(order, transA, transB, M, N, K, alpha, A, lda, strideA,
272+
B, ldb, strideB, beta, C, ldc, strideC, batchCount));
273+
}
274+
275+
int gpublas_sgemm3d(
276+
cb_order order, cb_transpose transA, cb_transpose transB,
277+
size_t M, size_t N, size_t K, float alpha,
278+
gpudata *A, size_t lda, ssize_t strideA,
279+
gpudata *B, size_t ldb, ssize_t strideB,
280+
float beta, gpudata *C, size_t ldc, ssize_t strideC,
281+
size_t batchCount, int flags) {
282+
BLAS_OPBF(A, sgemm3d,
283+
(order, transA, transB, M, N, K, alpha, A, lda, strideA,
284+
B, ldb, strideB, beta, C, ldc, strideC, batchCount));
285+
}
286+
287+
int gpublas_dgemm3d(
288+
cb_order order, cb_transpose transA, cb_transpose transB,
289+
size_t M, size_t N, size_t K, float alpha,
290+
gpudata *A, size_t lda, ssize_t strideA,
291+
gpudata *B, size_t ldb, ssize_t strideB,
292+
float beta, gpudata *C, size_t ldc, ssize_t strideC,
293+
size_t batchCount, int flags) {
294+
BLAS_OPBF(A, dgemm3d,
295+
(order, transA, transB, M, N, K, alpha, A, lda, strideA,
296+
B, ldb, strideB, beta, C, ldc, strideC, batchCount));
297+
}

0 commit comments

Comments
 (0)