Skip to content

Commit 6e47206

Browse files
Add FieldSet.from_sgrid_conventions() (#2432)
1 parent 0bcf7d8 commit 6e47206

5 files changed

Lines changed: 130 additions & 49 deletions

File tree

src/parcels/_core/fieldset.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import xgcm
1212

1313
from parcels._core.field import Field, VectorField
14+
from parcels._core.utils import sgrid
1415
from parcels._core.utils.string import _assert_str_and_python_varname
1516
from parcels._core.utils.time import get_datetime_type_calendar
1617
from parcels._core.utils.time import is_compatible as datetime_is_compatible
@@ -295,6 +296,92 @@ def from_fesom2(ds: ux.UxDataset):
295296

296297
return FieldSet(list(fields.values()))
297298

299+
def from_sgrid_conventions(
300+
ds: xr.Dataset, mesh: Mesh
301+
): # TODO: Update mesh to be discovered from the dataset metadata
302+
"""Create a FieldSet from a dataset using SGRID convention metadata.
303+
304+
This is the primary ingestion method in Parcels for structured grid datasets.
305+
306+
Assumes that U, V, (and optionally W) variables are named 'U', 'V', and 'W' in the dataset.
307+
308+
Parameters
309+
----------
310+
ds : xarray.Dataset
311+
xarray.Dataset with SGRID convention metadata.
312+
mesh : str
313+
String indicating the type of mesh coordinates and units used during
314+
velocity interpolation. Options are "spherical" or "flat".
315+
316+
Returns
317+
-------
318+
FieldSet
319+
FieldSet object containing the fields from the dataset that can be used for a Parcels simulation.
320+
321+
Notes
322+
-----
323+
This method uses the SGRID convention metadata to parse the grid structure
324+
and create appropriate Fields for a Parcels simulation. The dataset should
325+
contain a variable with 'cf_role' attribute set to 'grid_topology'.
326+
327+
See https://sgrid.github.io/ for more information on the SGRID conventions.
328+
"""
329+
ds = ds.copy()
330+
331+
# Ensure time dimension has axis attribute if present
332+
if "time" in ds.dims and "time" in ds.coords:
333+
if "axis" not in ds["time"].attrs:
334+
logger.debug(
335+
"Dataset contains 'time' dimension but no 'axis' attribute. Setting 'axis' attribute to 'T'."
336+
)
337+
ds["time"].attrs["axis"] = "T"
338+
339+
# Find time dimension based on axis attribute and rename to `time`
340+
if (time_dims := ds.cf.axes.get("T")) is not None:
341+
if len(time_dims) > 1:
342+
raise ValueError("Multiple time coordinates found in dataset. This is not supported by Parcels.")
343+
(time_dim,) = time_dims
344+
if time_dim != "time":
345+
logger.debug(f"Renaming time axis coordinate from {time_dim} to 'time'.")
346+
ds = ds.rename({time_dim: "time"})
347+
348+
# Parse SGRID metadata and get xgcm kwargs
349+
_, xgcm_kwargs = sgrid.parse_sgrid(ds)
350+
351+
# Add time axis to xgcm_kwargs if present
352+
if "time" in ds.dims:
353+
if "T" not in xgcm_kwargs["coords"]:
354+
xgcm_kwargs["coords"]["T"] = {"center": "time"}
355+
356+
# Create xgcm Grid object
357+
xgcm_grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs, **_DEFAULT_XGCM_KWARGS)
358+
359+
# Wrap in XGrid
360+
grid = XGrid(xgcm_grid, mesh=mesh)
361+
362+
# Create fields from data variables, skipping grid metadata variables
363+
# Skip variables that are SGRID metadata (have cf_role='grid_topology')
364+
skip_vars = set()
365+
for var in ds.data_vars:
366+
if ds[var].attrs.get("cf_role") == "grid_topology":
367+
skip_vars.add(var)
368+
369+
fields = {}
370+
if "U" in ds.data_vars and "V" in ds.data_vars:
371+
fields["U"] = Field("U", ds["U"], grid, XLinear)
372+
fields["V"] = Field("V", ds["V"], grid, XLinear)
373+
374+
if "W" in ds.data_vars:
375+
fields["W"] = Field("W", ds["W"], grid, XLinear)
376+
fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"])
377+
else:
378+
fields["UV"] = VectorField("UV", fields["U"], fields["V"])
379+
380+
for varname in set(ds.data_vars) - set(fields.keys()) - skip_vars:
381+
fields[varname] = Field(varname, ds[varname], grid, XLinear)
382+
383+
return FieldSet(list(fields.values()))
384+
298385

