-
Notifications
You must be signed in to change notification settings - Fork 92
Expand file tree
/
Copy pathbasic.py
More file actions
98 lines (88 loc) · 2.67 KB
/
basic.py
File metadata and controls
98 lines (88 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from string import Template
from .gpuarray import GpuArray, GpuKernel, SIZE, dtype_to_ctype
import numpy
def _generate_kernel(ctx, cols, dtype, upper=True):
tmpl = Template("""
#include "cluda.h"
KERNEL void extract_tri(GLOBAL_MEM ${ctype} *a, ga_size a_off, ga_uint N) {
a = (GLOBAL_MEM ${ctype} *)(((GLOBAL_MEM char *)a) + a_off);
unsigned int idx = GID_1 * LDIM_0 * GDIM_0 +
GID_0 * LDIM_0 + LID_0;
unsigned int ix = idx/${cols};
unsigned int iy = idx%${cols};
if (idx < N) {
if (ix ${le} iy)
a[idx] = 0.0;
}
}
""")
if upper:
le = '>'
else:
le = '<'
ctype = dtype_to_ctype(dtype)
src = tmpl.substitute(cols=cols, ctype=ctype, le=le)
spec = [GpuArray, SIZE, 'uint32']
have_small = False
have_double = False
have_complex = False
if dtype.itemsize < 4:
have_small = True
if dtype in [numpy.float64, numpy.complex128]:
have_double = True
if dtype in [numpy.complex64, numpy.complex128]:
have_complex = True
k = GpuKernel(src, "extract_tri", spec, context=ctx,
have_double=have_double, have_small=have_small,
have_complex=have_complex)
return k
def triu(A, inplace=True):
if A.ndim != 2:
raise ValueError("triu only works for 2d arrays")
if A.flags.c_contiguous is A.flags.f_contiguous is False:
raise ValueError("triu only works for contiguous arrays")
if not inplace:
A = A.copy()
if A.flags['F_CONTIGUOUS']:
upper = False
cols = A.shape[0]
else:
upper = True
cols = A.shape[1]
k = _generate_kernel(A.context, cols, A.dtype, upper)
n = int(A.shape[0]*A.shape[1])
ls = 256
if n < ls:
ls = n
gs = 1
else:
(gs, r) = divmod(n, ls)
if r > 0:
gs += 1
k(A, A.offset, A.shape[0] * A.shape[1], ls=ls, gs=gs)
return A
def tril(A, inplace=True):
if A.ndim != 2:
raise ValueError("tril only works for 2d arrays")
if A.flags.c_contiguous is A.flags.f_contiguous is False:
raise ValueError("tril only works for contiguous arrays")
if not inplace:
A = A.copy()
if A.flags['F_CONTIGUOUS']:
upper = True
cols = A.shape[0]
else:
upper = False
cols = A.shape[1]
k = _generate_kernel(A.context, cols, A.dtype, upper)
n = int(A.shape[0]*A.shape[1])
ls = 256
if n < ls:
ls = n
gs = 1
else:
(gs, r) = divmod(n, ls)
if r > 0:
gs += 1
k(A, A.offset, A.shape[0] * A.shape[1], ls=ls, gs=gs)
return A