Skip to content

Commit 910310f

Browse files
Improve SGRID metadata tooling, add sgrid datasets (#2431)
1 parent 551cf3a commit 910310f

4 files changed

Lines changed: 290 additions & 49 deletions

File tree

src/parcels/_core/utils/sgrid.py

Lines changed: 120 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import xarray as xr
2121

22+
from parcels._python import repr_from_dunder_dict
23+
2224
RE_DIM_DIM_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)"
2325

2426
Dim = str
@@ -31,12 +33,21 @@ class Padding(enum.Enum):
3133
BOTH = "both"
3234

3335

34-
class SGridMetadataProtocol(Protocol):
36+
SGRID_PADDING_TO_XGCM_POSITION = {
37+
Padding.LOW: "right",
38+
Padding.HIGH: "left",
39+
Padding.BOTH: "inner",
40+
Padding.NONE: "outer",
41+
# "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves
42+
}
43+
44+
45+
class AttrsSerializable(Protocol):
3546
def to_attrs(self) -> dict[str, str | int]: ...
3647
def from_attrs(cls, d: dict[str, Hashable]) -> Self: ...
3748

3849

39-
class Grid2DMetadata(SGridMetadataProtocol):
50+
class Grid2DMetadata(AttrsSerializable):
4051
def __init__(
4152
self,
4253
cf_role: Literal["grid_topology"],
@@ -94,16 +105,13 @@ def __init__(
94105
#! Important optional attribute for 2D grids with vertical layering
95106
self.vertical_dimensions = vertical_dimensions
96107

108+
def __repr__(self) -> str:
109+
return repr_from_dunder_dict(self)
110+
97111
def __eq__(self, other: Any) -> bool:
98112
if not isinstance(other, Grid2DMetadata):
99113
return NotImplemented
100-
return (
101-
self.cf_role == other.cf_role
102-
and self.topology_dimension == other.topology_dimension
103-
and self.node_dimensions == other.node_dimensions
104-
and self.face_dimensions == other.face_dimensions
105-
and self.vertical_dimensions == other.vertical_dimensions
106-
)
114+
return self.to_attrs() == other.to_attrs()
107115

108116
@classmethod
109117
def from_attrs(cls, attrs):
@@ -129,8 +137,11 @@ def to_attrs(self) -> dict[str, str | int]:
129137
d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions)
130138
return d
131139

140+
def rename_dims(self, dims_dict: dict[str, str]) -> Self:
141+
return _metadata_rename_dims(self, dims_dict)
142+
132143

133-
class Grid3DMetadata(SGridMetadataProtocol):
144+
class Grid3DMetadata(AttrsSerializable):
134145
def __init__(
135146
self,
136147
cf_role: Literal["grid_topology"],
@@ -180,15 +191,13 @@ def __init__(
180191
# face *i_coordinates*
181192
# volume_coordinates
182193

194+
def __repr__(self) -> str:
195+
return repr_from_dunder_dict(self)
196+
183197
def __eq__(self, other: Any) -> bool:
184198
if not isinstance(other, Grid3DMetadata):
185199
return NotImplemented
186-
return (
187-
self.cf_role == other.cf_role
188-
and self.topology_dimension == other.topology_dimension
189-
and self.node_dimensions == other.node_dimensions
190-
and self.volume_dimensions == other.volume_dimensions
191-
)
200+
return self.to_attrs() == other.to_attrs()
192201

193202
@classmethod
194203
def from_attrs(cls, attrs):
@@ -210,6 +219,9 @@ def to_attrs(self) -> dict[str, str | int]:
210219
volume_dimensions=dump_mappings(self.volume_dimensions),
211220
)
212221

222+
def rename_dims(self, dims_dict: dict[str, str]) -> Self:
223+
return _metadata_rename_dims(self, dims_dict)
224+
213225

214226
@dataclass
215227
class DimDimPadding:
@@ -318,15 +330,6 @@ def maybe_load_mappings(s):
318330
return load_mappings(s)
319331

320332

321-
SGRID_PADDING_TO_XGCM_POSITION = {
322-
Padding.LOW: "right",
323-
Padding.HIGH: "left",
324-
Padding.BOTH: "inner",
325-
Padding.NONE: "outer",
326-
# "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves
327-
}
328-
329-
330333
class SGridParsingException(Exception):
331334
"""Exception raised when parsing SGrid attributes fails."""
332335

@@ -378,3 +381,95 @@ def parse_sgrid(ds: xr.Dataset):
378381
xgcm_coords[axis] = {"center": dim_dim_padding.dim2, xgcm_position: dim_dim_padding.dim1}
379382

380383
return (ds, {"coords": xgcm_coords})
384+
385+
386+
def rename_dims(ds: xr.Dataset, dims_dict: dict[str, str]) -> xr.Dataset:
387+
grid_da = get_grid_topology(ds)
388+
if grid_da is None:
389+
raise ValueError(
390+
"No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions."
391+
)
392+
393+
ds = ds.rename_dims(dims_dict)
394+
395+
# Update the metadata
396+
grid = parse_grid_attrs(grid_da.attrs)
397+
ds[grid_da.name].attrs = grid.rename_dims(dims_dict).to_attrs()
398+
return ds
399+
400+
401+
def get_unique_dim_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
402+
dims = set()
403+
dims.update(set(grid.node_dimensions))
404+
405+
for key, value in grid.__dict__.items():
406+
if key in ("cf_role", "topology_dimension") or value is None:
407+
continue
408+
assert isinstance(value, tuple), (
409+
f"Expected sgrid metadata attribute to be represented as a tuple, got {value!r}. This is an internal error to Parcels - please post an issue if you encounter this."
410+
)
411+
for item in value:
412+
if isinstance(item, DimDimPadding):
413+
dims.add(item.dim1)
414+
dims.add(item.dim2)
415+
else:
416+
assert isinstance(item, str)
417+
dims.add(item)
418+
return dims
419+
420+
421+
@overload
422+
def _metadata_rename_dims(grid: Grid2DMetadata, dims_dict: dict[str, str]) -> Grid2DMetadata: ...
423+
424+
425+
@overload
426+
def _metadata_rename_dims(grid: Grid3DMetadata, dims_dict: dict[str, str]) -> Grid3DMetadata: ...
427+
428+
429+
def _metadata_rename_dims(grid, dims_dict):
430+
"""
431+
Renames dimensions in SGrid metadata.
432+
433+
Similar in API to xr.Dataset.rename_dims. Renames dimensions according to dims_dict mapping
434+
of old dimension names to new dimension names.
435+
"""
436+
dims_dict = dims_dict.copy()
437+
assert len(dims_dict) == len(set(dims_dict.values())), "dims_dict contains duplicate target dimension names"
438+
439+
existing_dims = get_unique_dim_names(grid)
440+
for dim in dims_dict.keys():
441+
if dim not in existing_dims:
442+
raise ValueError(f"Dimension {dim!r} not found in SGrid metadata dimensions {existing_dims!r}")
443+
444+
for dim in existing_dims:
445+
if dim not in dims_dict:
446+
dims_dict[dim] = dim # identity mapping for dimensions not being renamed
447+
448+
kwargs = {}
449+
for key, value in grid.__dict__.items():
450+
if isinstance(value, tuple):
451+
new_value = []
452+
for item in value:
453+
if isinstance(item, DimDimPadding):
454+
new_item = DimDimPadding(
455+
dim1=dims_dict[item.dim1],
456+
dim2=dims_dict[item.dim2],
457+
padding=item.padding,
458+
)
459+
new_value.append(new_item)
460+
else:
461+
assert isinstance(item, str)
462+
new_value.append(dims_dict[item])
463+
kwargs[key] = tuple(new_value)
464+
continue
465+
466+
if key in ("cf_role", "topology_dimension") or value is None:
467+
kwargs[key] = value
468+
continue
469+
470+
if isinstance(value, str):
471+
kwargs[key] = dims_dict[value]
472+
continue
473+
474+
raise ValueError(f"Unexpected attribute {key!r} on {grid!r}")
475+
return type(grid)(**kwargs)

src/parcels/_datasets/structured/generic.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,35 @@
11
import numpy as np
22
import xarray as xr
33

4+
from parcels._core.utils.sgrid import (
5+
DimDimPadding,
6+
Grid2DMetadata,
7+
Grid3DMetadata,
8+
Padding,
9+
)
10+
from parcels._core.utils.sgrid import (
11+
rename_dims as sgrid_rename_dims,
12+
)
13+
414
from . import T, X, Y, Z
515

616
__all__ = ["T", "X", "Y", "Z", "datasets"]
717

818
TIME = xr.date_range("2000", "2001", T)
919

1020

21+
def _attach_sgrid_metadata(ds, grid: Grid2DMetadata | Grid3DMetadata):
22+
"""Copies the dataset and attaches the SGRID metadata in 'grid' variable. Modifies 'conventions' attribute."""
23+
ds = ds.copy()
24+
ds["grid"] = (
25+
[],
26+
0,
27+
grid.to_attrs(),
28+
)
29+
ds.attrs["Conventions"] = "SGRID"
30+
return ds
31+
32+
1133
def _rotated_curvilinear_grid():
1234
XG = np.arange(X)
1335
YG = np.arange(Y)
@@ -225,3 +247,54 @@ def _unrolled_cone_curvilinear_grid():
225247
),
226248
"2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(),
227249
}
250+
251+
_COMODO_TO_2D_SGRID = { # Note "2D SGRID" here is meant in the context of SGRID convention (i.e., 1D depth)
252+
"XG": "node_dimension1",
253+
"YG": "node_dimension2",
254+
"XC": "face_dimension1",
255+
"YC": "face_dimension2",
256+
"ZG": "vertical_dimensions_dim1",
257+
"ZC": "vertical_dimensions_dim2",
258+
}
259+
datasets_sgrid = {
260+
"ds_2d_padded_high": (
261+
datasets["ds_2d_left"]
262+
.pipe(
263+
_attach_sgrid_metadata,
264+
Grid2DMetadata(
265+
cf_role="grid_topology",
266+
topology_dimension=2,
267+
node_dimensions=("YG", "XG"),
268+
face_dimensions=(
269+
DimDimPadding("YC", "YG", Padding.HIGH),
270+
DimDimPadding("XC", "XG", Padding.HIGH),
271+
),
272+
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.HIGH),),
273+
),
274+
)
275+
.pipe(
276+
sgrid_rename_dims,
277+
_COMODO_TO_2D_SGRID,
278+
)
279+
),
280+
"ds_2d_padded_low": (
281+
datasets["ds_2d_right"]
282+
.pipe(
283+
_attach_sgrid_metadata,
284+
Grid2DMetadata(
285+
cf_role="grid_topology",
286+
topology_dimension=2,
287+
node_dimensions=("YG", "XG"),
288+
face_dimensions=(
289+
DimDimPadding("YC", "YG", Padding.LOW),
290+
DimDimPadding("XC", "XG", Padding.LOW),
291+
),
292+
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.LOW),),
293+
),
294+
)
295+
.pipe(
296+
sgrid_rename_dims,
297+
_COMODO_TO_2D_SGRID,
298+
)
299+
),
300+
}

src/parcels/_python.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Generic Python helpers
22
import inspect
3-
from collections.abc import Callable
3+
from types import FunctionType
44

55

66
def isinstance_noimport(obj, class_or_tuple):
@@ -14,7 +14,13 @@ def isinstance_noimport(obj, class_or_tuple):
1414
)
1515

1616

17-
def assert_same_function_signature(f: Callable, *, ref: Callable, context: str) -> None:
17+
def repr_from_dunder_dict(obj: object) -> str:
18+
"""Dataclass-like __repr__ implementation based on __dict__."""
19+
parts = [f"{k}={v!r}" for k, v in obj.__dict__.items()]
20+
return f"{obj.__class__.__qualname__}(" + ", ".join(parts) + ")"
21+
22+
23+
def assert_same_function_signature(f: FunctionType, *, ref: FunctionType, context: str) -> None:
1824
"""Ensures a function `f` has the same signature as the reference function `ref`."""
1925
sig_ref = inspect.signature(ref)
2026
sig = inspect.signature(f)

0 commit comments

Comments
 (0)