Skip to content

Commit 0b91bdf

Browse files
committed
Clean up problems.
1 parent cb4a79f commit 0b91bdf

5 files changed

Lines changed: 64 additions & 57 deletions

File tree

src/gpuarray/buffer_blas.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,25 +118,25 @@ GPUARRAY_PUBLIC int gpublas_hgemmBatch(
118118
GPUARRAY_PUBLIC int gpublas_hgemm3D(
119119
cb_order order, cb_transpose transA, cb_transpose transB,
120120
size_t M, size_t N, size_t K, float alpha,
121-
gpudata *A, size_t lda, ssize_t strideA,
122-
gpudata *B, size_t ldb, ssize_t strideB,
123-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
121+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
122+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
123+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
124124
size_t batchCount, int flags);
125125

126126
GPUARRAY_PUBLIC int gpublas_sgemm3D(
127127
cb_order order, cb_transpose transA, cb_transpose transB,
128128
size_t M, size_t N, size_t K, float alpha,
129-
gpudata *A, size_t lda, ssize_t strideA,
130-
gpudata *B, size_t ldb, ssize_t strideB,
131-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
129+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
130+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
131+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
132132
size_t batchCount, int flags);
133133

134134
GPUARRAY_PUBLIC int gpublas_dgemm3D(
135135
cb_order order, cb_transpose transA, cb_transpose transB,
136136
size_t M, size_t N, size_t K, double alpha,
137-
gpudata *A, size_t lda, ssize_t strideA,
138-
gpudata *B, size_t ldb, ssize_t strideB,
139-
double beta, gpudata *C, size_t ldc, ssize_t strideC,
137+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
138+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
139+
double beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
140140
size_t batchCount, int flags);
141141

142142
GPUARRAY_PUBLIC int gpublas_sgemmBatch(

src/gpuarray_array_blas.c

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -482,9 +482,6 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
482482
cb_order o;
483483
int cA, cB, cC;
484484
int err;
485-
gpudata **A_datas = NULL, **B_datas = NULL, **C_datas = NULL;
486-
size_t *A_offsets = NULL, *B_offsets = NULL, *C_offsets = NULL;
487-
size_t i;
488485

489486
if (A->typecode != GA_FLOAT && A->typecode != GA_DOUBLE && A->typecode != GA_HALF)
490487
return error_set(ctx->err, GA_INVALID_ERROR, "Unsupported dtype");
@@ -627,23 +624,23 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
627624

628625
switch (C->typecode) {
629626
case GA_HALF:
630-
err = gpublas_hgemm3d(o, transA, transB, m, n, k, (float)alpha,
627+
err = gpublas_hgemm3D(o, transA, transB, m, n, k, (float)alpha,
631628
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
632629
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
633630
(float)beta,
634631
Cp->data, Cp->offset/elsize, ldc, Cp->strides[0]/elsize,
635632
batchCount, 0);
636633
break;
637634
case GA_FLOAT:
638-
err = gpublas_sgemm3d(o, transA, transB, m, n, k, (float)alpha,
635+
err = gpublas_sgemm3D(o, transA, transB, m, n, k, (float)alpha,
639636
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
640637
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
641638
(float)beta,
642639
Cp->data, Cp->offset/elsize, ldc, Cp->strides[0]/elsize,
643640
batchCount, 0);
644641
break;
645642
case GA_DOUBLE:
646-
err = gpublas_dgemm3d(o, transA, transB, m, n, k, (double)alpha,
643+
err = gpublas_dgemm3D(o, transA, transB, m, n, k, (double)alpha,
647644
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
648645
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
649646
(double)beta,

src/gpuarray_blas_cuda_cublas.c

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ static int hgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
533533
ctx = A->ctx;
534534

535535
if (cublasHgemmStridedBatched == NULL)
536-
return error_set(ctx->error, GA_DEVSUP_ERROR, "cublasHgemmStridedBatched not available in your version of cuBLAS");
536+
return error_set(ctx->err, GA_DEVSUP_ERROR, "cublasHgemmStridedBatched not available in your version of cuBLAS");
537537

538538
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
539539
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
@@ -573,11 +573,11 @@ static int hgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
573573
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));
574574
err = cublasHgemmStridedBatched(h->h,
575575
convT(transA), convT(transB),
576-
M, N, K, &halpha,
576+
M, N, K, (__half *)&halpha,
577577
((__half *)A->ptr) + offA, lda, strideA,
578578
((__half *)B->ptr) + offB, ldb, strideB,
579-
&hbeta,
580-
((__half *)C->ptr) + offC, ldc, strideB,
579+
(__half *)&hbeta,
580+
((__half *)C->ptr) + offC, ldc, strideC,
581581
batchCount);
582582
if (err != CUBLAS_STATUS_SUCCESS) {
583583
cuda_exit(ctx);
@@ -613,7 +613,7 @@ static int sgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
613613
ctx = A->ctx;
614614

615615
if (cublasSgemmStridedBatched == NULL)
616-
return error_set(ctx->error, GA_DEVSUP_ERROR, "cublasSgemmStridedBatched not available in your version of cuBLAS");
616+
return error_set(ctx->err, GA_DEVSUP_ERROR, "cublasSgemmStridedBatched not available in your version of cuBLAS");
617617

618618
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
619619
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
@@ -655,7 +655,7 @@ static int sgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
655655
((float *)A->ptr) + offA, (int)lda, strideA,
656656
((float *)B->ptr) + offB, (int)ldb, strideB,
657657
&beta,
658-
((float *)C->ptr) + offC, (int)ldc, strideB,
658+
((float *)C->ptr) + offC, (int)ldc, strideC,
659659
batchCount);
660660
if (err != CUBLAS_STATUS_SUCCESS) {
661661
cuda_exit(ctx);
@@ -691,7 +691,7 @@ static int dgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
691691
ctx = A->ctx;
692692

693693
if (cublasDgemmStridedBatched == NULL)
694-
return error_set(ctx->error, GA_DEVSUP_ERROR, "cublasDgemmStridedBatched not available in your version of cuBLAS");
694+
return error_set(ctx->err, GA_DEVSUP_ERROR, "cublasDgemmStridedBatched not available in your version of cuBLAS");
695695

696696
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
697697
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
@@ -733,7 +733,7 @@ static int dgemm3D(cb_order order, cb_transpose transA, cb_transpose transB,
733733
((double *)A->ptr) + offA, (int)lda, strideA,
734734
((double *)B->ptr) + offB, (int)ldb, strideB,
735735
&beta,
736-
((double *)C->ptr) + offC, (int)ldc, strideB,
736+
((double *)C->ptr) + offC, (int)ldc, strideC,
737737
batchCount);
738738
if (err != CUBLAS_STATUS_SUCCESS) {
739739
cuda_exit(ctx);

src/gpuarray_buffer_blas.c

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -260,38 +260,48 @@ int gpublas_dgerBatch(cb_order order, size_t M, size_t N, double alpha,
260260
}
261261

262262

263-
int gpublas_hgemm3d(
263+
#define BLAS_OP3F(b, name, args) \
264+
gpucontext *ctx; \
265+
if (batchCount == 0) return GA_NO_ERROR; \
266+
ctx = gpudata_context(b); \
267+
if (flags != 0) return error_set(ctx->err, GA_INVALID_ERROR, "flags is not 0"); \
268+
if (ctx->blas_ops->name) \
269+
return ctx->blas_ops->name args; \
270+
else \
271+
return error_fmt(ctx->err, GA_DEVSUP_ERROR, "Blas operation not supported by library in use: %s", #name)
272+
273+
int gpublas_hgemm3D(
264274
cb_order order, cb_transpose transA, cb_transpose transB,
265275
size_t M, size_t N, size_t K, float alpha,
266-
gpudata *A, size_t lda, ssize_t strideA,
267-
gpudata *B, size_t ldb, ssize_t strideB,
268-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
276+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
277+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
278+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
269279
size_t batchCount, int flags) {
270-
BLAS_OPBF(A, hgemm3d,
271-
(order, transA, transB, M, N, K, alpha, A, lda, strideA,
272-
B, ldb, strideB, beta, C, ldc, strideC, batchCount));
280+
BLAS_OP3F(A, hgemm3D,
281+
(order, transA, transB, M, N, K, alpha, A, offA, lda, strideA,
282+
B, offB, ldb, strideB, beta, C, offC, ldc, strideC, batchCount));
273283
}
274284

275-
int gpublas_sgemm3d(
285+
int gpublas_sgemm3D(
276286
cb_order order, cb_transpose transA, cb_transpose transB,
277287
size_t M, size_t N, size_t K, float alpha,
278-
gpudata *A, size_t lda, ssize_t strideA,
279-
gpudata *B, size_t ldb, ssize_t strideB,
280-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
288+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
289+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
290+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
281291
size_t batchCount, int flags) {
282-
BLAS_OPBF(A, sgemm3d,
283-
(order, transA, transB, M, N, K, alpha, A, lda, strideA,
284-
B, ldb, strideB, beta, C, ldc, strideC, batchCount));
292+
BLAS_OP3F(A, sgemm3D,
293+
(order, transA, transB, M, N, K, alpha, A, offA, lda, strideA,
294+
B, offB, ldb, strideB, beta, C, offC, ldc, strideC, batchCount));
285295
}
286296

287-
int gpublas_dgemm3d(
297+
int gpublas_dgemm3D(
288298
cb_order order, cb_transpose transA, cb_transpose transB,
289-
size_t M, size_t N, size_t K, float alpha,
290-
gpudata *A, size_t lda, ssize_t strideA,
291-
gpudata *B, size_t ldb, ssize_t strideB,
292-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
299+
size_t M, size_t N, size_t K, double alpha,
300+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
301+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
302+
double beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
293303
size_t batchCount, int flags) {
294-
BLAS_OPBF(A, dgemm3d,
295-
(order, transA, transB, M, N, K, alpha, A, lda, strideA,
296-
B, ldb, strideB, beta, C, ldc, strideC, batchCount));
304+
BLAS_OP3F(A, dgemm3D,
305+
(order, transA, transB, M, N, K, alpha, A, offA, lda, strideA,
306+
B, offB, ldb, strideB, beta, C, offC, ldc, strideC, batchCount));
297307
}

src/private.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -214,23 +214,23 @@ struct _gpuarray_blas_ops {
214214
gpudata **y, size_t *offY, size_t incY,
215215
gpudata **A, size_t *offA, size_t lda,
216216
size_t batchCount, int flags);
217-
int (*hgemm3d)(cb_order order, cb_transpose transA, cb_transpose transB,
217+
int (*hgemm3D)(cb_order order, cb_transpose transA, cb_transpose transB,
218218
size_t M, size_t N, size_t K, float alpha,
219-
gpudata *A, size_t lda, ssize_t strideA,
220-
gpudata *B, size_t ldb, ssize_t strideB,
221-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
219+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
220+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
221+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
222222
size_t batchCount);
223-
int (*sgemm3d)(cb_order order, cb_transpose transA, cb_transpose transB,
223+
int (*sgemm3D)(cb_order order, cb_transpose transA, cb_transpose transB,
224224
size_t M, size_t N, size_t K, float alpha,
225-
gpudata *A, size_t lda, ssize_t strideA,
226-
gpudata *B, size_t ldb, ssize_t strideB,
227-
float beta, gpudata *C, size_t ldc, ssize_t strideC,
225+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
226+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
227+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
228228
size_t batchCount);
229-
int (*dgemm3d)(cb_order order, cb_transpose transA, cb_transpose transB,
229+
int (*dgemm3D)(cb_order order, cb_transpose transA, cb_transpose transB,
230230
size_t M, size_t N, size_t K, double alpha,
231-
gpudata *A, size_t lda, ssize_t strideA,
232-
gpudata *B, size_t ldb, ssize_t strideB,
233-
double beta, gpudata *C, size_t ldc, ssize_t strideC,
231+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
232+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
233+
double beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
234234
size_t batchCount);
235235
};
236236

0 commit comments

Comments
 (0)