Skip to content

Commit 7dd4bed

Browse files
committed
scale ICU cases and deaths for pyabc fittings
1 parent 02e1be1 commit 7dd4bed

2 files changed

Lines changed: 119 additions & 105 deletions

File tree

tutorial04.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _():
6161
from memilio.simulation import AgeGroup, Damping, LogLevel, set_log_level
6262
# deactivate the log level to avoid warning messages from adaptive step sizing
6363
set_log_level(LogLevel.Warning)
64-
return AgeGroup, LogLevel, osecir, set_log_level
64+
return AgeGroup, osecir
6565

6666

6767
@app.cell
@@ -152,17 +152,7 @@ def _(mo):
152152

153153

154154
@app.cell
155-
def _(
156-
LogLevel,
157-
contact_frequency,
158-
np,
159-
osecir,
160-
set_log_level,
161-
set_parameters,
162-
set_population,
163-
t0,
164-
tmax,
165-
):
155+
def _(contact_frequency, np, osecir, set_parameters, set_population, t0, tmax):
166156
def run_simulation(parameters, tmax = tmax):
167157
# Create model and set parameters
168158
influenza_model = osecir.Model(1)
@@ -269,16 +259,22 @@ def _(mo):
269259

270260

271261
@app.cell
272-
def _(np, osecir):
273-
def distance_function(simulation, real_data):
262+
def _(np, osecir, pyabc):
263+
def distance_ICU(simulation, real_data):
274264
real_ICU = real_data['Critical']
275-
real_Deaths = real_data['Deaths']
276265
sim = simulation['data']
277266
sim_ICU = sim[1+ int(osecir.InfectionState.InfectedCritical), :]
267+
return np.sum(np.abs(real_ICU - sim_ICU))
268+
269+
def distance_Deaths(simulation, real_data):
270+
deaths_per_critical = 1
271+
real_Deaths = real_data['Deaths']
272+
sim = simulation['data']
278273
sim_Death = sim[1 + int(osecir.InfectionState.Dead), :]
279-
return (np.sum(np.abs(real_ICU - sim_ICU)) + np.sum(np.abs(real_Deaths - sim_Death))) / 1000
274+
return 1/deaths_per_critical * np.sum(np.abs(real_Deaths - sim_Death))
280275

281-
return (distance_function,)
276+
distance = pyabc.AdaptiveAggregatedDistance([distance_ICU, distance_Deaths], adaptive=False, scale_function=pyabc.distance.mean)
277+
return (distance,)
282278

283279

284280
@app.cell(hide_code=True)
@@ -316,8 +312,8 @@ def _(example_results):
316312

317313

318314
@app.cell
319-
def _(distance_function, example_results, observation_data):
320-
distance_function(example_results, observation_data)
315+
def _(distance, example_results, observation_data):
316+
distance(example_results, observation_data)
321317
return
322318

323319

@@ -352,24 +348,16 @@ def _(mo):
352348

353349

354350
@app.cell
355-
def _(
356-
distance_function,
357-
observation_data,
358-
os,
359-
prior,
360-
pyabc,
361-
run_simulation,
362-
tempfile,
363-
):
364-
abc = pyabc.ABCSMC(run_simulation, prior, distance_function, population_size=400)
351+
def _(distance, observation_data, os, prior, pyabc, run_simulation, tempfile):
352+
abc = pyabc.ABCSMC(run_simulation, prior, distance, population_size=400, sampler=pyabc.sampler.SingleCoreSampler())
365353
db_path = "sqlite:///" + os.path.join(tempfile.gettempdir(), "tmp.db")
366354
abc.new(db_path, observation_data)
367355
return (abc,)
368356

369357

370358
@app.cell
371359
def _(abc):
372-
history = abc.run(minimum_epsilon=0.1)
360+
history = abc.run(minimum_epsilon=1e-03)
373361
return (history,)
374362

375363

0 commit comments

Comments
 (0)