Skip to content

Commit d18c7d6

Browse files
Merge branch 'v4-dev' into renaming_pyfunc_to_kernels
2 parents 822db95 + 36d73f0 commit d18c7d6

9 files changed

Lines changed: 237 additions & 104 deletions

File tree

docs/user_guide/examples/tutorial_croco_3D.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"cell_type": "markdown",
6060
"metadata": {},
6161
"source": [
62-
"Now we create a FieldSet object using the `FieldSet.from_croco()` method. Note that CROCO is a C-grid (with similar indexing at MITgcm)."
62+
"Now we create a FieldSet object using the `convert.croco_to_sgrid()` function to first create an S-Grid compliant datatset, and then use that in `FieldSet.from_sgrid_conventions()` to create the FieldSet."
6363
]
6464
},
6565
{
@@ -86,8 +86,7 @@
8686
" \"Cs_w\": ds_fields[\"Cs_w\"],\n",
8787
"}\n",
8888
"\n",
89-
"coords = ds_fields[[\"x_rho\", \"y_rho\", \"s_w\", \"time\"]]\n",
90-
"ds_fset = parcels.convert.croco_to_sgrid(fields=fields, coords=coords)\n",
89+
"ds_fset = parcels.convert.croco_to_sgrid(fields=fields, coords=ds_fields)\n",
9190
"\n",
9291
"fieldset = parcels.FieldSet.from_sgrid_conventions(ds_fset)\n",
9392
"\n",

docs/user_guide/examples/tutorial_mitgcm.ipynb

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@
3838
"id": "3",
3939
"metadata": {},
4040
"source": [
41-
"We can use a combination of `parcels.convert.mitgcm_to_sgrid` and `FieldSet.from_sgrid_conventions` to read in the data. See below for an example.\n",
42-
"\n",
43-
"```{note}\n",
44-
"It is very important that you provide the corner nodes as coordinates when converting MITgcm data to S-grid conventions. These corner nodes are typically called `XG` and `YG` in MITgcm output. Failing to do so will lead to incorrect interpolation of the velocity fields.\n",
45-
"```"
41+
"We can use a combination of `parcels.convert.mitgcm_to_sgrid` and `FieldSet.from_sgrid_conventions` to read in the data. See below for an example."
4642
]
4743
},
4844
{
@@ -52,9 +48,8 @@
5248
"metadata": {},
5349
"outputs": [],
5450
"source": [
55-
"coords = ds_fields[[\"XG\", \"YG\", \"Zl\", \"time\"]]\n",
5651
"ds_fset = parcels.convert.mitgcm_to_sgrid(\n",
57-
" fields={\"U\": ds_fields.UVEL, \"V\": ds_fields.VVEL}, coords=coords\n",
52+
" fields={\"U\": ds_fields.UVEL, \"V\": ds_fields.VVEL}, coords=ds_fields\n",
5853
")\n",
5954
"fieldset = parcels.FieldSet.from_sgrid_conventions(ds_fset)"
6055
]

docs/user_guide/examples/tutorial_nemo.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@
369369
],
370370
"metadata": {
371371
"kernelspec": {
372-
"display_name": "parcels",
372+
"display_name": "docs",
373373
"language": "python",
374374
"name": "python3"
375375
},
@@ -383,7 +383,7 @@
383383
"name": "python",
384384
"nbconvert_exporter": "python",
385385
"pygments_lexer": "ipython3",
386-
"version": "3.12.3"
386+
"version": "3.14.2"
387387
}
388388
},
389389
"nbformat": 4,

