Skip to content

Commit 68c9916

Browse files
committed
Add another case
1 parent 322681b commit 68c9916

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/gpuarray_array_blas.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,15 +563,19 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
563563
goto cleanup;
564564
}
565565
if (cA == 2) {
566-
lda = Ap->strides[2] / elsize;
566+
lda = Ap->dimensions[2] > 1
567+
? Ap->strides[2] / elsize
568+
: Ap->dimensions[1];
567569
if (o == cb_c) {
568570
if (transA == cb_no_trans)
569571
transA = cb_trans;
570572
else
571573
transA = cb_no_trans;
572574
}
573575
} else if (cA == 1) {
574-
lda = Ap->strides[1] / elsize;
576+
lda = Ap->dimensions[1] > 1
577+
? Ap->strides[1] / elsize
578+
: Ap->dimensions[2];
575579
if (o == cb_fortran) {
576580
if (transA == cb_no_trans)
577581
transA = cb_trans;

0 commit comments

Comments
 (0)