11from string import Template
2- from .gpuarray import GpuArray , GpuKernel
2+ from .gpuarray import GpuArray , GpuKernel , SIZE
33
44
55def _generate_kernel (ctx , cols , upper = True ):
66 tmpl = Template ("""
7- KERNEL void extract_tri(GLOBAL_MEM ga_float *a, ga_uint N) {
7+ KERNEL void extract_tri(GLOBAL_MEM ga_float *a, ga_size a_off, ga_uint N) {
8+ a = (GLOBAL_MEM ga_float *)(((char *)a) + a_off);
89 unsigned int idx = GID_1 * LDIM_0 * GDIM_0 +
910 GID_0 * LDIM_0 + LID_0;
1011 unsigned int ix = idx/${cols};
@@ -20,7 +21,7 @@ def _generate_kernel(ctx, cols, upper=True):
2021 else :
2122 le = '<'
2223 src = tmpl .substitute (cols = cols , le = le )
23- spec = [GpuArray , 'uint32' ]
24+ spec = [GpuArray , SIZE , 'uint32' ]
2425 k = GpuKernel (src , "extract_tri" , spec , context = ctx )
2526 return k
2627
@@ -40,7 +41,7 @@ def triu(A, inplace=True):
4041 upper = True
4142 cols = A .shape [1 ]
4243 k = _generate_kernel (A .context , cols , upper )
43- k (A , A .shape [0 ] * A .shape [1 ], n = A .shape [0 ] * A .shape [1 ])
44+ k (A , A .offset , A . shape [0 ] * A .shape [1 ], n = A .shape [0 ] * A .shape [1 ])
4445 return A
4546
4647
@@ -59,5 +60,5 @@ def tril(A, inplace=True):
5960 upper = False
6061 cols = A .shape [1 ]
6162 k = _generate_kernel (A .context , cols , upper )
62- k (A , A .shape [0 ] * A .shape [1 ], n = A .shape [0 ] * A .shape [1 ])
63+ k (A , A .offset , A . shape [0 ] * A .shape [1 ], n = A .shape [0 ] * A .shape [1 ])
6364 return A
0 commit comments