src/parcels/_core/fieldset.py

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ def gridset(self) -> list[BaseGrid]:
183183
@classmethod
184184
def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):
185185
"""Create a FieldSet from a Parcels compliant uxarray.UxDataset.
186+
187+
This is the primary ingestion method in Parcels for structured grid datasets.
188+
186189
The main requirements for a uxDataset are naming conventions for vertical grid dimensions & coordinates
187190
188191
zf - Name for coordinate and dimension for vertical positions at layer interfaces
@@ -225,63 +228,6 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):
225228

226229
return cls(list(fields.values()))
227230

228-
@classmethod
229-
def from_fesom2(cls, ds: ux.UxDataset, mesh: str = "spherical"):
230-
"""Create a FieldSet from a FESOM2 uxarray.UxDataset.
231-
232-
Parameters
233-
----------
234-
ds : uxarray.UxDataset
235-
uxarray.UxDataset as obtained from the uxarray package.
236-
237-
Returns
238-
-------
239-
FieldSet
240-
FieldSet object containing the fields from the dataset that can be used for a Parcels simulation.
241-
"""
242-
ds = ds.copy()
243-
ds_dims = list(ds.dims)
244-
if not all(dim in ds_dims for dim in ["time", "nz", "nz1"]):
245-
raise ValueError(
246-
f"Dataset missing one of the required dimensions 'time', 'nz', or 'nz1' for FESOM data. Found dimensions {ds_dims}"
247-
)
248-
ds = ds.rename(
249-
{
250-
"nz": "zf", # Vertical Interface
251-
"nz1": "zc", # Vertical Center
252-
}
253-
).set_index(zf="zf", zc="zc")
254-
255-
return FieldSet.from_ugrid_conventions(ds, mesh=mesh)
256-
257-
@classmethod
258-
def from_icon(cls, ds: ux.UxDataset, mesh: str = "spherical"):
259-
"""Create a FieldSet from a ICON uxarray.UxDataset.
260-
261-
Parameters
262-
----------
263-
ds : uxarray.UxDataset
264-
uxarray.UxDataset as obtained from the uxarray package.
265-
266-
Returns
267-
-------
268-
FieldSet
269-
FieldSet object containing the fields from the dataset that can be used for a Parcels simulation.
270-
"""
271-
ds = ds.copy()
272-
ds_dims = list(ds.dims)
273-
if not all(dim in ds_dims for dim in ["time", "depth", "depth_2"]):
274-
raise ValueError(
275-
f"Dataset missing one of the required dimensions 'time', 'depth', or 'depth_2' for ICON data. Found dimensions {ds_dims}"
276-
)
277-
ds = ds.rename(
278-
{
279-
"depth_2": "zf", # Vertical Interface
280-
"depth": "zc", # Vertical Center
281-
}
282-
).set_index(zf="zf", zc="zc")
283-
return FieldSet.from_ugrid_conventions(ds, mesh=mesh)
284-
285231
@classmethod
286232
def from_sgrid_conventions(
287233
cls, ds: xr.Dataset, mesh: Mesh | None = None

src/parcels/convert.py

Lines changed: 210 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
if typing.TYPE_CHECKING:
2424
import uxarray as ux
2525

26+
_NEMO_EXPECTED_COORDS = ["glamf", "gphif"]
27+
2628
_NEMO_DIMENSION_COORD_NAMES = ["x", "y", "time", "x", "x_center", "y", "y_center", "depth", "glamf", "gphif"]
2729

2830
_NEMO_AXIS_VARNAMES = {
@@ -42,6 +44,8 @@
4244
"wo": "W",
4345
}
4446

47+
_MITGCM_EXPECTED_COORDS = ["XG", "YG", "Zl"]
48+
4549
_MITGCM_AXIS_VARNAMES = {
4650
"XC": "X",
4751
"XG": "X",
@@ -70,13 +74,25 @@
7074
"T": "time",
7175
}
7276

77+
_CROCO_EXPECTED_COORDS = ["x_rho", "y_rho", "s_w", "time"]
78+
7379
_CROCO_VARNAMES_MAPPING = {
7480
"x_rho": "lon",
7581
"y_rho": "lat",
7682
"s_w": "depth",
7783
}
7884

7985

86+
def _pick_expected_coords(coords: xr.Dataset, expected_coord_names: list[str]) -> xr.Dataset:
87+
coords_to_use = {}
88+
for name in expected_coord_names:
89+
if name in coords:
90+
coords_to_use[name] = coords[name]
91+
else:
92+
raise ValueError(f"Expected coordinate '{name}' not found in provided coords dataset.")
93+
return xr.Dataset(coords_to_use)
94+
95+
8096
def _maybe_bring_other_depths_to_depth(ds):
8197
if "depth" in ds.coords:
8298
for var in ds.data_vars:
@@ -246,7 +262,7 @@ def nemo_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.Da
246262
247263
"""
248264
fields = fields.copy()
249-
coords = coords[["gphif", "glamf"]]
265+
coords = _pick_expected_coords(coords, _NEMO_EXPECTED_COORDS)
250266

251267
for name, field_da in fields.items():
252268
if isinstance(field_da, xr.Dataset):
@@ -358,6 +374,8 @@ def mitgcm_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.
358374
field_da = field_da.rename(name)
359375
fields[name] = field_da
360376

