Skip to content

Commit db17dde

Browse files
committed
Enforce Mypy disallow any generics
1 parent 81c29df commit db17dde

29 files changed

Lines changed: 158 additions & 134 deletions

imod/mf6/aggregate/aggregate_schemes.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable
1+
from typing import Any, Callable
22

33
import numpy as np
44
from pydantic.dataclasses import dataclass
@@ -37,10 +37,10 @@ class RiverAggregationMethod(DataclassType):
3737
3838
"""
3939

40-
stage: Callable = np.nanmean
41-
conductance: Callable = np.nansum
42-
bottom_elevation: Callable = np.nanmean
43-
concentration: Callable = np.nanmean
40+
stage: Callable[..., Any] = np.nanmean
41+
conductance: Callable[..., Any] = np.nansum
42+
bottom_elevation: Callable[..., Any] = np.nanmean
43+
concentration: Callable[..., Any] = np.nanmean
4444

4545

4646
@dataclass(config=_CONFIG)
@@ -65,9 +65,9 @@ class DrainageAggregationMethod(DataclassType):
6565
6666
"""
6767

68-
elevation: Callable = np.nanmean
69-
conductance: Callable = np.nansum
70-
concentration: Callable = np.nanmean
68+
elevation: Callable[..., Any] = np.nanmean
69+
conductance: Callable[..., Any] = np.nansum
70+
concentration: Callable[..., Any] = np.nanmean
7171

7272

7373
@dataclass(config=_CONFIG)
@@ -92,9 +92,9 @@ class GeneralHeadBoundaryAggregationMethod(DataclassType):
9292
9393
"""
9494

95-
head: Callable = np.nanmean
96-
conductance: Callable = np.nansum
97-
concentration: Callable = np.nanmean
95+
head: Callable[..., Any] = np.nanmean
96+
conductance: Callable[..., Any] = np.nansum
97+
concentration: Callable[..., Any] = np.nanmean
9898

9999

100100
@dataclass(config=_CONFIG)
@@ -118,5 +118,5 @@ class RechargeAggregationMethod(DataclassType):
118118
119119
"""
120120

121-
rate: Callable = np.nansum
122-
concentration: Callable = np.nanmean
121+
rate: Callable[..., Any] = np.nansum
122+
concentration: Callable[..., Any] = np.nanmean

imod/mf6/ats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from imod.schemata import AllValueSchema, DimsSchema, DTypeSchema
99
from imod.typing import GridDataset
1010

11-
_PeriodDataType: TypeAlias = dict[np.int64, list]
11+
_PeriodDataType: TypeAlias = dict[np.int64, list[Any]]
1212
_PeriodDataVarNames: TypeAlias = tuple[str, str, str, str, str]
1313

1414

