Skip to content

Commit 9c2e317

Browse files
nouizabergeron
authored andcommitted
First version that execute, but give wrong result!
1 parent 1137c81 commit 9c2e317

10 files changed

Lines changed: 149 additions & 3 deletions

src/gpuarray/blas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ GPUARRAY_PUBLIC int GpuArray_rger(double alpha, GpuArray *X, GpuArray *Y,
3434
GPUARRAY_PUBLIC int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB,
3535
double alpha, GpuArray *A, GpuArray *B,
3636
double beta, GpuArray *C, int nocopy);
37+
#define GpuArray_hgemmBatch_3d GpuArray_rgemmBatch_3d
3738
#define GpuArray_sgemmBatch_3d GpuArray_rgemmBatch_3d
3839
#define GpuArray_dgemmBatch_3d GpuArray_rgemmBatch_3d
3940

src/gpuarray/buffer_blas.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,15 @@ GPUARRAY_PUBLIC int gpublas_hgemmBatch(
115115
float beta, gpudata **C, size_t *offC, size_t ldc,
116116
size_t batchCount, int flags);
117117

118+
//TODO: float should be half
119+
GPUARRAY_PUBLIC int gpublas_hgemmStridedBatch(
120+
cb_order order, cb_transpose transA, cb_transpose transB,
121+
size_t M, size_t N, size_t K, float alpha,
122+
gpudata *A, size_t lda, ssize_t strideA,
123+
gpudata *B, size_t ldb, ssize_t strideB,
124+
float beta, gpudata *C, size_t ldc, ssize_t strideC,
125+
size_t batchCount, int flags);
126+
118127
GPUARRAY_PUBLIC int gpublas_sgemmBatch(
119128
cb_order order, cb_transpose transA, cb_transpose transB,
120129
size_t M, size_t N, size_t K, float alpha,

src/gpuarray_array_blas.c

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
486486
size_t *A_offsets = NULL, *B_offsets = NULL, *C_offsets = NULL;
487487
size_t i;
488488

489-
if (A->typecode != GA_FLOAT && A->typecode != GA_DOUBLE)
489+
if (A->typecode != GA_FLOAT && A->typecode != GA_DOUBLE && A->typecode != GA_HALF)
490490
return error_set(ctx->err, GA_INVALID_ERROR, "Unsupported dtype");
491491

492492
if (A->nd != 3 || B->nd != 3 || C->nd != 3)
@@ -625,6 +625,21 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
625625
if (err != GA_NO_ERROR)
626626
goto cleanup;
627627

628+
if(C->typecode == GA_HALF){
629+
//TODO: handle offset
630+
assert (Ap->offset == 0);
631+
assert (Bp->offset == 0);
632+
assert (Cp->offset == 0);
633+
//TODO: float should be half
634+
err = gpublas_hgemmStridedBatch(o, transA, transB, m, n, k, alpha,
635+
Ap->data, lda, Ap->strides[0]/elsize,
636+
Bp->data, ldb, Bp->strides[0]/elsize,
637+
beta,
638+
Cp->data, ldc, Cp->strides[0]/elsize,
639+
batchCount, 0);
640+
goto cleanup;
641+
}
642+
628643
A_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
629644
B_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
630645
C_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));

