Skip to content

Commit 4f396cd

Browse files
committed
Add fallback to the old code when the new functions don't work.
1 parent a7ade4a commit 4f396cd

1 file changed

Lines changed: 56 additions & 0 deletions

File tree

src/gpuarray_array_blas.c

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,62 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
649649
break;
650650
}
651651

652+
if (err == GA_DEVSUP_ERROR) {
653+
gpudata **A_datas = NULL, **B_datas = NULL, **C_datas = NULL;
654+
size_t *A_offsets = NULL, *B_offsets = NULL, *C_offsets = NULL;
655+
size_t i;
656+
657+
A_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
658+
B_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
659+
C_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
660+
661+
A_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
662+
B_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
663+
C_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
664+
665+
if (A_datas == NULL || B_datas == NULL || C_datas == NULL ||
666+
A_offsets == NULL || B_offsets == NULL || C_offsets) {
667+
err = error_sys(ctx->err, "malloc");
668+
goto old_cleanup;
669+
}
670+
671+
for (i = 0; i < batchCount; i++) {
672+
A_datas[i] = Ap->data;
673+
B_datas[i] = Bp->data;
674+
C_datas[i] = Cp->data;
675+
A_offsets[i] = (Ap->offset + i * Ap->strides[0]) / elsize;
676+
B_offsets[i] = (Bp->offset + i * Bp->strides[0]) / elsize;
677+
C_offsets[i] = (Cp->offset + i * Cp->strides[0]) / elsize;
678+
}
679+
680+
switch (C->typecode) {
681+
case GA_HALF:
682+
err = gpublas_hgemmBatch(o, transA, transB, m, n, k, (float)alpha,
683+
A_datas, A_offsets, lda,
684+
B_datas, B_offsets, ldb,
685+
(float)beta,
686+
C_datas, C_offsets, ldc, batchCount, 0);
687+
break;
688+
case GA_FLOAT:
689+
err = gpublas_sgemmBatch(o, transA, transB, m, n, k, (float)alpha,
690+
A_datas, A_offsets, lda,
691+
B_datas, B_offsets, ldb,
692+
(float)beta,
693+
C_datas, C_offsets, ldc, batchCount, 0);
694+
break;
695+
case GA_DOUBLE:
696+
err = gpublas_dgemmBatch(o, transA, transB, m, n, k, (double)alpha,
697+
A_datas, A_offsets, lda,
698+
B_datas, B_offsets, ldb,
699+
(double)beta,
700+
C_datas, C_offsets, ldc, batchCount, 0);
701+
break;
702+
}
703+
old_cleanup:
704+
free(A_datas); free(B_datas); free(C_datas);
705+
free(A_offsets); free(B_offsets); free(C_offsets);
706+
}
707+
652708
cleanup:
653709
if (Ap == &copyA)
654710
GpuArray_clear(&copyA);

0 commit comments

Comments
 (0)