@@ -208,7 +208,7 @@ def _get_render_dictionary(
208208
d["perioddata"] = perioddata
209209
return d
210210

211-
def _validate(self, schemata: dict, **kwargs):
211+
def _validate(self, schemata: dict[str, Any], **kwargs):
212212
# Insert additional kwargs
213213
kwargs["dt_max"] = self["dt_max"]
214214
errors = super()._validate(schemata, **kwargs)

imod/mf6/boundary_condition.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import abc
22
import pathlib
33
from copy import copy, deepcopy
4-
from typing import Mapping, Optional, Union
4+
from typing import Any, MutableMapping, Optional, Union
55

66
import numpy as np
77
import xarray as xr
@@ -69,13 +69,15 @@ class BoundaryCondition(Package, abc.ABC):
6969
not the array input which is used in :class:`Package`.
7070
"""
7171

72-
def __init__(self, allargs: Mapping[str, GridDataArray | float | int | bool | str]):
72+
def __init__(
73+
self, allargs: MutableMapping[str, GridDataArray | float | int | bool | str]
74+
):
7375
# Convert repeat_stress in dict to a xr.DataArray in the right shape if
7476
# necessary, which is required to merge it into the dataset.
7577
if "repeat_stress" in allargs.keys() and isinstance(
7678
allargs["repeat_stress"], dict
7779
):
78-
allargs["repeat_stress"] = get_repeat_stress(allargs["repeat_stress"]) # type: ignore
80+
allargs["repeat_stress"] = get_repeat_stress(allargs["repeat_stress"])
7981
# Call the Package constructor, this merges the arguments into a dataset.
8082
super().__init__(allargs)
8183
if "concentration" in allargs.keys() and allargs["concentration"] is None:
@@ -197,7 +199,9 @@ def _period_paths(
197199
return periods
198200

199201
def _get_unfiltered_pkg_options(
200-
self, predefined_options: dict, not_options: Optional[list] = None
202+
self,
203+
predefined_options: dict[str, Any],
204+
not_options: Optional[list[str]] = None,
201205
):
202206
options = copy(predefined_options)
203207

@@ -208,11 +212,13 @@ def _get_unfiltered_pkg_options(
208212
if varname in not_options:
209213
continue
210214
v = self.dataset[varname].values[()]
211-
options[varname] = v
215+
options[str(varname)] = v
212216
return options
213217

214218
def _get_pkg_options(
215-
self, predefined_options: dict, not_options: Optional[list] = None
219+
self,
220+
predefined_options: dict[str, Any],
221+
not_options: Optional[list[str]] = None,
216222
):
217223
unfiltered_options = self._get_unfiltered_pkg_options(
218224
predefined_options, not_options=not_options

imod/mf6/evt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Any, Optional
22

33
import numpy as np
44

@@ -250,7 +250,9 @@ def _validate(self, schemata, **kwargs):
250250
return errors
251251

252252
def _get_pkg_options(
253-
self, predefined_options: dict, not_options: Optional[list] = None
253+
self,
254+
predefined_options: dict[str, Any],
255+
not_options: Optional[list[str]] = None,
254256
):
255257
options = super()._get_pkg_options(predefined_options, not_options=not_options)
256258
# Add amount of segments

imod/mf6/hfb.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from copy import deepcopy
66
from enum import Enum
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Dict, List, Optional, Self, Tuple
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Self, Tuple
99

1010
import cftime
1111
import numpy as np
@@ -186,7 +186,7 @@ def to_connected_cells_dataset(
186186
idomain: GridDataArray,
187187
grid: xu.Ugrid2d,
188188
edge_index: np.ndarray,
189-
edge_values: dict,
189+
edge_values: dict[str, Any],
190190
) -> xr.Dataset:
191191
"""
192192
Converts a cell edge grid with values defined on the edges to a dataset with the cell ids of the connected cells,
@@ -416,7 +416,7 @@ def _prepare_barrier_dataset_for_mf6_adapter(dataset: xr.Dataset) -> xr.Dataset:
416416

417417
def _snap_to_grid_and_aggregate(
418418
barrier_dataframe: GeoDataFrameType, grid2d: xu.Ugrid2d, vardict_agg: dict[str, str]
419-
) -> tuple[xu.UgridDataset, npt.NDArray]:
419+
) -> tuple[xu.UgridDataset, npt.NDArray[Any]]:
420420
"""
421421
Snap barrier dataframe to grid and aggregate multiple lines with a list of
422422
methods per variable.
@@ -481,7 +481,7 @@ def __init__(
481481
geometry: "gpd.GeoDataFrame",
482482
print_input: bool = False,
483483
) -> None:
484-
dict_dataset = {"print_input": print_input}
484+
dict_dataset: dict[str, Any] = {"print_input": print_input}
485485
super().__init__(dict_dataset)
486486
self.line_data = geometry
487487

@@ -859,7 +859,7 @@ def _get_variable_name(self) -> str:
859859
raise NotImplementedError
860860

861861
@abc.abstractmethod
862-
def _get_vertical_variables(self) -> list:
862+
def _get_vertical_variables(self) -> list[str]:
863863
raise NotImplementedError
864864

865865
def clip_box(
@@ -1143,7 +1143,7 @@ def _get_barrier_type(self):
11431143
def _get_variable_name(self) -> str:
11441144
return "hydraulic_characteristic"
11451145

1146-
def _get_vertical_variables(self) -> list:
1146+
def _get_vertical_variables(self) -> list[str]:
11471147
return []
11481148

11491149
def _compute_barrier_values(
@@ -1220,7 +1220,7 @@ def _get_barrier_type(self):
12201220
def _get_variable_name(self) -> str:
12211221
return "hydraulic_characteristic"
12221222

1223-
def _get_vertical_variables(self) -> list:
1223+
def _get_vertical_variables(self) -> list[str]:
12241224
return ["layer"]
12251225

12261226
def _compute_barrier_values(
@@ -1294,7 +1294,7 @@ def _get_barrier_type(self):
12941294
def _get_variable_name(self) -> str:
12951295
return "multiplier"
12961296

1297-
def _get_vertical_variables(self) -> list:
1297+
def _get_vertical_variables(self) -> list[str]:
12981298
return []
12991299

13001300
def _compute_barrier_values(
@@ -1373,7 +1373,7 @@ def _get_barrier_type(self):
13731373
def _get_variable_name(self) -> str:
13741374
return "multiplier"
13751375

1376-
def _get_vertical_variables(self) -> list:
1376+
def _get_vertical_variables(self) -> list[str]:
13771377
return ["layer"]
13781378

13791379
def _compute_barrier_values(
@@ -1467,7 +1467,7 @@ def _get_barrier_type(self):
14671467
def _get_variable_name(self) -> str:
14681468
return "resistance"
14691469

1470-
def _get_vertical_variables(self) -> list:
1470+
def _get_vertical_variables(self) -> list[str]:
14711471
return []
14721472

14731473
def _compute_barrier_values(
@@ -1539,7 +1539,7 @@ def _get_barrier_type(self):
15391539
def _get_variable_name(self) -> str:
15401540
return "resistance"
15411541

1542-
def _get_vertical_variables(self) -> list:
1542+
def _get_vertical_variables(self) -> list[str]:
15431543
return ["layer"]
15441544

15451545
def _compute_barrier_values(

imod/mf6/mf6_hfb_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from copy import deepcopy
2-
from typing import Union
2+
from typing import Any, Union
33

44
import numpy as np
55
import xarray as xr
@@ -115,7 +115,7 @@ def __init__(
115115
print_input: Union[bool, xr.DataArray] = False,
116116
validate: Union[bool, xr.DataArray] = True,
117117
):
118-
dict_dataset = {
118+
dict_dataset: dict[str, Any] = {
119119
"cell_id1": cell_id1,
120120
"cell_id2": cell_id2,
121121
"layer": layer,

imod/mf6/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77
import warnings
88
from pathlib import Path
9-
from typing import Any, List, Optional, Tuple, Union
9+
from typing import Any, List, Optional, Tuple, Union, cast
1010

1111
import cftime
1212
import jinja2
@@ -90,7 +90,7 @@ def _create_boundary_condition_clipped_boundary(
9090
original_model: Modflow6Model,
9191
clipped_model: Modflow6Model,
9292
state_for_boundary: Optional[GridDataArray],
93-
clip_box_args: tuple,
93+
clip_box_args: tuple[Any, ...],
9494
) -> Optional[StateType]:
9595
# Create temporary boundary condition for the original model boundary. This
9696
# is used later to see which boundaries can be ignored as they were already
@@ -146,7 +146,7 @@ def _create_boundary_condition_clipped_boundary(
146146
return bc_constant_pkg
147147

148148

149-
class Modflow6Model(collections.UserDict, IModel, abc.ABC):
149+
class Modflow6Model(collections.UserDict[str, Package], IModel, abc.ABC):
150150
_mandatory_packages: tuple[str, ...] = ()
151151
_init_schemata: SchemataDict = {}
152152
_model_id: Optional[str] = None
@@ -165,7 +165,7 @@ def __init__(self):
165165

166166
@standard_log_decorator()
167167
def _validate_options(
168-
self, schemata: dict, **kwargs
168+
self, schemata: dict[str, Any], **kwargs
169169
) -> dict[str, list[ValidationError]]:
170170
return validate_schemata_dict(schemata, self._options, **kwargs)
171171

@@ -569,7 +569,7 @@ def _write(
569569
write_context=pkg_write_context,
570570
)
571571
elif issubclass(type(pkg), imod.mf6.HorizontalFlowBarrierBase):
572-
mf6_hfb_ls.append(pkg)
572+
mf6_hfb_ls.append(cast(HorizontalFlowBarrierBase, pkg))
573573
else:
574574
pkg._write(
575575
pkgname=pkg_name,
@@ -656,7 +656,7 @@ def dump(
656656
if statusinfo.has_errors():
657657
raise ValidationError(statusinfo.to_string())
658658

659-
toml_content: dict = collections.defaultdict(dict)
659+
toml_content: dict[str, Any] = collections.defaultdict(dict)
660660

661661
for pkgname, pkg in self.items():
662662
pkg_path = pkg.to_file(
@@ -696,7 +696,7 @@ def from_file(cls, toml_path):
696696
return instance
697697

698698
@property
699-
def options(self) -> dict:
699+
def options(self) -> dict[str, Any]:
700700
if self._options is None:
701701
raise ValueError("Model id has not been set")
702702
return self._options

imod/mf6/out/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def open_cbc(
6363
"disu": disu.open_hds,
6464
}
6565

66-
_OPEN_CBC: Dict[str, Callable] = {
66+
_OPEN_CBC: Dict[str, Callable[..., Any]] = {
6767
"dis": dis.open_cbc,
6868
"disv": disv.open_cbc,
6969
"disu": disu.open_cbc,
@@ -76,7 +76,7 @@ def open_cbc(
7676
}
7777

7878

79-
def _get_function(d: Dict[str, Callable], key: str) -> Callable:
79+
def _get_function(d: Dict[str, Callable[..., Any]], key: str) -> Callable[..., Any]:
8080
try:
8181
func = d[key]
8282
except KeyError:

imod/mf6/out/cbc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def read_imeth6_budgets_dense(
266266
dtype: np.dtype,
267267
pos: int,
268268
size: int,
269-
shape: tuple,
269+
shape: Tuple[int, ...],
270270
return_variable: str,
271271
indices: np.ndarray | None,
272272
) -> FloatArray:

imod/mf6/out/dis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def open_dvs(
208208

209209

210210
def open_imeth1_budgets(
211-
cbc_path: FilePath, grb_content: dict, header_list: List[cbc.Imeth1Header]
211+
cbc_path: FilePath, grb_content: Dict[str, Any], header_list: List[cbc.Imeth1Header]
212212
) -> xr.DataArray:
213213
"""
214214
Open the data for an imeth==1 budget section. Data is read lazily per
@@ -252,7 +252,7 @@ def open_imeth1_budgets(
252252

253253
def open_imeth6_budgets(
254254
cbc_path: FilePath,
255-
grb_content: dict,
255+
grb_content: Dict[str, Any],
256256
header_list: List[cbc.Imeth6Header],
257257
return_variable: str = "budget",
258258
indices: np.ndarray | None = None,
@@ -374,7 +374,7 @@ def dis_indices(
374374

375375

376376
def dis_to_right_front_lower_indices(
377-
grb_content: dict,
377+
grb_content: Dict[str, Any],
378378
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
379379
"""
380380
Infer the indices to extract right, front, and lower face flows from the
@@ -442,7 +442,7 @@ def dis_extract_face_budgets(
442442

443443

444444
def dis_open_face_budgets(
445-
cbc_path: FilePath, grb_content: dict, header_list: List[cbc.Imeth1Header]
445+
cbc_path: FilePath, grb_content: Dict[str, Any], header_list: List[cbc.Imeth1Header]
446446
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
447447
"""
448448
Open the flow-ja-face, and extract right, front, and lower face flows.

0 commit comments

Comments
 (0)