Skip to content

Commit 00ad934

Browse files
committed
include AdaptiveAggregatedDistance in exercises as well
1 parent 616f211 commit 00ad934

4 files changed

Lines changed: 202 additions & 131 deletions

File tree

exercises/exercise04.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,13 @@ def _(mo):
161161

162162
@app.cell
163163
def _(contact_frequency, np, osecir, t0, tmax):
164-
def run_simulation(parameters, tmax = tmax):
164+
def run_simulation(parameters, tmax=tmax):
165165
# Create model and set parameters
166166
influenza_model = osecir.Model(1)
167167
# TODO: We need to run the set_population and set_parameters functions:
168-
169-
170-
influenza_model.parameters.ContactPatterns.cont_freq_mat[0].baseline = np.ones((1,1)) * contact_frequency
168+
169+
influenza_model.parameters.ContactPatterns.cont_freq_mat[0].baseline = np.ones(
170+
(1, 1)) * contact_frequency
171171
# Check that the parameters are meaningful, i.e., no negative dwelling times
172172
influenza_model.check_constraints()
173173

@@ -263,7 +263,7 @@ def _(mo):
263263
mo.md(r"""
264264
## Defining the objective function
265265
266-
The last step before running the fitting is the defintion of an objective (or distance) function. Here, we are given data for the ICU cases and deaths per day. Thus an obvious choice for the distance function is to calculate the difference between the simulated and the observed numbers per day and adding them up.
266+
The last step before running the fitting is the defintion of an objective (or distance) function. Here, we are given data for the ICU cases and deaths per day. Thus an obvious choice for the distance function is to calculate the difference between the simulated and the observed numbers per day and then add them up. As they live on different scales, we define the distance on ICU cases and deaths seperately and use `pyabc.AdaptiveAggregatedDistance` to scale them and aggregate.
267267
268268
We need a function that takes a `data` dictionary provided by our `run_simulation` function and an `observation` dictionary, given by our input data. As in the plotting sections of the previous tutorials, we have to access the correct columns of our simulation results by indexing as there is no name provided.
269269
""")
@@ -272,15 +272,23 @@ def _(mo):
272272

273273
app._unparsable_cell(
274274
r"""
275-
def distance_function(simulation, real_data):
275+
def distance_ICU(simulation, real_data):
276276
# TODO: Extract the correct columns from the data dataframe:
277277
real_ICU =
278-
real_Deaths =
279278
sim = simulation['data']
280279
sim_ICU = sim[1+ int(osecir.InfectionState.InfectedCritical), :]
280+
# TODO: Add the code to calculate the absolute differences between data and simulation
281+
return np.sum()
282+
283+
def distance_Deaths(simulation, real_data):
284+
# TODO: Extract the correct columns from the data dataframe:
285+
real_Deaths =
286+
sim = simulation['data']
281287
sim_Death = sim[1 + int(osecir.InfectionState.Dead), :]
282288
# TODO: Add the code to calculate the absolute differences between data and simulation
283-
return (np.sum( ) + np.sum( )) / 1000
289+
return np.sum()
290+
291+
distance = pyabc.AdaptiveAggregatedDistance([distance_ICU, distance_Deaths], adaptive=False, scale_function=pyabc.distance.median)
284292
""",
285293
name="_"
286294
)
@@ -325,7 +333,7 @@ def _(example_results):
325333

326334
@app.cell
327335
def _():
328-
# TODO: Test the distance function on the example_results and the observation_data
336+
# TODO: Test the distance function on the example_results and the observation_data. You need to specify a time t=-1 to avoid an error later.
329337

330338
return
331339

@@ -362,15 +370,15 @@ def _(mo):
362370

363371
@app.cell
364372
def _(
365-
distance_function,
373+
distance,
366374
observation_data,
367375
os,
368376
prior,
369377
pyabc,
370378
run_simulation,
371379
tempfile,
372380
):
373-
abc = pyabc.ABCSMC(run_simulation, prior, distance_function, population_size=400)
381+
abc = pyabc.ABCSMC(run_simulation, prior, distance, population_size=400)
374382
db_path = "sqlite:///" + os.path.join(tempfile.gettempdir(), "tmp.db")
375383
abc.new(db_path, observation_data)
376384
return
@@ -424,14 +432,16 @@ def _(mo):
424432
@app.cell
425433
def _(np, osecir):
426434
def plot_critical_data(sum_stat, weight, ax, **kwargs):
427-
ax.plot(range(0, 31), sum_stat['data'][1+ int(osecir.InfectionState.InfectedCritical), :], color = 'grey', alpha = 0.1)
435+
ax.plot(range(0, 31), sum_stat['data'][1 + int(
436+
osecir.InfectionState.InfectedCritical), :], color='grey', alpha=0.1)
428437

429438
def plot_critical_mean(sum_stats, weights, ax, **kwargs):
430439
weights = np.array(weights)
431440
weights /= weights.sum()
432-
data = np.array([sum_stat['data'][1+ int(osecir.InfectionState.InfectedCritical), :] for sum_stat in sum_stats])
441+
data = np.array([sum_stat['data'][1 + int(osecir.InfectionState.InfectedCritical), :]
442+
for sum_stat in sum_stats])
433443
mean = (data * weights.reshape((-1, 1))).sum(axis=0)
434-
ax.plot(range(0, 31), mean, color='C2', label = "Simulation mean")
444+
ax.plot(range(0, 31), mean, color='C2', label="Simulation mean")
435445

