Skip to content

Commit 20cee90

Browse files
mhhegerHenrZu
andauthored
1139 introduce improved GNN surrogate models (#1378)
Co-authored-by: HenrZu <69154294+HenrZu@users.noreply.github.com>
1 parent a0c1752 commit 20cee90

10 files changed

Lines changed: 2909 additions & 16 deletions

File tree

docs/source/python/m-surrogate.rst

Lines changed: 315 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
.. include:: ../literature.rst
2+
13
MEmilio Surrogate Model
24
========================
35

46
MEmilio Surrogate Model contains machine learning based surrogate models that make predictions based on the MEmilio simulation models.
57
Currently there are only surrogate models for ODE-type models. The simulations of these models are used for data generation.
68
The goal is to create a powerful tool that predicts the infection dynamics faster than a simulation of an expert model,
79
e.g., a metapopulation or agent-based model while still having acceptable errors with respect to the original simulations.
8-
10+
911
The package can be found in `pycode/memilio-surrogatemodel <https://github.com/SciCompMod/memilio/blob/main/pycode/memilio-surrogatemodel>`_.
1012

11-
For more details, we refer to: Schmidt A, Zunker H, Heinlein A, Kühn MJ. (2025). *Graph Neural Network Surrogates to leverage Mechanistic Expert Knowledge towards Reliable and Immediate Pandemic Response*. Submitted for publication. `arXiv:2411.06500 <https://arxiv.org/abs/2411.06500>`_
13+
For more details, we refer to:
14+
15+
|Graph_Neural_Network_Surrogates|
1216

1317
Dependencies
1418
------------
@@ -30,17 +34,18 @@ Usage
3034
The package currently provides the following modules:
3135

3236
- `models`: models for different tasks
33-
Currently we have the following models:
34-
- `ode_secir_simple`: A simple model allowing for asymptomatic as well as symptomatic infection states not stratified by age groups.
35-
- `ode_secir_groups`: A model allowing for asymptomatic as well as symptomatic infection states stratified by age groups and including one damping.
3637

37-
Each model folder contains the following files:
38-
- `data_generation`: Data generation from expert model simulation outputs.
39-
- `model`: Training and evaluation of the model.
40-
- `network_architectures`: Contains multiple network architectures.
41-
- `grid_search`: Utilities for hyperparameter optimization.
42-
- `hyperparameter_tuning`: Scripts for tuning model hyperparameters.
38+
Currently we have the following models:
39+
40+
- `ode_secir_simple`: A simple model allowing for asymptomatic as well as symptomatic infection states not stratified by age groups.
41+
- `ode_secir_groups`: A model allowing for asymptomatic as well as symptomatic infection states stratified by age groups and including one damping.
4342

43+
Each model folder contains the following files:
44+
45+
- `data_generation`: Data generation from expert model simulation outputs.
46+
- `model`: Training and evaluation of the model.
47+
- `network_architectures`: Contains multiple network architectures.
48+
- `grid_search`: Utilities for hyperparameter optimization.
4449

4550
- `tests`: This file contains all tests.
4651

@@ -146,7 +151,7 @@ Example usage:
146151
Hyperparameter Optimization
147152
~~~~~~~~~~~~~~~~~~~~~~~~~~~
148153

149-
The `grid_search.py` and `hyperparameter_tuning.py` modules provide tools for systematic hyperparameter optimization:
154+
The `grid_search.py` module provides tools for systematic hyperparameter optimization:
150155

151156
1. **Cross-validation**:
152157

@@ -163,7 +168,302 @@ The `grid_search.py` and `hyperparameter_tuning.py` modules provide tools for sy
163168
- Visualization of hyperparameter importance
164169
- Selection of optimal model configurations
165170

166-
SECIR Groups Model
167-
------------------
168171

169-
To be added...
172+
173+
Graph Neural Network (GNN) Surrogate Models
174+
--------------------------------------------
175+
176+
The Graph Neural Network (GNN) module provides advanced surrogate models that leverage spatial connectivity and age-stratified epidemiological dynamics. These models are designed for immediate and reliable pandemic response by combining mechanistic expert knowledge with machine learning efficiency.
177+
178+
Overview and Scientific Foundation
179+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
180+
181+
The GNN surrogate models are based on the research presented in:
182+
183+
|Graph_Neural_Network_Surrogates|
184+
185+
The implementation leverages the mechanistic ODE-SECIR model (see :doc:`ODE-SECIR documentation <cpp/models/osecir>`) as the underlying expert model, using Python bindings to the C++ backend for efficient simulation during data generation.
186+
187+
Module Structure
188+
~~~~~~~~~~~~~~~~
189+
190+
The GNN module is located in `pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_ and consists of:
191+
192+
- **data_generation.py**: Generates training and evaluation data by simulating epidemiological scenarios with the mechanistic SECIR model
193+
- **network_architectures.py**: Defines various GNN architectures (ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv) with configurable depth and channels
194+
- **evaluate_and_train.py**: Implements training and evaluation pipelines for GNN models
195+
- **grid_search.py**: Provides hyperparameter optimization through systematic grid search
196+
- **GNN_utils.py**: Contains utility functions for data preprocessing, graph construction, and population data handling
197+
198+
Data Generation
199+
~~~~~~~~~~~~~~~
200+
201+
The data generation process in ``data_generation.py`` creates graph-structured training data through mechanistic simulations. Use ``generate_data`` to run multiple simulations and persist a pickle with inputs, labels, damping info, and contact matrices:
202+
203+
.. code-block:: python
204+
205+
from memilio.surrogatemodel.GNN import data_generation
206+
import memilio.simulation as mio
207+
208+
data = data_generation.generate_data(
209+
num_runs=5,
210+
data_dir="/path/to/memilio/data",
211+
output_path="/tmp/generated_datasets",
212+
input_width=5,
213+
label_width=30,
214+
start_date=mio.Date(2020, 10, 1),
215+
end_date=mio.Date(2021, 10, 31),
216+
mobility_file="commuter_mobility.txt", # or commuter_mobility_2022.txt
217+
transform=True,
218+
save_data=True
219+
)
220+
221+
**Data Generation Workflow:**
222+
223+
1. **Parameter Sampling**: Randomly sample epidemiological parameters (transmission rates, incubation periods, recovery rates) from predefined distributions to create diverse scenarios.
224+
225+
2. **Compartment Initialization**: Initialize epidemic compartments for each age group in each region based on realistic demographic data. Compartments are initialized using shared base factors.
226+
227+
3. **Mobility Graph Construction**: Build a spatial graph where:
228+
229+
- Nodes represent geographic regions (e.g., German counties)
230+
- Edges represent mobility connections with weights from commuting data
231+
- Node features include age-stratified population sizes
232+
233+
4. **Contact Matrix Configuration**: Load and configure baseline contact matrices for different location types (home, school, work, other) stratified by age groups.
234+
235+
5. **Damping Application**: Apply time-varying dampings to contact matrices to simulate NPIs:
236+
237+
- Multiple damping periods with random start days
238+
- Location-specific damping factors (e.g., stronger school closures, moderate workplace restrictions)
239+
- Realistic parameter ranges based on observed intervention strengths
240+
241+
6. **Simulation Execution**: Run the mechanistic ODE-SECIR model using MEmilio's C++ backend through Python bindings to generate the dataset.
242+
243+
7. **Data Processing**: Transform simulation results into graph-structured format:
244+
245+
- Extract compartment time series for each node (region) and age group
246+
- Apply logarithmic transformation for numerical stability
247+
- Store graph topology, node features, and temporal sequences
248+
249+
Network Architectures
250+
~~~~~~~~~~~~~~~~~~~~~
251+
252+
The ``network_architectures.py`` module provides flexible GNN model construction for supported layer types (ARMAConv, GCSConv, GATConv, GCNConv, APPNPConv).
253+
254+
.. code-block:: python
255+
256+
from memilio.surrogatemodel.GNN import network_architectures
257+
258+
model = network_architectures.get_model(
259+
layer_type="GCNConv",
260+
num_layers=3,
261+
num_channels=64,
262+
activation="relu",
263+
num_output=48 # outputs per node
264+
)
265+
266+
267+
Training and Evaluation
268+
~~~~~~~~~~~~~~~~~~~~~~~
269+
270+
The ``evaluate_and_train.py`` module provides the training functionality:
271+
272+
.. code-block:: python
273+
274+
from tensorflow.keras.losses import MeanAbsolutePercentageError
275+
from tensorflow.keras.optimizers import Adam
276+
from memilio.surrogatemodel.GNN import evaluate_and_train, network_architectures
277+
278+
dataset = evaluate_and_train.load_gnn_dataset(
279+
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic5.pickle",
280+
"/path/to/memilio/data/Germany/mobility",
281+
number_of_nodes=400
282+
)
283+
284+
model = network_architectures.get_model(
285+
layer_type="GCNConv",
286+
num_layers=3,
287+
num_channels=32,
288+
activation="relu",
289+
num_output=48
290+
)
291+
292+
results = evaluate_and_train.train_and_evaluate(
293+
data=dataset,
294+
batch_size=32,
295+
epochs=50,
296+
model=model,
297+
loss_fn=MeanAbsolutePercentageError(),
298+
optimizer=Adam(learning_rate=0.001),
299+
es_patience=10,
300+
save_dir="/tmp/model_results",
301+
save_name="gnn_model"
302+
)
303+
304+
**Training Features:**
305+
306+
1. **Mini-batch Training**: Graph batching for efficient training on large datasets
307+
2. **Custom Loss Functions**: MSE, MAE, MAPE, or custom compartment-weighted losses
308+
3. **Early Stopping**: Monitors validation loss to prevent overfitting
309+
4. **Save Best Weights**: Saves best model weights based on validation performance
310+
311+
**Evaluation Metrics:**
312+
313+
- **Mean Absolute Error (MAE)**: Average absolute prediction error per compartment
314+
- **Mean Absolute Percentage Error (MAPE)**: Mean absolute error as percentage
315+
- **R² Score**: Coefficient of determination for prediction quality
316+
317+
**Data Splitting:**
318+
319+
- **Training Set (70%)**: For model parameter optimization
320+
- **Validation Set (15%)**: For hyperparameter tuning and early stopping
321+
- **Test Set (15%)**: For final performance evaluation
322+
323+
Hyperparameter Optimization
324+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
325+
326+
The ``grid_search.py`` module enables systematic exploration of hyperparameter space:
327+
328+
.. code-block:: python
329+
330+
from pathlib import Path
331+
from memilio.surrogatemodel.GNN import grid_search, evaluate_and_train
332+
333+
data = evaluate_and_train.create_dataset(
334+
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic5.pickle",
335+
"/path/to/memilio/data/Germany/mobility",
336+
number_of_nodes=400
337+
)
338+
339+
parameter_grid = grid_search.generate_parameter_grid(
340+
layer_types=["GCNConv", "GATConv"],
341+
num_layers_options=[2, 3],
342+
num_channels_options=[16, 32],
343+
activation_functions=["relu", "elu"]
344+
)
345+
346+
grid_search.perform_grid_search(
347+
data=data,
348+
parameter_grid=parameter_grid,
349+
save_dir=str(Path("/tmp/grid_results")),
350+
batch_size=32,
351+
max_epochs=50,
352+
es_patience=10,
353+
learning_rate=0.001
354+
)
355+
356+
Utility Functions
357+
~~~~~~~~~~~~~~~~~
358+
359+
The ``GNN_utils.py`` module provides essential helper functions used throughout the GNN workflow:
360+
361+
**Data Preprocessing:**
362+
363+
.. code-block:: python
364+
365+
from memilio.surrogatemodel.GNN import GNN_utils
366+
367+
# Remove confirmed compartments (simplify model)
368+
simplified_data = GNN_utils.remove_confirmed_compartments(
369+
dataset_entries=dataset,
370+
num_groups=6
371+
)
372+
373+
# Apply logarithmic scaling
374+
scaled_inputs, scaled_labels = GNN_utils.scale_data(
375+
data=dataset,
376+
transform=True
377+
)
378+
379+
**Graph Construction:**
380+
381+
.. code-block:: python
382+
383+
# Create mobility graph from commuting data
384+
graph = GNN_utils.create_mobility_graph(
385+
mobility_dir='path/to/mobility',
386+
num_regions=401, # German counties
387+
county_ids=county_list,
388+
models=models_per_region # SECIR models for each region
389+
)
390+
391+
# Get baseline contact matrix
392+
contact_matrix = GNN_utils.get_baseline_contact_matrix(
393+
data_dir='path/to/contact_matrices'
394+
)
395+
396+
Practical Usage Example
397+
~~~~~~~~~~~~~~~~~~~~~~~
398+
399+
Here is a complete example workflow from data generation to model evaluation:
400+
401+
.. code-block:: python
402+
403+
import memilio.simulation as mio
404+
from tensorflow.keras.losses import MeanAbsolutePercentageError
405+
from tensorflow.keras.optimizers import Adam
406+
from memilio.surrogatemodel.GNN import (
407+
data_generation,
408+
network_architectures,
409+
evaluate_and_train
410+
)
411+
412+
# Step 1: Generate and save training data
413+
data_generation.generate_data(
414+
num_runs=100,
415+
data_dir="/path/to/memilio/data",
416+
output_path="/tmp/generated_datasets",
417+
input_width=5,
418+
label_width=30,
419+
start_date=mio.Date(2020, 10, 1),
420+
end_date=mio.Date(2021, 10, 31),
421+
save_data=True,
422+
mobility_file="commuter_mobility.txt"
423+
)
424+
425+
# Step 2: Load dataset and build model
426+
dataset = evaluate_and_train.load_gnn_dataset(
427+
"/tmp/generated_datasets/GNN_data_30days_3dampings_classic100.pickle",
428+
"/path/to/memilio/data/Germany/mobility",
429+
number_of_nodes=400
430+
)
431+
432+
model = network_architectures.get_model(
433+
layer_type="GCNConv",
434+
num_layers=4,
435+
num_channels=128,
436+
activation="relu",
437+
num_output=48
438+
)
439+
440+
# Step 3: Train and evaluate
441+
results = evaluate_and_train.train_and_evaluate(
442+
data=dataset,
443+
batch_size=32,
444+
epochs=100,
445+
model=model,
446+
loss_fn=MeanAbsolutePercentageError(),
447+
optimizer=Adam(learning_rate=0.001),
448+
es_patience=20,
449+
save_dir="/tmp/model_results",
450+
save_name="gnn_weights_best"
451+
)
452+
453+
**GPU Acceleration:**
454+
455+
- TensorFlow automatically uses GPU when available
456+
- Spektral layers are optimized for GPU execution
457+
- Training time can be heavily reduced with appropriate GPU hardware
458+
459+
Additional Resources
460+
~~~~~~~~~~~~~~~~~~~~
461+
462+
**Code and Examples:**
463+
464+
- `GNN Module <https://github.com/SciCompMod/memilio/tree/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN>`_
465+
- `GNN README <https://github.com/SciCompMod/memilio/blob/main/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/README.md>`_
466+
467+
**Related Documentation:**
468+
469+
- :doc:`MEmilio Simulation Package <m-simulation>`

0 commit comments

Comments
 (0)