1515
1616import pygpu .blas as gblas
1717
18+
1819def test_dot ():
1920 bools = [True , False ]
2021 for N , dtype , offseted_i , sliced in product (
21- [1 , 256 , 1337 ], ['float32' , 'float64' ], bools , bools ):
22+ [1 , 256 , 1337 ], ['float32' , 'float64' ], bools , bools ):
2223 yield dot , N , dtype , offseted_i , sliced , True , False
2324 for overwrite , init_z in product (bools , bools ):
2425 yield dot , 666 , 'float32' , False , False , overwrite , init_z
2526
27+
2628@guard_devsup
2729def dot (N , dtype , offseted_i , sliced , overwrite , init_z ):
2830 cX , gX = gen_gpuarray ((N ,), dtype , offseted_inner = offseted_i ,
2931 sliced = sliced , ctx = context )
3032 cY , gY = gen_gpuarray ((N ,), dtype , offseted_inner = offseted_i ,
3133 sliced = sliced , ctx = context )
3234 if init_z :
33- _ , gZ = gen_gpuarray ((), dtype , offseted_inner = offseted_i ,
34- sliced = sliced , ctx = context )
35+ gZ = gen_gpuarray ((), dtype , offseted_inner = offseted_i ,
36+ sliced = sliced , ctx = context )[ 1 ]
3537 else :
36- _ , gZ = None , None
38+ gZ = None
3739
3840 if dtype == 'float32' :
3941 cr = fblas .sdot (cX , cY )
@@ -46,21 +48,22 @@ def dot(N, dtype, offseted_i, sliced, overwrite, init_z):
4648def test_gemv ():
4749 bools = [False , True ]
4850 for shape , order , trans , offseted_i , sliced in product (
49- [(100 , 128 ), (128 , 50 )], 'fc' , bools , bools , [1 , 2 , - 1 , - 2 ]):
51+ [(100 , 128 ), (128 , 50 )], 'fc' , bools , bools , [1 , 2 , - 1 , - 2 ]):
5052 yield gemv , shape , 'float32' , order , trans , \
5153 offseted_i , sliced , True , False
5254 for overwrite , init_y in product (bools , bools ):
5355 yield gemv , (4 , 3 ), 'float32' , 'f' , False , False , 1 , \
5456 overwrite , init_y
5557 yield gemv , (32 , 32 ), 'float64' , 'f' , False , False , 1 , True , False
5658 for alpha , beta , overwrite in product (
57- [0 , 1 , - 1 , 0.6 ], [0 , 1 , - 1 , 0.6 ], bools ):
59+ [0 , 1 , - 1 , 0.6 ], [0 , 1 , - 1 , 0.6 ], bools ):
5860 yield gemv , (32 , 32 ), 'float32' , 'f' , False , False , 1 , \
5961 overwrite , True , alpha , beta
6062
63+
6164@guard_devsup
6265def gemv (shp , dtype , order , trans , offseted_i , sliced ,
63- overwrite , init_y , alpha = 1.0 , beta = 0.0 ):
66+ overwrite , init_y , alpha = 1.0 , beta = 0.0 ):
6467 cA , gA = gen_gpuarray (shp , dtype , order = order , offseted_inner = offseted_i ,
6568 sliced = sliced , ctx = context )
6669 if trans :
@@ -92,31 +95,31 @@ def test_gemm():
9295 bools = [False , True ]
9396 for (m , n , k ), order , trans , offseted_o in product (
9497 [(48 , 15 , 32 ), (15 , 32 , 48 )], list (product (* ['fc' ]* 3 )),
95- list (product (bools , bools )), bools ):
98+ list (product (bools , bools )), bools ):
9699 yield gemm , m , n , k , 'float32' , order , trans , \
97100 offseted_o , 1 , False , False
98- for sliced , overwrite , init_res in product (
99- [1 , 2 , - 1 , - 2 ], bools , bools ):
101+ for sliced , overwrite , init_res in product ([1 , 2 , - 1 , - 2 ], bools , bools ):
100102 yield gemm , 4 , 3 , 2 , 'float32' , ('f' , 'f' , 'f' ), \
101103 (False , False ), False , sliced , overwrite , init_res
102104 yield gemm , 32 , 32 , 32 , 'float64' , ('f' , 'f' , 'f' ), (False , False ), \
103105 False , 1 , False , False
104106 for alpha , beta , overwrite in product (
105- [0 , 1 , - 1 , 0.6 ], [0 , 1 , - 1 , 0.6 ], bools ):
107+ [0 , 1 , - 1 , 0.6 ], [0 , 1 , - 1 , 0.6 ], bools ):
106108 yield gemm , 32 , 23 , 32 , 'float32' , ('f' , 'f' , 'f' ), \
107109 (False , False ), False , 1 , overwrite , True , alpha , beta
108110
111+
109112@guard_devsup
110113def gemm (m , n , k , dtype , order , trans , offseted_o , sliced , overwrite ,
111114 init_res , alpha = 1.0 , beta = 0.0 ):
112115 if trans [0 ]:
113- shpA = (k ,m )
116+ shpA = (k , m )
114117 else :
115- shpA = (m ,k )
118+ shpA = (m , k )
116119 if trans [1 ]:
117- shpB = (n ,k )
120+ shpB = (n , k )
118121 else :
119- shpB = (k ,n )
122+ shpB = (k , n )
120123
121124 cA , gA = gen_gpuarray (shpA , dtype , order = order [0 ],
122125 offseted_outer = offseted_o ,
@@ -125,7 +128,7 @@ def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,
125128 offseted_outer = offseted_o ,
126129 sliced = sliced , ctx = context )
127130 if init_res :
128- cC , gC = gen_gpuarray ((m ,n ), dtype , order = order [2 ], ctx = context )
131+ cC , gC = gen_gpuarray ((m , n ), dtype , order = order [2 ], ctx = context )
129132 else :
130133 cC , gC = None , None
131134
@@ -143,13 +146,14 @@ def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,
143146
144147def test_ger ():
145148 bools = [False , True ]
146- for (m ,n ), order , sliced_x , sliced_y in product (
147- [(4 ,5 )], 'fc' , [1 , 2 , - 2 , - 1 ], [1 , 2 , - 2 , - 1 ]):
149+ for (m , n ), order , sliced_x , sliced_y in product (
150+ [(4 , 5 )], 'fc' , [1 , 2 , - 2 , - 1 ], [1 , 2 , - 2 , - 1 ]):
148151 yield ger , m , n , 'float32' , order , sliced_x , sliced_y , False
149152 yield ger , 4 , 5 , 'float64' , 'f' , 1 , 1 , False
150153 for init_res , overwrite in product (bools , bools ):
151154 yield ger , 4 , 5 , 'float32' , 'f' , 1 , 1 , init_res , overwrite
152155
156+
153157def ger (m , n , dtype , order , sliced_x , sliced_y , init_res , overwrite = False ):
154158 cX , gX = gen_gpuarray ((m ,), dtype , order , sliced = sliced_x , ctx = context )
155159 cY , gY = gen_gpuarray ((n ,), dtype , order , sliced = sliced_y , ctx = context )
@@ -168,35 +172,37 @@ def ger(m, n, dtype, order, sliced_x, sliced_y, init_res, overwrite=False):
168172
169173 numpy .testing .assert_allclose (cr , numpy .asarray (gr ), rtol = 1e-6 )
170174
175+
171176def test_rgemmBatch_3d ():
172177 bools = [False , True ]
173178 for b , (m , n , k ), order , trans , offseted_o in product (
174- [1 , 17 , 31 ], [(24 , 7 , 16 ), (7 , 16 , 24 )], list (product ('fc' , 'fc' , 'c' )),
175- list (product (bools , bools )), bools ):
179+ [1 , 17 , 31 ], [(24 , 7 , 16 ), (7 , 16 , 24 )],
180+ list (product ('fc' , 'fc' , 'c' )),
181+ list (product (bools , bools )), bools ):
176182 yield rgemmBatch_3d , b , m , n , k , 'float32' , order , trans , \
177183 offseted_o , 1 , False , False
178- for sliced , overwrite , init_res in product (
179- [1 , 2 , - 1 , - 2 ], bools , bools ):
184+ for sliced , overwrite , init_res in product ([1 , 2 , - 1 , - 2 ], bools , bools ):
180185 yield rgemmBatch_3d , 5 , 4 , 3 , 2 , 'float32' , ('f' , 'f' , 'c' ), \
181186 (False , False ), False , sliced , overwrite , init_res
182- yield rgemmBatch_3d , 16 , 16 , 16 , 16 , 'float64' , ('f' , 'f' , 'c' ), ( False , False ), \
183- False , 1 , False , False
187+ yield rgemmBatch_3d , 16 , 16 , 16 , 16 , 'float64' , ('f' , 'f' , 'c' ), \
188+ ( False , False ), False , 1 , False , False
184189 for alpha , beta , overwrite in product (
185- [0 , 1 , - 1 , 0.6 ], [0 , 1 , - 1 , 0.6 ], bools ):
190+ [0 , 1 , - 1 , 0.6 ], [0 , 1 , - 1 , 0.6 ], bools ):
186191 yield rgemmBatch_3d , 16 , 16 , 9 , 16 , 'float32' , ('f' , 'f' , 'c' ), \
187192 (False , False ), False , 1 , overwrite , True , alpha , beta
188193
194+
189195@guard_devsup
190- def rgemmBatch_3d (b , m , n , k , dtype , order , trans , offseted_o , sliced , overwrite ,
191- init_res , alpha = 1.0 , beta = 0.0 ):
196+ def rgemmBatch_3d (b , m , n , k , dtype , order , trans , offseted_o , sliced ,
197+ overwrite , init_res , alpha = 1.0 , beta = 0.0 ):
192198 if trans [0 ]:
193- shpA = (b ,k , m )
199+ shpA = (b , k , m )
194200 else :
195- shpA = (b ,m , k )
201+ shpA = (b , m , k )
196202 if trans [1 ]:
197- shpB = (b ,n , k )
203+ shpB = (b , n , k )
198204 else :
199- shpB = (b ,k , n )
205+ shpB = (b , k , n )
200206
201207 cA , gA = gen_gpuarray (shpA , dtype , order = order [0 ],
202208 offseted_outer = offseted_o ,
@@ -205,21 +211,21 @@ def rgemmBatch_3d(b, m, n, k, dtype, order, trans, offseted_o, sliced, overwrite
205211 offseted_outer = offseted_o ,
206212 sliced = sliced , ctx = context )
207213 if init_res :
208- cC , gC = gen_gpuarray ((b ,m , n ), dtype , order = order [2 ], ctx = context )
214+ cC , gC = gen_gpuarray ((b , m , n ), dtype , order = order [2 ], ctx = context )
209215 else :
210216 cC , gC = None , None
211217
212- cr = numpy .empty ((b ,m , n ), dtype = dtype )
218+ cr = numpy .empty ((b , m , n ), dtype = dtype )
213219 if dtype == 'float32' :
214220 fn_gemm_c = fblas .sgemm
215221 else :
216222 fn_gemm_c = fblas .dgemm
217223 for i in range (b ):
218224 cCi = cC if cC is None else cC [i ]
219225 cr [i ] = fn_gemm_c (alpha , cA [i ], cB [i ], beta , cCi , trans_a = trans [0 ],
220- trans_b = trans [1 ], overwrite_c = overwrite )
226+ trans_b = trans [1 ], overwrite_c = overwrite )
221227
222228 gr = gblas .gemmBatch_3d (alpha , gA , gB , beta , gC , trans_a = trans [0 ],
223- trans_b = trans [1 ], overwrite_c = overwrite )
229+ trans_b = trans [1 ], overwrite_c = overwrite )
224230
225231 numpy .testing .assert_allclose (cr , numpy .asarray (gr ), rtol = 1e-5 )
0 commit comments