You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
MEmilio Surrogate Model contains machine learning based surrogate models that make predictions based on the MEmilio simulation models.
5
7
Currently there are only surrogate models for ODE-type models. The simulations of these models are used for data generation.
6
8
The goal is to create a powerful tool that predicts the infection dynamics faster than a simulation of an expert model,
7
9
e.g., a metapopulation or agent-based model while still having acceptable errors with respect to the original simulations.
8
-
10
+
9
11
The package can be found in `pycode/memilio-surrogatemodel <https://github.com/SciCompMod/memilio/blob/main/pycode/memilio-surrogatemodel>`_.
10
12
11
-
For more details, we refer to: Schmidt A, Zunker H, Heinlein A, Kühn MJ. (2025). *Graph Neural Network Surrogates to leverage Mechanistic Expert Knowledge towards Reliable and Immediate Pandemic Response*. Submitted for publication. `arXiv:2411.06500 <https://arxiv.org/abs/2411.06500>`_
13
+
For more details, we refer to:
14
+
15
+
|Graph_Neural_Network_Surrogates|
12
16
13
17
Dependencies
14
18
------------
@@ -30,17 +34,18 @@ Usage
30
34
The package currently provides the following modules:
31
35
32
36
- `models`: models for different tasks
33
-
Currently we have the following models:
34
-
- `ode_secir_simple`: A simple model allowing for asymptomatic as well as symptomatic infection states not stratified by age groups.
35
-
- `ode_secir_groups`: A model allowing for asymptomatic as well as symptomatic infection states stratified by age groups and including one damping.
36
37
37
-
Each model folder contains the following files:
38
-
- `data_generation`: Data generation from expert model simulation outputs.
- `grid_search`: Utilities for hyperparameter optimization.
44
49
45
50
- `tests`: This file contains all tests.
46
51
@@ -146,7 +151,7 @@ Example usage:
146
151
Hyperparameter Optimization
147
152
~~~~~~~~~~~~~~~~~~~~~~~~~~~
148
153
149
-
The `grid_search.py` and `hyperparameter_tuning.py` modules provide tools for systematic hyperparameter optimization:
154
+
The `grid_search.py` module provides tools for systematic hyperparameter optimization:
150
155
151
156
1. **Cross-validation**:
152
157
@@ -163,7 +168,302 @@ The `grid_search.py` and `hyperparameter_tuning.py` modules provide tools for sy
163
168
- Visualization of hyperparameter importance
164
169
- Selection of optimal model configurations
165
170
166
-
SECIR Groups Model
167
-
------------------
168
171
169
-
To be added...
172
+
173
+
Graph Neural Network (GNN) Surrogate Models
174
+
--------------------------------------------
175
+
176
+
The Graph Neural Network (GNN) module provides advanced surrogate models that leverage spatial connectivity and age-stratified epidemiological dynamics. These models are designed for immediate and reliable pandemic response by combining mechanistic expert knowledge with machine learning efficiency.
177
+
178
+
Overview and Scientific Foundation
179
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
180
+
181
+
The GNN surrogate models are based on the research presented in:
182
+
183
+
|Graph_Neural_Network_Surrogates|
184
+
185
+
The implementation leverages the mechanistic ODE-SECIR model (see :doc:`ODE-SECIR documentation <cpp/models/osecir>`) as the underlying expert model, using Python bindings to the C++ backend for efficient simulation during data generation.
186
+
187
+
Module Structure
188
+
~~~~~~~~~~~~~~~~
189
+
190
+
The GNN module is located in `pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_ and consists of:
191
+
192
+
- **data_generation.py**: Generates training and evaluation data by simulating epidemiological scenarios with the mechanistic SECIR model
193
+
- **network_architectures.py**: Defines various GNN architectures (ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv) with configurable depth and channels
194
+
- **evaluate_and_train.py**: Implements training and evaluation pipelines for GNN models
195
+
- **grid_search.py**: Provides hyperparameter optimization through systematic grid search
196
+
- **GNN_utils.py**: Contains utility functions for data preprocessing, graph construction, and population data handling
197
+
198
+
Data Generation
199
+
~~~~~~~~~~~~~~~
200
+
201
+
The data generation process in ``data_generation.py`` creates graph-structured training data through mechanistic simulations. Use ``generate_data`` to run multiple simulations and persist a pickle with inputs, labels, damping info, and contact matrices:
202
+
203
+
.. code-block:: python
204
+
205
+
from memilio.surrogatemodel.GNNimport data_generation
206
+
import memilio.simulation as mio
207
+
208
+
data = data_generation.generate_data(
209
+
num_runs=5,
210
+
data_dir="/path/to/memilio/data",
211
+
output_path="/tmp/generated_datasets",
212
+
input_width=5,
213
+
label_width=30,
214
+
start_date=mio.Date(2020, 10, 1),
215
+
end_date=mio.Date(2021, 10, 31),
216
+
mobility_file="commuter_mobility.txt", # or commuter_mobility_2022.txt
217
+
transform=True,
218
+
save_data=True
219
+
)
220
+
221
+
**Data Generation Workflow:**
222
+
223
+
1. **Parameter Sampling**: Randomly sample epidemiological parameters (transmission rates, incubation periods, recovery rates) from predefined distributions to create diverse scenarios.
224
+
225
+
2. **Compartment Initialization**: Initialize epidemic compartments for each age group in each region based on realistic demographic data. Compartments are initialized using shared base factors.
226
+
227
+
3. **Mobility Graph Construction**: Build a spatial graph where:
228
+
229
+
- Nodes represent geographic regions (e.g., German counties)
230
+
- Edges represent mobility connections with weights from commuting data
231
+
- Node features include age-stratified population sizes
232
+
233
+
4. **Contact Matrix Configuration**: Load and configure baseline contact matrices for different location types (home, school, work, other) stratified by age groups.
234
+
235
+
5. **Damping Application**: Apply time-varying dampings to contact matrices to simulate NPIs:
- Realistic parameter ranges based on observed intervention strengths
240
+
241
+
6. **Simulation Execution**: Run the mechanistic ODE-SECIR model using MEmilio's C++ backend through Python bindings to generate the dataset.
242
+
243
+
7. **Data Processing**: Transform simulation results into graph-structured format:
244
+
245
+
- Extract compartment time series for each node (region) and age group
246
+
- Apply logarithmic transformation for numerical stability
247
+
- Store graph topology, node features, and temporal sequences
248
+
249
+
Network Architectures
250
+
~~~~~~~~~~~~~~~~~~~~~
251
+
252
+
The ``network_architectures.py`` module provides flexible GNN model construction for supported layer types (ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv).
253
+
254
+
.. code-block:: python
255
+
256
+
from memilio.surrogatemodel.GNNimport network_architectures
257
+
258
+
model = network_architectures.get_model(
259
+
layer_type="GCNConv",
260
+
num_layers=3,
261
+
num_channels=64,
262
+
activation="relu",
263
+
num_output=48# outputs per node
264
+
)
265
+
266
+
267
+
Training and Evaluation
268
+
~~~~~~~~~~~~~~~~~~~~~~~
269
+
270
+
The ``evaluate_and_train.py`` module provides the training functionality:
271
+
272
+
.. code-block:: python
273
+
274
+
from tensorflow.keras.losses import MeanAbsolutePercentageError
275
+
from tensorflow.keras.optimizers import Adam
276
+
from memilio.surrogatemodel.GNNimport evaluate_and_train, network_architectures
0 commit comments