Skip to content

Commit 6e0c074

Browse files
Cleaning up ParticleSetView
1 parent c839c0f commit 6e0c074

1 file changed

Lines changed: 10 additions & 68 deletions

File tree

src/parcels/_core/particle.py

Lines changed: 10 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def add_variable(self, variable: Variable | list[Variable]):
117117

118118

119119
class ParticleSetView:
120-
"""Class to be used in a kernel that links a particle (on the kernel level) to a particle dataset."""
120+
"""Class to be used in a kernel that links a View of the ParticleSet (on the kernel level) to a ParticleSet."""
121121

122122
def __init__(self, data, index):
123123
self._data = data
@@ -134,7 +134,6 @@ def __getattr__(self, name):
134134
# user-facing semantics (e.g., `pset[0].time` should be a number).
135135
if isinstance(self._index, (int, np.integer)):
136136
return self._data[name][self._index]
137-
# For 0-d numpy integer scalars
138137
if isinstance(self._index, np.ndarray) and self._index.ndim == 0:
139138
return self._data[name][int(self._index)]
140139
return ParticleSetViewArray(self._data, self._index, name)
@@ -168,36 +167,19 @@ def __getitem__(self, index):
168167
)
169168
return ParticleSetView(self._data, new_index)
170169

171-
# Integer array / list of indices relative to local view
172-
if isinstance(index, (np.ndarray, list)):
173-
idx_arr = np.asarray(index)
174-
if idx_arr.dtype == bool:
175-
# handled above, but keep for safety
176-
if idx_arr.size == base.size:
177-
new_index = idx_arr
178-
else:
179-
new_index[base] = idx_arr
180-
else:
181-
if base.dtype == bool:
182-
particle_idxs = np.flatnonzero(base)
183-
sel = particle_idxs[idx_arr]
184-
new_index[sel] = True
185-
else:
186-
base_arr = np.asarray(base)
187-
sel = base_arr[idx_arr]
188-
new_index[sel] = True
189-
return ParticleSetView(self._data, new_index)
190-
191-
# Slice or single integer index relative to local view
192-
if isinstance(index, slice) or isinstance(index, int):
170+
# Integer array/list, slice or single integer relative to the local view
171+
# (boolean masks were handled above). Normalize and map to global
172+
# particle indices for both boolean-base and integer-base `self._index`.
173+
if isinstance(index, (np.ndarray, list, slice, int)):
174+
# convert list/ndarray to ndarray, keep slice/int as-is
175+
idx = np.asarray(index) if isinstance(index, (np.ndarray, list)) else index
193176
if base.dtype == bool:
194177
particle_idxs = np.flatnonzero(base)
195-
sel = particle_idxs[index]
196-
new_index[sel] = True
178+
sel = particle_idxs[idx]
197179
else:
198180
base_arr = np.asarray(base)
199-
sel = base_arr[index]
200-
new_index[sel] = True
181+
sel = base_arr[idx]
182+
new_index[sel] = True
201183
return ParticleSetView(self._data, new_index)
202184

203185
# Fallback: try to assign directly (preserves previous behaviour for other index types)
@@ -207,46 +189,6 @@ def __getitem__(self, index):
207189
except Exception as e:
208190
raise TypeError(f"Unsupported index type for ParticleSetView.__getitem__: {type(index)!r}") from e
209191

210-
# def __setitem__(self, index, value):
211-
# """Assign to a subset of particles represented by `index` relative to
212-
# this ParticleSetView's current selection.
213-
214-
# The incoming `index` is interpreted in the same way as for
215-
# `__getitem__`: it indexes into the subset defined by `self._index`.
216-
217-
# `value` may be another ParticleSetView (in which case common variables
218-
# are copied), or a dict mapping variable names to arrays/scalars which
219-
# will be written into the parent arrays at the computed positions.
220-
# """
221-
# # Map the provided index (which indexes into the current subset)
222-
# # back to the full parent-array index.
223-
# new_index = np.zeros_like(self._index, dtype=bool)
224-
# new_index[self._index] = index
225-
226-
# # Helper to perform assignment for a given variable name
227-
# def _assign(varname, src):
228-
# # write into parent array at positions new_index
229-
# self._data[varname][new_index] = src
230-
231-
# # Case: assign from another ParticleSetView-like object
232-
# if isinstance(value, ParticleSetView):
233-
# # copy across common fields
234-
# for k in set(self._data.keys()).intersection(value._data.keys()):
235-
# _assign(k, value._data[k][value._index])
236-
# return
237-
238-
# # Case: assign from a dict-like mapping variable names -> values
239-
# if isinstance(value, dict):
240-
# for k, v in value.items():
241-
# if k not in self._data:
242-
# raise KeyError(f"Unknown particle variable: {k}")
243-
# _assign(k, v)
244-
# return
245-
246-
# # Otherwise, if a scalar/array is provided, assign it to all variables
247-
# # is ambiguous: raise TypeError to avoid surprising behaviour.
248-
# raise TypeError("Unsupported value for ParticleSetView.__setitem__; provide a ParticleSetView or dict of variable values")
249-
250192
def __len__(self):
251193
return len(self._index)
252194

0 commit comments

Comments
 (0)