@@ -835,14 +835,23 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
835835 return ctx -> err -> code ;
836836 }
837837
838- GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (Ta , CUDA_WAIT_READ ));
838+ if (cuda_wait (Ta , CUDA_WAIT_READ ) != GA_NO_ERROR ) {
839+ gpudata_release (Ta );
840+ cuda_exit (ctx );
841+ return ctx -> err -> code ;
842+ }
839843
840844 err = cublasSgemmBatched (h -> h ,
841845 convT (transA ), convT (transB ),
842846 M , N , K , & alpha ,
843847 (const float * * )Aa , lda ,
844848 (const float * * )Ba , ldb , & beta ,
845849 (float * * )Ca , ldc , batchCount );
850+ if (cuda_record (Ta , CUDA_WAIT_READ ) != GA_NO_ERROR ) {
851+ gpudata_release (Ta );
852+ cuda_exit (ctx );
853+ return ctx -> err -> code ;
854+ }
846855 gpudata_release (Ta );
847856 if (err != CUBLAS_STATUS_SUCCESS ) {
848857 cuda_exit (ctx );
@@ -964,15 +973,26 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
964973 return ctx -> err -> code ;
965974 }
966975
967- GA_CUDA_EXIT_ON_ERROR (ctx , cuda_wait (Ta , CUDA_WAIT_READ ));
976+ if (cuda_wait (Ta , CUDA_WAIT_READ ) != GA_NO_ERROR ) {
977+ gpudata_release (Ta );
978+ cuda_exit (ctx );
979+ return ctx -> err -> code ;
980+ }
968981
969982 err = cublasDgemmBatched (h -> h ,
970983 convT (transA ), convT (transB ),
971984 M , N , K , & alpha ,
972985 (const double * * )Aa , lda ,
973986 (const double * * )Ba , ldb , & beta ,
974987 (double * * )Ca , ldc , batchCount );
988+
989+ if (cuda_record (Ta , CUDA_WAIT_READ ) != GA_NO_ERROR ) {
990+ gpudata_release (Ta );
991+ cuda_exit (ctx );
992+ return ctx -> err -> code ;
993+ }
975994 gpudata_release (Ta );
995+
976996 if (err != CUBLAS_STATUS_SUCCESS ) {
977997 cuda_exit (ctx );
978998 return error_cublas (ctx -> err , "cublasDgemmBatched" , err );
0 commit comments