@@ -161,13 +161,13 @@ def _(mo):
161161
162162@app .cell
163163def _ (contact_frequency , np , osecir , t0 , tmax ):
164- def run_simulation (parameters , tmax = tmax ):
164+ def run_simulation (parameters , tmax = tmax ):
165165 # Create model and set parameters
166166 influenza_model = osecir .Model (1 )
167167 # TODO: We need to run the set_population and set_parameters functions:
168-
169-
170- influenza_model . parameters . ContactPatterns . cont_freq_mat [ 0 ]. baseline = np . ones (( 1 , 1 )) * contact_frequency
168+
169+ influenza_model . parameters . ContactPatterns . cont_freq_mat [ 0 ]. baseline = np . ones (
170+ ( 1 , 1 )) * contact_frequency
171171 # Check that the parameters are meaningful, i.e., no negative dwelling times
172172 influenza_model .check_constraints ()
173173
@@ -263,7 +263,7 @@ def _(mo):
263263 mo .md (r"""
264264 ## Defining the objective function
265265
266- The last step before running the fitting is the defintion of an objective (or distance) function. Here, we are given data for the ICU cases and deaths per day. Thus an obvious choice for the distance function is to calculate the difference between the simulated and the observed numbers per day and adding them up.
266+ The last step before running the fitting is the defintion of an objective (or distance) function. Here, we are given data for the ICU cases and deaths per day. Thus an obvious choice for the distance function is to calculate the difference between the simulated and the observed numbers per day and then add them up. As they live on different scales, we define the distance on ICU cases and deaths seperately and use `pyabc.AdaptiveAggregatedDistance` to scale them and aggregate .
267267
268268 We need a function that takes a `data` dictionary provided by our `run_simulation` function and an `observation` dictionary, given by our input data. As in the plotting sections of the previous tutorials, we have to access the correct columns of our simulation results by indexing as there is no name provided.
269269 """ )
@@ -272,15 +272,23 @@ def _(mo):
272272
273273app ._unparsable_cell (
274274 r"""
275- def distance_function (simulation, real_data):
275+ def distance_ICU (simulation, real_data):
276276 # TODO: Extract the correct columns from the data dataframe:
277277 real_ICU =
278- real_Deaths =
279278 sim = simulation['data']
280279 sim_ICU = sim[1+ int(osecir.InfectionState.InfectedCritical), :]
280+ # TODO: Add the code to calculate the absolute differences between data and simulation
281+ return np.sum()
282+
283+ def distance_Deaths(simulation, real_data):
284+ # TODO: Extract the correct columns from the data dataframe:
285+ real_Deaths =
286+ sim = simulation['data']
281287 sim_Death = sim[1 + int(osecir.InfectionState.Dead), :]
282288 # TODO: Add the code to calculate the absolute differences between data and simulation
283- return (np.sum( ) + np.sum( )) / 1000
289+ return np.sum()
290+
291+ distance = pyabc.AdaptiveAggregatedDistance([distance_ICU, distance_Deaths], adaptive=False, scale_function=pyabc.distance.median)
284292 """ ,
285293 name = "_"
286294)
@@ -325,7 +333,7 @@ def _(example_results):
325333
326334@app .cell
327335def _ ():
328- # TODO: Test the distance function on the example_results and the observation_data
336+ # TODO: Test the distance function on the example_results and the observation_data. You need to specify a time t=-1 to avoid an error later.
329337
330338 return
331339
@@ -362,15 +370,15 @@ def _(mo):
362370
363371@app .cell
364372def _ (
365- distance_function ,
373+ distance ,
366374 observation_data ,
367375 os ,
368376 prior ,
369377 pyabc ,
370378 run_simulation ,
371379 tempfile ,
372380):
373- abc = pyabc .ABCSMC (run_simulation , prior , distance_function , population_size = 400 )
381+ abc = pyabc .ABCSMC (run_simulation , prior , distance , population_size = 400 )
374382 db_path = "sqlite:///" + os .path .join (tempfile .gettempdir (), "tmp.db" )
375383 abc .new (db_path , observation_data )
376384 return
@@ -424,14 +432,16 @@ def _(mo):
424432@app .cell
425433def _ (np , osecir ):
426434 def plot_critical_data (sum_stat , weight , ax , ** kwargs ):
427- ax .plot (range (0 , 31 ), sum_stat ['data' ][1 + int (osecir .InfectionState .InfectedCritical ), :], color = 'grey' , alpha = 0.1 )
435+ ax .plot (range (0 , 31 ), sum_stat ['data' ][1 + int (
436+ osecir .InfectionState .InfectedCritical ), :], color = 'grey' , alpha = 0.1 )
428437
429438 def plot_critical_mean (sum_stats , weights , ax , ** kwargs ):
430439 weights = np .array (weights )
431440 weights /= weights .sum ()
432- data = np .array ([sum_stat ['data' ][1 + int (osecir .InfectionState .InfectedCritical ), :] for sum_stat in sum_stats ])
441+ data = np .array ([sum_stat ['data' ][1 + int (osecir .InfectionState .InfectedCritical ), :]
442+ for sum_stat in sum_stats ])
433443 mean = (data * weights .reshape ((- 1 , 1 ))).sum (axis = 0 )
434- ax .plot (range (0 , 31 ), mean , color = 'C2' , label = "Simulation mean" )
444+ ax .plot (range (0 , 31 ), mean , color = 'C2' , label = "Simulation mean" )
435445
436446 return plot_critical_data , plot_critical_mean
437447
@@ -446,9 +456,11 @@ def _(
446456 pyabc ,
447457):
448458 fig , ax = plt .subplots ()
449- ax = pyabc .visualization .plot_data_callback (history , plot_critical_data , plot_critical_mean , ax = ax )
459+ ax = pyabc .visualization .plot_data_callback (
460+ history , plot_critical_data , plot_critical_mean , ax = ax )
450461
451- plt .scatter (range (0 , 31 ), observation_data ["Critical" ], color = "C1" , label = "Data" , zorder = 2 )
462+ plt .scatter (
463+ range (0 , 31 ), observation_data ["Critical" ], color = "C1" , label = "Data" , zorder = 2 )
452464 plt .xlabel ("Time" )
453465 plt .ylabel ("# Cases" )
454466 plt .title ("Number of ICU patients" )
@@ -467,9 +479,11 @@ def _():
467479@app .cell
468480def _ (history , observation_data , plot_dead_data , plot_dead_mean , plt , pyabc ):
469481 fig_dead , ax_dead = plt .subplots ()
470- ax_dead = pyabc .visualization .plot_data_callback (history , plot_dead_data , plot_dead_mean , ax = ax_dead )
482+ ax_dead = pyabc .visualization .plot_data_callback (
483+ history , plot_dead_data , plot_dead_mean , ax = ax_dead )
471484
472- plt .scatter (range (0 , 31 ), observation_data ["Deaths" ], color = "C1" , label = "Data" , zorder = 2 )
485+ plt .scatter (
486+ range (0 , 31 ), observation_data ["Deaths" ], color = "C1" , label = "Data" , zorder = 2 )
473487 plt .xlabel ("Time" )
474488 plt .ylabel ("# Cases" )
475489 plt .title ("Cumulative number of dead patients" )
@@ -504,34 +518,40 @@ def _():
504518
505519@app .cell
506520def _ (df , run_simulation_with_params ):
507- predictions = df .apply (run_simulation_with_params , axis = 1 )
521+ predictions = df .apply (run_simulation_with_params , axis = 1 )
508522 return (predictions ,)
509523
510524
511525@app .cell
512526def _ (np , observation_data , osecir , plt , predictions , w ):
513- fig_prog , ax_prog = plt .subplots (1 , 2 , figsize = (12 ,4 ))
514- ax_prog [0 ].scatter (range (0 , 31 ), observation_data ["Deaths" ], color = "C1" , label = "Data" )
527+ fig_prog , ax_prog = plt .subplots (1 , 2 , figsize = (12 , 4 ))
528+ ax_prog [0 ].scatter (
529+ range (0 , 31 ), observation_data ["Deaths" ], color = "C1" , label = "Data" )
515530 for prediction in predictions :
516- ax_prog [0 ].plot (range (61 ), prediction ["data" ][1 + int (osecir .InfectionState .Dead )], color = "grey" , alpha = 0.2 )
531+ ax_prog [0 ].plot (range (61 ), prediction ["data" ]
532+ [1 + int (osecir .InfectionState .Dead )], color = "grey" , alpha = 0.2 )
517533 weights = np .array (w )
518534 weights /= weights .sum ()
519- data = np .array ([prediction ['data' ][1 + int (osecir .InfectionState .Dead ), :] for prediction in predictions ])
535+ data = np .array ([prediction ['data' ][1 + int (osecir .InfectionState .Dead ), :]
536+ for prediction in predictions ])
520537 mean = (data * weights .reshape ((- 1 , 1 ))).sum (axis = 0 )
521- ax_prog [0 ].plot (range (0 , 61 ), mean , color = 'C2' , label = "Simulation mean" )
538+ ax_prog [0 ].plot (range (0 , 61 ), mean , color = 'C2' , label = "Simulation mean" )
522539 ax_prog [0 ].legend ()
523540 ax_prog [0 ].set_title ("Number of Deaths" )
524541 ax_prog [0 ].set_ylabel ("Number" )
525542 ax_prog [0 ].set_xlabel ("Time [d]" )
526543
527- ax_prog [1 ].scatter (range (0 , 31 ), observation_data ["Critical" ], color = "C1" , label = "Data" )
544+ ax_prog [1 ].scatter (
545+ range (0 , 31 ), observation_data ["Critical" ], color = "C1" , label = "Data" )
528546 for prediction in predictions :
529- ax_prog [1 ].plot (range (61 ), prediction ["data" ][1 + int (osecir .InfectionState .InfectedCritical )], color = "grey" , alpha = 0.2 )
547+ ax_prog [1 ].plot (range (61 ), prediction ["data" ]
548+ [1 + int (osecir .InfectionState .InfectedCritical )], color = "grey" , alpha = 0.2 )
530549 weights = np .array (w )
531550 weights /= weights .sum ()
532- data = np .array ([prediction ['data' ][1 + int (osecir .InfectionState .InfectedCritical ), :] for prediction in predictions ])
551+ data = np .array ([prediction ['data' ][1 + int (osecir .InfectionState .InfectedCritical ), :]
552+ for prediction in predictions ])
533553 mean = (data * weights .reshape ((- 1 , 1 ))).sum (axis = 0 )
534- ax_prog [1 ].plot (range (0 , 61 ), mean , color = 'C2' , label = "Simulation mean" )
554+ ax_prog [1 ].plot (range (0 , 61 ), mean , color = 'C2' , label = "Simulation mean" )
535555 ax_prog [1 ].legend ()
536556 ax_prog [1 ].set_title ("Number of ICU cases" )
537557 ax_prog [1 ].set_ylabel ("Number" )
0 commit comments