1919
2020import xarray as xr
2121
22+ from parcels ._python import repr_from_dunder_dict
23+
2224RE_DIM_DIM_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)"
2325
2426Dim = str
@@ -31,12 +33,21 @@ class Padding(enum.Enum):
3133 BOTH = "both"
3234
3335
34- class SGridMetadataProtocol (Protocol ):
36+ SGRID_PADDING_TO_XGCM_POSITION = {
37+ Padding .LOW : "right" ,
38+ Padding .HIGH : "left" ,
39+ Padding .BOTH : "inner" ,
40+ Padding .NONE : "outer" ,
41+ # "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves
42+ }
43+
44+
45+ class AttrsSerializable (Protocol ):
3546 def to_attrs (self ) -> dict [str , str | int ]: ...
3647 def from_attrs (cls , d : dict [str , Hashable ]) -> Self : ...
3748
3849
39- class Grid2DMetadata (SGridMetadataProtocol ):
50+ class Grid2DMetadata (AttrsSerializable ):
4051 def __init__ (
4152 self ,
4253 cf_role : Literal ["grid_topology" ],
@@ -94,16 +105,13 @@ def __init__(
94105 #! Important optional attribute for 2D grids with vertical layering
95106 self .vertical_dimensions = vertical_dimensions
96107
108+ def __repr__ (self ) -> str :
109+ return repr_from_dunder_dict (self )
110+
97111 def __eq__ (self , other : Any ) -> bool :
98112 if not isinstance (other , Grid2DMetadata ):
99113 return NotImplemented
100- return (
101- self .cf_role == other .cf_role
102- and self .topology_dimension == other .topology_dimension
103- and self .node_dimensions == other .node_dimensions
104- and self .face_dimensions == other .face_dimensions
105- and self .vertical_dimensions == other .vertical_dimensions
106- )
114+ return self .to_attrs () == other .to_attrs ()
107115
108116 @classmethod
109117 def from_attrs (cls , attrs ):
@@ -129,8 +137,11 @@ def to_attrs(self) -> dict[str, str | int]:
129137 d ["vertical_dimensions" ] = dump_mappings (self .vertical_dimensions )
130138 return d
131139
140+ def rename_dims (self , dims_dict : dict [str , str ]) -> Self :
141+ return _metadata_rename_dims (self , dims_dict )
142+
132143
133- class Grid3DMetadata (SGridMetadataProtocol ):
144+ class Grid3DMetadata (AttrsSerializable ):
134145 def __init__ (
135146 self ,
136147 cf_role : Literal ["grid_topology" ],
@@ -180,15 +191,13 @@ def __init__(
180191 # face *i_coordinates*
181192 # volume_coordinates
182193
194+ def __repr__ (self ) -> str :
195+ return repr_from_dunder_dict (self )
196+
183197 def __eq__ (self , other : Any ) -> bool :
184198 if not isinstance (other , Grid3DMetadata ):
185199 return NotImplemented
186- return (
187- self .cf_role == other .cf_role
188- and self .topology_dimension == other .topology_dimension
189- and self .node_dimensions == other .node_dimensions
190- and self .volume_dimensions == other .volume_dimensions
191- )
200+ return self .to_attrs () == other .to_attrs ()
192201
193202 @classmethod
194203 def from_attrs (cls , attrs ):
@@ -210,6 +219,9 @@ def to_attrs(self) -> dict[str, str | int]:
210219 volume_dimensions = dump_mappings (self .volume_dimensions ),
211220 )
212221
222+ def rename_dims (self , dims_dict : dict [str , str ]) -> Self :
223+ return _metadata_rename_dims (self , dims_dict )
224+
213225
214226@dataclass
215227class DimDimPadding :
@@ -318,15 +330,6 @@ def maybe_load_mappings(s):
318330 return load_mappings (s )
319331
320332
321- SGRID_PADDING_TO_XGCM_POSITION = {
322- Padding .LOW : "right" ,
323- Padding .HIGH : "left" ,
324- Padding .BOTH : "inner" ,
325- Padding .NONE : "outer" ,
326- # "center" position is not used in SGrid, in SGrid this would just be the edges/faces themselves
327- }
328-
329-
330333class SGridParsingException (Exception ):
331334 """Exception raised when parsing SGrid attributes fails."""
332335
@@ -378,3 +381,95 @@ def parse_sgrid(ds: xr.Dataset):
378381 xgcm_coords [axis ] = {"center" : dim_dim_padding .dim2 , xgcm_position : dim_dim_padding .dim1 }
379382
380383 return (ds , {"coords" : xgcm_coords })
384+
385+
386+ def rename_dims (ds : xr .Dataset , dims_dict : dict [str , str ]) -> xr .Dataset :
387+ grid_da = get_grid_topology (ds )
388+ if grid_da is None :
389+ raise ValueError (
390+ "No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions."
391+ )
392+
393+ ds = ds .rename_dims (dims_dict )
394+
395+ # Update the metadata
396+ grid = parse_grid_attrs (grid_da .attrs )
397+ ds [grid_da .name ].attrs = grid .rename_dims (dims_dict ).to_attrs ()
398+ return ds
399+
400+
401+ def get_unique_dim_names (grid : Grid2DMetadata | Grid3DMetadata ) -> set [str ]:
402+ dims = set ()
403+ dims .update (set (grid .node_dimensions ))
404+
405+ for key , value in grid .__dict__ .items ():
406+ if key in ("cf_role" , "topology_dimension" ) or value is None :
407+ continue
408+ assert isinstance (value , tuple ), (
409+ f"Expected sgrid metadata attribute to be represented as a tuple, got { value !r} . This is an internal error to Parcels - please post an issue if you encounter this."
410+ )
411+ for item in value :
412+ if isinstance (item , DimDimPadding ):
413+ dims .add (item .dim1 )
414+ dims .add (item .dim2 )
415+ else :
416+ assert isinstance (item , str )
417+ dims .add (item )
418+ return dims
419+
420+
421+ @overload
422+ def _metadata_rename_dims (grid : Grid2DMetadata , dims_dict : dict [str , str ]) -> Grid2DMetadata : ...
423+
424+
425+ @overload
426+ def _metadata_rename_dims (grid : Grid3DMetadata , dims_dict : dict [str , str ]) -> Grid3DMetadata : ...
427+
428+
429+ def _metadata_rename_dims (grid , dims_dict ):
430+ """
431+ Renames dimensions in SGrid metadata.
432+
433+ Similar in API to xr.Dataset.rename_dims. Renames dimensions according to dims_dict mapping
434+ of old dimension names to new dimension names.
435+ """
436+ dims_dict = dims_dict .copy ()
437+ assert len (dims_dict ) == len (set (dims_dict .values ())), "dims_dict contains duplicate target dimension names"
438+
439+ existing_dims = get_unique_dim_names (grid )
440+ for dim in dims_dict .keys ():
441+ if dim not in existing_dims :
442+ raise ValueError (f"Dimension { dim !r} not found in SGrid metadata dimensions { existing_dims !r} " )
443+
444+ for dim in existing_dims :
445+ if dim not in dims_dict :
446+ dims_dict [dim ] = dim # identity mapping for dimensions not being renamed
447+
448+ kwargs = {}
449+ for key , value in grid .__dict__ .items ():
450+ if isinstance (value , tuple ):
451+ new_value = []
452+ for item in value :
453+ if isinstance (item , DimDimPadding ):
454+ new_item = DimDimPadding (
455+ dim1 = dims_dict [item .dim1 ],
456+ dim2 = dims_dict [item .dim2 ],
457+ padding = item .padding ,
458+ )
459+ new_value .append (new_item )
460+ else :
461+ assert isinstance (item , str )
462+ new_value .append (dims_dict [item ])
463+ kwargs [key ] = tuple (new_value )
464+ continue
465+
466+ if key in ("cf_role" , "topology_dimension" ) or value is None :
467+ kwargs [key ] = value
468+ continue
469+
470+ if isinstance (value , str ):
471+ kwargs [key ] = dims_dict [value ]
472+ continue
473+
474+ raise ValueError (f"Unexpected attribute { key !r} on { grid !r} " )
475+ return type (grid )(** kwargs )
0 commit comments