Skip to content

Commit cb4a79f

Browse files
committed
Change GpuArray_rgemmBatch_3d to use the new gemm3d functions.
1 parent 0767181 commit cb4a79f

1 file changed

Lines changed: 18 additions & 49 deletions

File tree

src/gpuarray_array_blas.c

Lines changed: 18 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -625,65 +625,34 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
625625
if (err != GA_NO_ERROR)
626626
goto cleanup;
627627

628-
if(C->typecode == GA_HALF){
629-
//TODO: handle offset
630-
assert (Ap->offset == 0);
631-
assert (Bp->offset == 0);
632-
assert (Cp->offset == 0);
633-
//TODO: float should be half
634-
err = gpublas_hgemmStridedBatch(o, transA, transB, m, n, k, alpha,
635-
Ap->data, lda, Ap->strides[0]/elsize,
636-
Bp->data, ldb, Bp->strides[0]/elsize,
637-
beta,
638-
Cp->data, ldc, Cp->strides[0]/elsize,
639-
batchCount, 0);
640-
goto cleanup;
641-
}
642-
643-
A_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
644-
B_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
645-
C_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
646-
647-
A_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
648-
B_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
649-
C_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
650-
651-
for (i = 0; i < batchCount; i++) {
652-
A_datas[i] = Ap->data;
653-
B_datas[i] = Bp->data;
654-
C_datas[i] = Cp->data;
655-
A_offsets[i] = (Ap->offset + i * Ap->strides[0]) / elsize;
656-
B_offsets[i] = (Bp->offset + i * Bp->strides[0]) / elsize;
657-
C_offsets[i] = (Cp->offset + i * Cp->strides[0]) / elsize;
658-
}
659-
660628
switch (C->typecode) {
661629
case GA_HALF:
662-
err = gpublas_hgemmBatch(o, transA, transB, m, n, k, (float)alpha,
663-
A_datas, A_offsets, lda,
664-
B_datas, B_offsets, ldb,
665-
(float)beta,
666-
C_datas, C_offsets, ldc, batchCount, 0);
630+
err = gpublas_hgemm3d(o, transA, transB, m, n, k, (float)alpha,
631+
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
632+
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
633+
(float)beta,
634+
Cp->data, Cp->offset/elsize, ldc, Cp->strides[0]/elsize,
635+
batchCount, 0);
667636
break;
668637
case GA_FLOAT:
669-
err = gpublas_sgemmBatch(o, transA, transB, m, n, k, (float)alpha,
670-
A_datas, A_offsets, lda,
671-
B_datas, B_offsets, ldb,
672-
(float)beta,
673-
C_datas, C_offsets, ldc, batchCount, 0);
638+
err = gpublas_sgemm3d(o, transA, transB, m, n, k, (float)alpha,
639+
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
640+
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
641+
(float)beta,
642+
Cp->data, Cp->offset/elsize, ldc, Cp->strides[0]/elsize,
643+
batchCount, 0);
674644
break;
675645
case GA_DOUBLE:
676-
err = gpublas_dgemmBatch(o, transA, transB, m, n, k, (double)alpha,
677-
A_datas, A_offsets, lda,
678-
B_datas, B_offsets, ldb,
679-
(double)beta,
680-
C_datas, C_offsets, ldc, batchCount, 0);
646+
err = gpublas_dgemm3d(o, transA, transB, m, n, k, (double)alpha,
647+
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
648+
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
649+
(double)beta,
650+
Cp->data, Cp->offset/elsize, ldc, Cp->strides[0]/elsize,
651+
batchCount, 0);
681652
break;
682653
}
683654

684655
cleanup:
685-
free(A_datas); free(B_datas); free(C_datas);
686-
free(A_offsets); free(B_offsets); free(C_offsets);
687656
if (Ap == &copyA)
688657
GpuArray_clear(&copyA);
689658
if (Bp == &copyB)

0 commit comments

Comments
 (0)