299386
class CalendarError(Exception): # TODO: Move to a parcels errors module
300387
"""Exception raised when the calendar of a field is not compatible with the rest of the Fields. The user should ensure that they only add fields to a FieldSet that have compatible CFtime calendars."""

src/parcels/_core/utils/sgrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def parse_sgrid(ds: xr.Dataset):
378378
xgcm_coords = {}
379379
for dim_dim_padding, axis in zip(dimensions, "XYZ", strict=False):
380380
xgcm_position = SGRID_PADDING_TO_XGCM_POSITION[dim_dim_padding.padding]
381-
xgcm_coords[axis] = {"center": dim_dim_padding.dim2, xgcm_position: dim_dim_padding.dim1}
381+
xgcm_coords[axis] = {"center": dim_dim_padding.dim1, xgcm_position: dim_dim_padding.dim2}
382382

383383
return (ds, {"coords": xgcm_coords})
384384

src/parcels/_datasets/structured/generic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,10 @@ def _unrolled_cone_curvilinear_grid():
264264
Grid2DMetadata(
265265
cf_role="grid_topology",
266266
topology_dimension=2,
267-
node_dimensions=("YG", "XG"),
267+
node_dimensions=("XG", "YG"),
268268
face_dimensions=(
269-
DimDimPadding("YC", "YG", Padding.HIGH),
270269
DimDimPadding("XC", "XG", Padding.HIGH),
270+
DimDimPadding("YC", "YG", Padding.HIGH),
271271
),
272272
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.HIGH),),
273273
),
@@ -284,10 +284,10 @@ def _unrolled_cone_curvilinear_grid():
284284
Grid2DMetadata(
285285
cf_role="grid_topology",
286286
topology_dimension=2,
287-
node_dimensions=("YG", "XG"),
287+
node_dimensions=("XG", "YG"),
288288
face_dimensions=(
289-
DimDimPadding("YC", "YG", Padding.LOW),
290289
DimDimPadding("XC", "XG", Padding.LOW),
290+
DimDimPadding("YC", "YG", Padding.LOW),
291291
),
292292
vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.LOW),),
293293
),

tests/test_fieldset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from parcels._datasets.structured.circulation_models import datasets as datasets_circulation_models
1212
from parcels._datasets.structured.generic import T as T_structured
1313
from parcels._datasets.structured.generic import datasets as datasets_structured
14+
from parcels._datasets.structured.generic import datasets_sgrid
1415
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
1516
from parcels.interpolators import XLinear
1617
from tests import utils
@@ -342,3 +343,11 @@ def test_fieldset_from_fesom2_missingUV():
342343
with pytest.raises(ValueError) as info:
343344
_ = FieldSet.from_fesom2(localds)
344345
assert "Dataset has only one of the two variables 'U' and 'V'" in str(info)
346+
347+
348+
@pytest.mark.parametrize("ds_name", list(datasets_sgrid.keys()))
349+
def test_fieldset_from_sgrid_conventions(ds_name):
350+
ds = datasets_sgrid[ds_name]
351+
fieldset = FieldSet.from_sgrid_conventions(ds, mesh="flat")
352+
assert isinstance(fieldset, FieldSet)
353+
assert len(fieldset.fields) > 0

tests/utils/test_sgrid.py

Lines changed: 29 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,29 @@
77
from parcels._core.utils import sgrid
88
from tests.strategies import sgrid as sgrid_strategies
99