436446
return plot_critical_data, plot_critical_mean
437447

@@ -446,9 +456,11 @@ def _(
446456
pyabc,
447457
):
448458
fig, ax = plt.subplots()
449-
ax = pyabc.visualization.plot_data_callback(history, plot_critical_data, plot_critical_mean, ax=ax)
459+
ax = pyabc.visualization.plot_data_callback(
460+
history, plot_critical_data, plot_critical_mean, ax=ax)
450461

451-
plt.scatter(range(0, 31), observation_data["Critical"], color = "C1", label = "Data", zorder = 2)
462+
plt.scatter(
463+
range(0, 31), observation_data["Critical"], color="C1", label="Data", zorder=2)
452464
plt.xlabel("Time")
453465
plt.ylabel("# Cases")
454466
plt.title("Number of ICU patients")
@@ -467,9 +479,11 @@ def _():
467479
@app.cell
468480
def _(history, observation_data, plot_dead_data, plot_dead_mean, plt, pyabc):
469481
fig_dead, ax_dead = plt.subplots()
470-
ax_dead = pyabc.visualization.plot_data_callback(history, plot_dead_data, plot_dead_mean, ax=ax_dead)
482+
ax_dead = pyabc.visualization.plot_data_callback(
483+
history, plot_dead_data, plot_dead_mean, ax=ax_dead)
471484

472-
plt.scatter(range(0, 31), observation_data["Deaths"], color = "C1", label = "Data", zorder = 2)
485+
plt.scatter(
486+
range(0, 31), observation_data["Deaths"], color="C1", label="Data", zorder=2)
473487
plt.xlabel("Time")
474488
plt.ylabel("# Cases")
475489
plt.title("Cumulative number of dead patients")
@@ -504,34 +518,40 @@ def _():
504518

505519
@app.cell
506520
def _(df, run_simulation_with_params):
507-
predictions = df.apply(run_simulation_with_params, axis = 1)
521+
predictions = df.apply(run_simulation_with_params, axis=1)
508522
return (predictions,)
509523

510524

511525
@app.cell
512526
def _(np, observation_data, osecir, plt, predictions, w):
513-
fig_prog, ax_prog = plt.subplots(1, 2, figsize = (12,4))
514-
ax_prog[0].scatter(range(0, 31), observation_data["Deaths"], color = "C1", label = "Data")
527+
fig_prog, ax_prog = plt.subplots(1, 2, figsize=(12, 4))
528+
ax_prog[0].scatter(
529+
range(0, 31), observation_data["Deaths"], color="C1", label="Data")
515530
for prediction in predictions:
516-
ax_prog[0].plot(range(61), prediction["data"][1+int(osecir.InfectionState.Dead)], color = "grey", alpha = 0.2)
531+
ax_prog[0].plot(range(61), prediction["data"]
532+
[1+int(osecir.InfectionState.Dead)], color="grey", alpha=0.2)
517533
weights = np.array(w)
518534
weights /= weights.sum()
519-
data = np.array([prediction['data'][1+ int(osecir.InfectionState.Dead), :] for prediction in predictions])
535+
data = np.array([prediction['data'][1 + int(osecir.InfectionState.Dead), :]
536+
for prediction in predictions])
520537
mean = (data * weights.reshape((-1, 1))).sum(axis=0)
521-
ax_prog[0].plot(range(0, 61), mean, color='C2', label = "Simulation mean")
538+
ax_prog[0].plot(range(0, 61), mean, color='C2', label="Simulation mean")
522539
ax_prog[0].legend()
523540
ax_prog[0].set_title("Number of Deaths")
524541
ax_prog[0].set_ylabel("Number")
525542
ax_prog[0].set_xlabel("Time [d]")
526543

527-
ax_prog[1].scatter(range(0, 31), observation_data["Critical"], color = "C1", label = "Data")
544+
ax_prog[1].scatter(
545+
range(0, 31), observation_data["Critical"], color="C1", label="Data")
528546
for prediction in predictions:
529-
ax_prog[1].plot(range(61), prediction["data"][1+int(osecir.InfectionState.InfectedCritical)], color = "grey", alpha = 0.2)
547+
ax_prog[1].plot(range(61), prediction["data"]
548+
[1+int(osecir.InfectionState.InfectedCritical)], color="grey", alpha=0.2)
530549
weights = np.array(w)
531550
weights /= weights.sum()
532-
data = np.array([prediction['data'][1+ int(osecir.InfectionState.InfectedCritical), :] for prediction in predictions])
551+
data = np.array([prediction['data'][1 + int(osecir.InfectionState.InfectedCritical), :]
552+
for prediction in predictions])
533553
mean = (data * weights.reshape((-1, 1))).sum(axis=0)
534-
ax_prog[1].plot(range(0, 61), mean, color='C2', label = "Simulation mean")
554+
ax_prog[1].plot(range(0, 61), mean, color='C2', label="Simulation mean")
535555
ax_prog[1].legend()
536556
ax_prog[1].set_title("Number of ICU cases")
537557
ax_prog[1].set_ylabel("Number")

0 commit comments

Comments
 (0)