377+
coords = _pick_expected_coords(coords, _MITGCM_EXPECTED_COORDS)
378+
361379
ds = xr.merge(list(fields.values()) + [coords])
362380
ds.attrs.clear() # Clear global attributes from the merging
363381

@@ -418,6 +436,8 @@ def croco_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.D
418436
field_da = field_da.rename(name)
419437
fields[name] = field_da
420438

439+
coords = _pick_expected_coords(coords, _CROCO_EXPECTED_COORDS)
440+
421441
ds = xr.merge(list(fields.values()) + [coords])
422442
ds.attrs.clear() # Clear global attributes from the merging
423443

@@ -509,3 +529,192 @@ def copernicusmarine_to_sgrid(
509529
)
510530

511531
return ds
532+
533+
534+
# Known vertical dimension mappings by model
535+
_FESOM2_VERTICAL_DIMS = {"interface": "nz", "center": "nz1"}
536+
_ICON_VERTICAL_DIMS = {"interface": "depth_2", "center": "depth"}
537+
538+
539+
def _detect_vertical_coordinates(
540+
ds: ux.UxDataset,
541+
known_mappings: dict[str, str] | None = None,
542+
) -> tuple[str, str]:
543+
"""Detect vertical coordinate dimensions for faces (zf) and centers (zc).
544+
545+
Detection strategy (with fallback):
546+
1. Use known_mappings if provided and dimensions exist
547+
2. Look for CF convention axis='Z' metadata
548+
3. Find dimension pairs where sizes differ by exactly 1
549+
550+
Parameters
551+
----------
552+
ds : ux.UxDataset
553+
UxDataset to analyze for vertical coordinates.
554+
known_mappings : dict[str, str] | None
555+
Optional mapping with keys "interface" and "center" specifying
556+
the dimension names for layer interfaces (zf) and centers (zc).
557+
558+
Returns
559+
-------
560+
tuple[str, str]
561+
Tuple of (interface_dim_name, center_dim_name).
562+
563+
Raises
564+
------
565+
ValueError
566+
If vertical coordinates cannot be detected.
567+
"""
568+
ds_dims = set(ds.dims)
569+
570+
# Strategy 1: Use known mappings if provided and dimensions exist
571+
if known_mappings is not None:
572+
interface_dim = known_mappings.get("interface")
573+
center_dim = known_mappings.get("center")
574+
if interface_dim in ds_dims and center_dim in ds_dims:
575+
logger.info(
576+
f"Using known vertical dimension mapping: {interface_dim!r} (interfaces) and {center_dim!r} (centers)."
577+
)
578+
return interface_dim, center_dim
579+
logger.debug(f"Known mappings {known_mappings} not found in dataset dimensions {ds_dims}. Trying CF metadata.")
580+
581+
# Strategy 2: Look for CF convention axis='Z' metadata
582+
z_dims = []
583+
for dim in ds_dims:
584+
if dim in ds.coords:
585+
coord = ds.coords[dim]
586+
if coord.attrs.get("axis") == "Z":
587+
z_dims.append(dim)
588+
elif coord.attrs.get("positive") in ("up", "down"):
589+
z_dims.append(dim)
590+
elif "depth" in coord.attrs.get("standard_name", "").lower():
591+
z_dims.append(dim)
592+
593+
if len(z_dims) == 2:
594+
# Sort by size - interface has n+1 values, center has n
595+
z_dims_sorted = sorted(z_dims, key=lambda d: ds.sizes[d], reverse=True)
596+
interface_dim, center_dim = z_dims_sorted
597+
if ds.sizes[interface_dim] == ds.sizes[center_dim] + 1:
598+
logger.info(
599+
f"Detected vertical dimensions from CF metadata: {interface_dim!r} (interfaces) and {center_dim!r} (centers)."
600+
)
601+
return interface_dim, center_dim
602+
603+
# Strategy 3: Find dimension pairs where sizes differ by exactly 1
604+
# Skip known non-vertical dimensions
605+
skip_dims = {"time", "n_face", "n_node", "n_edge", "n_max_face_nodes"}
606+
candidate_dims = [d for d in ds_dims if d not in skip_dims]
607+
608+
for dim1 in candidate_dims:
609+
for dim2 in candidate_dims:
610+
if dim1 != dim2:
611+
size1, size2 = ds.sizes[dim1], ds.sizes[dim2]
612+
if size1 == size2 + 1:
613+
logger.info(
614+
f"Auto-detected vertical dimensions by size difference: {dim1!r} (interfaces, size={size1}) "
615+
f"and {dim2!r} (centers, size={size2})."
616+
)
617+
return dim1, dim2
618+
619+
raise ValueError(
620+
f"Could not detect vertical coordinate dimensions in dataset with dims {list(ds_dims)}. "
621+
"Please ensure the dataset has vertical layer interface and center dimensions, "
622+
"or rename them manually to 'zf' (interfaces) and 'zc' (centers)."
623+
)
624+
625+
626+
def _rename_vertical_dims(
627+
ds: ux.UxDataset,
628+
interface_dim: str,
629+
center_dim: str,
630+
) -> ux.UxDataset:
631+
"""Rename vertical dimensions to zf (interfaces) and zc (centers).
632+
633+
Parameters
634+
----------
635+
ds : ux.UxDataset
636+
UxDataset with vertical dimensions to rename.
637+
interface_dim : str
638+
Current name of the interface dimension.
639+
center_dim : str
640+
Current name of the center dimension.
641+
642+
Returns
643+
-------
644+
ux.UxDataset
645+
Dataset with renamed dimensions and indexed coordinates.
646+
"""
647+
rename_dict = {}
648+
if interface_dim != "zf":
649+
rename_dict[interface_dim] = "zf"
650+
if center_dim != "zc":
651+
rename_dict[center_dim] = "zc"
652+
653+
if rename_dict:
654+
logger.info(f"Renaming vertical dimensions: {rename_dict}")
655+
ds = ds.rename(rename_dict)
656+
657+
ds = ds.set_index(zf="zf", zc="zc")
658+
return ds
659+
660+
661+
def fesom_to_ugrid(ds: ux.UxDataset) -> ux.UxDataset:
662+
"""Convert FESOM2 UxDataset to Parcels UGRID-compliant format.
663+
664+
Renames vertical dimensions:
665+
- nz -> zf (vertical layer faces/interfaces)
666+
- nz1 -> zc (vertical layer centers)
667+
668+
Parameters
669+
----------
670+
ds : ux.UxDataset
671+
FESOM2 UxDataset as obtained from uxarray.
672+
673+
Returns
674+
-------
675+
ux.UxDataset
676+
UGRID-compliant dataset ready for FieldSet.from_ugrid_conventions().
677+
678+
Examples
679+
--------
680+
>>> import uxarray as ux
681+
>>> from parcels import FieldSet
682+
>>> from parcels.convert import fesom_to_ugrid
683+
>>> ds = ux.open_mfdataset(grid_path, data_path)
684+
>>> ds_ugrid = fesom_to_ugrid(ds)
685+
>>> fieldset = FieldSet.from_ugrid_conventions(ds_ugrid, mesh="flat")
686+
"""
687+
ds = ds.copy()
688+
interface_dim, center_dim = _detect_vertical_coordinates(ds, _FESOM2_VERTICAL_DIMS)
689+
return _rename_vertical_dims(ds, interface_dim, center_dim)
690+
691+
692+
def icon_to_ugrid(ds: ux.UxDataset) -> ux.UxDataset:
693+
"""Convert ICON UxDataset to Parcels UGRID-compliant format.
694+
695+
Renames vertical dimensions:
696+
- depth_2 -> zf (vertical layer faces/interfaces)
697+
- depth -> zc (vertical layer centers)
698+
699+
Parameters
700+
----------
701+
ds : ux.UxDataset
702+
ICON UxDataset as obtained from uxarray.
703+
704+
Returns
705+
-------
706+
ux.UxDataset
707+
UGRID-compliant dataset ready for FieldSet.from_ugrid_conventions().
708+
709+
Examples
710+
--------
711+
>>> import uxarray as ux
712+
>>> from parcels import FieldSet
713+
>>> from parcels.convert import icon_to_ugrid
714+
>>> ds = ux.open_mfdataset(grid_path, data_path)
715+
>>> ds_ugrid = icon_to_ugrid(ds)
716+
>>> fieldset = FieldSet.from_ugrid_conventions(ds_ugrid, mesh="flat")
717+
"""
718+
ds = ds.copy()
719+
interface_dim, center_dim = _detect_vertical_coordinates(ds, _ICON_VERTICAL_DIMS)
720+
return _rename_vertical_dims(ds, interface_dim, center_dim)

0 commit comments

Comments
 (0)