Skip to content

Commit 757f96d

Browse files
authored
Merge pull request #294 from khaotik/blas_dot
BLAS vector-vector dot
2 parents 07269f1 + cae3671 commit 757f96d

14 files changed

Lines changed: 529 additions & 55 deletions

pygpu/blas.pyx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ cdef extern from "gpuarray/buffer_blas.h":
1010
cb_conj_trans
1111

1212
cdef extern from "gpuarray/blas.h":
13+
int GpuArray_rdot(_GpuArray *X, _GpuArray *Y, _GpuArray *Z, int nocopy)
1314
int GpuArray_rgemv(cb_transpose transA, double alpha, _GpuArray *A,
1415
_GpuArray *X, double beta, _GpuArray *Y, int nocopy)
1516
int GpuArray_rgemm(cb_transpose transA, cb_transpose transB,
@@ -18,6 +19,13 @@ cdef extern from "gpuarray/blas.h":
1819
int GpuArray_rger(double alpha, _GpuArray *X, _GpuArray *Y, _GpuArray *A,
1920
int nocopy)
2021

22+
cdef api int pygpu_blas_rdot(GpuArray X, GpuArray Y, GpuArray Z, bint nocopy) except -1:
23+
cdef int err
24+
err = GpuArray_rdot(&X.ga, &Y.ga, &Z.ga, nocopy)
25+
if err != GA_NO_ERROR:
26+
raise GpuArrayException(GpuArray_error(&X.ga, err), err)
27+
return 0
28+
2129
cdef api int pygpu_blas_rgemv(cb_transpose transA, double alpha, GpuArray A,
2230
GpuArray X, double beta, GpuArray Y,
2331
bint nocopy) except -1:
@@ -45,6 +53,16 @@ cdef api int pygpu_blas_rger(double alpha, GpuArray X, GpuArray Y, GpuArray A,
4553
return 0
4654

4755

56+
def dot(GpuArray X, GpuArray Y, GpuArray Z=None, overwrite_z=False):
57+
if Z is None:
58+
Z = pygpu_empty(0, NULL, X.typecode, GA_ANY_ORDER, X.context, None)
59+
overwrite_z = True
60+
61+
if not overwrite_z:
62+
Z = pygpu_copy(Z, GA_ANY_ORDER)
63+
pygpu_blas_rdot(X, Y, Z, 0)
64+
return Z
65+
4866
def gemv(double alpha, GpuArray A, GpuArray X, double beta=0.0,
4967
GpuArray Y=None, trans_a=False, overwrite_y=False):
5068
cdef cb_transpose transA

pygpu/tests/test_blas.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import numpy
1+
from itertools import product
2+
import numpy
23
from nose.plugins.skip import SkipTest
34

45
from .support import (guard_devsup, gen_gpuarray, context)
@@ -14,25 +15,48 @@
1415

1516
import pygpu.blas as gblas
1617

18+
def test_dot():
19+
bools = [True, False]
20+
for N, dtype, offseted_i, sliced in product(
21+
[1, 256, 1337], ['float32', 'float64'], bools, bools):
22+
yield dot, N, dtype, offseted_i, sliced, True, False
23+
for overwrite, init_z in product(bools, bools):
24+
yield dot, 666, 'float32', False, False, overwrite, init_z
25+
26+
@guard_devsup
27+
def dot(N, dtype, offseted_i, sliced, overwrite, init_z):
28+
cX, gX = gen_gpuarray((N,), dtype, offseted_inner=offseted_i,
29+
sliced=sliced, ctx=context)
30+
cY, gY = gen_gpuarray((N,), dtype, offseted_inner=offseted_i,
31+
sliced=sliced, ctx=context)
32+
if init_z:
33+
_, gZ = gen_gpuarray((), dtype, offseted_inner=offseted_i,
34+
sliced=sliced, ctx=context)
35+
else:
36+
_, gZ = None, None
37+
38+
if dtype == 'float32':
39+
cr = fblas.sdot(cX, cY)
40+
else:
41+
cr = fblas.ddot(cX, cY)
42+
gr = gblas.dot(gX, gY, gZ, overwrite_z=overwrite)
43+
numpy.testing.assert_allclose(cr, numpy.asarray(gr), rtol=1e-6)
44+
45+
1746
def test_gemv():
18-
for shape in [(100, 128), (128, 50)]:
19-
for order in ['f', 'c']:
20-
for trans in [False, True]:
21-
for offseted_i in [True, False]:
22-
for sliced in [1, 2, -1, -2]:
23-
yield gemv, shape, 'float32', order, trans, \
24-
offseted_i, sliced, True, False
25-
for overwrite in [True, False]:
26-
for init_y in [True, False]:
27-
yield gemv, (4, 3), 'float32', 'f', False, False, 1, \
28-
overwrite, init_y
47+
bools = [False, True]
48+
for shape, order, trans, offseted_i, sliced in product(
49+
[(100, 128), (128, 50)], 'fc', bools, bools, [1, 2, -1, -2]):
50+
yield gemv, shape, 'float32', order, trans, \
51+
offseted_i, sliced, True, False
52+
for overwrite, init_y in product(bools, bools):
53+
yield gemv, (4, 3), 'float32', 'f', False, False, 1, \
54+
overwrite, init_y
2955
yield gemv, (32, 32), 'float64', 'f', False, False, 1, True, False
30-
for alpha in [0, 1, -1, 0.6]:
31-
for beta in [0, 1, -1, 0.6]:
32-
for overwite in [True, False]:
33-
yield gemv, (32, 32), 'float32', 'f', False, False, 1, \
34-
overwrite, True, alpha, beta
35-
56+
for alpha, beta, overwrite in product(
57+
[0, 1, -1, 0.6], [0, 1, -1, 0.6], bools):
58+
yield gemv, (32, 32), 'float32', 'f', False, False, 1, \
59+
overwrite, True, alpha, beta
3660

3761
@guard_devsup
3862
def gemv(shp, dtype, order, trans, offseted_i, sliced,
@@ -65,28 +89,22 @@ def gemv(shp, dtype, order, trans, offseted_i, sliced,
6589

6690

6791
def test_gemm():
68-
for m, n, k in [(48, 15, 32), (15, 32, 48)]:
69-
for order in [('f', 'f', 'f'), ('c', 'c', 'c'),
70-
('f', 'f', 'c'), ('f', 'c', 'f'),
71-
('f', 'c', 'c'), ('c', 'f', 'f'),
72-
('c', 'f', 'c'), ('c', 'c', 'f')]:
73-
for trans in [(False, False), (True, True),
74-
(False, True), (True, False)]:
75-
for offseted_o in [False, True]:
76-
yield gemm, m, n, k, 'float32', order, trans, \
77-
offseted_o, 1, False, False
78-
for sliced in [1, 2, -1, -2]:
79-
for overwrite in [True, False]:
80-
for init_res in [True, False]:
81-
yield gemm, 4, 3, 2, 'float32', ('f', 'f', 'f'), \
82-
(False, False), False, sliced, overwrite, init_res
92+
bools = [False, True]
93+
for (m, n, k), order, trans, offseted_o in product(
94+
[(48, 15, 32), (15, 32, 48)], list(product(*['fc']*3)),
95+
list(product(bools, bools)), bools):
96+
yield gemm, m, n, k, 'float32', order, trans, \
97+
offseted_o, 1, False, False
98+
for sliced, overwrite, init_res in product(
99+
[1, 2, -1, -2], bools, bools):
100+
yield gemm, 4, 3, 2, 'float32', ('f', 'f', 'f'), \
101+
(False, False), False, sliced, overwrite, init_res
83102
yield gemm, 32, 32, 32, 'float64', ('f', 'f', 'f'), (False, False), \
84103
False, 1, False, False
85-
for alpha in [0, 1, -1, 0.6]:
86-
for beta in [0, 1, -1, 0.6]:
87-
for overwrite in [True, False]:
88-
yield gemm, 32, 23, 32, 'float32', ('f', 'f', 'f'), \
89-
(False, False), False, 1, overwrite, True, alpha, beta
104+
for alpha, beta, overwrite in product(
105+
[0, 1, -1, 0.6], [0, 1, -1, 0.6], bools):
106+
yield gemm, 32, 23, 32, 'float32', ('f', 'f', 'f'), \
107+
(False, False), False, 1, overwrite, True, alpha, beta
90108

91109
@guard_devsup
92110
def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,
@@ -124,19 +142,13 @@ def gemm(m, n, k, dtype, order, trans, offseted_o, sliced, overwrite,
124142

125143

126144
def test_ger():
127-
for m, n in [(4, 5)]:
128-
for order in ['f', 'c']:
129-
for sliced_x in [1, 2, -2, -1]:
130-
for sliced_y in [1, 2, -2, -1]:
131-
yield ger, m, n, 'float32', order, sliced_x, sliced_y, \
132-
False
133-
145+
bools = [False, True]
146+
for (m,n), order, sliced_x, sliced_y in product(
147+
[(4,5)], 'fc', [1, 2, -2, -1], [1, 2, -2, -1]):
148+
yield ger, m, n, 'float32', order, sliced_x, sliced_y, False
134149
yield ger, 4, 5, 'float64', 'f', 1, 1, False
135-
136-
for init_res in [True, False]:
137-
for overwrite in [True, False]:
138-
yield ger, 4, 5, 'float32', 'f', 1, 1, init_res, overwrite
139-
150+
for init_res, overwrite in product(bools, bools):
151+
yield ger, 4, 5, 'float32', 'f', 1, 1, init_res, overwrite
140152

141153
def ger(m, n, dtype, order, sliced_x, sliced_y, init_res, overwrite=False):
142154
cX, gX = gen_gpuarray((m,), dtype, order, sliced=sliced_x, ctx=context)

setup.py

100644100755
File mode changed.

src/gpuarray/blas.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
extern "C" {
99
#endif
1010

11+
// only for vector-vector dot
12+
GPUARRAY_PUBLIC int GpuArray_rdot( GpuArray *X, GpuArray *Y,
13+
GpuArray *Z, int nocopy);
14+
#define GpuArray_hdot GpuArray_rdot
15+
#define GpuArray_sdot GpuArray_rdot
16+
#define GpuArray_ddot GpuArray_rdot
1117
GPUARRAY_PUBLIC int GpuArray_rgemv(cb_transpose transA, double alpha,
1218
GpuArray *A, GpuArray *X, double beta,
1319
GpuArray *Y, int nocopy);

src/gpuarray/buffer_blas.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,24 @@ GPUARRAY_PUBLIC void gpublas_teardown(gpucontext *ctx);
3838

3939
GPUARRAY_PUBLIC const char *gpublas_error(gpucontext *ctx);
4040

41+
GPUARRAY_PUBLIC int gpublas_hdot(
42+
size_t N,
43+
gpudata *X, size_t offX, size_t incX,
44+
gpudata *Y, size_t offY, size_t incY,
45+
gpudata *Z, size_t offZ);
46+
47+
GPUARRAY_PUBLIC int gpublas_sdot(
48+
size_t N,
49+
gpudata *X, size_t offX, size_t incX,
50+
gpudata *Y, size_t offY, size_t incY,
51+
gpudata *Z, size_t offZ);
52+
53+
GPUARRAY_PUBLIC int gpublas_ddot(
54+
size_t N,
55+
gpudata *X, size_t offX, size_t incX,
56+
gpudata *Y, size_t offY, size_t incY,
57+
gpudata *Z, size_t offZ);
58+
4159
GPUARRAY_PUBLIC int gpublas_hgemv(
4260
cb_order order, cb_transpose transA, size_t M, size_t N, float alpha,
4361
gpudata *A, size_t offA, size_t lda, gpudata *X, size_t offX, int incX,

src/gpuarray_array_blas.c

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,91 @@
55
#include "gpuarray/util.h"
66
#include "gpuarray/error.h"
77

8+
int GpuArray_rdot( GpuArray *X, GpuArray *Y,
9+
GpuArray *Z, int nocopy) {
10+
GpuArray *Xp = X;
11+
GpuArray copyX;
12+
GpuArray *Yp = Y;
13+
GpuArray copyY;
14+
GpuArray *Zp = Z;
15+
size_t n;
16+
void *ctx;
17+
size_t elsize;
18+
int err;
19+
20+
if (X->typecode != GA_HALF &&
21+
X->typecode != GA_FLOAT &&
22+
X->typecode != GA_DOUBLE)
23+
return GA_INVALID_ERROR;
24+
25+
if (X->nd != 1 || Y->nd != 1 || Z->nd != 0 ||
26+
X->typecode != Y->typecode || X->typecode != Z->typecode)
27+
return GA_VALUE_ERROR;
28+
n = X->dimensions[0];
29+
if (!(X->flags & GA_ALIGNED) || !(Y->flags & GA_ALIGNED) ||
30+
!(Z->flags & GA_ALIGNED))
31+
return GA_UNALIGNED_ERROR;
32+
if (X->dimensions[0] != Y->dimensions[0])
33+
return GA_VALUE_ERROR;
34+
35+
elsize = gpuarray_get_elsize(X->typecode);
36+
if (X->strides[0] < 0) {
37+
if (nocopy)
38+
return GA_COPY_ERROR;
39+
else {
40+
err = GpuArray_copy(&copyX, X, GA_ANY_ORDER);
41+
if (err != GA_NO_ERROR)
42+
goto cleanup;
43+
Xp = &copyX;
44+
}
45+
}
46+
if (Y->strides[0] < 0) {
47+
if (nocopy)
48+
return GA_COPY_ERROR;
49+
else {
50+
err = GpuArray_copy(&copyY, Y, GA_ANY_ORDER);
51+
if (err != GA_NO_ERROR)
52+
goto cleanup;
53+
Yp = &copyY;
54+
}
55+
}
56+
57+
ctx = gpudata_context(Xp->data);
58+
err = gpublas_setup(ctx);
59+
if (err != GA_NO_ERROR)
60+
goto cleanup;
61+
62+
switch (Xp->typecode) {
63+
case GA_HALF:
64+
err = gpublas_hdot(
65+
n,
66+
Xp->data, Xp->offset / elsize, Xp->strides[0] / elsize,
67+
Yp->data, Yp->offset / elsize, Yp->strides[0] / elsize,
68+
Zp->data, Zp->offset / elsize);
69+
break;
70+
case GA_FLOAT:
71+
err = gpublas_sdot(
72+
n,
73+
Xp->data, Xp->offset / elsize, Xp->strides[0] / elsize,
74+
Yp->data, Yp->offset / elsize, Yp->strides[0] / elsize,
75+
Zp->data, Zp->offset / elsize);
76+
break;
77+
case GA_DOUBLE:
78+
err = gpublas_ddot(
79+
n,
80+
Xp->data, Xp->offset / elsize, Xp->strides[0] / elsize,
81+
Yp->data, Yp->offset / elsize, Yp->strides[0] / elsize,
82+
Zp->data, Zp->offset / elsize);
83+
break;
84+
}
85+
cleanup:
86+
if (Xp == &copyX)
87+
GpuArray_clear(&copyX);
88+
if (Yp == &copyY)
89+
GpuArray_clear(&copyY);
90+
return err;
91+
}
92+
893
int GpuArray_rgemv(cb_transpose transA, double alpha, GpuArray *A,
994
GpuArray *X, double beta, GpuArray *Y, int nocopy) {
1095
GpuArray *Ap = A;
@@ -24,8 +109,7 @@ int GpuArray_rgemv(cb_transpose transA, double alpha, GpuArray *A,
24109
return GA_INVALID_ERROR;
25110

26111
if (A->nd != 2 || X->nd != 1 || Y->nd != 1 ||
27-
A->typecode != A->typecode || X->typecode != A->typecode ||
28-
Y->typecode != A->typecode)
112+
X->typecode != A->typecode || Y->typecode != A->typecode)
29113
return GA_VALUE_ERROR;
30114

31115
if (!(A->flags & GA_ALIGNED) || !(X->flags & GA_ALIGNED) ||

0 commit comments

Comments
 (0)