@@ -260,38 +260,48 @@ int gpublas_dgerBatch(cb_order order, size_t M, size_t N, double alpha,
260260}
261261
262262
263- int gpublas_hgemm3d (
263+ #define BLAS_OP3F (b , name , args ) \
264+ gpucontext *ctx; \
265+ if (batchCount == 0) return GA_NO_ERROR; \
266+ ctx = gpudata_context(b); \
267+ if (flags != 0) return error_set(ctx->err, GA_INVALID_ERROR, "flags is not 0"); \
268+ if (ctx->blas_ops->name) \
269+ return ctx->blas_ops->name args; \
270+ else \
271+ return error_fmt(ctx->err, GA_DEVSUP_ERROR, "Blas operation not supported by library in use: %s", #name)
272+
273+ int gpublas_hgemm3D (
264274 cb_order order , cb_transpose transA , cb_transpose transB ,
265275 size_t M , size_t N , size_t K , float alpha ,
266- gpudata * A , size_t lda , ssize_t strideA ,
267- gpudata * B , size_t ldb , ssize_t strideB ,
268- float beta , gpudata * C , size_t ldc , ssize_t strideC ,
276+ gpudata * A , size_t offA , size_t lda , ssize_t strideA ,
277+ gpudata * B , size_t offB , size_t ldb , ssize_t strideB ,
278+ float beta , gpudata * C , size_t offC , size_t ldc , ssize_t strideC ,
269279 size_t batchCount , int flags ) {
270- BLAS_OPBF (A , hgemm3d ,
271- (order , transA , transB , M , N , K , alpha , A , lda , strideA ,
272- B , ldb , strideB , beta , C , ldc , strideC , batchCount ));
280+ BLAS_OP3F (A , hgemm3D ,
281+ (order , transA , transB , M , N , K , alpha , A , offA , lda , strideA ,
282+ B , offB , ldb , strideB , beta , C , offC , ldc , strideC , batchCount ));
273283}
274284
275- int gpublas_sgemm3d (
285+ int gpublas_sgemm3D (
276286 cb_order order , cb_transpose transA , cb_transpose transB ,
277287 size_t M , size_t N , size_t K , float alpha ,
278- gpudata * A , size_t lda , ssize_t strideA ,
279- gpudata * B , size_t ldb , ssize_t strideB ,
280- float beta , gpudata * C , size_t ldc , ssize_t strideC ,
288+ gpudata * A , size_t offA , size_t lda , ssize_t strideA ,
289+ gpudata * B , size_t offB , size_t ldb , ssize_t strideB ,
290+ float beta , gpudata * C , size_t offC , size_t ldc , ssize_t strideC ,
281291 size_t batchCount , int flags ) {
282- BLAS_OPBF (A , sgemm3d ,
283- (order , transA , transB , M , N , K , alpha , A , lda , strideA ,
284- B , ldb , strideB , beta , C , ldc , strideC , batchCount ));
292+ BLAS_OP3F (A , sgemm3D ,
293+ (order , transA , transB , M , N , K , alpha , A , offA , lda , strideA ,
294+ B , offB , ldb , strideB , beta , C , offC , ldc , strideC , batchCount ));
285295}
286296
287- int gpublas_dgemm3d (
297+ int gpublas_dgemm3D (
288298 cb_order order , cb_transpose transA , cb_transpose transB ,
289- size_t M , size_t N , size_t K , float alpha ,
290- gpudata * A , size_t lda , ssize_t strideA ,
291- gpudata * B , size_t ldb , ssize_t strideB ,
292- float beta , gpudata * C , size_t ldc , ssize_t strideC ,
299+ size_t M , size_t N , size_t K , double alpha ,
300+ gpudata * A , size_t offA , size_t lda , ssize_t strideA ,
301+ gpudata * B , size_t offB , size_t ldb , ssize_t strideB ,
302+ double beta , gpudata * C , size_t offC , size_t ldc , ssize_t strideC ,
293303 size_t batchCount , int flags ) {
294- BLAS_OPBF (A , dgemm3d ,
295- (order , transA , transB , M , N , K , alpha , A , lda , strideA ,
296- B , ldb , strideB , beta , C , ldc , strideC , batchCount ));
304+ BLAS_OP3F (A , dgemm3D ,
305+ (order , transA , transB , M , N , K , alpha , A , offA , lda , strideA ,
306+ B , offB , ldb , strideB , beta , C , offC , ldc , strideC , batchCount ));
297307}
0 commit comments