|
1 | | -import numpy |
| 1 | +from itertools import product |
| 2 | +import numpy |
2 | 3 | from nose.plugins.skip import SkipTest |
3 | 4 |
|
4 | 5 | from .support import (guard_devsup, gen_gpuarray, context) |
|
14 | 15 |
|
15 | 16 | import pygpu.blas as gblas |
16 | 17 |
|
| 18 | +def test_dot(): |
| 19 | + bools = [True, False] |
| 20 | + for N, dtype, offseted_i, sliced in product( |
| 21 | + [1, 256, 1337], ['float32', 'float64'], bools, bools): |
| 22 | + yield dot, N, dtype, offseted_i, sliced, True, False |
| 23 | + for overwrite, init_z in product(bools, bools): |
| 24 | + yield dot, 666, 'float32', False, False, overwrite, init_z |
| 25 | + |
| 26 | +@guard_devsup |
| 27 | +def dot(N, dtype, offseted_i, sliced, overwrite, init_z): |
| 28 | + cX, gX = gen_gpuarray((N,), dtype, offseted_inner=offseted_i, |
| 29 | + sliced=sliced, ctx=context) |
| 30 | + cY, gY = gen_gpuarray((N,), dtype, offseted_inner=offseted_i, |
| 31 | + sliced=sliced, ctx=context) |
| 32 | + if init_z: |
| 33 | + _, gZ = gen_gpuarray((), dtype, offseted_inner=offseted_i, |
| 34 | + sliced=sliced, ctx=context) |
| 35 | + else: |
| 36 | + _, gZ = None, None |
| 37 | + |
| 38 | + if dtype == 'float32': |
| 39 | + cr = fblas.sdot(cX, cY) |
| 40 | + else: |
| 41 | + cr = fblas.ddot(cX, cY) |
| 42 | + gr = gblas.dot(gX, gY, gZ, overwrite_z=overwrite) |
| 43 | + numpy.testing.assert_allclose(cr, numpy.asarray(gr), rtol=1e-6) |
| 44 | + |
| 45 | + |
17 | 46 | def test_gemv(): |
18 | | - for shape in [(100, 128), (128, 50)]: |
19 | | - for order in ['f', 'c']: |
20 | | - for trans in [False, True]: |
21 | | - for offseted_i in [True, False]: |
22 | | - for sliced in [1, 2, -1, -2]: |
23 | | - yield gemv, shape, 'float32', order, trans, \ |
24 | | - offseted_i, sliced, True, False |
25 | | - for overwrite in [True, False]: |
26 | | - for init_y in [True, False]: |
27 | | - yield gemv, (4, 3), 'float32', 'f', False, False, 1, \ |
28 | | - overwrite, init_y |
| 47 | + bools = [False, True] |
| 48 | + for shape, order, trans, offseted_i, sliced in product( |
| 49 | + [(100, 128), (128, 50)], 'fc', bools, bools, [1, 2, -1, -2]): |
| 50 | + yield gemv, shape, 'float32', order, trans, \ |
| 51 | + offseted_i, sliced, True, False |
| 52 | + for overwrite, init_y in product(bools, bools): |
| 53 | + yield gemv, (4, 3), 'float32', 'f', False, False, 1, \ |
| 54 | + overwrite, init_y |
29 | 55 | yield gemv, (32, 32), 'float64', 'f', False, False, 1, True, False |
30 | | - for alpha in [0, 1, -1, 0.6]: |
31 | | - for beta in [0, 1, -1, 0.6]: |
32 | | - for overwite in [True, False]: |
33 | | - yield gemv, (32, 32), 'float32', 'f', False, False, 1, \ |
34 | | - overwrite, True, alpha, beta |
35 | | - |
| 56 | + for alpha, beta, overwrite in product( |
| 57 | + [0, 1, -1, 0.6], [0, 1, -1, 0.6], bools): |
| 58 | + yield gemv, (32, 32), 'float32', 'f', False, False, 1, \ |
| 59 | + overwrite, True, alpha, beta |
36 | 60 |
|
37 | 61 | @guard_devsup |
38 | 62 | def gemv(shp, dtype, order, trans, offseted_i, sliced, |
@@ -65,28 +89,22 @@ def gemv(shp, dtype, order, trans, offseted_i, sliced, |
65 | 89 |
|
66 | 90 |
|
67 | 91 | def test_gemm(): |
68 | | - for m, n, k in [(48, 15, 32), (15, 32, 48)]: |
69 | | - for order in [('f', 'f', 'f'), ('c', 'c', 'c'), |
70 | | - ('f', 'f', 'c'), ('f', 'c', 'f'), |
71 | | - ('f', 'c', 'c'), ('c', 'f', 'f'), |
72 | | - ('c', 'f', 'c'), ('c', 'c', 'f')]: |
73 | | - for trans in [(False, False), (True, True), |
74 | | - (False, True), (True, False)]: |
75 | | - for offseted_o in [False, True]: |
76 | | - yield gemm, m, n, k, 'float32', order, trans, \ |
77 | | - offseted_o, 1, False, False |
78 | | - for sliced in [1, 2, -1, -2]: |
79 | | - for overwrite in [True, False]: |
80 | | - for init_res in [True, False]: |
81 | | - yield gemm, 4, 3, 2, 'float32', ('f', 'f', 'f'), \ |
82 | | - (False, False), False, sliced, overwrite, init_res |
| 92 | + bools = [False, True] |
| 93 | + for (m, n, k), order, trans, offseted_o in product( |
| 94 | + [(48, 15, 32), (15, 32, 48)], list(product(*['fc']*3)), |
| 95 | + list(product(bools, bools)), bools): |
| 96 | + yield gemm, m, n, k, 'float32', order, trans, \ |
| 97 | + offseted_o, 1, False, False |
| 98 | + for sliced, overwrite, init_res in product( |
| 99 | + [1, 2, -1, -2], bools, bools): |
| 100 | + yield gemm, 4, 3, 2, 'float32', ('f', 'f', 'f'), \ |
| 101 | + (False, False), False, sliced, overwrite, init_res |
83 | 102 | yield gemm, 32, 32, 32, 'float64', ('f', 'f', 'f'), (False, False), \ |
84 | 103 | False, 1, False, False |
85 | | - for alpha in [0, 1, -1, 0.6]: |
86 | | - for beta in [0, 1, -1, 0.6]: |
87 | | - for overwrite in [True, False]: |
88 | | - yield gemm, 32, 23, 32, 'float32', ('f', 'f', 'f'), \ |
89 | | - (False, False), False, 1, overwrite, True, alpha, beta |
| 104 | + for alpha, beta, overwrite in product( |
| 105 | + [0, 1, -1, 0.6], [0, 1, -1, 0.6], bools): |
| 106 | + yield gemm, 32, 23, 32, 'float32', ('f', 'f', 'f'), \ |
| 107 | + (False, False), False, 1, overwrite, True, alpha, beta |
90 | 108 |
|
91 | 109 | @guard_devsup |
92 | 110 | def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite, |
@@ -124,19 +142,13 @@ def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite, |
124 | 142 |
|
125 | 143 |
|
126 | 144 | def test_ger(): |
127 | | - for m, n in [(4, 5)]: |
128 | | - for order in ['f', 'c']: |
129 | | - for sliced_x in [1, 2, -2, -1]: |
130 | | - for sliced_y in [1, 2, -2, -1]: |
131 | | - yield ger, m, n, 'float32', order, sliced_x, sliced_y, \ |
132 | | - False |
133 | | - |
| 145 | + bools = [False, True] |
| 146 | + for (m,n), order, sliced_x, sliced_y in product( |
| 147 | + [(4,5)], 'fc', [1, 2, -2, -1], [1, 2, -2, -1]): |
| 148 | + yield ger, m, n, 'float32', order, sliced_x, sliced_y, False |
134 | 149 | yield ger, 4, 5, 'float64', 'f', 1, 1, False |
135 | | - |
136 | | - for init_res in [True, False]: |
137 | | - for overwrite in [True, False]: |
138 | | - yield ger, 4, 5, 'float32', 'f', 1, 1, init_res, overwrite |
139 | | - |
| 150 | + for init_res, overwrite in product(bools, bools): |
| 151 | + yield ger, 4, 5, 'float32', 'f', 1, 1, init_res, overwrite |
140 | 152 |
|
141 | 153 | def ger(m, n, dtype, order, sliced_x, sliced_y, init_res, overwrite=False): |
142 | 154 | cX, gX = gen_gpuarray((m,), dtype, order, sliced=sliced_x, ctx=context) |
|
0 commit comments