|
1 | 1 | import abc |
2 | 2 | import numbers |
3 | 3 | import pathlib |
4 | | -from typing import Mapping |
| 4 | +from typing import Any, Mapping, Optional |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import xarray as xr |
8 | 8 | import xugrid as xu |
| 9 | +from xarray.core.utils import is_scalar |
9 | 10 |
|
10 | 11 | import imod |
11 | 12 | from imod.mf6.interfaces.ipackagebase import IPackageBase |
12 | | -from imod.typing.grid import GridDataArray, GridDataset, merge_with_dictionary |
| 13 | +from imod.typing.grid import ( |
| 14 | + GridDataArray, |
| 15 | + GridDataset, |
| 16 | + is_spatial_grid, |
| 17 | + merge_with_dictionary, |
| 18 | +) |
13 | 19 |
|
14 | 20 | TRANSPORT_PACKAGES = ("adv", "dsp", "ssm", "mst", "ist", "src") |
15 | 21 | EXCHANGE_PACKAGES = ("gwfgwf", "gwfgwt", "gwtgwt") |
16 | 22 |
|
17 | 23 |
|
| 24 | +def _is_scalar_nan(da: GridDataArray): |
| 25 | + """ |
| 26 | + Test if is_scalar_nan, carefully avoid loading grids in memory |
| 27 | + """ |
| 28 | + scalar_data: bool = is_scalar(da) |
| 29 | + if scalar_data: |
| 30 | + stripped_value = da.to_numpy()[()] |
| 31 | + return isinstance(stripped_value, numbers.Real) and np.isnan(stripped_value) # type: ignore[call-overload] |
| 32 | + return False |
| 33 | + |
| 34 | + |
18 | 35 | class PackageBase(IPackageBase, abc.ABC): |
19 | 36 | """ |
20 | 37 | This class is used for storing a collection of Xarray DataArrays or UgridDataArrays |
@@ -48,15 +65,48 @@ def __getitem__(self, key): |
48 | 65 | def __setitem__(self, key, value): |
49 | 66 | self.dataset.__setitem__(key, value) |
50 | 67 |
|
51 | | - def to_netcdf(self, *args, **kwargs): |
| 68 | + def to_netcdf( |
| 69 | + self, *args, mdal_compliant: bool = False, crs: Optional[Any] = None, **kwargs |
| 70 | + ): |
52 | 71 | """ |
53 | 72 |
|
54 | | - Write dataset contents to a netCDF file. |
55 | | - Custom encoding rules can be provided on package level by overriding the _netcdf_encoding in the package |
| 73 | + Write dataset contents to a netCDF file. Custom encoding rules can be |
| 74 | + provided on package level by overriding the _netcdf_encoding in the |
| 75 | + package |
| 76 | +
|
| 77 | + Parameters |
| 78 | + ---------- |
| 79 | + *args: |
| 80 | + Will be passed on to ``xr.Dataset.to_netcdf`` or |
| 81 | + ``xu.UgridDataset.to_netcdf``. |
| 82 | + mdal_compliant: bool, optional |
| 83 | + Convert data with |
| 84 | + :func:`imod.prepare.spatial.mdal_compliant_ugrid2d` to MDAL |
| 85 | + compliant unstructured grids. Defaults to False. |
| 86 | + crs: Any, optional |
| 87 | + Anything accepted by rasterio.crs.CRS.from_user_input |
| 88 | + Requires ``rioxarray`` installed. |
| 89 | + **kwargs: |
| 90 | + Will be passed on to ``xr.Dataset.to_netcdf`` or |
| 91 | + ``xu.UgridDataset.to_netcdf``. |
56 | 92 |
|
57 | 93 | """ |
58 | 94 | kwargs.update({"encoding": self._netcdf_encoding()}) |
59 | | - self.dataset.to_netcdf(*args, **kwargs) |
| 95 | + |
| 96 | + dataset = self.dataset |
| 97 | + if isinstance(dataset, xu.UgridDataset): |
| 98 | + if mdal_compliant: |
| 99 | + dataset = dataset.ugrid.to_dataset() |
| 100 | + mdal_dataset = imod.util.spatial.mdal_compliant_ugrid2d( |
| 101 | + dataset, crs=crs |
| 102 | + ) |
| 103 | + mdal_dataset.to_netcdf(*args, **kwargs) |
| 104 | + else: |
| 105 | + dataset.ugrid.to_netcdf(*args, **kwargs) |
| 106 | + else: |
| 107 | + if is_spatial_grid(dataset): |
| 108 | + dataset = imod.util.spatial.gdal_compliant_grid(dataset, crs=crs) |
| 109 | + dataset.to_netcdf(*args, **kwargs) |
60 | 110 |
|
61 | 111 | def _netcdf_encoding(self): |
62 | 112 | """ |
@@ -123,11 +173,14 @@ def from_file(cls, path, **kwargs): |
123 | 173 | if dataset.ugrid_roles.topology: |
124 | 174 | dataset = xu.UgridDataset(dataset) |
125 | 175 | dataset = imod.util.spatial.from_mdal_compliant_ugrid2d(dataset) |
| 176 | + # Drop node dim, as we don't need in the UgridDataset (it will be |
| 177 | + # preserved in the ``grid`` property, so topology stays intact!) |
| 178 | + node_dim = dataset.grid.node_dimension |
| 179 | + dataset = dataset.drop_dims(node_dim, errors="ignore") |
126 | 180 |
|
127 | 181 | # Replace NaNs by None |
128 | 182 | for key, value in dataset.items(): |
129 | | - stripped_value = value.values[()] |
130 | | - if isinstance(stripped_value, numbers.Real) and np.isnan(stripped_value): # type: ignore[call-overload] |
| 183 | + if _is_scalar_nan(value): |
131 | 184 | dataset[key] = None |
132 | 185 |
|
133 | 186 | return cls._from_dataset(dataset) |
0 commit comments