Skip to content

Commit 9612a40

Browse files
First attempt at implementing View for KernelParticle
1 parent c2f148b commit 9612a40

2 files changed

Lines changed: 361 additions & 10 deletions

File tree

src/parcels/_core/particle.py

Lines changed: 361 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,27 @@ def add_variable(self, variable: Variable | list[Variable]):
117117

118118

119119
class KernelParticle:
120-
"""Simple class to be used in a kernel that links a particle (on the kernel level) to a particle dataset."""
120+
"""Class to be used in a kernel that links a particle (on the kernel level) to a particle dataset."""
121121

122122
def __init__(self, data, index):
123123
self._data = data
124124
self._index = index
125125

126126
def __getattr__(self, name):
127+
# Return a proxy that behaves like the underlying numpy array but
128+
# writes back into the parent arrays when sliced/modified. This
129+
# enables constructs like `particles.dlon[mask] += vals` to update
130+
# the parent arrays rather than temporary copies.
131+
if name in self._data:
132+
# If this KernelParticle represents a single particle (integer
133+
# index), return the underlying scalar directly to preserve
134+
# user-facing semantics (e.g., `pset[0].time` should be a number).
135+
if isinstance(self._index, (int, np.integer)):
136+
return self._data[name][self._index]
137+
# For 0-d numpy integer scalars
138+
if isinstance(self._index, np.ndarray) and self._index.ndim == 0:
139+
return self._data[name][int(self._index)]
140+
return KernelParticleArray(self._data, self._index, name)
127141
return self._data[name][self._index]
128142

129143
def __setattr__(self, name, value):
@@ -133,13 +147,357 @@ def __setattr__(self, name, value):
133147
self._data[name][self._index] = value
134148

135149
def __getitem__(self, index):
136-
self._index = index
137-
return self
150+
# normalize single-element tuple indexing (e.g., (inds,))
151+
if isinstance(index, tuple) and len(index) == 1:
152+
index = index[0]
153+
154+
base = self._index
155+
new_index = np.zeros_like(base, dtype=bool)
156+
157+
# Boolean mask (could be local-length or global-length)
158+
if isinstance(index, (np.ndarray, list)) and np.asarray(index).dtype == bool:
159+
arr = np.asarray(index)
160+
if arr.size == base.size:
161+
# global mask
162+
new_index = arr
163+
elif arr.size == int(np.sum(base)):
164+
new_index[base] = arr
165+
else:
166+
raise ValueError(
167+
f"Boolean index has incompatible length {arr.size} for selection of size {int(np.sum(base))}"
168+
)
169+
return KernelParticle(self._data, new_index)
170+
171+
# Integer array / list of indices relative to local view
172+
if isinstance(index, (np.ndarray, list)):
173+
idx_arr = np.asarray(index)
174+
if idx_arr.dtype == bool:
175+
# handled above, but keep for safety
176+
if idx_arr.size == base.size:
177+
new_index = idx_arr
178+
else:
179+
new_index[base] = idx_arr
180+
else:
181+
if base.dtype == bool:
182+
particle_idxs = np.flatnonzero(base)
183+
sel = particle_idxs[idx_arr]
184+
new_index[sel] = True
185+
else:
186+
base_arr = np.asarray(base)
187+
sel = base_arr[idx_arr]
188+
new_index[sel] = True
189+
return KernelParticle(self._data, new_index)
190+
191+
# Slice or single integer index relative to local view
192+
if isinstance(index, slice) or isinstance(index, int):
193+
if base.dtype == bool:
194+
particle_idxs = np.flatnonzero(base)
195+
sel = particle_idxs[index]
196+
new_index[sel] = True
197+
else:
198+
base_arr = np.asarray(base)
199+
sel = base_arr[index]
200+
new_index[sel] = True
201+
return KernelParticle(self._data, new_index)
202+
203+
# Fallback: try to assign directly (preserves previous behaviour for other index types)
204+
try:
205+
new_index[base] = index
206+
return KernelParticle(self._data, new_index)
207+
except Exception as e:
208+
raise TypeError(f"Unsupported index type for KernelParticle.__getitem__: {type(index)!r}") from e
209+
210+
# def __setitem__(self, index, value):
211+
# """Assign to a subset of particles represented by `index` relative to
212+
# this KernelParticle's current selection.
213+
214+
# The incoming `index` is interpreted in the same way as for
215+
# `__getitem__`: it indexes into the subset defined by `self._index`.
216+
217+
# `value` may be another KernelParticle (in which case common variables
218+
# are copied), or a dict mapping variable names to arrays/scalars which
219+
# will be written into the parent arrays at the computed positions.
220+
# """
221+
# # Map the provided index (which indexes into the current subset)
222+
# # back to the full parent-array index.
223+
# new_index = np.zeros_like(self._index, dtype=bool)
224+
# new_index[self._index] = index
225+
226+
# # Helper to perform assignment for a given variable name
227+
# def _assign(varname, src):
228+
# # write into parent array at positions new_index
229+
# self._data[varname][new_index] = src
230+
231+
# # Case: assign from another KernelParticle-like object
232+
# if isinstance(value, KernelParticle):
233+
# # copy across common fields
234+
# for k in set(self._data.keys()).intersection(value._data.keys()):
235+
# _assign(k, value._data[k][value._index])
236+
# return
237+
238+
# # Case: assign from a dict-like mapping variable names -> values
239+
# if isinstance(value, dict):
240+
# for k, v in value.items():
241+
# if k not in self._data:
242+
# raise KeyError(f"Unknown particle variable: {k}")
243+
# _assign(k, v)
244+
# return
245+
246+
# # Otherwise, if a scalar/array is provided, assign it to all variables
247+
# # is ambiguous: raise TypeError to avoid surprising behaviour.
248+
# raise TypeError("Unsupported value for KernelParticle.__setitem__; provide a KernelParticle or dict of variable values")
138249

