Skip to content

Commit 806c3bf

Browse files
committed
Update requirements.txt and format code
1 parent 968da0c commit 806c3bf

2 files changed

Lines changed: 91 additions & 39 deletions

File tree

mimic/model_infer/infer_CRM_bayes.py

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,27 @@ def CRM_inf_func(y, t, p):
7575
N_safe = at.maximum(N, eps)
7676
R_safe = at.maximum(R, eps)
7777

78-
7978
# Species growth equation (dN)
8079
growth_term = at.dot(c, w * R_safe) # Matrix multiplication as tensor
8180
dN = (N_safe / tau) * (growth_term - m) # Species growth equation
8281

8382
# Resource consumption equation (dR)
8483
consumption_term = at.dot(N_safe, c) # Matrix multiplication as tensor
85-
dR = (1 / (r * K)) * (K - R_safe) * R_safe - consumption_term * R_safe # Resource consumption equation
86-
84+
dR = (1 / (r * K)) * (K - R_safe) * R_safe - \
85+
consumption_term * R_safe # Resource consumption equation
8786

88-
# If species population or resource concentration is smaller than eps *and* decreasing,
89-
# then set rate of change to zero to prevent negative values in the next step
87+
# If species population or resource concentration is smaller than eps *and* decreasing,
88+
# then set rate of change to zero to prevent negative values in the next
89+
# step
9090
dN = at.where((N < eps) & (dN < 0), 0.0, dN)
9191
dR = at.where((R < eps) & (dR < 0), 0.0, dR)
9292

93-
9493
# Flatten array to 1D for concatenation
9594
dN_flat = at.flatten(dN)
9695
dR_flat = at.flatten(dR)
9796

9897
# Combine dN and dR into a single 1D array
99-
#derivatives = [dN[0], dN[1], dR[0], dR[1]] # 1D array
98+
# derivatives = [dN[0], dN[1], dR[0], dR[1]] # 1D array
10099
derivatives = at.concatenate([dN_flat, dR_flat]) # Concatenate species and
101100
# resource derivatives
102101

@@ -443,7 +442,7 @@ def run_inference(self) -> None:
443442
n_states = nsp + nr
444443
n_theta = 2 + (2 * nsp) + (3 * nr) + (nsp * nr)
445444

446-
yobs_species_only = yobs[:, :nsp]
445+
yobs_species_only = yobs[:, :nsp]
447446

