Skip to content

Commit c068f74

Browse files
committed
Fix offset in triu/tril.
1 parent f863ab2 commit c068f74

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

pygpu/basic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from string import Template
2-
from .gpuarray import GpuArray, GpuKernel
2+
from .gpuarray import GpuArray, GpuKernel, SIZE
33

44

55
def _generate_kernel(ctx, cols, upper=True):
66
tmpl = Template("""
7-
KERNEL void extract_tri(GLOBAL_MEM ga_float *a, ga_uint N) {
7+
KERNEL void extract_tri(GLOBAL_MEM ga_float *a, ga_size a_off, ga_uint N) {
8+
a = (GLOBAL_MEM ga_float *)(((char *)a) + a_off);
89
unsigned int idx = GID_1 * LDIM_0 * GDIM_0 +
910
GID_0 * LDIM_0 + LID_0;
1011
unsigned int ix = idx/${cols};
@@ -20,7 +21,7 @@ def _generate_kernel(ctx, cols, upper=True):
2021
else:
2122
le = '<'
2223
src = tmpl.substitute(cols=cols, le=le)
23-
spec = [GpuArray, 'uint32']
24+
spec = [GpuArray, SIZE, 'uint32']
2425
k = GpuKernel(src, "extract_tri", spec, context=ctx)
2526
return k
2627

@@ -40,7 +41,7 @@ def triu(A, inplace=True):
4041
upper = True
4142
cols = A.shape[1]
4243
k = _generate_kernel(A.context, cols, upper)
43-
k(A, A.shape[0] * A.shape[1], n=A.shape[0] * A.shape[1])
44+
k(A, A.offset, A.shape[0] * A.shape[1], n=A.shape[0] * A.shape[1])
4445
return A
4546

4647

@@ -59,5 +60,5 @@ def tril(A, inplace=True):
5960
upper = False
6061
cols = A.shape[1]
6162
k = _generate_kernel(A.context, cols, upper)
62-
k(A, A.shape[0] * A.shape[1], n=A.shape[0] * A.shape[1])
63+
k(A, A.offset, A.shape[0] * A.shape[1], n=A.shape[0] * A.shape[1])
6364
return A

0 commit comments

Comments
 (0)