139250
def __len__(self):
140251
return len(self._index)
141252

142253

254+
class KernelParticleArray:
255+
"""Array-like proxy for a particle variable that writes through to the
256+
parent arrays when mutated.
257+
258+
Parameters
259+
----------
260+
data : dict-like
261+
Parent particle storage (mapping varname -> ndarray)
262+
index : array-like
263+
Index representing the subset in the parent arrays (boolean mask or integer indices)
264+
name : str
265+
Variable name in `data` to proxy
266+
"""
267+
268+
def __init__(self, data, index, name):
269+
self._data = data
270+
self._index = index
271+
self._name = name
272+
273+
def __array__(self, dtype=None):
274+
arr = self._data[self._name][self._index]
275+
return arr.astype(dtype) if dtype is not None else arr
276+
277+
def __repr__(self):
278+
return repr(self.__array__())
279+
280+
def __len__(self):
281+
return len(self.__array__())
282+
283+
def _to_global_index(self, subindex=None):
284+
"""Return a global index (boolean mask or integer indices) that
285+
addresses the parent arrays. If `subindex` is provided it selects
286+
within the current local view and maps back to the global index.
287+
"""
288+
base = self._index
289+
if subindex is None:
290+
return base
291+
292+
# If subindex is a boolean array, support both local-length masks
293+
# (length == base.sum()) and global-length masks (length == base.size).
294+
if isinstance(subindex, (np.ndarray, list)) and np.asarray(subindex).dtype == bool:
295+
arr = np.asarray(subindex)
296+
if arr.size == base.size:
297+
# already a global mask
298+
return arr
299+
if arr.size == int(np.sum(base)):
300+
global_mask = np.zeros_like(base, dtype=bool)
301+
global_mask[base] = arr
302+
return global_mask
303+
raise ValueError(
304+
f"Boolean index has incompatible length {arr.size} for selection of size {int(np.sum(base))}"
305+
)
306+
307+
# Handle tuple indexing where the first axis indexes particles
308+
# and later axes index into the per-particle array shape (e.g. ei[:, igrid])
309+
if isinstance(subindex, tuple):
310+
first, *rest = subindex
311+
# map the first index (local selection) to global particle indices
312+
if base.dtype == bool:
313+
particle_idxs = np.flatnonzero(base)
314+
if isinstance(first, slice):
315+
sel = particle_idxs[first]
316+
elif isinstance(first, (np.ndarray, list)):
317+
first_arr = np.asarray(first)
318+
if first_arr.dtype == bool:
319+
sel = particle_idxs[first_arr]
320+
else:
321+
sel = particle_idxs[first_arr]
322+
elif isinstance(first, int):
323+
sel = particle_idxs[first]
324+
else:
325+
sel = particle_idxs[first]
326+
else:
327+
base_arr = np.asarray(base)
328+
if isinstance(first, slice):
329+
sel = base_arr[first]
330+
else:
331+
sel = base_arr[first]
332+
333+
# if rest contains a single int (e.g., column), return tuple index
334+
if len(rest) == 1:
335+
return (sel, rest[0])
336+
# return full tuple (sel, ...) for higher-dim cases
337+
return tuple([sel] + rest)
338+
339+
# If base is a boolean mask over the parent array and subindex is
340+
# an integer or slice relative to the local view, map it to integer
341+
# indices in the parent array.
342+
if base.dtype == bool:
343+
if isinstance(subindex, (slice, int)):
344+
rel = np.flatnonzero(base)[subindex]
345+
return rel
346+
# otherwise assume subindex is an integer/array selection relative
347+
# to the local view and map to global indices
348+
global_mask = np.zeros_like(base, dtype=bool)
349+
global_mask[base] = subindex
350+
return global_mask
351+
352+
# If base is an array of integer indices
353+
base_arr = np.asarray(base)
354+
try:
355+
return base_arr[subindex]
356+
except Exception:
357+
return base_arr[np.asarray(subindex, dtype=bool)]
358+
359+
def __getitem__(self, subindex):
360+
# Handle tuple indexing (e.g. [:, igrid]) by applying the tuple
361+
# to the local selection first. This covers the common case
362+
# `particles.ei[:, igrid]` where `ei` is a 2D parent array and the
363+
# second index selects the grid index.
364+
if isinstance(subindex, tuple):
365+
local = self._data[self._name][self._index]
366+
return local[subindex]
367+
368+
new_index = self._to_global_index(subindex)
369+
return KernelParticleArray(self._data, new_index, self._name)
370+
371+
def __setitem__(self, subindex, value):
372+
tgt = self._to_global_index(subindex)
373+
self._data[self._name][tgt] = value
374+
375+
# in-place ops must write back into the parent array
376+
def __iadd__(self, other):
377+
vals = self._data[self._name][self._index] + (
378+
other.__array__() if isinstance(other, KernelParticleArray) else other
379+
)
380+
self._data[self._name][self._index] = vals
381+
return self
382+
383+
def __isub__(self, other):
384+
vals = self._data[self._name][self._index] - (
385+
other.__array__() if isinstance(other, KernelParticleArray) else other
386+
)
387+
self._data[self._name][self._index] = vals
388+
return self
389+
390+
def __imul__(self, other):
391+
vals = self._data[self._name][self._index] * (
392+
other.__array__() if isinstance(other, KernelParticleArray) else other
393+
)
394+
self._data[self._name][self._index] = vals
395+
return self
396+
397+
# Provide simple numpy-like evaluation for binary ops by delegating to ndarray
398+
def __add__(self, other):
399+
return self.__array__() + (other.__array__() if isinstance(other, KernelParticleArray) else other)
400+
401+
def __sub__(self, other):
402+
return self.__array__() - (other.__array__() if isinstance(other, KernelParticleArray) else other)
403+
404+
def __mul__(self, other):
405+
return self.__array__() * (other.__array__() if isinstance(other, KernelParticleArray) else other)
406+
407+
def __truediv__(self, other):
408+
return self.__array__() / (other.__array__() if isinstance(other, KernelParticleArray) else other)
409+
410+
def __floordiv__(self, other):
411+
return self.__array__() // (other.__array__() if isinstance(other, KernelParticleArray) else other)
412+
413+
def __pow__(self, other):
414+
return self.__array__() ** (other.__array__() if isinstance(other, KernelParticleArray) else other)
415+
416+
def __neg__(self):
417+
return -self.__array__()
418+
419+
def __pos__(self):
420+
return +self.__array__()
421+
422+
def __abs__(self):
423+
return abs(self.__array__())
424+
425+
# Right-hand operations to handle cases like `scalar - KernelParticleArray`
426+
def __radd__(self, other):
427+
return (other.__array__() if isinstance(other, KernelParticleArray) else other) + self.__array__()
428+
429+
def __rsub__(self, other):
430+
return (other.__array__() if isinstance(other, KernelParticleArray) else other) - self.__array__()
431+
432+
def __rmul__(self, other):
433+
return (other.__array__() if isinstance(other, KernelParticleArray) else other) * self.__array__()
434+
435+
def __rtruediv__(self, other):
436+
return (other.__array__() if isinstance(other, KernelParticleArray) else other) / self.__array__()
437+
438+
def __rfloordiv__(self, other):
439+
return (other.__array__() if isinstance(other, KernelParticleArray) else other) // self.__array__()
440+
441+
def __rpow__(self, other):
442+
return (other.__array__() if isinstance(other, KernelParticleArray) else other) ** self.__array__()
443+
444+
# Comparison operators should return plain numpy boolean arrays so that
445+
# expressions like `mask = particles.gridID == gid` produce an ndarray
446+
# usable for indexing (rather than another KernelParticleArray).
447+
def __eq__(self, other):
448+
left = np.asarray(self.__array__())
449+
if isinstance(other, KernelParticleArray):
450+
right = np.asarray(other.__array__())
451+
else:
452+
right = other
453+
return left == right
454+
455+
def __ne__(self, other):
456+
left = np.asarray(self.__array__())
457+
if isinstance(other, KernelParticleArray):
458+
right = np.asarray(other.__array__())
459+
else:
460+
right = other
461+
return left != right
462+
463+
def __lt__(self, other):
464+
left = np.asarray(self.__array__())
465+
if isinstance(other, KernelParticleArray):
466+
right = np.asarray(other.__array__())
467+
else:
468+
right = other
469+
return left < right
470+
471+
def __le__(self, other):
472+
left = np.asarray(self.__array__())
473+
if isinstance(other, KernelParticleArray):
474+
right = np.asarray(other.__array__())
475+
else:
476+
right = other
477+
return left <= right
478+
479+
def __gt__(self, other):
480+
left = np.asarray(self.__array__())
481+
if isinstance(other, KernelParticleArray):
482+
right = np.asarray(other.__array__())
483+
else:
484+
right = other
485+
return left > right
486+
487+
def __ge__(self, other):
488+
left = np.asarray(self.__array__())
489+
if isinstance(other, KernelParticleArray):
490+
right = np.asarray(other.__array__())
491+
else:
492+
right = other
493+
return left >= right
494+
495+
# Allow attribute access like .dtype etc. by forwarding to the ndarray
496+
def __getattr__(self, item):
497+
arr = self.__array__()
498+
return getattr(arr, item)
499+
500+
143501
def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_vars: list[Variable]):
144502
existing_names = {var.name for var in existing_vars}
145503
for var in new_vars:

0 commit comments

Comments
 (0)