@@ -117,13 +117,27 @@ def add_variable(self, variable: Variable | list[Variable]):
117117
118118
119119class KernelParticle :
120- """Simple class to be used in a kernel that links a particle (on the kernel level) to a particle dataset."""
120+ """Class to be used in a kernel that links a particle (on the kernel level) to a particle dataset."""
121121
122122 def __init__ (self , data , index ):
123123 self ._data = data
124124 self ._index = index
125125
126126 def __getattr__ (self , name ):
127+ # Return a proxy that behaves like the underlying numpy array but
128+ # writes back into the parent arrays when sliced/modified. This
129+ # enables constructs like `particles.dlon[mask] += vals` to update
130+ # the parent arrays rather than temporary copies.
131+ if name in self ._data :
132+ # If this KernelParticle represents a single particle (integer
133+ # index), return the underlying scalar directly to preserve
134+ # user-facing semantics (e.g., `pset[0].time` should be a number).
135+ if isinstance (self ._index , (int , np .integer )):
136+ return self ._data [name ][self ._index ]
137+ # For 0-d numpy integer scalars
138+ if isinstance (self ._index , np .ndarray ) and self ._index .ndim == 0 :
139+ return self ._data [name ][int (self ._index )]
140+ return KernelParticleArray (self ._data , self ._index , name )
127141 return self ._data [name ][self ._index ]
128142
129143 def __setattr__ (self , name , value ):
@@ -133,13 +147,357 @@ def __setattr__(self, name, value):
133147 self ._data [name ][self ._index ] = value
134148
135149 def __getitem__ (self , index ):
136- self ._index = index
137- return self
150+ # normalize single-element tuple indexing (e.g., (inds,))
151+ if isinstance (index , tuple ) and len (index ) == 1 :
152+ index = index [0 ]
153+
154+ base = self ._index
155+ new_index = np .zeros_like (base , dtype = bool )
156+
157+ # Boolean mask (could be local-length or global-length)
158+ if isinstance (index , (np .ndarray , list )) and np .asarray (index ).dtype == bool :
159+ arr = np .asarray (index )
160+ if arr .size == base .size :
161+ # global mask
162+ new_index = arr
163+ elif arr .size == int (np .sum (base )):
164+ new_index [base ] = arr
165+ else :
166+ raise ValueError (
167+ f"Boolean index has incompatible length { arr .size } for selection of size { int (np .sum (base ))} "
168+ )
169+ return KernelParticle (self ._data , new_index )
170+
171+ # Integer array / list of indices relative to local view
172+ if isinstance (index , (np .ndarray , list )):
173+ idx_arr = np .asarray (index )
174+ if idx_arr .dtype == bool :
175+ # handled above, but keep for safety
176+ if idx_arr .size == base .size :
177+ new_index = idx_arr
178+ else :
179+ new_index [base ] = idx_arr
180+ else :
181+ if base .dtype == bool :
182+ particle_idxs = np .flatnonzero (base )
183+ sel = particle_idxs [idx_arr ]
184+ new_index [sel ] = True
185+ else :
186+ base_arr = np .asarray (base )
187+ sel = base_arr [idx_arr ]
188+ new_index [sel ] = True
189+ return KernelParticle (self ._data , new_index )
190+
191+ # Slice or single integer index relative to local view
192+ if isinstance (index , slice ) or isinstance (index , int ):
193+ if base .dtype == bool :
194+ particle_idxs = np .flatnonzero (base )
195+ sel = particle_idxs [index ]
196+ new_index [sel ] = True
197+ else :
198+ base_arr = np .asarray (base )
199+ sel = base_arr [index ]
200+ new_index [sel ] = True
201+ return KernelParticle (self ._data , new_index )
202+
203+ # Fallback: try to assign directly (preserves previous behaviour for other index types)
204+ try :
205+ new_index [base ] = index
206+ return KernelParticle (self ._data , new_index )
207+ except Exception as e :
208+ raise TypeError (f"Unsupported index type for KernelParticle.__getitem__: { type (index )!r} " ) from e
209+
210+ # def __setitem__(self, index, value):
211+ # """Assign to a subset of particles represented by `index` relative to
212+ # this KernelParticle's current selection.
213+
214+ # The incoming `index` is interpreted in the same way as for
215+ # `__getitem__`: it indexes into the subset defined by `self._index`.
216+
217+ # `value` may be another KernelParticle (in which case common variables
218+ # are copied), or a dict mapping variable names to arrays/scalars which
219+ # will be written into the parent arrays at the computed positions.
220+ # """
221+ # # Map the provided index (which indexes into the current subset)
222+ # # back to the full parent-array index.
223+ # new_index = np.zeros_like(self._index, dtype=bool)
224+ # new_index[self._index] = index
225+
226+ # # Helper to perform assignment for a given variable name
227+ # def _assign(varname, src):
228+ # # write into parent array at positions new_index
229+ # self._data[varname][new_index] = src
230+
231+ # # Case: assign from another KernelParticle-like object
232+ # if isinstance(value, KernelParticle):
233+ # # copy across common fields
234+ # for k in set(self._data.keys()).intersection(value._data.keys()):
235+ # _assign(k, value._data[k][value._index])
236+ # return
237+
238+ # # Case: assign from a dict-like mapping variable names -> values
239+ # if isinstance(value, dict):
240+ # for k, v in value.items():
241+ # if k not in self._data:
242+ # raise KeyError(f"Unknown particle variable: {k}")
243+ # _assign(k, v)
244+ # return
245+
246+ # # Otherwise, if a scalar/array is provided, assign it to all variables
247+ # # is ambiguous: raise TypeError to avoid surprising behaviour.
248+ # raise TypeError("Unsupported value for KernelParticle.__setitem__; provide a KernelParticle or dict of variable values")
138249
139250 def __len__ (self ):
140251 return len (self ._index )
141252
142253
254+ class KernelParticleArray :
255+ """Array-like proxy for a particle variable that writes through to the
256+ parent arrays when mutated.
257+
258+ Parameters
259+ ----------
260+ data : dict-like
261+ Parent particle storage (mapping varname -> ndarray)
262+ index : array-like
263+ Index representing the subset in the parent arrays (boolean mask or integer indices)
264+ name : str
265+ Variable name in `data` to proxy
266+ """
267+
268+ def __init__ (self , data , index , name ):
269+ self ._data = data
270+ self ._index = index
271+ self ._name = name
272+
273+ def __array__ (self , dtype = None ):
274+ arr = self ._data [self ._name ][self ._index ]
275+ return arr .astype (dtype ) if dtype is not None else arr
276+
277+ def __repr__ (self ):
278+ return repr (self .__array__ ())
279+
280+ def __len__ (self ):
281+ return len (self .__array__ ())
282+
283+ def _to_global_index (self , subindex = None ):
284+ """Return a global index (boolean mask or integer indices) that
285+ addresses the parent arrays. If `subindex` is provided it selects
286+ within the current local view and maps back to the global index.
287+ """
288+ base = self ._index
289+ if subindex is None :
290+ return base
291+
292+ # If subindex is a boolean array, support both local-length masks
293+ # (length == base.sum()) and global-length masks (length == base.size).
294+ if isinstance (subindex , (np .ndarray , list )) and np .asarray (subindex ).dtype == bool :
295+ arr = np .asarray (subindex )
296+ if arr .size == base .size :
297+ # already a global mask
298+ return arr
299+ if arr .size == int (np .sum (base )):
300+ global_mask = np .zeros_like (base , dtype = bool )
301+ global_mask [base ] = arr
302+ return global_mask
303+ raise ValueError (
304+ f"Boolean index has incompatible length { arr .size } for selection of size { int (np .sum (base ))} "
305+ )
306+
307+ # Handle tuple indexing where the first axis indexes particles
308+ # and later axes index into the per-particle array shape (e.g. ei[:, igrid])
309+ if isinstance (subindex , tuple ):
310+ first , * rest = subindex
311+ # map the first index (local selection) to global particle indices
312+ if base .dtype == bool :
313+ particle_idxs = np .flatnonzero (base )
314+ if isinstance (first , slice ):
315+ sel = particle_idxs [first ]
316+ elif isinstance (first , (np .ndarray , list )):
317+ first_arr = np .asarray (first )
318+ if first_arr .dtype == bool :
319+ sel = particle_idxs [first_arr ]
320+ else :
321+ sel = particle_idxs [first_arr ]
322+ elif isinstance (first , int ):
323+ sel = particle_idxs [first ]
324+ else :
325+ sel = particle_idxs [first ]
326+ else :
327+ base_arr = np .asarray (base )
328+ if isinstance (first , slice ):
329+ sel = base_arr [first ]
330+ else :
331+ sel = base_arr [first ]
332+
333+ # if rest contains a single int (e.g., column), return tuple index
334+ if len (rest ) == 1 :
335+ return (sel , rest [0 ])
336+ # return full tuple (sel, ...) for higher-dim cases
337+ return tuple ([sel ] + rest )
338+
339+ # If base is a boolean mask over the parent array and subindex is
340+ # an integer or slice relative to the local view, map it to integer
341+ # indices in the parent array.
342+ if base .dtype == bool :
343+ if isinstance (subindex , (slice , int )):
344+ rel = np .flatnonzero (base )[subindex ]
345+ return rel
346+ # otherwise assume subindex is an integer/array selection relative
347+ # to the local view and map to global indices
348+ global_mask = np .zeros_like (base , dtype = bool )
349+ global_mask [base ] = subindex
350+ return global_mask
351+
352+ # If base is an array of integer indices
353+ base_arr = np .asarray (base )
354+ try :
355+ return base_arr [subindex ]
356+ except Exception :
357+ return base_arr [np .asarray (subindex , dtype = bool )]
358+
359+ def __getitem__ (self , subindex ):
360+ # Handle tuple indexing (e.g. [:, igrid]) by applying the tuple
361+ # to the local selection first. This covers the common case
362+ # `particles.ei[:, igrid]` where `ei` is a 2D parent array and the
363+ # second index selects the grid index.
364+ if isinstance (subindex , tuple ):
365+ local = self ._data [self ._name ][self ._index ]
366+ return local [subindex ]
367+
368+ new_index = self ._to_global_index (subindex )
369+ return KernelParticleArray (self ._data , new_index , self ._name )
370+
371+ def __setitem__ (self , subindex , value ):
372+ tgt = self ._to_global_index (subindex )
373+ self ._data [self ._name ][tgt ] = value
374+
375+ # in-place ops must write back into the parent array
376+ def __iadd__ (self , other ):
377+ vals = self ._data [self ._name ][self ._index ] + (
378+ other .__array__ () if isinstance (other , KernelParticleArray ) else other
379+ )
380+ self ._data [self ._name ][self ._index ] = vals
381+ return self
382+
383+ def __isub__ (self , other ):
384+ vals = self ._data [self ._name ][self ._index ] - (
385+ other .__array__ () if isinstance (other , KernelParticleArray ) else other
386+ )
387+ self ._data [self ._name ][self ._index ] = vals
388+ return self
389+
390+ def __imul__ (self , other ):
391+ vals = self ._data [self ._name ][self ._index ] * (
392+ other .__array__ () if isinstance (other , KernelParticleArray ) else other
393+ )
394+ self ._data [self ._name ][self ._index ] = vals
395+ return self
396+
397+ # Provide simple numpy-like evaluation for binary ops by delegating to ndarray
398+ def __add__ (self , other ):
399+ return self .__array__ () + (other .__array__ () if isinstance (other , KernelParticleArray ) else other )
400+
401+ def __sub__ (self , other ):
402+ return self .__array__ () - (other .__array__ () if isinstance (other , KernelParticleArray ) else other )
403+
404+ def __mul__ (self , other ):
405+ return self .__array__ () * (other .__array__ () if isinstance (other , KernelParticleArray ) else other )
406+
407+ def __truediv__ (self , other ):
408+ return self .__array__ () / (other .__array__ () if isinstance (other , KernelParticleArray ) else other )
409+
410+ def __floordiv__ (self , other ):
411+ return self .__array__ () // (other .__array__ () if isinstance (other , KernelParticleArray ) else other )
412+
413+ def __pow__ (self , other ):
414+ return self .__array__ () ** (other .__array__ () if isinstance (other , KernelParticleArray ) else other )
415+
416+ def __neg__ (self ):
417+ return - self .__array__ ()
418+
419+ def __pos__ (self ):
420+ return + self .__array__ ()
421+
422+ def __abs__ (self ):
423+ return abs (self .__array__ ())
424+
425+ # Right-hand operations to handle cases like `scalar - KernelParticleArray`
426+ def __radd__ (self , other ):
427+ return (other .__array__ () if isinstance (other , KernelParticleArray ) else other ) + self .__array__ ()
428+
429+ def __rsub__ (self , other ):
430+ return (other .__array__ () if isinstance (other , KernelParticleArray ) else other ) - self .__array__ ()
431+
432+ def __rmul__ (self , other ):
433+ return (other .__array__ () if isinstance (other , KernelParticleArray ) else other ) * self .__array__ ()
434+
435+ def __rtruediv__ (self , other ):
436+ return (other .__array__ () if isinstance (other , KernelParticleArray ) else other ) / self .__array__ ()
437+
438+ def __rfloordiv__ (self , other ):
439+ return (other .__array__ () if isinstance (other , KernelParticleArray ) else other ) // self .__array__ ()
440+
441+ def __rpow__ (self , other ):
442+ return (other .__array__ () if isinstance (other , KernelParticleArray ) else other ) ** self .__array__ ()
443+
444+ # Comparison operators should return plain numpy boolean arrays so that
445+ # expressions like `mask = particles.gridID == gid` produce an ndarray
446+ # usable for indexing (rather than another KernelParticleArray).
447+ def __eq__ (self , other ):
448+ left = np .asarray (self .__array__ ())
449+ if isinstance (other , KernelParticleArray ):
450+ right = np .asarray (other .__array__ ())
451+ else :
452+ right = other
453+ return left == right
454+
455+ def __ne__ (self , other ):
456+ left = np .asarray (self .__array__ ())
457+ if isinstance (other , KernelParticleArray ):
458+ right = np .asarray (other .__array__ ())
459+ else :
460+ right = other
461+ return left != right
462+
463+ def __lt__ (self , other ):
464+ left = np .asarray (self .__array__ ())
465+ if isinstance (other , KernelParticleArray ):
466+ right = np .asarray (other .__array__ ())
467+ else :
468+ right = other
469+ return left < right
470+
471+ def __le__ (self , other ):
472+ left = np .asarray (self .__array__ ())
473+ if isinstance (other , KernelParticleArray ):
474+ right = np .asarray (other .__array__ ())
475+ else :
476+ right = other
477+ return left <= right
478+
479+ def __gt__ (self , other ):
480+ left = np .asarray (self .__array__ ())
481+ if isinstance (other , KernelParticleArray ):
482+ right = np .asarray (other .__array__ ())
483+ else :
484+ right = other
485+ return left > right
486+
487+ def __ge__ (self , other ):
488+ left = np .asarray (self .__array__ ())
489+ if isinstance (other , KernelParticleArray ):
490+ right = np .asarray (other .__array__ ())
491+ else :
492+ right = other
493+ return left >= right
494+
495+ # Allow attribute access like .dtype etc. by forwarding to the ndarray
496+ def __getattr__ (self , item ):
497+ arr = self .__array__ ()
498+ return getattr (arr , item )
499+
500+
143501def _assert_no_duplicate_variable_names (* , existing_vars : list [Variable ], new_vars : list [Variable ]):
144502 existing_names = {var .name for var in existing_vars }
145503 for var in new_vars :
0 commit comments