Skip to content

Commit ddb016b

Browse files
committed
Add in functions for 3d batch gemm for all dtypes and make them optional.
1 parent ed74310 commit ddb016b

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/loaders/libcublas.fn

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ DEF_PROC_OPT(cublasSgemmEx, (cublasHandle_t handle, cublasOperation_t transa, cu
2424
DEF_PROC(cublasSgemmBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *Aarray[], int lda, const float *Barray[], int ldb, const float *beta, float *Carray[], int ldc, int batchCount));
2525
DEF_PROC(cublasDgemmBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *Aarray[], int lda, const double *Barray[], int ldb, const double *beta, double *Carray[], int ldc, int batchCount));
2626

27-
DEF_PROC(cublasHgemmStridedBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half *alpha, const __half *A, int lda, long long int strideA, const __half *B, int ldb, long long int strideB, const __half *beta, __half *C, int ldc, long long int strideC, int batchCount));
27+
DEF_PROC_OPT(cublasHgemmStridedBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half *alpha, const __half *A, int lda, long long int strideA, const __half *B, int ldb, long long int strideB, const __half *beta, __half *C, int ldc, long long int strideC, int batchCount));
28+
DEF_PROC_OPT(cublasSgemmStridedBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *A, int lda, long long int strideA, const float *B, int ldb, long long int strideB, const float *beta, float *C, int ldc, long long int strideC, int batchCount));
29+
DEF_PROC_OPT(cublasDgemmStridedBatched, (cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double *alpha, const double *A, int lda, long long int strideA, const double *B, int ldb, long long int strideB, const double *beta, double *C, int ldc, long long int strideC, int batchCount));

0 commit comments

Comments
 (0)