-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathmodelsplitter.py
More file actions
313 lines (262 loc) Β· 12.2 KB
/
modelsplitter.py
File metadata and controls
313 lines (262 loc) Β· 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
from typing import List, cast
from imod.common.interfaces.iagnosticpackage import IAgnosticPackage
from imod.common.interfaces.imodel import IModel
from imod.common.interfaces.ipackage import IPackage
from imod.common.utilities.clip import clip_by_grid
from imod.common.utilities.partitioninfo import PartitionInfo
from imod.mf6.boundary_condition import BoundaryCondition
from imod.mf6.ims import Solution
from imod.mf6.model_gwf import GroundwaterFlowModel
from imod.mf6.model_gwt import GroundwaterTransportModel
from imod.mf6.ssm import SourceSinkMixing
from imod.mf6.validation_settings import ValidationSettings, trim_time_dimension
from imod.typing import GridDataArray
from imod.typing.grid import get_non_spatial_dimension_names
class ModelSplitter:
# pkg_id to variable mapping. For boundary packages we need to check if the
# package has any active cells in the partition We do this based on the
# variable that defines the active cells for that package This mapping is
# used to get that variable name based on the package id If a package is not
# in this mapping, we assume it does not need special treatment
_pkg_id_to_var_mapping = {
"chd": "head",
"cnc": "concentration",
"evt": "rate",
"dis": "idomain",
"drn": "elevation",
"ghb": "head",
"src": "rate",
"rch": "rate",
"riv": "conductance",
"uzf": "infiltration_rate",
"wel": "rate",
}
# Some boundary packages don't have a variable that defines active cells
# For these packages we skip the check if the package has any active cells in the partition
_pkg_id_skip_active_domain_check = ["ssm", "lak"]
def __init__(self, partition_info: List[PartitionInfo]) -> None:
self.partition_info = partition_info
# Initialize mapping from original model names to partitioned models
self._model_to_partitioned_model: dict[str, dict[str, IModel]] = {}
# Initialize mapping from partition IDs to models
self._partition_id_to_models: dict[int, dict[str, IModel]] = {}
for submodel_partition_info in self.partition_info:
self._partition_id_to_models[submodel_partition_info.id] = {}
def split(
self, model_name: str, model: IModel, ignore_time: bool = True
) -> dict[str, IModel]:
"""
Split a model into multiple partitioned models based on partition
information.
Each partition creates a separate submodel containing:
- All non-boundary packages from the original model, clipped to the
partition's domain
- Boundary packages that have active cells within the partition's
domain, clipped accordingly
- IAgnosticPackages are excluded if they contain no data after clipping
Parameters
----------
model_name : str
Base name of the input model. Partition IDs are appended to create
unique names for each submodel (e.g., "model_0", "model_1").
model : IModel
The input model instance to partition.
Returns
-------
dict[str, IModel]
A mapping from generated submodel names to the corresponding partitioned
model instances, each clipped to its respective active domain.
"""
modelclass = type(model)
partitioned_models = {}
model_to_partition = {}
# Create empty model for each partition
for submodel_partition_info in self.partition_info:
new_model_name = f"{model_name}_{submodel_partition_info.id}"
new_model = modelclass(**model.options)
partitioned_models[new_model_name] = new_model
model_to_partition[new_model_name] = submodel_partition_info
self._partition_id_to_models[submodel_partition_info.id][new_model_name] = (
new_model
)
self._model_to_partitioned_model[model_name] = partitioned_models
# Add packages to models
for pkg_name, package in model.items():
# Determine active domain for boundary packages
active_package_domain = (
self._get_package_domain(package, ignore_time=ignore_time)
if isinstance(package, BoundaryCondition)
else None
)
# Add package to each partitioned model
for new_model_name, new_model in partitioned_models.items():
partition_info = model_to_partition[new_model_name]
has_overlap = self._has_package_data_in_domain(
package, active_package_domain, partition_info
)
if not has_overlap:
continue
# Slice and add the package to the partitioned model
sliced_package = clip_by_grid(package, partition_info.active_domain)
# For agnostic packages, if the sliced package has no data, do
# not add it to the model
if isinstance(package, IAgnosticPackage) and sliced_package.is_empty(
ignore_time=ignore_time
):
continue
# Add package to model if it has data
new_model[pkg_name] = sliced_package
return partitioned_models
def update_dependent_packages(self) -> None:
"""
Update packages that reference other models after partitioning.
This method performs two updates:
1. Updates buoyancy packages in flow models to reference the correct
partitioned transport model names.
2. Recreates Source Sink Mixing (SSM) packages in transport models
based on the partitioned flow model data.
"""
# Update buoyancy packages
for _, models in self._partition_id_to_models.items():
flow_model = self._get_flow_model(models)
transport_model_names = self._get_transport_model_names(models)
flow_model._update_buoyancy_package(transport_model_names)
# Update ssm packages
for _, models in self._partition_id_to_models.items():
flow_model = self._get_flow_model(models)
transport_models = self._get_transport_models(models)
for transport_model in transport_models:
ssm_key = transport_model._get_pkgkey("ssm")
if ssm_key is None:
continue
old_ssm_package = transport_model.pop(ssm_key)
state_variable_name = old_ssm_package.dataset[
"auxiliary_variable_name"
].values[0]
ssm_package = SourceSinkMixing.from_flow_model(
flow_model, state_variable_name, is_split=True
)
if ssm_package is not None:
transport_model[ssm_key] = ssm_package
def update_solutions(
self, original_model_name_to_solution: dict[str, Solution]
) -> None:
"""
Update solution objects to reference partitioned models instead of
original models.
For each original model that was split:
1. Removes the original model reference from its solution (This was a
deepcopy of the original solution and thus references the original
model).
2. Adds all partitioned submodel references to the same solution
This ensures that the Solution objects correctly reference the new
partitioned model names after splitting.
"""
for model_name, new_models in self._model_to_partitioned_model.items():
solution = original_model_name_to_solution[model_name]
solution._remove_model_from_solution(model_name)
for new_model_name, new_model in new_models.items():
solution._add_model_to_solution(new_model_name)
def _is_package_to_skip(self, package: IPackage) -> bool:
"""
Determine if a package should be skipped in grid checks.
Package can be skipped in the following cases:
- Package is not a BoundaryCondition (non-boundary packages are always
included)
- Package is an IAgnosticPackage (overlap check deferred until after
slicing)
- Package is in _pkg_id_skip_active_domain_check (e.g., SSM, LAK
packages)
"""
pkg_id = package.pkg_id
if not isinstance(package, BoundaryCondition):
return True
return isinstance(package, IAgnosticPackage) or (
pkg_id in self._pkg_id_skip_active_domain_check
)
def _get_package_domain(
self, package: IPackage, ignore_time: bool = True
) -> GridDataArray | None:
"""
Extract the active domain of a boundary condition package.
For boundary condition packages, this method identifies which cells
contain active boundary data by checking the package's defining variable
(e.g., "head" for CHD, "rate" for WEL). Non-boundary packages return
None.
The active domain is determined by:
1. Retrieving the variable that defines active cells (from
_pkg_id_to_var_mapping)
2. Removing non-spatial dimensions
3. Creating a boolean mask where non-null values indicate active cells
Returns None if the package is not a boundary condition or should be
skipped.
"""
if self._is_package_to_skip(package):
return None
pkg_id = package.pkg_id
da = cast(GridDataArray, package.dataset[self._pkg_id_to_var_mapping[pkg_id]])
# Drop non-spatial dimensions if present
if "time" in da.dims:
trim_settings = ValidationSettings(ignore_time=ignore_time)
da = trim_time_dimension(da, validation_context=trim_settings)
dims_to_be_removed = get_non_spatial_dimension_names(da)
da = cast(GridDataArray, da.drop_vars(dims_to_be_removed))
active_package_domain = da.notnull()
return active_package_domain
def _has_package_data_in_domain(
self,
package: IPackage,
active_package_domain: GridDataArray,
partition_info: PartitionInfo,
) -> bool:
"""
Check if a package has any active data within a partition's domain.
For boundary condition packages, this method determines whether the
package should be included in a partitioned model by checking if its
active cells overlap with the partition's active domain.
The method returns True in the following cases:
- Package should be skipped in grid checks.
- Package has at least one active cell overlapping with the partition
domain
"""
if self._is_package_to_skip(package):
return True
overlap_grid = active_package_domain & partition_info.active_domain.astype(bool)
has_overlap = cast(bool, overlap_grid.any())
return has_overlap
def _get_flow_model(self, models: dict[str, IModel]) -> GroundwaterFlowModel:
flow_model = next(
(
model
for model_name, model in models.items()
if isinstance(model, GroundwaterFlowModel)
),
None,
)
if flow_model is None:
raise ValueError(
"Could not find a groundwater flow model for updating the buoyancy package."
)
return flow_model
def _get_transport_model_names(self, models: dict[str, IModel]) -> List[str]:
return [
model_name
for model_name, model in models.items()
if isinstance(model, GroundwaterTransportModel)
]
def _get_transport_models(
self, models: dict[str, IModel]
) -> List[GroundwaterTransportModel]:
return [
model
for model_name, model in models.items()
if isinstance(model, GroundwaterTransportModel)
]
def _get_model_to_mpi_rank_mapping(self) -> dict[str, int]:
mpi_mapping = {}
for id, model_dict in self._partition_id_to_models.items():
mpi_rank = self.partition_info[id].mpi_rank
if mpi_rank > 0:
for model_name in model_dict.keys():
mpi_mapping[model_name] = mpi_rank
return mpi_mapping