|
23 | 23 | if typing.TYPE_CHECKING: |
24 | 24 | import uxarray as ux |
25 | 25 |
|
| 26 | +_NEMO_EXPECTED_COORDS = ["glamf", "gphif"] |
| 27 | + |
26 | 28 | _NEMO_DIMENSION_COORD_NAMES = ["x", "y", "time", "x", "x_center", "y", "y_center", "depth", "glamf", "gphif"] |
27 | 29 |
|
28 | 30 | _NEMO_AXIS_VARNAMES = { |
|
42 | 44 | "wo": "W", |
43 | 45 | } |
44 | 46 |
|
| 47 | +_MITGCM_EXPECTED_COORDS = ["XG", "YG", "Zl"] |
| 48 | + |
45 | 49 | _MITGCM_AXIS_VARNAMES = { |
46 | 50 | "XC": "X", |
47 | 51 | "XG": "X", |
|
70 | 74 | "T": "time", |
71 | 75 | } |
72 | 76 |
|
| 77 | +_CROCO_EXPECTED_COORDS = ["x_rho", "y_rho", "s_w", "time"] |
| 78 | + |
73 | 79 | _CROCO_VARNAMES_MAPPING = { |
74 | 80 | "x_rho": "lon", |
75 | 81 | "y_rho": "lat", |
76 | 82 | "s_w": "depth", |
77 | 83 | } |
78 | 84 |
|
79 | 85 |
|
| 86 | +def _pick_expected_coords(coords: xr.Dataset, expected_coord_names: list[str]) -> xr.Dataset: |
| 87 | + coords_to_use = {} |
| 88 | + for name in expected_coord_names: |
| 89 | + if name in coords: |
| 90 | + coords_to_use[name] = coords[name] |
| 91 | + else: |
| 92 | + raise ValueError(f"Expected coordinate '{name}' not found in provided coords dataset.") |
| 93 | + return xr.Dataset(coords_to_use) |
| 94 | + |
| 95 | + |
80 | 96 | def _maybe_bring_other_depths_to_depth(ds): |
81 | 97 | if "depth" in ds.coords: |
82 | 98 | for var in ds.data_vars: |
@@ -246,7 +262,7 @@ def nemo_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.Da |
246 | 262 |
|
247 | 263 | """ |
248 | 264 | fields = fields.copy() |
249 | | - coords = coords[["gphif", "glamf"]] |
| 265 | + coords = _pick_expected_coords(coords, _NEMO_EXPECTED_COORDS) |
250 | 266 |
|
251 | 267 | for name, field_da in fields.items(): |
252 | 268 | if isinstance(field_da, xr.Dataset): |
@@ -358,6 +374,8 @@ def mitgcm_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr. |
358 | 374 | field_da = field_da.rename(name) |
359 | 375 | fields[name] = field_da |
360 | 376 |
|
| 377 | + coords = _pick_expected_coords(coords, _MITGCM_EXPECTED_COORDS) |
| 378 | + |
361 | 379 | ds = xr.merge(list(fields.values()) + [coords]) |
362 | 380 | ds.attrs.clear() # Clear global attributes from the merging |
363 | 381 |
|
@@ -418,6 +436,8 @@ def croco_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.D |
418 | 436 | field_da = field_da.rename(name) |
419 | 437 | fields[name] = field_da |
420 | 438 |
|
| 439 | + coords = _pick_expected_coords(coords, _CROCO_EXPECTED_COORDS) |
| 440 | + |
421 | 441 | ds = xr.merge(list(fields.values()) + [coords]) |
422 | 442 | ds.attrs.clear() # Clear global attributes from the merging |
423 | 443 |
|
@@ -509,3 +529,192 @@ def copernicusmarine_to_sgrid( |
509 | 529 | ) |
510 | 530 |
|
511 | 531 | return ds |
| 532 | + |
| 533 | + |
| 534 | +# Known vertical dimension mappings by model |
| 535 | +_FESOM2_VERTICAL_DIMS = {"interface": "nz", "center": "nz1"} |
| 536 | +_ICON_VERTICAL_DIMS = {"interface": "depth_2", "center": "depth"} |
| 537 | + |
| 538 | + |
| 539 | +def _detect_vertical_coordinates( |
| 540 | + ds: ux.UxDataset, |
| 541 | + known_mappings: dict[str, str] | None = None, |
| 542 | +) -> tuple[str, str]: |
| 543 | + """Detect vertical coordinate dimensions for faces (zf) and centers (zc). |
| 544 | +
|
| 545 | + Detection strategy (with fallback): |
| 546 | + 1. Use known_mappings if provided and dimensions exist |
| 547 | + 2. Look for CF convention axis='Z' metadata |
| 548 | + 3. Find dimension pairs where sizes differ by exactly 1 |
| 549 | +
|
| 550 | + Parameters |
| 551 | + ---------- |
| 552 | + ds : ux.UxDataset |
| 553 | + UxDataset to analyze for vertical coordinates. |
| 554 | + known_mappings : dict[str, str] | None |
| 555 | + Optional mapping with keys "interface" and "center" specifying |
| 556 | + the dimension names for layer interfaces (zf) and centers (zc). |
| 557 | +
|
| 558 | + Returns |
| 559 | + ------- |
| 560 | + tuple[str, str] |
| 561 | + Tuple of (interface_dim_name, center_dim_name). |
| 562 | +
|
| 563 | + Raises |
| 564 | + ------ |
| 565 | + ValueError |
| 566 | + If vertical coordinates cannot be detected. |
| 567 | + """ |
| 568 | + ds_dims = set(ds.dims) |
| 569 | + |
| 570 | + # Strategy 1: Use known mappings if provided and dimensions exist |
| 571 | + if known_mappings is not None: |
| 572 | + interface_dim = known_mappings.get("interface") |
| 573 | + center_dim = known_mappings.get("center") |
| 574 | + if interface_dim in ds_dims and center_dim in ds_dims: |
| 575 | + logger.info( |
| 576 | + f"Using known vertical dimension mapping: {interface_dim!r} (interfaces) and {center_dim!r} (centers)." |
| 577 | + ) |
| 578 | + return interface_dim, center_dim |
| 579 | + logger.debug(f"Known mappings {known_mappings} not found in dataset dimensions {ds_dims}. Trying CF metadata.") |
| 580 | + |
| 581 | + # Strategy 2: Look for CF convention axis='Z' metadata |
| 582 | + z_dims = [] |
| 583 | + for dim in ds_dims: |
| 584 | + if dim in ds.coords: |
| 585 | + coord = ds.coords[dim] |
| 586 | + if coord.attrs.get("axis") == "Z": |
| 587 | + z_dims.append(dim) |
| 588 | + elif coord.attrs.get("positive") in ("up", "down"): |
| 589 | + z_dims.append(dim) |
| 590 | + elif "depth" in coord.attrs.get("standard_name", "").lower(): |
| 591 | + z_dims.append(dim) |
| 592 | + |
| 593 | + if len(z_dims) == 2: |
| 594 | + # Sort by size - interface has n+1 values, center has n |
| 595 | + z_dims_sorted = sorted(z_dims, key=lambda d: ds.sizes[d], reverse=True) |
| 596 | + interface_dim, center_dim = z_dims_sorted |
| 597 | + if ds.sizes[interface_dim] == ds.sizes[center_dim] + 1: |
| 598 | + logger.info( |
| 599 | + f"Detected vertical dimensions from CF metadata: {interface_dim!r} (interfaces) and {center_dim!r} (centers)." |
| 600 | + ) |
| 601 | + return interface_dim, center_dim |
| 602 | + |
| 603 | + # Strategy 3: Find dimension pairs where sizes differ by exactly 1 |
| 604 | + # Skip known non-vertical dimensions |
| 605 | + skip_dims = {"time", "n_face", "n_node", "n_edge", "n_max_face_nodes"} |
| 606 | + candidate_dims = [d for d in ds_dims if d not in skip_dims] |
| 607 | + |
| 608 | + for dim1 in candidate_dims: |
| 609 | + for dim2 in candidate_dims: |
| 610 | + if dim1 != dim2: |
| 611 | + size1, size2 = ds.sizes[dim1], ds.sizes[dim2] |
| 612 | + if size1 == size2 + 1: |
| 613 | + logger.info( |
| 614 | + f"Auto-detected vertical dimensions by size difference: {dim1!r} (interfaces, size={size1}) " |
| 615 | + f"and {dim2!r} (centers, size={size2})." |
| 616 | + ) |
| 617 | + return dim1, dim2 |
| 618 | + |
| 619 | + raise ValueError( |
| 620 | + f"Could not detect vertical coordinate dimensions in dataset with dims {list(ds_dims)}. " |
| 621 | + "Please ensure the dataset has vertical layer interface and center dimensions, " |
| 622 | + "or rename them manually to 'zf' (interfaces) and 'zc' (centers)." |
| 623 | + ) |
| 624 | + |
| 625 | + |
| 626 | +def _rename_vertical_dims( |
| 627 | + ds: ux.UxDataset, |
| 628 | + interface_dim: str, |
| 629 | + center_dim: str, |
| 630 | +) -> ux.UxDataset: |
| 631 | + """Rename vertical dimensions to zf (interfaces) and zc (centers). |
| 632 | +
|
| 633 | + Parameters |
| 634 | + ---------- |
| 635 | + ds : ux.UxDataset |
| 636 | + UxDataset with vertical dimensions to rename. |
| 637 | + interface_dim : str |
| 638 | + Current name of the interface dimension. |
| 639 | + center_dim : str |
| 640 | + Current name of the center dimension. |
| 641 | +
|
| 642 | + Returns |
| 643 | + ------- |
| 644 | + ux.UxDataset |
| 645 | + Dataset with renamed dimensions and indexed coordinates. |
| 646 | + """ |
| 647 | + rename_dict = {} |
| 648 | + if interface_dim != "zf": |
| 649 | + rename_dict[interface_dim] = "zf" |
| 650 | + if center_dim != "zc": |
| 651 | + rename_dict[center_dim] = "zc" |
| 652 | + |
| 653 | + if rename_dict: |
| 654 | + logger.info(f"Renaming vertical dimensions: {rename_dict}") |
| 655 | + ds = ds.rename(rename_dict) |
| 656 | + |
| 657 | + ds = ds.set_index(zf="zf", zc="zc") |
| 658 | + return ds |
| 659 | + |
| 660 | + |
| 661 | +def fesom_to_ugrid(ds: ux.UxDataset) -> ux.UxDataset: |
| 662 | + """Convert FESOM2 UxDataset to Parcels UGRID-compliant format. |
| 663 | +
|
| 664 | + Renames vertical dimensions: |
| 665 | + - nz -> zf (vertical layer faces/interfaces) |
| 666 | + - nz1 -> zc (vertical layer centers) |
| 667 | +
|
| 668 | + Parameters |
| 669 | + ---------- |
| 670 | + ds : ux.UxDataset |
| 671 | + FESOM2 UxDataset as obtained from uxarray. |
| 672 | +
|
| 673 | + Returns |
| 674 | + ------- |
| 675 | + ux.UxDataset |
| 676 | + UGRID-compliant dataset ready for FieldSet.from_ugrid_conventions(). |
| 677 | +
|
| 678 | + Examples |
| 679 | + -------- |
| 680 | + >>> import uxarray as ux |
| 681 | + >>> from parcels import FieldSet |
| 682 | + >>> from parcels.convert import fesom_to_ugrid |
| 683 | + >>> ds = ux.open_mfdataset(grid_path, data_path) |
| 684 | + >>> ds_ugrid = fesom_to_ugrid(ds) |
| 685 | + >>> fieldset = FieldSet.from_ugrid_conventions(ds_ugrid, mesh="flat") |
| 686 | + """ |
| 687 | + ds = ds.copy() |
| 688 | + interface_dim, center_dim = _detect_vertical_coordinates(ds, _FESOM2_VERTICAL_DIMS) |
| 689 | + return _rename_vertical_dims(ds, interface_dim, center_dim) |
| 690 | + |
| 691 | + |
| 692 | +def icon_to_ugrid(ds: ux.UxDataset) -> ux.UxDataset: |
| 693 | + """Convert ICON UxDataset to Parcels UGRID-compliant format. |
| 694 | +
|
| 695 | + Renames vertical dimensions: |
| 696 | + - depth_2 -> zf (vertical layer faces/interfaces) |
| 697 | + - depth -> zc (vertical layer centers) |
| 698 | +
|
| 699 | + Parameters |
| 700 | + ---------- |
| 701 | + ds : ux.UxDataset |
| 702 | + ICON UxDataset as obtained from uxarray. |
| 703 | +
|
| 704 | + Returns |
| 705 | + ------- |
| 706 | + ux.UxDataset |
| 707 | + UGRID-compliant dataset ready for FieldSet.from_ugrid_conventions(). |
| 708 | +
|
| 709 | + Examples |
| 710 | + -------- |
| 711 | + >>> import uxarray as ux |
| 712 | + >>> from parcels import FieldSet |
| 713 | + >>> from parcels.convert import icon_to_ugrid |
| 714 | + >>> ds = ux.open_mfdataset(grid_path, data_path) |
| 715 | + >>> ds_ugrid = icon_to_ugrid(ds) |
| 716 | + >>> fieldset = FieldSet.from_ugrid_conventions(ds_ugrid, mesh="flat") |
| 717 | + """ |
| 718 | + ds = ds.copy() |
| 719 | + interface_dim, center_dim = _detect_vertical_coordinates(ds, _ICON_VERTICAL_DIMS) |
| 720 | + return _rename_vertical_dims(ds, interface_dim, center_dim) |
0 commit comments