Skip to content

Commit 4fd50fb

Browse files
authored
Merge pull request #419 from khaotik/gemmb3d_py
python interface for rgemmBatch_3d
2 parents 5db51f9 + 1cc02f5 commit 4fd50fb

2 files changed

Lines changed: 117 additions & 5 deletions

File tree

pygpu/blas.pyx

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ cdef extern from "gpuarray/blas.h":
1818
double beta, _GpuArray *C, int nocopy)
1919
int GpuArray_rger(double alpha, _GpuArray *X, _GpuArray *Y, _GpuArray *A,
2020
int nocopy)
21+
int GpuArray_rgemmBatch_3d(
22+
cb_transpose transA, cb_transpose transB, double alpha,
23+
_GpuArray *A, _GpuArray *B, double beta, _GpuArray *C, int nocopy)
2124

2225
cdef api int pygpu_blas_rdot(GpuArray X, GpuArray Y, GpuArray Z, bint nocopy) except -1:
2326
cdef int err
@@ -52,6 +55,17 @@ cdef api int pygpu_blas_rger(double alpha, GpuArray X, GpuArray Y, GpuArray A,
5255
raise GpuArrayException(GpuArray_error(&X.ga, err), err)
5356
return 0
5457

58+
cdef api int pygpu_blas_rgemmBatch_3d(cb_transpose transA, cb_transpose transB,
59+
double alpha, GpuArray A, GpuArray B,
60+
double beta, GpuArray C, bint nocopy) except -1:
61+
cdef int err
62+
err = GpuArray_rgemmBatch_3d(transA, transB,
63+
alpha, &A.ga, &B.ga,
64+
beta, &C.ga, nocopy)
65+
if err != GA_NO_ERROR:
66+
raise GpuArrayException(GpuArray_error(&A.ga, err), err)
67+
return 0
68+
5569

5670
def dot(GpuArray X, GpuArray Y, GpuArray Z=None, overwrite_z=False):
5771
"""dot(X, Y, Z=None, overwrite_z=False)
@@ -78,14 +92,14 @@ def gemv(double alpha, GpuArray A, GpuArray X, double beta=0.0,
7892
transA = cb_no_trans
7993

8094
if A.ga.nd != 2:
81-
raise TypeError, "A is not a matrix"
95+
raise TypeError("A is not a matrix")
8296
if transA == cb_no_trans:
8397
Yshp = A.ga.dimensions[0]
8498
else:
8599
Yshp = A.ga.dimensions[1]
86100
if Y is None:
87101
if beta != 0.0:
88-
raise ValueError, "Y not provided and beta != 0"
102+
raise ValueError("Y not provided and beta != 0")
89103
Y = pygpu_empty(1, &Yshp, A.ga.typecode, GA_ANY_ORDER, A.context, None)
90104
overwrite_y = True
91105

@@ -113,9 +127,9 @@ def gemm(double alpha, GpuArray A, GpuArray B, double beta, GpuArray C=None,
113127
transB = cb_no_trans
114128

115129
if A.ga.nd != 2:
116-
raise TypeError, "A is not a matrix"
130+
raise TypeError("A is not a matrix")
117131
if B.ga.nd != 2:
118-
raise TypeError, "B is not a matrix"
132+
raise TypeError("B is not a matrix")
119133
if transA == cb_no_trans:
120134
Cshp[0] = A.ga.dimensions[0]
121135
else:
@@ -126,7 +140,7 @@ def gemm(double alpha, GpuArray A, GpuArray B, double beta, GpuArray C=None,
126140
Cshp[1] = B.ga.dimensions[0]
127141
if C is None:
128142
if beta != 0.0:
129-
raise ValueError, "C not provided and beta != 0"
143+
raise ValueError("C not provided and beta != 0")
130144
C = pygpu_empty(2, Cshp, A.ga.typecode, GA_ANY_ORDER, A.context, None)
131145
overwrite_c = True
132146

@@ -153,3 +167,45 @@ def ger(double alpha, GpuArray X, GpuArray Y, GpuArray A=None,
153167
pygpu_blas_rger(alpha, X, Y, A, 0)
154168

155169
return A
170+
171+
def gemmBatch_3d(double alpha, GpuArray A, GpuArray B,
172+
double beta, GpuArray C=None,
173+
trans_a=False, trans_b=False, overwrite_c=False):
174+
"""gemmBatch_3d(alpha, A, B, beta, C=None, trans_a=False, trans_b=False, overwrite_c=False)
175+
"""
176+
cdef cb_transpose transA
177+
cdef cb_transpose transB
178+
cdef size_t[3] Cshp
179+
180+
if trans_a:
181+
transA = cb_trans
182+
else:
183+
transA = cb_no_trans
184+
if trans_b:
185+
transB = cb_trans
186+
else:
187+
transB = cb_no_trans
188+
189+
if A.ga.nd != 3:
190+
raise TypeError("A is not a batch of matrices")
191+
if B.ga.nd != 3:
192+
raise TypeError("B is not a batch of matrices")
193+
194+
Cshp[0] = A.ga.dimensions[0]
195+
if transA == cb_no_trans:
196+
Cshp[1] = A.ga.dimensions[1]
197+
else:
198+
Cshp[1] = A.ga.dimensions[2]
199+
if transB == cb_no_trans:
200+
Cshp[2] = B.ga.dimensions[2]
201+
else:
202+
Cshp[2] = B.ga.dimensions[1]
203+
if C is None:
204+
if beta != 0.0:
205+
raise ValueError("C not provided and beta != 0")
206+
C = pygpu_empty(3, Cshp, A.ga.typecode, GA_ANY_ORDER, A.context, None)
207+
elif not overwrite_c:
208+
C = pygpu_copy(C, GA_ANY_ORDER)
209+
pygpu_blas_rgemmBatch_3d(transA, transB, alpha, A, B, beta, C, 0)
210+
211+
return C

pygpu/tests/test_blas.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,59 @@ def ger(m, n, dtype, order, sliced_x, sliced_y, init_res, overwrite=False):
167167
gr = gblas.ger(1.0, gX, gY, gA, overwrite_a=overwrite)
168168

169169
numpy.testing.assert_allclose(cr, numpy.asarray(gr), rtol=1e-6)
170+
171+
def test_rgemmBatch_3d():
172+
bools = [False, True]
173+
for b, (m, n, k), order, trans, offseted_o in product(
174+
[1, 17, 31], [(24, 7, 16), (7, 16, 24)], list(product('fc', 'fc', 'c')),
175+
list(product(bools, bools)), bools):
176+
yield rgemmBatch_3d, b, m, n, k, 'float32', order, trans, \
177+
offseted_o, 1, False, False
178+
for sliced, overwrite, init_res in product(
179+
[1, 2, -1, -2], bools, bools):
180+
yield rgemmBatch_3d, 5, 4, 3, 2, 'float32', ('f', 'f', 'c'), \
181+
(False, False), False, sliced, overwrite, init_res
182+
yield rgemmBatch_3d, 16, 16, 16, 16, 'float64', ('f', 'f', 'c'), (False, False), \
183+
False, 1, False, False
184+
for alpha, beta, overwrite in product(
185+
[0, 1, -1, 0.6], [0, 1, -1, 0.6], bools):
186+
yield rgemmBatch_3d, 16, 16, 9, 16, 'float32', ('f', 'f', 'c'), \
187+
(False, False), False, 1, overwrite, True, alpha, beta
188+
189+
@guard_devsup
190+
def rgemmBatch_3d(b, m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,
191+
init_res, alpha=1.0, beta=0.0):
192+
if trans[0]:
193+
shpA = (b,k,m)
194+
else:
195+
shpA = (b,m,k)
196+
if trans[1]:
197+
shpB = (b,n,k)
198+
else:
199+
shpB = (b,k,n)
200+
201+
cA, gA = gen_gpuarray(shpA, dtype, order=order[0],
202+
offseted_outer=offseted_o,
203+
sliced=sliced, ctx=context)
204+
cB, gB = gen_gpuarray(shpB, dtype, order=order[1],
205+
offseted_outer=offseted_o,
206+
sliced=sliced, ctx=context)
207+
if init_res:
208+
cC, gC = gen_gpuarray((b,m,n), dtype, order=order[2], ctx=context)
209+
else:
210+
cC, gC = None, None
211+
212+
cr = numpy.empty((b,m,n), dtype=dtype)
213+
if dtype == 'float32':
214+
fn_gemm_c = fblas.sgemm
215+
else:
216+
fn_gemm_c = fblas.dgemm
217+
for i in range(b):
218+
cCi = cC if cC is None else cC[i]
219+
cr[i] = fn_gemm_c(alpha, cA[i], cB[i], beta, cCi, trans_a=trans[0],
220+
trans_b=trans[1], overwrite_c=overwrite)
221+
222+
gr = gblas.gemmBatch_3d(alpha, gA, gB, beta, gC, trans_a=trans[0],
223+
trans_b=trans[1], overwrite_c=overwrite)
224+
225+
numpy.testing.assert_allclose(cr, numpy.asarray(gr), rtol=1e-5)

0 commit comments

Comments
 (0)