@@ -510,37 +510,33 @@ static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
510510 cuda_exit (ctx );
511511 return GA_NO_ERROR ;
512512}
513- //TODO: change float to half
514- static int hgemmStridedBatch (cb_order order , cb_transpose transA , cb_transpose transB ,
515- size_t M , size_t N , size_t K , float alpha ,
516- gpudata * A , size_t lda , ssize_t strideA ,
517- gpudata * B , size_t ldb , ssize_t strideB ,
518- float beta , gpudata * C , size_t ldc , ssize_t strideC ,
519- size_t batchCount ) {
513+
514+ static int hgemm3D (cb_order order , cb_transpose transA , cb_transpose transB ,
515+ size_t M , size_t N , size_t K , float alpha ,
516+ gpudata * A , size_t offA , size_t lda , ssize_t strideA ,
517+ gpudata * B , size_t offB , size_t ldb , ssize_t strideB ,
518+ float beta , gpudata * C , size_t offC , size_t ldc , ssize_t strideC ,
519+ size_t batchCount ) {
520520 cuda_context * ctx ;
521521 blas_handle * h ;
522522 size_t t ;
523- ssize_t lt ;
523+ ssize_t st ;
524524 gpudata * T ;
525525 cb_transpose transT ;
526526 cublasStatus_t err ;
527- __half halpha , hbeta ;
528-
529- //ignore overflow, underflow, denormalized and inf values. Mayve also nan.
530- uint32_t x = (uint32_t )alpha ;
531- alpha = ((x >>16 )& 0x8000 )|((((x & 0x7f800000 )- 0x38000000 )>>13 )& 0x7c00 )|((x >>13 )& 0x03ff );
532- x = (uint32_t )beta ;
533- beta = ((x >>16 )& 0x8000 )|((((x & 0x7f800000 )- 0x38000000 )>>13 )& 0x7c00 )|((x >>13 )& 0x03ff );
534-
527+ ga_half_t halpha , hbeta ;
528+
535529 ASSERT_BUF (A );
536- if ( cublasHgemmStridedBatched == NULL )
537- return GA_DEVSUP_ERROR ;
530+ ASSERT_BUF ( B );
531+ ASSERT_BUF ( C ) ;
538532
539533 ctx = A -> ctx ;
540- // TODO: stride* are long long int in cuda, LARGE_VAL check for int.
534+
535+ if (cublasHgemmStridedBatched == NULL )
536+ return error_set (ctx -> error , GA_DEVSUP_ERROR , "cublasHgemmStridedBatched not available in your version of cuBLAS" );
537+
541538 if (LARGE_VAL (M ) || LARGE_VAL (N ) || LARGE_VAL (K ) ||
542539 LARGE_VAL (lda ) || LARGE_VAL (ldb ) || LARGE_VAL (ldc ) ||
543- LARGE_VAL (strideA ) || LARGE_VAL (strideB ) || LARGE_VAL (strideC ) ||
544540 LARGE_VAL (M * N ) || LARGE_VAL (M * K ) || LARGE_VAL (K * N ))
545541 return error_set (ctx -> err , GA_XLARGE_ERROR , "Passed-in sizes would overflow the ints in the cublas interface" );
546542
@@ -558,28 +554,108 @@ static int hgemmStridedBatch(cb_order order, cb_transpose transA, cb_transpose t
558554 t = lda ;
559555 lda = ldb ;
560556 ldb = t ;
557+ t = offA ;
558+ offA = offB ;
559+ offB = t ;
561560 transT = transA ;
562561 transA = transB ;
563562 transB = transT ;
564- lt = strideA ;
563+ st = strideA ;
565564 strideA = strideB ;
566- strideB = lt ;
565+ strideB = st ;
567566 }
568567
569- ASSERT_BUF ( A );
570- ASSERT_BUF ( B );
571- ASSERT_BUF ( C );
568+ halpha = ga_float2half ( alpha );
569+ hbeta = ga_float2half ( beta );
570+
572571 GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (A , CUDA_WAIT_READ ));
573572 GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (B , CUDA_WAIT_READ ));
574573 GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (C , CUDA_WAIT_ALL ));
575- raise (SIGINT );
576574 err = cublasHgemmStridedBatched (h -> h ,
577575 convT (transA ), convT (transB ),
578576 M , N , K , & halpha ,
579- (__half * )( A -> ptr ), ( int ) lda , strideA ,
580- (__half * )( B -> ptr ), ( int ) ldb , strideB ,
577+ (( __half * )A -> ptr ) + offA , lda , strideA ,
578+ (( __half * )B -> ptr ) + offB , ldb , strideB ,
581579 & hbeta ,
582- (__half * )(C -> ptr ), (int ) ldc , strideB ,
580+ ((__half * )C -> ptr ) + offC , ldc , strideB ,
581+ batchCount );
582+ if (err != CUBLAS_STATUS_SUCCESS ) {
583+ cuda_exit (ctx );
584+ return error_cublas (ctx -> err , "cublasHgemmStridedBatched" , err );
585+ }
586+
587+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (A , CUDA_WAIT_READ ));
588+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (B , CUDA_WAIT_READ ));
589+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (C , CUDA_WAIT_ALL ));
590+
591+ cuda_exit (ctx );
592+ return GA_NO_ERROR ;
593+ }
594+
595+ static int sgemm3D (cb_order order , cb_transpose transA , cb_transpose transB ,
596+ size_t M , size_t N , size_t K , float alpha ,
597+ gpudata * A , size_t offA , size_t lda , ssize_t strideA ,
598+ gpudata * B , size_t offB , size_t ldb , ssize_t strideB ,
599+ float beta , gpudata * C , size_t offC , size_t ldc , ssize_t strideC ,
600+ size_t batchCount ) {
601+ cuda_context * ctx ;
602+ blas_handle * h ;
603+ size_t t ;
604+ ssize_t st ;
605+ gpudata * T ;
606+ cb_transpose transT ;
607+ cublasStatus_t err ;
608+
609+ ASSERT_BUF (A );
610+ ASSERT_BUF (B );
611+ ASSERT_BUF (C );
612+
613+ ctx = A -> ctx ;
614+
615+ if (cublasSgemmStridedBatched == NULL )
616+ return error_set (ctx -> error , GA_DEVSUP_ERROR , "cublasSgemmStridedBatched not available in your version of cuBLAS" );
617+
618+ if (LARGE_VAL (M ) || LARGE_VAL (N ) || LARGE_VAL (K ) ||
619+ LARGE_VAL (lda ) || LARGE_VAL (ldb ) || LARGE_VAL (ldc ) ||
620+ LARGE_VAL (M * N ) || LARGE_VAL (M * K ) || LARGE_VAL (K * N ))
621+ return error_set (ctx -> err , GA_XLARGE_ERROR , "Passed-in sizes would overflow the ints in the cublas interface" );
622+
623+ h = (blas_handle * )ctx -> blas_handle ;
624+ cuda_enter (ctx );
625+
626+ if (order == cb_c ) {
627+ /* swap A and B */
628+ t = N ;
629+ N = M ;
630+ M = t ;
631+ T = A ;
632+ A = B ;
633+ B = T ;
634+ t = lda ;
635+ lda = ldb ;
636+ ldb = t ;
637+ t = offA ;
638+ offA = offB ;
639+ offB = t ;
640+ transT = transA ;
641+ transA = transB ;
642+ transB = transT ;
643+ st = strideA ;
644+ strideA = strideB ;
645+ strideB = st ;
646+ }
647+
648+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (A , CUDA_WAIT_READ ));
649+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (B , CUDA_WAIT_READ ));
650+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (C , CUDA_WAIT_ALL ));
651+
652+ err = cublasSgemmStridedBatched (h -> h ,
653+ convT (transA ), convT (transB ),
654+ M , N , K , & alpha ,
655+ ((float * )A -> ptr ) + offA , (int )lda , strideA ,
656+ ((float * )B -> ptr ) + offB , (int )ldb , strideB ,
657+ & beta ,
658+ ((float * )C -> ptr ) + offC , (int )ldc , strideB ,
583659 batchCount );
584660 if (err != CUBLAS_STATUS_SUCCESS ) {
585661 cuda_exit (ctx );
@@ -594,6 +670,84 @@ static int hgemmStridedBatch(cb_order order, cb_transpose transA, cb_transpose t
594670 return GA_NO_ERROR ;
595671}
596672
673+ static int dgemm3D (cb_order order , cb_transpose transA , cb_transpose transB ,
674+ size_t M , size_t N , size_t K , double alpha ,
675+ gpudata * A , size_t offA , size_t lda , ssize_t strideA ,
676+ gpudata * B , size_t offB , size_t ldb , ssize_t strideB ,
677+ double beta , gpudata * C , size_t offC , size_t ldc , ssize_t strideC ,
678+ size_t batchCount ) {
679+ cuda_context * ctx ;
680+ blas_handle * h ;
681+ size_t t ;
682+ ssize_t st ;
683+ gpudata * T ;
684+ cb_transpose transT ;
685+ cublasStatus_t err ;
686+
687+ ASSERT_BUF (A );
688+ ASSERT_BUF (B );
689+ ASSERT_BUF (C );
690+
691+ ctx = A -> ctx ;
692+
693+ if (cublasDgemmStridedBatched == NULL )
694+ return error_set (ctx -> error , GA_DEVSUP_ERROR , "cublasDgemmStridedBatched not available in your version of cuBLAS" );
695+
696+ if (LARGE_VAL (M ) || LARGE_VAL (N ) || LARGE_VAL (K ) ||
697+ LARGE_VAL (lda ) || LARGE_VAL (ldb ) || LARGE_VAL (ldc ) ||
698+ LARGE_VAL (M * N ) || LARGE_VAL (M * K ) || LARGE_VAL (K * N ))
699+ return error_set (ctx -> err , GA_XLARGE_ERROR , "Passed-in sizes would overflow the ints in the cublas interface" );
700+
701+ h = (blas_handle * )ctx -> blas_handle ;
702+ cuda_enter (ctx );
703+
704+ if (order == cb_c ) {
705+ /* swap A and B */
706+ t = N ;
707+ N = M ;
708+ M = t ;
709+ T = A ;
710+ A = B ;
711+ B = T ;
712+ t = lda ;
713+ lda = ldb ;
714+ ldb = t ;
715+ t = offA ;
716+ offA = offB ;
717+ offB = t ;
718+ transT = transA ;
719+ transA = transB ;
720+ transB = transT ;
721+ st = strideA ;
722+ strideA = strideB ;
723+ strideB = st ;
724+ }
725+
726+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (A , CUDA_WAIT_READ ));
727+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (B , CUDA_WAIT_READ ));
728+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (C , CUDA_WAIT_ALL ));
729+
730+ err = cublasDgemmStridedBatched (h -> h ,
731+ convT (transA ), convT (transB ),
732+ M , N , K , & alpha ,
733+ ((double * )A -> ptr ) + offA , (int )lda , strideA ,
734+ ((double * )B -> ptr ) + offB , (int )ldb , strideB ,
735+ & beta ,
736+ ((double * )C -> ptr ) + offC , (int )ldc , strideB ,
737+ batchCount );
738+ if (err != CUBLAS_STATUS_SUCCESS ) {
739+ cuda_exit (ctx );
740+ return error_cublas (ctx -> err , "cublasDgemmStridedBatched" , err );
741+ }
742+
743+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (A , CUDA_WAIT_READ ));
744+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (B , CUDA_WAIT_READ ));
745+ GA_CUDA_EXIT_ON_ERROR (ctx , cuda_record (C , CUDA_WAIT_ALL ));
746+
747+ cuda_exit (ctx );
748+ return GA_NO_ERROR ;
749+ }
750+
597751static int sgemmBatch (cb_order order , cb_transpose transA , cb_transpose transB ,
598752 size_t M , size_t N , size_t K , float alpha ,
599753 gpudata * * A , size_t * offA , size_t lda ,
@@ -1662,5 +1816,7 @@ gpuarray_blas_ops cublas_ops = {
16621816 NULL , /* hgerBatch */
16631817 sgerBatch ,
16641818 dgerBatch ,
1665- hgemmStridedBatch ,
1819+ hgemm3D ,
1820+ sgemm3D ,
1821+ dgemm3D
16661822};
0 commit comments