src/gpuarray_blas_cuda_cublas.c

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,89 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
510510
cuda_exit(ctx);
511511
return GA_NO_ERROR;
512512
}
513+
//TODO: change float to half
514+
static int hgemmStridedBatch(cb_order order, cb_transpose transA, cb_transpose transB,
515+
size_t M, size_t N, size_t K, float alpha,
516+
gpudata *A, size_t lda, ssize_t strideA,
517+
gpudata *B, size_t ldb, ssize_t strideB,
518+
float beta, gpudata *C, size_t ldc, ssize_t strideC,
519+
size_t batchCount) {
520+
cuda_context *ctx;
521+
blas_handle *h;
522+
size_t t;
523+
ssize_t lt;
524+
gpudata *T;
525+
cb_transpose transT;
526+
cublasStatus_t err;
527+
__half halpha, hbeta;
528+
529+
//ignore overflow, underflow, denormalized and inf values. Mayve also nan.
530+
uint32_t x = (uint32_t)alpha;
531+
alpha = ((x>>16)&0x8000)|((((x&0x7f800000)-0x38000000)>>13)&0x7c00)|((x>>13)&0x03ff);
532+
x = (uint32_t)beta;
533+
beta = ((x>>16)&0x8000)|((((x&0x7f800000)-0x38000000)>>13)&0x7c00)|((x>>13)&0x03ff);
534+
535+
ASSERT_BUF(A);
536+
if (cublasHgemmStridedBatched == NULL)
537+
return GA_DEVSUP_ERROR;
538+
539+
ctx = A->ctx;
540+
// TODO: stride* are long long int in cuda, LARGE_VAL check for int.
541+
if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
542+
LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
543+
LARGE_VAL(strideA) || LARGE_VAL(strideB) || LARGE_VAL(strideC) ||
544+
LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
545+
return error_set(ctx->err, GA_XLARGE_ERROR, "Passed-in sizes would overflow the ints in the cublas interface");
546+
547+
h = (blas_handle *)ctx->blas_handle;
548+
cuda_enter(ctx);
549+
550+
if (order == cb_c) {
551+
/* swap A and B */
552+
t = N;
553+
N = M;
554+
M = t;
555+
T = A;
556+
A = B;
557+
B = T;
558+
t = lda;
559+
lda = ldb;
560+
ldb = t;
561+
transT = transA;
562+
transA = transB;
563+
transB = transT;
564+
lt = strideA;
565+
strideA = strideB;
566+
strideB = lt;
567+
}
568+
569+
ASSERT_BUF(A);
570+
ASSERT_BUF(B);
571+
ASSERT_BUF(C);
572+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A, CUDA_WAIT_READ));
573+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
574+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));
575+
raise(SIGINT);
576+
err = cublasHgemmStridedBatched(h->h,
577+
convT(transA), convT(transB),
578+
M, N, K, &halpha,
579+
(__half *)(A->ptr), (int) lda, strideA,
580+
(__half *)(B->ptr), (int) ldb, strideB,
581+
&hbeta,
582+
(__half *)(C->ptr), (int) ldc, strideB,
583+
batchCount);
584+
if (err != CUBLAS_STATUS_SUCCESS) {
585+
cuda_exit(ctx);
586+
return error_cublas(ctx->err, "cublasHgemmStridedBatched", err);
587+
}
588+
589+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
590+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B, CUDA_WAIT_READ));
591+
GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C, CUDA_WAIT_ALL));
592+
593+
cuda_exit(ctx);
594+
return GA_NO_ERROR;
595+
}
513596

514597
static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
515598
size_t M, size_t N, size_t K, float alpha,
@@ -1578,5 +1661,6 @@ gpuarray_blas_ops cublas_ops = {
15781661
dgemvBatch,
15791662
NULL, /* hgerBatch */
15801663
sgerBatch,
1581-
dgerBatch
1664+
dgerBatch,
1665+
hgemmStridedBatch,
15821666
};

src/gpuarray_blas_opencl_clblas.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,4 +449,5 @@ gpuarray_blas_ops clblas_ops = {
449449
NULL, /* hgerBatch */
450450
NULL, /* sgerBatch */
451451
NULL, /* dgerBatch */
452+
NULL, /* hgemmStridedzBatch */
452453
};

src/gpuarray_blas_opencl_clblast.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,4 +524,5 @@ gpuarray_blas_ops clblast_ops = {
524524
NULL, /* hgerBatch */
525525
NULL, /* sgerBatch */
526526
NULL, /* dgerBatch */
527+
NULL, /* hgemmStridedzBatch */
527528
};

src/gpuarray_buffer_blas.c

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,21 @@ const char *gpublas_error(gpucontext *ctx) {
1717
return ctx->err->msg;
1818
}
1919

