@@ -18,6 +18,9 @@ cdef extern from "gpuarray/blas.h":
1818 double beta, _GpuArray * C, int nocopy)
1919 int GpuArray_rger(double alpha, _GpuArray * X, _GpuArray * Y, _GpuArray * A,
2020 int nocopy)
21+ int GpuArray_rgemmBatch_3d(
22+ cb_transpose transA, cb_transpose transB, double alpha,
23+ _GpuArray * A, _GpuArray * B, double beta, _GpuArray * C, int nocopy)
2124
2225cdef api int pygpu_blas_rdot(GpuArray X, GpuArray Y, GpuArray Z, bint nocopy) except - 1 :
2326 cdef int err
@@ -52,6 +55,17 @@ cdef api int pygpu_blas_rger(double alpha, GpuArray X, GpuArray Y, GpuArray A,
5255 raise GpuArrayException(GpuArray_error(& X.ga, err), err)
5356 return 0
5457
58+ cdef api int pygpu_blas_rgemmBatch_3d(cb_transpose transA, cb_transpose transB,
59+ double alpha, GpuArray A, GpuArray B,
60+ double beta, GpuArray C, bint nocopy) except - 1 :
61+ cdef int err
62+ err = GpuArray_rgemmBatch_3d(transA, transB,
63+ alpha, & A.ga, & B.ga,
64+ beta, & C.ga, nocopy)
65+ if err != GA_NO_ERROR:
66+ raise GpuArrayException(GpuArray_error(& A.ga, err), err)
67+ return 0
68+
5569
5670def dot (GpuArray X , GpuArray Y , GpuArray Z = None , overwrite_z = False ):
5771 """ dot(X, Y, Z=None, overwrite_z=False)
@@ -78,14 +92,14 @@ def gemv(double alpha, GpuArray A, GpuArray X, double beta=0.0,
7892 transA = cb_no_trans
7993
8094 if A.ga.nd != 2 :
81- raise TypeError , " A is not a matrix"
95+ raise TypeError ( " A is not a matrix" )
8296 if transA == cb_no_trans:
8397 Yshp = A.ga.dimensions[0 ]
8498 else :
8599 Yshp = A.ga.dimensions[1 ]
86100 if Y is None :
87101 if beta != 0.0 :
88- raise ValueError , " Y not provided and beta != 0"
102+ raise ValueError ( " Y not provided and beta != 0" )
89103 Y = pygpu_empty(1 , & Yshp, A.ga.typecode, GA_ANY_ORDER, A.context, None )
90104 overwrite_y = True
91105
@@ -113,9 +127,9 @@ def gemm(double alpha, GpuArray A, GpuArray B, double beta, GpuArray C=None,
113127 transB = cb_no_trans
114128
115129 if A.ga.nd != 2 :
116- raise TypeError , " A is not a matrix"
130+ raise TypeError ( " A is not a matrix" )
117131 if B.ga.nd != 2 :
118- raise TypeError , " B is not a matrix"
132+ raise TypeError ( " B is not a matrix" )
119133 if transA == cb_no_trans:
120134 Cshp[0 ] = A.ga.dimensions[0 ]
121135 else :
@@ -126,7 +140,7 @@ def gemm(double alpha, GpuArray A, GpuArray B, double beta, GpuArray C=None,
126140 Cshp[1 ] = B.ga.dimensions[0 ]
127141 if C is None :
128142 if beta != 0.0 :
129- raise ValueError , " C not provided and beta != 0"
143+ raise ValueError ( " C not provided and beta != 0" )
130144 C = pygpu_empty(2 , Cshp, A.ga.typecode, GA_ANY_ORDER, A.context, None )
131145 overwrite_c = True
132146
@@ -153,3 +167,45 @@ def ger(double alpha, GpuArray X, GpuArray Y, GpuArray A=None,
153167 pygpu_blas_rger(alpha, X, Y, A, 0 )
154168
155169 return A
170+
171+ def gemmBatch_3d (double alpha , GpuArray A , GpuArray B ,
172+ double beta , GpuArray C = None ,
173+ trans_a = False , trans_b = False , overwrite_c = False ):
174+ """ gemmBatch_3d(alpha, A, B, beta, C=None, trans_a=False, trans_b=False, overwrite_c=False)
175+ """
176+ cdef cb_transpose transA
177+ cdef cb_transpose transB
178+ cdef size_t[3 ] Cshp
179+
180+ if trans_a:
181+ transA = cb_trans
182+ else :
183+ transA = cb_no_trans
184+ if trans_b:
185+ transB = cb_trans
186+ else :
187+ transB = cb_no_trans
188+
189+ if A.ga.nd != 3 :
190+ raise TypeError (" A is not a batch of matrices" )
191+ if B.ga.nd != 3 :
192+ raise TypeError (" B is not a batch of matrices" )
193+
194+ Cshp[0 ] = A.ga.dimensions[0 ]
195+ if transA == cb_no_trans:
196+ Cshp[1 ] = A.ga.dimensions[1 ]
197+ else :
198+ Cshp[1 ] = A.ga.dimensions[2 ]
199+ if transB == cb_no_trans:
200+ Cshp[2 ] = B.ga.dimensions[2 ]
201+ else :
202+ Cshp[2 ] = B.ga.dimensions[1 ]
203+ if C is None :
204+ if beta != 0.0 :
205+ raise ValueError (" C not provided and beta != 0" )
206+ C = pygpu_empty(3 , Cshp, A.ga.typecode, GA_ANY_ORDER, A.context, None )
207+ elif not overwrite_c:
208+ C = pygpu_copy(C, GA_ANY_ORDER)
209+ pygpu_blas_rgemmBatch_3d(transA, transB, alpha, A, B, beta, C, 0 )
210+
211+ return C
0 commit comments