10+
grid2dmetadata = sgrid.Grid2DMetadata(
11+
cf_role="grid_topology",
12+
topology_dimension=2,
13+
node_dimensions=("node_dimension1", "node_dimension2"),
14+
face_dimensions=(
15+
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
16+
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
17+
),
18+
vertical_dimensions=(
19+
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
20+
),
21+
)
1022

11-
@pytest.fixture
12-
def grid2dmetadata():
13-
return sgrid.Grid2DMetadata(
14-
cf_role="grid_topology",
15-
topology_dimension=2,
16-
node_dimensions=("node_dimension1", "node_dimension2"),
17-
face_dimensions=(
18-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
19-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
20-
),
21-
vertical_dimensions=(
22-
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
23-
),
24-
)
23+
grid3dmetadata = sgrid.Grid3DMetadata(
24+
cf_role="grid_topology",
25+
topology_dimension=3,
26+
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
27+
volume_dimensions=(
28+
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
29+
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
30+
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
31+
),
32+
)
2533

2634

2735
def dummy_sgrid_ds(grid: sgrid.Grid2DMetadata | sgrid.Grid3DMetadata) -> xr.Dataset:
@@ -151,39 +159,15 @@ def test_load_dump_mappings(input_, expected):
151159
assert sgrid.load_mappings(input_) == expected
152160

153161

154-
@example(
155-
grid=sgrid.Grid2DMetadata(
156-
cf_role="grid_topology",
157-
topology_dimension=2,
158-
node_dimensions=("node_dimension1", "node_dimension2"),
159-
face_dimensions=(
160-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
161-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
162-
),
163-
vertical_dimensions=(
164-
sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),
165-
),
166-
)
167-
)
162+
@example(grid2dmetadata)
168163
@given(sgrid_strategies.grid2Dmetadata())
169164
def test_Grid2DMetadata_roundtrip(grid: sgrid.Grid2DMetadata):
170165
attrs = grid.to_attrs()
171166
parsed = sgrid.Grid2DMetadata.from_attrs(attrs)
172167
assert parsed == grid
173168

174169

175-
@example(
176-
grid=sgrid.Grid3DMetadata(
177-
cf_role="grid_topology",
178-
topology_dimension=3,
179-
node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"),
180-
volume_dimensions=(
181-
sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW),
182-
sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW),
183-
sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW),
184-
),
185-
)
186-
)
170+
@example(grid3dmetadata)
187171
@given(sgrid_strategies.grid3Dmetadata())
188172
def test_Grid3DMetadata_roundtrip(grid: sgrid.Grid3DMetadata):
189173
attrs = grid.to_attrs()
@@ -198,6 +182,7 @@ def test_parse_grid_attrs(grid: sgrid.AttrsSerializable):
198182
assert parsed == grid
199183

200184

185+
@example(grid2dmetadata)
201186
@given(sgrid_strategies.grid2Dmetadata())
202187
def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata):
203188
"""Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided"""
@@ -207,7 +192,7 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata):
207192
grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs)
208193

209194
for ddp, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True):
210-
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
195+
dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding
211196
coords = grid.axes[axis].coords
212197
assert coords["center"] == dim_edge
213198
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
@@ -216,7 +201,7 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata):
216201
assert "Z" not in grid.axes
217202
else:
218203
ddp = grid_metadata.vertical_dimensions[0]
219-
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
204+
dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding
220205
coords = grid.axes["Z"].coords
221206
assert coords["center"] == dim_edge
222207
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
@@ -231,7 +216,7 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata):
231216
grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs)
232217

233218
for ddp, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True):
234-
dim_node, dim_edge, padding = ddp.dim1, ddp.dim2, ddp.padding
219+
dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding
235220
coords = grid.axes[axis].coords
236221
assert coords["center"] == dim_edge
237222
assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node
@@ -291,7 +276,7 @@ def test_rename_dims(grid):
291276
assert grid == grid_new.rename_dims(dims_dict_inv)
292277

293278

294-
def test_rename_dims_errors(grid2dmetadata):
279+
def test_rename_dims_errors():
295280
# Test various error modes of rename_dims
296281
grid = grid2dmetadata
297282
# Non-unique target dimension names

0 commit comments

Comments
 (0)