20-
#define BLAS_OP(buf,name, args) \
20+
#define BLAS_OP(buf, name, args) \
2121
gpucontext *ctx = gpudata_context(buf); \
2222
if (ctx->blas_ops->name) \
2323
return ctx->blas_ops->name args; \
2424
else \
2525
return error_fmt(ctx->err, GA_DEVSUP_ERROR, "Blas operation not supported by device or missing library: %s", #name)
2626

27+
#define BLAS_OPF(buf, name, args) \
28+
gpucontext *ctx = gpudata_context(buf); \
29+
if (flags != 0) return error_set(ctx->err, GA_INVALID_ERROR, "flags is not 0"); \
30+
if (ctx->blas_ops->name) \
31+
return ctx->blas_ops->name args; \
32+
else \
33+
return error_fmt(ctx->err, GA_DEVSUP_ERROR, "Blas operation not supported by device or missing library: %s", #name)
34+
2735

2836
int gpublas_hdot(
2937
size_t N,
@@ -161,6 +169,19 @@ int gpublas_hgemmBatch(
161169
B, offB, ldb, beta, C, offC, ldc, batchCount));
162170
}
163171

172+
//TODO: use half and not float here.
173+
int gpublas_hgemmStridedBatch(
174+
cb_order order, cb_transpose transA, cb_transpose transB,
175+
size_t M, size_t N, size_t K, float alpha,
176+
gpudata *A, size_t lda, ssize_t strideA,
177+
gpudata *B, size_t ldb, ssize_t strideB,
178+
float beta, gpudata *C, size_t ldc, ssize_t strideC,
179+
size_t batchCount, int flags) {
180+
BLAS_OPF(A, hgemmStridedBatch,
181+
(order, transA, transB, M, N, K, alpha, A, lda, strideA,
182+
B, ldb, strideB, beta, C, ldc, strideC, batchCount));
183+
}
184+
164185
int gpublas_sgemmBatch(
165186
cb_order order, cb_transpose transA, cb_transpose transB,
166187
size_t M, size_t N, size_t K, float alpha,

src/loaders/libcublas.fn

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,5 @@ DEF_PROC_OPT(cublasSgemmEx, (cublasHandle_t handle, cublasOperation_t transa, cu
2323

2424
DEF_PROC(cublasSgemmBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *Aarray[], int lda, const float *Barray[], int ldb, const float *beta, float *Carray[], int ldc, int batchCount));
2525
DEF_PROC(cublasDgemmBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *Aarray[], int lda, const double *Barray[], int ldb, const double *beta, double *Carray[], int ldc, int batchCount));
26+
27+
DEF_PROC(cublasHgemmStridedBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half *alpha, const __half *A, int lda, long long int strideA, const __half *B, int ldb, long long int strideB, const __half *beta, __half *C, int ldc, long long int strideC, int batchCount));

src/loaders/libcublas.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
#define LOADER_LIBCUBLAS_H
33

44
#include "util/error.h"
5+
//TODO: how to have it work with align?
6+
typedef struct {//__align__(2) {
7+
unsigned short x;
8+
} __half;
9+
510

611
/** @cond NEVER */
712

src/private.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ struct _gpuarray_blas_ops {
214214
gpudata **y, size_t *offY, size_t incY,
215215
gpudata **A, size_t *offA, size_t lda,
216216
size_t batchCount, int flags);
217+
//TODO: float should be half
218+
int (*hgemmStridedBatch)(cb_order order, cb_transpose transA, cb_transpose transB,
219+
size_t M, size_t N, size_t K, float alpha,
220+
gpudata *A, size_t lda, ssize_t strideA,
221+
gpudata *B, size_t ldb, ssize_t strideB,
222+
float beta, gpudata *C, size_t ldc, ssize_t strideC,
223+
size_t batchCount);
217224
};
218225

219226
struct _gpuarray_comm_ops {

0 commit comments

Comments
 (0)