@@ -625,65 +625,34 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
625625 if (err != GA_NO_ERROR )
626626 goto cleanup ;
627627
628- if (C -> typecode == GA_HALF ){
629- //TODO: handle offset
630- assert (Ap -> offset == 0 );
631- assert (Bp -> offset == 0 );
632- assert (Cp -> offset == 0 );
633- //TODO: float should be half
634- err = gpublas_hgemmStridedBatch (o , transA , transB , m , n , k , alpha ,
635- Ap -> data , lda , Ap -> strides [0 ]/elsize ,
636- Bp -> data , ldb , Bp -> strides [0 ]/elsize ,
637- beta ,
638- Cp -> data , ldc , Cp -> strides [0 ]/elsize ,
639- batchCount , 0 );
640- goto cleanup ;
641- }
642-
643- A_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
644- B_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
645- C_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
646-
647- A_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
648- B_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
649- C_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
650-
651- for (i = 0 ; i < batchCount ; i ++ ) {
652- A_datas [i ] = Ap -> data ;
653- B_datas [i ] = Bp -> data ;
654- C_datas [i ] = Cp -> data ;
655- A_offsets [i ] = (Ap -> offset + i * Ap -> strides [0 ]) / elsize ;
656- B_offsets [i ] = (Bp -> offset + i * Bp -> strides [0 ]) / elsize ;
657- C_offsets [i ] = (Cp -> offset + i * Cp -> strides [0 ]) / elsize ;
658- }
659-
660628 switch (C -> typecode ) {
661629 case GA_HALF :
662- err = gpublas_hgemmBatch (o , transA , transB , m , n , k , (float )alpha ,
663- A_datas , A_offsets , lda ,
664- B_datas , B_offsets , ldb ,
665- (float )beta ,
666- C_datas , C_offsets , ldc , batchCount , 0 );
630+ err = gpublas_hgemm3d (o , transA , transB , m , n , k , (float )alpha ,
631+ Ap -> data , Ap -> offset /elsize , lda , Ap -> strides [0 ]/elsize ,
632+ Bp -> data , Bp -> offset /elsize , ldb , Bp -> strides [0 ]/elsize ,
633+ (float )beta ,
634+ Cp -> data , Cp -> offset /elsize , ldc , Cp -> strides [0 ]/elsize ,
635+ batchCount , 0 );
667636 break ;
668637 case GA_FLOAT :
669- err = gpublas_sgemmBatch (o , transA , transB , m , n , k , (float )alpha ,
670- A_datas , A_offsets , lda ,
671- B_datas , B_offsets , ldb ,
672- (float )beta ,
673- C_datas , C_offsets , ldc , batchCount , 0 );
638+ err = gpublas_sgemm3d (o , transA , transB , m , n , k , (float )alpha ,
639+ Ap -> data , Ap -> offset /elsize , lda , Ap -> strides [0 ]/elsize ,
640+ Bp -> data , Bp -> offset /elsize , ldb , Bp -> strides [0 ]/elsize ,
641+ (float )beta ,
642+ Cp -> data , Cp -> offset /elsize , ldc , Cp -> strides [0 ]/elsize ,
643+ batchCount , 0 );
674644 break ;
675645 case GA_DOUBLE :
676- err = gpublas_dgemmBatch (o , transA , transB , m , n , k , (double )alpha ,
677- A_datas , A_offsets , lda ,
678- B_datas , B_offsets , ldb ,
679- (double )beta ,
680- C_datas , C_offsets , ldc , batchCount , 0 );
646+ err = gpublas_dgemm3d (o , transA , transB , m , n , k , (double )alpha ,
647+ Ap -> data , Ap -> offset /elsize , lda , Ap -> strides [0 ]/elsize ,
648+ Bp -> data , Bp -> offset /elsize , ldb , Bp -> strides [0 ]/elsize ,
649+ (double )beta ,
650+ Cp -> data , Cp -> offset /elsize , ldc , Cp -> strides [0 ]/elsize ,
651+ batchCount , 0 );
681652 break ;
682653 }
683654
684655 cleanup :
685- free (A_datas ); free (B_datas ); free (C_datas );
686- free (A_offsets ); free (B_offsets ); free (C_offsets );
687656 if (Ap == & copyA )
688657 GpuArray_clear (& copyA );
689658 if (Bp == & copyB )
0 commit comments