Skip to content

Commit 2237b40

Browse files
authored
Merge pull request #398 from abergeron/fix_bgemm
Fix contiguous detection to properly handle F-contiguous input.
2 parents e9abc18 + 89a2f44 commit 2237b40

4 files changed

Lines changed: 133 additions & 24 deletions

File tree

src/gpuarray_array.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,9 @@ int GpuArray_fdump(FILE *fd, const GpuArray *a) {
10961096
case GA_LONG:
10971097
fprintf(fd, "%lld", (long long)*(int64_t *)p);
10981098
break;
1099+
case GA_FLOAT:
1100+
fprintf(fd, "%f", *(float *)p);
1101+
break;
10991102
case GA_SSIZE:
11001103
fprintf(fd, "%" SPREFIX "d", *(ssize_t *)p);
11011104
break;

src/gpuarray_array_blas.c

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,23 @@ int GpuArray_rger(double alpha, GpuArray *X, GpuArray *Y, GpuArray *A,
439439
return err;
440440
}
441441

442+
static inline int is_last_2d_contiguous(const GpuArray *a) {
443+
size_t size = GpuArray_ITEMSIZE(a);
444+
445+
if (GpuArray_IS_C_CONTIGUOUS(a))
446+
return 1; // C contiguous
447+
448+
if (a->strides[a->nd - 2] <= 0 || a->strides[a->nd - 1] <= 0)
449+
return 0;
450+
451+
if (a->strides[a->nd - 2] == size)
452+
return 2; // F contiguous
453+
if (a->strides[a->nd - 1] == size)
454+
return 1; // C contiguous
455+
456+
return 0;
457+
}
458+
442459
int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alpha,
443460
GpuArray *A, GpuArray *B, double beta, GpuArray *C,
444461
int nocopy) {
@@ -451,6 +468,7 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
451468
size_t elsize;
452469
size_t batchCount, m, n, k, lda, ldb, ldc;
453470
cb_order o;
471+
int cA, cB, cC;
454472
int err;
455473
gpudata **A_datas = NULL, **B_datas = NULL, **C_datas = NULL;
456474
size_t *A_offsets = NULL, *B_offsets = NULL, *C_offsets = NULL;
@@ -495,52 +513,56 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
495513

496514
elsize = gpuarray_get_elsize(A->typecode);
497515

498-
// FIXME: these conditions are overly restrictive; the first axis need not be contiguous
499-
if (!GpuArray_ISONESEGMENT(A)) {
516+
cA = is_last_2d_contiguous(A);
517+
if (!cA) {
500518
if (nocopy)
501519
return GA_COPY_ERROR;
502520
else {
503-
err = GpuArray_copy(&copyA, A, GA_F_ORDER);
521+
err = GpuArray_copy(&copyA, A, GA_C_ORDER);
522+
cA = 1;
504523
if (err != GA_NO_ERROR)
505524
goto cleanup;
506525
Ap = &copyA;
507526
}
508527
}
509-
if (!GpuArray_ISONESEGMENT(B)) {
528+
cB = is_last_2d_contiguous(B);
529+
if (!cB) {
510530
if (nocopy)
511531
return GA_COPY_ERROR;
512532
else {
513-
err = GpuArray_copy(&copyB, B, GA_F_ORDER);
533+
err = GpuArray_copy(&copyB, B, GA_C_ORDER);
534+
cB = 1;
514535
if (err != GA_NO_ERROR)
515536
goto cleanup;
516537
Bp = &copyB;
517538
}
518539
}
519-
if (!GpuArray_ISONESEGMENT(C)) {
540+
cC = is_last_2d_contiguous(C);
541+
if (!cC) {
520542
err = GA_VALUE_ERROR;
521543
goto cleanup;
522544
}
523545

524-
if (Cp->flags & GA_F_CONTIGUOUS) {
546+
if (cC == 2) {
525547
o = cb_fortran;
526-
ldc = Cp->dimensions[1];
527-
} else if (Cp->flags & GA_C_CONTIGUOUS) {
548+
ldc = Cp->strides[2] / elsize;
549+
} else if (cC == 1) {
528550
o = cb_c;
529-
ldc = Cp->dimensions[2];
551+
ldc = Cp->strides[1] / elsize;
530552
} else {
531553
err = GA_VALUE_ERROR;
532554
goto cleanup;
533555
}
534-
if (Ap->flags & GA_F_CONTIGUOUS) {
535-
lda = Ap->dimensions[1];
556+
if (cA == 2) {
557+
lda = Ap->strides[2] / elsize;
536558
if (o == cb_c) {
537559
if (transA == cb_no_trans)
538560
transA = cb_trans;
539561
else
540562
transA = cb_no_trans;
541563
}
542-
} else if (Ap->flags & GA_C_CONTIGUOUS) {
543-
lda = Ap->dimensions[2];
564+
} else if (cA == 1) {
565+
lda = Ap->strides[1] / elsize;
544566
if (o == cb_fortran) {
545567
if (transA == cb_no_trans)
546568
transA = cb_trans;
@@ -551,16 +573,16 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
551573
err = GA_VALUE_ERROR;
552574
goto cleanup;
553575
}
554-
if (Bp->flags & GA_F_CONTIGUOUS) {
555-
ldb = Bp->dimensions[1];
576+
if (cB == 2) {
577+
ldb = Bp->strides[2] / elsize;
556578
if (o == cb_c) {
557579
if (transB == cb_no_trans)
558580
transB = cb_trans;
559581
else
560582
transB = cb_no_trans;
561583
}
562-
} else if (Bp->flags & GA_C_CONTIGUOUS) {
563-
ldb = Bp->dimensions[2];
584+
} else if (cB == 1) {
585+
ldb = Bp->strides[1] / elsize;
564586
if (o == cb_fortran) {
565587
if (transB == cb_no_trans)
566588
transB = cb_trans;

src/gpuarray_blas_cuda_cublas.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,10 +589,10 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
589589
h->err = cublasSgemm(h->h,
590590
convT(transA), convT(transB),
591591
M, N, K, &alpha,
592-
(float*)A[i]->ptr + offA[i], lda,
593-
(float*)B[i]->ptr + offB[i], ldb,
592+
((float*)A[i]->ptr) + offA[i], lda,
593+
((float*)B[i]->ptr) + offB[i], ldb,
594594
&beta,
595-
(float*)C[i]->ptr + offC[i], ldc);
595+
((float*)C[i]->ptr) + offC[i], ldc);
596596
if (h->err != CUBLAS_STATUS_SUCCESS) {
597597
cuda_exit(ctx);
598598
if (h->err == CUBLAS_STATUS_ARCH_MISMATCH)

tests/check_blas.c

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,100 @@ void teardown(void);
1414

1515
#define ga_assert_ok(e) ck_assert_int_eq(e, GA_NO_ERROR)
1616

17-
START_TEST(test_gemmBatch_3d) {
17+
static inline void ck_assert_fbuf_eq(const float *b, const float *r,
18+
unsigned int n) {
19+
unsigned int i;
20+
for (i = 0; i < n; i++) {
21+
ck_assert_msg(b[i] == r[i], "Difference at %u: %f != %f(ref)", i, b[i], r[i]);
22+
}
23+
}
24+
25+
START_TEST(test_gemmBatch_3d_C) {
1826
GpuArray A;
1927
GpuArray B;
2028
GpuArray C;
2129

22-
size_t dims[3] = {32, 32, 32};
30+
size_t dims[3] = {2, 3, 3};
31+
float data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9,
32+
1, 2, 3, 4, 5, 6, 7, 8, 9};
33+
const float res[] = {30, 36, 42, 66, 81, 96, 102, 126, 150,
34+
30, 36, 42, 66, 81, 96, 102, 126, 150};
2335

2436
ga_assert_ok(GpuArray_empty(&A, ctx, GA_FLOAT, 3, dims, GA_C_ORDER));
2537
ga_assert_ok(GpuArray_empty(&B, ctx, GA_FLOAT, 3, dims, GA_C_ORDER));
2638
ga_assert_ok(GpuArray_empty(&C, ctx, GA_FLOAT, 3, dims, GA_C_ORDER));
2739

40+
ga_assert_ok(GpuArray_write(&A, data, sizeof(data)));
41+
ga_assert_ok(GpuArray_write(&B, data, sizeof(data)));
42+
2843
ga_assert_ok(GpuArray_rgemmBatch_3d(cb_no_trans, cb_no_trans, 1, &A, &B, 0, &C, 1));
44+
45+
ga_assert_ok(GpuArray_read(data, sizeof(data), &C));
46+
47+
ck_assert_fbuf_eq(data, res, sizeof(res)/sizeof(float));
48+
}
49+
END_TEST
50+
51+
START_TEST(test_gemmBatch_3d_F) {
52+
GpuArray A;
53+
GpuArray B;
54+
GpuArray C;
55+
56+
size_t dims[3] = {2, 3, 3};
57+
float data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9,
58+
1, 2, 3, 4, 5, 6, 7, 8, 9};
59+
const float res[] = {42, 78, 78, 60, 114, 114, 51, 69, 96,
60+
66, 39, 111, 54, 54, 90, 78, 78, 132};
61+
62+
ga_assert_ok(GpuArray_empty(&A, ctx, GA_FLOAT, 3, dims, GA_F_ORDER));
63+
ga_assert_ok(GpuArray_empty(&B, ctx, GA_FLOAT, 3, dims, GA_F_ORDER));
64+
ga_assert_ok(GpuArray_empty(&C, ctx, GA_FLOAT, 3, dims, GA_C_ORDER));
65+
66+
ga_assert_ok(GpuArray_write(&A, data, sizeof(data)));
67+
ga_assert_ok(GpuArray_write(&B, data, sizeof(data)));
68+
69+
ga_assert_ok(GpuArray_rgemmBatch_3d(cb_no_trans, cb_no_trans, 1, &A, &B, 0, &C, 0));
70+
71+
ga_assert_ok(GpuArray_read(data, sizeof(data), &C));
72+
73+
ck_assert_fbuf_eq(data, res, sizeof(res)/sizeof(float));
74+
}
75+
END_TEST
76+
77+
START_TEST(test_gemmBatch_3d_S) {
78+
GpuArray A;
79+
GpuArray B;
80+
GpuArray C;
81+
ssize_t t;
82+
83+
size_t dims[3] = {2, 3, 3};
84+
float data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9,
85+
1, 2, 3, 4, 5, 6, 7, 8, 9};
86+
const float res[] = {14, 32, 50, 50, 122, 194, 32, 77, 122,
87+
26, 62, 98, 17, 53, 89, 44, 107, 170};
88+
89+
ga_assert_ok(GpuArray_empty(&A, ctx, GA_FLOAT, 3, dims, GA_F_ORDER));
90+
ga_assert_ok(GpuArray_empty(&B, ctx, GA_FLOAT, 3, dims, GA_C_ORDER));
91+
ga_assert_ok(GpuArray_empty(&C, ctx, GA_FLOAT, 3, dims, GA_C_ORDER));
92+
93+
ga_assert_ok(GpuArray_write(&A, data, sizeof(data)));
94+
ga_assert_ok(GpuArray_write(&B, data, sizeof(data)));
95+
96+
A.strides[0] = 8;
97+
A.strides[1] = 24;
98+
A.strides[2] = 4;
99+
GpuArray_fix_flags(&A);
100+
101+
t = B.strides[1];
102+
B.strides[1] = B.strides[2];
103+
B.strides[2] = t;
104+
GpuArray_fix_flags(&B);
105+
106+
ga_assert_ok(GpuArray_rgemmBatch_3d(cb_no_trans, cb_no_trans, 1, &A, &B, 0, &C, 1));
107+
108+
ga_assert_ok(GpuArray_read(data, sizeof(data), &C));
109+
110+
ck_assert_fbuf_eq(data, res, sizeof(res)/sizeof(float));
29111
}
30112
END_TEST
31113

@@ -34,7 +116,9 @@ Suite *get_suite(void) {
34116
TCase *tc = tcase_create("all");
35117
tcase_add_checked_fixture(tc, setup, teardown);
36118
tcase_set_timeout(tc, 16.0);
37-
tcase_add_test(tc, test_gemmBatch_3d);
119+
tcase_add_test(tc, test_gemmBatch_3d_C);
120+
tcase_add_test(tc, test_gemmBatch_3d_F);
121+
tcase_add_test(tc, test_gemmBatch_3d_S);
38122
suite_add_tcase(s, tc);
39123
return s;
40124
}

0 commit comments

Comments
 (0)