@@ -649,6 +649,62 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
649649 break ;
650650 }
651651
652+ if (err == GA_DEVSUP_ERROR ) {
653+ gpudata * * A_datas = NULL , * * B_datas = NULL , * * C_datas = NULL ;
654+ size_t * A_offsets = NULL , * B_offsets = NULL , * C_offsets = NULL ;
655+ size_t i ;
656+
657+ A_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
658+ B_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
659+ C_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
660+
661+ A_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
662+ B_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
663+ C_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
664+
665+ if (A_datas == NULL || B_datas == NULL || C_datas == NULL ||
666+ A_offsets == NULL || B_offsets == NULL || C_offsets ) {
667+ err = error_sys (ctx -> err , "malloc" );
668+ goto old_cleanup ;
669+ }
670+
671+ for (i = 0 ; i < batchCount ; i ++ ) {
672+ A_datas [i ] = Ap -> data ;
673+ B_datas [i ] = Bp -> data ;
674+ C_datas [i ] = Cp -> data ;
675+ A_offsets [i ] = (Ap -> offset + i * Ap -> strides [0 ]) / elsize ;
676+ B_offsets [i ] = (Bp -> offset + i * Bp -> strides [0 ]) / elsize ;
677+ C_offsets [i ] = (Cp -> offset + i * Cp -> strides [0 ]) / elsize ;
678+ }
679+
680+ switch (C -> typecode ) {
681+ case GA_HALF :
682+ err = gpublas_hgemmBatch (o , transA , transB , m , n , k , (float )alpha ,
683+ A_datas , A_offsets , lda ,
684+ B_datas , B_offsets , ldb ,
685+ (float )beta ,
686+ C_datas , C_offsets , ldc , batchCount , 0 );
687+ break ;
688+ case GA_FLOAT :
689+ err = gpublas_sgemmBatch (o , transA , transB , m , n , k , (float )alpha ,
690+ A_datas , A_offsets , lda ,
691+ B_datas , B_offsets , ldb ,
692+ (float )beta ,
693+ C_datas , C_offsets , ldc , batchCount , 0 );
694+ break ;
695+ case GA_DOUBLE :
696+ err = gpublas_dgemmBatch (o , transA , transB , m , n , k , (double )alpha ,
697+ A_datas , A_offsets , lda ,
698+ B_datas , B_offsets , ldb ,
699+ (double )beta ,
700+ C_datas , C_offsets , ldc , batchCount , 0 );
701+ break ;
702+ }
703+ old_cleanup :
704+ free (A_datas ); free (B_datas ); free (C_datas );
705+ free (A_offsets ); free (B_offsets ); free (C_offsets );
706+ }
707+
652708 cleanup :
653709 if (Ap == & copyA )
654710 GpuArray_clear (& copyA );
0 commit comments