448447
# Define the DifferentialEquation model
449448
crm_model = DifferentialEquation(
@@ -459,32 +458,53 @@ def run_inference(self) -> None:
459458
with bayes_model:
460459
# Priors for unknown model parameters
461460

462-
463-
sigma = pm.HalfNormal('sigma', sigma = 0.1, shape=(1,)) # Same sigma for all responses
464-
461+
sigma = pm.HalfNormal(
462+
'sigma', sigma=0.1, shape=(
463+
1,)) # Same sigma for all responses
465464

466465
# Conditionally define parameters based on whether priors are
467466
# provided
468467

469468
# For tau parameter
470469
if prior_tau_mean is not None and prior_tau_sigma is not None:
471-
tau_hat = pm.TruncatedNormal('tau_hat',mu=prior_tau_mean,sigma=prior_tau_sigma,lower=0,shape=(nsp,))
470+
tau_hat = pm.TruncatedNormal(
471+
'tau_hat',
472+
mu=prior_tau_mean,
473+
sigma=prior_tau_sigma,
474+
lower=0,
475+
shape=(
476+
nsp,
477+
))
472478
print("tau_hat is inferred")
473479
else:
474480
tau_hat = at.as_tensor_variable(tau)
475481
print("tau_hat is fixed")
476482

477483
# For w parameter
478484
if prior_w_mean is not None and prior_w_sigma is not None:
479-
w_hat = pm.TruncatedNormal('w_hat',mu=prior_w_mean,sigma=prior_w_sigma,lower=0,shape=(nr,))
485+
w_hat = pm.TruncatedNormal(
486+
'w_hat',
487+
mu=prior_w_mean,
488+
sigma=prior_w_sigma,
489+
lower=0,
490+
shape=(
491+
nr,
492+
))
480493
print("w_hat is inferred")
481494
else:
482495
w_hat = at.as_tensor_variable(w)
483496
print("w_hat is fixed")
484497

485498
# For c parameter
486499
if prior_c_mean is not None and prior_c_sigma is not None:
487-
c_hat_vals = pm.TruncatedNormal('c_hat_vals',mu=prior_c_mean,sigma=prior_c_sigma,lower=0,shape=(nsp,nr))
500+
c_hat_vals = pm.TruncatedNormal(
501+
'c_hat_vals',
502+
mu=prior_c_mean,
503+
sigma=prior_c_sigma,
504+
lower=0,
505+
shape=(
506+
nsp,
507+
nr))
488508
c_hat = pm.Deterministic('c_hat', c_hat_vals)
489509
print("c_hat is inferred")
490510
else:
@@ -493,23 +513,44 @@ def run_inference(self) -> None:
493513

494514
# For m parameter
495515
if prior_m_mean is not None and prior_m_sigma is not None:
496-
m_hat = pm.TruncatedNormal('m_hat',mu=prior_m_mean,sigma=prior_m_sigma,lower=0,shape=(nsp, ))
516+
m_hat = pm.TruncatedNormal(
517+
'm_hat',
518+
mu=prior_m_mean,
519+
sigma=prior_m_sigma,
520+
lower=0,
521+
shape=(
522+
nsp,
523+
))
497524
print("m_hat is inferred")
498525
else:
499526
m_hat = at.as_tensor_variable(m)
500527
print("m_hat is fixed")
501528

502529
# For r parameter
503530
if prior_r_mean is not None and prior_r_sigma is not None:
504-
r_hat = pm.TruncatedNormal('r_hat',mu=prior_r_mean,sigma=prior_r_sigma,lower=0,shape=(nr,))
531+
r_hat = pm.TruncatedNormal(
532+
'r_hat',
533+
mu=prior_r_mean,
534+
sigma=prior_r_sigma,
535+
lower=0,
536+
shape=(
537+
nr,
538+
))
505539
print("r_hat is inferred")
506540
else:
507541
r_hat = at.as_tensor_variable(r)
508542
print("r_hat is fixed")
509543

510544
# For K parameter
511545
if prior_K_mean is not None and prior_K_sigma is not None:
512-
K_hat = pm.TruncatedNormal('K_hat', mu=prior_K_mean, sigma=prior_K_sigma,lower=0,shape=(nr,))
546+
K_hat = pm.TruncatedNormal(
547+
'K_hat',
548+
mu=prior_K_mean,
549+
sigma=prior_K_sigma,
550+
lower=0,
551+
shape=(
552+
nr,
553+
))
513554
print("K_hat is inferred")
514555
else:
515556
K_hat = at.as_tensor_variable(K)
@@ -518,23 +559,23 @@ def run_inference(self) -> None:
518559
# Flatten to read into CRM_inf_func as a single vector
519560
nsp_tensor = at.as_tensor_variable([nsp])
520561
nr_tensor = at.as_tensor_variable([nr])
521-
522562

523-
theta = at.concatenate([nsp_tensor, nr_tensor, tau_hat, w_hat, c_hat.flatten(), m_hat, r_hat, K_hat])
563+
theta = at.concatenate(
564+
[nsp_tensor, nr_tensor, tau_hat, w_hat, c_hat.flatten(), m_hat, r_hat, K_hat])
524565

525566
print("=== RSME ===")
526567
try:
527568
y0 = np.full(n_states, 10.0)
528569
test_curves = crm_model(y0=y0, theta=theta)
529570
test_pred = test_curves.eval()
530-
571+
531572
rmse = np.sqrt(np.mean((test_pred - yobs)**2))
532573
print(f"RMSE: {rmse:.6f}")
533574
print(f"Data scale: {np.mean(yobs):.3f}")
534575
print(f"Model scale: {np.mean(test_pred):.3f}")
535576
print(f"First few predictions: {test_pred[:3]}")
536577
print(f"First few observations: {yobs[:3]}")
537-
578+
538579
except Exception as e:
539580
print(f"MODEL FAILED: {e}")
540581

@@ -547,8 +588,8 @@ def run_inference(self) -> None:
547588
# Initial conditions for the ODE
548589
# initial_conditions = np.concatenate([(yobs[0,:nsp]), np.array([10.0, 10.0])])
549590
# Initial species and resource populations
550-
#y0 = np.concatenate([np.ones(nsp), np.ones(nr)])
551-
#y0 = yobs[0, :]
591+
# y0 = np.concatenate([np.ones(nsp), np.ones(nr)])
592+
# y0 = yobs[0, :]
552593
y0 = np.full(nsp + nr, 10.0)
553594
print(f"Initial conditions (y0): {y0}")
554595
# y0 = np.array([10.0, 10.0, 10.0, 10.0])
@@ -558,9 +599,10 @@ def run_inference(self) -> None:
558599
crm_curves = crm_model(y0=y0, theta=theta)
559600

560601
# Define the log-normal likelihood with log-transformed observed data
561-
#Y = pm.Lognormal( "Y",mu=at.log(crm_curves),sigma=sigma, observed=yobs)
602+
# Y = pm.Lognormal( "Y",mu=at.log(crm_curves),sigma=sigma, observed=yobs)
562603
Y = pm.Normal("Y", mu=crm_curves, sigma=sigma, observed=yobs)
563-
#Y = pm.Normal("Y", mu=crm_curves[:, :nsp], sigma=sigma, observed=yobs_species_only) # species only
604+
# Y = pm.Normal("Y", mu=crm_curves[:, :nsp], sigma=sigma,
605+
# observed=yobs_species_only) # species only
564606

565607
# For debugging:
566608
# print if `debug` is set to 'high' or 'low'
@@ -580,7 +622,12 @@ def run_inference(self) -> None:
580622
print("Shape of crm_curves:", crm_curves.shape.eval())
581623

582624
# Sample the posterior
583-
idata = pm.sample(draws=draws,tune=tune,chains=chains,cores=cores,progressbar=True)
625+
idata = pm.sample(
626+
draws=draws,
627+
tune=tune,
628+
chains=chains,
629+
cores=cores,
630+
progressbar=True)
584631

585632
return idata
586633

@@ -602,7 +649,7 @@ def plot_posterior(self, idata, true_params=None):
602649
for i, param in enumerate(param_names):
603650
if param in available_vars:
604651
print(f"Plotting posterior for {param}")
605-
652+
606653
# Extract the posterior mean for the parameter (as before)
607654
if param == "c_hat":
608655
# Special handling for c_hat due to its shape
@@ -613,35 +660,40 @@ def plot_posterior(self, idata, true_params=None):
613660
param_np = idata.posterior[param].mean(
614661
dim=('chain', 'draw')).values.flatten()
615662
ref_val = param_np.tolist()
616-
663+
617664
# Plot the posterior distribution (original behavior)
618665
az.plot_posterior(
619666
idata,
620667
var_names=[param],
621668
ref_val=ref_val
622669
)
623-
670+
624671
# Add true value as a vertical line if available
625672
true_param_name = true_param_names[i]
626673
if true_params and true_param_name in true_params:
627674
true_val = true_params[true_param_name]
628-
675+
629676
# Flatten the true values to match the subplot structure
630677
true_vals = true_val.flatten()
631-
678+
632679
# Get current axes
633680
axes = plt.gcf().get_axes()
634681
for j, ax in enumerate(axes):
635682
if j < len(true_vals):
636-
ax.axvline(true_vals[j], color='red', linestyle='--', linewidth=2,
637-
label=f'True value')
683+
ax.axvline(
684+
true_vals[j],
685+
color='red',
686+
linestyle='--',
687+
linewidth=2,
688+
label=f'True value')
638689
ax.legend()
639-
690+
640691
print(f"Added true value line for {param}: {true_val}")
641-
692+
642693
# Save the plot
643694
plt.savefig(f"plot-posterior-{param}.pdf")
644695
plt.show()
645696
plt.close()
646697
else:
647-
print(f"Parameter {param} not found in posterior samples, skipping plot.")
698+
print(
699+
f"Parameter {param} not found in posterior samples, skipping plot.")

mimic/model_infer/infer_VAR_bayes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _run_inference_large(self, **kwargs) -> None:
279279
c2 = pm.InverseGamma("c2", 2, 8)
280280
tau = pm.HalfCauchy("tau", beta=tau0)
281281
lam = pm.HalfCauchy("lam", beta=1, shape=(ndim, ndim))
282-
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam * \
282+
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam *
283283
at.sqrt(c2 / (c2 + tau**2 * lam**2)), shape=(ndim, ndim))
284284

285285
# If noise covariance is provided, use it as a prior
@@ -438,14 +438,14 @@ def _run_inference_large_xs(self, **kwargs) -> None:
438438
c2_A = pm.InverseGamma("c2_A", 2, 1)
439439
tau_A = pm.HalfCauchy("tau_A", beta=tau0_A)
440440
lam_A = pm.HalfCauchy("lam_A", beta=1, shape=(nX, nX))
441-
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A * \
441+
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A *
442442
at.sqrt(c2_A / (c2_A + tau_A**2 * lam_A**2)), shape=(nX, nX))
443443

444444
tau0_B = (DB0 / (DB - DB0)) * 0.1 / np.sqrt(N)
445445
c2_B = pm.InverseGamma("c2_B", 2, 1)
446446
tau_B = pm.HalfCauchy("tau_B", beta=tau0_B)
447447
lam_B = pm.HalfCauchy("lam_B", beta=1, shape=(nS, nX))
448-
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B * \
448+
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B *
449449
at.sqrt(c2_B / (c2_B + tau_B**2 * lam_B**2)), shape=(nS, nX))
450450

451451
if noise_cov_prior is not None:

0 commit comments

Comments
 (0)