Skip to content

Commit f5da976

Browse files
committed
fixed_CRM_bug
1 parent 589eb10 commit f5da976

4 files changed

Lines changed: 1666 additions & 161 deletions

File tree

examples/CRM/examples-bayes-CRM.ipynb

Lines changed: 1582 additions & 72 deletions
Large diffs are not rendered by default.

examples/CRM/examples-sim-CRM.ipynb

Lines changed: 29 additions & 15 deletions
Large diffs are not rendered by default.

mimic/model_infer/infer_CRM_bayes.py

Lines changed: 54 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def plot_growth_curves(data, ax=None):
5151

5252
def CRM_inf_func(y, t, p):
5353
# Unpack parameters from the vector p
54-
nr = p[0].astype("int32") # Number of resources
55-
nsp = p[1].astype("int32") # Number of species
54+
nsp = p[0].astype("int32") # Number of resources
55+
nr = p[1].astype("int32") # Number of species
5656
tau = p[2:2 + nsp] # Species time scales
5757
w = p[2 + nsp:2 + nsp + nr] # Resource quality
5858
# Flattened resource preferences
@@ -69,18 +69,23 @@ def CRM_inf_func(y, t, p):
6969
N = y[:nsp] # Species populations
7070
R = y[nsp:] # Resource availability
7171

72+
7273
# Species growth equation (dN)
7374
growth_term = at.dot(c, w * R) # Matrix multiplication as tensor
7475
dN = (N / tau) * (growth_term - m) # Species growth equation
7576

7677
# Resource consumption equation (dR)
7778
consumption_term = at.dot(N, c) # Matrix multiplication as tensor
78-
dR = (1 / (r * K)) * (K - R) * R - consumption_term * \
79-
R # Resource consumption equation
79+
dR = (1 / (r * K)) * (K - R) * R - consumption_term * R # Resource consumption equation
80+
81+
82+
# Flatten array to 1D for concatenation
83+
dN_flat = at.flatten(dN)
84+
dR_flat = at.flatten(dR)
8085

8186
# Combine dN and dR into a single 1D array
82-
derivatives = [dN[0], dN[1], dR[0], dR[1]] # 1D array
83-
# derivatives = np.concatenate([dN, dR]) # Concatenate species and
87+
#derivatives = [dN[0], dN[1], dR[0], dR[1]] # 1D array
88+
derivatives = at.concatenate([dN_flat, dR_flat]) # Concatenate species and
8489
# resource derivatives
8590

8691
# Return the derivatives for both species and resources as a single array
@@ -426,6 +431,8 @@ def run_inference(self) -> None:
426431
n_states = nsp + nr
427432
n_theta = 2 + (2 * nsp) + (3 * nr) + (nsp * nr)
428433

434+
yobs_species_only = yobs[:, :nsp]
435+
429436
# Define the DifferentialEquation model
430437
crm_model = DifferentialEquation(
431438
func=CRM_inf_func, # The ODE function
@@ -440,53 +447,30 @@ def run_inference(self) -> None:
440447
with bayes_model:
441448
# Priors for unknown model parameters
442449

443-
sigma = pm.HalfNormal(
444-
'sigma', sigma = 0.1, shape=(
445-
1,)) # Same sigma for all responses
450+
sigma = pm.HalfNormal('sigma', sigma = 0.1, shape=(1,)) # Same sigma for all responses
446451

447452
# Conditionally define parameters based on whether priors are
448453
# provided
449454

450455
# For tau parameter
451456
if prior_tau_mean is not None and prior_tau_sigma is not None:
452-
tau_hat = pm.TruncatedNormal(
453-
'tau_hat',
454-
mu=prior_tau_mean,
455-
sigma=prior_tau_sigma,
456-
lower=0.1,
457-
shape=(
458-
nsp,
459-
))
457+
tau_hat = pm.TruncatedNormal('tau_hat',mu=prior_tau_mean,sigma=prior_tau_sigma,lower=0,shape=(nsp,))
460458
print("tau_hat is inferred")
461459
else:
462460
tau_hat = at.as_tensor_variable(tau)
463461
print("tau_hat is fixed")
464462

465463
# For w parameter
466464
if prior_w_mean is not None and prior_w_sigma is not None:
467-
w_hat = pm.TruncatedNormal(
468-
'w_hat',
469-
mu=prior_w_mean,
470-
sigma=prior_w_sigma,
471-
lower=0.1,
472-
shape=(
473-
nr,
474-
))
465+
w_hat = pm.TruncatedNormal('w_hat',mu=prior_w_mean,sigma=prior_w_sigma,lower=0,shape=(nr,))
475466
print("w_hat is inferred")
476467
else:
477468
w_hat = at.as_tensor_variable(w)
478469
print("w_hat is fixed")
479470

480471
# For c parameter
481472
if prior_c_mean is not None and prior_c_sigma is not None:
482-
c_hat_vals = pm.TruncatedNormal(
483-
'c_hat_vals',
484-
mu=prior_c_mean,
485-
sigma=prior_c_sigma,
486-
lower=0,
487-
shape=(
488-
nsp,
489-
nr))
473+
c_hat_vals = pm.TruncatedNormal('c_hat_vals',mu=prior_c_mean,sigma=prior_c_sigma,lower=0,shape=(nsp,nr))
490474
c_hat = pm.Deterministic('c_hat', c_hat_vals)
491475
print("c_hat is inferred")
492476
else:
@@ -495,60 +479,64 @@ def run_inference(self) -> None:
495479

496480
# For m parameter
497481
if prior_m_mean is not None and prior_m_sigma is not None:
498-
m_hat = pm.TruncatedNormal(
499-
'm_hat',
500-
mu=prior_m_mean,
501-
sigma=prior_m_sigma,
502-
lower=0.1,
503-
shape=(
504-
nsp,
505-
))
482+
m_hat = pm.TruncatedNormal('m_hat',mu=prior_m_mean,sigma=prior_m_sigma,lower=0,shape=(nsp, ))
506483
print("m_hat is inferred")
507484
else:
508485
m_hat = at.as_tensor_variable(m)
509486
print("m_hat is fixed")
510487

511488
# For r parameter
512489
if prior_r_mean is not None and prior_r_sigma is not None:
513-
r_hat = pm.TruncatedNormal(
514-
'r_hat',
515-
mu=prior_r_mean,
516-
sigma=prior_r_sigma,
517-
lower=0,
518-
shape=(
519-
nr,
520-
))
490+
r_hat = pm.TruncatedNormal('r_hat',mu=prior_r_mean,sigma=prior_r_sigma,lower=0,shape=(nr,))
521491
print("r_hat is inferred")
522492
else:
523493
r_hat = at.as_tensor_variable(r)
524494
print("r_hat is fixed")
525495

526496
# For K parameter
527497
if prior_K_mean is not None and prior_K_sigma is not None:
528-
K_hat = pm.TruncatedNormal(
529-
'K_hat',
530-
mu=prior_K_mean,
531-
sigma=prior_K_sigma,
532-
lower=1.0,
533-
shape=(
534-
nr,
535-
))
498+
K_hat = pm.TruncatedNormal('K_hat', mu=prior_K_mean, sigma=prior_K_sigma,lower=0,shape=(nr,))
536499
print("K_hat is inferred")
537500
else:
538501
K_hat = at.as_tensor_variable(K)
539502
print("K_hat is fixed")
540503

541504
# Flatten to read into CRM_inf_func as a single vector
542-
nr_tensor = at.as_tensor_variable([nr])
543505
nsp_tensor = at.as_tensor_variable([nsp])
544-
545-
theta = at.concatenate(
546-
[nr_tensor, nsp_tensor, tau_hat, w_hat, c_hat.flatten(), m_hat, r_hat, K_hat])
506+
nr_tensor = at.as_tensor_variable([nr])
507+
508+
509+
theta = at.concatenate([nsp_tensor, nr_tensor, tau_hat, w_hat, c_hat.flatten(), m_hat, r_hat, K_hat])
510+
511+
print("=== CRITICAL: TESTING IF MODEL STRUCTURE IS CORRECT ===")
512+
try:
513+
y0 = np.full(n_states, 10.0)
514+
test_curves = crm_model(y0=y0, theta=theta)
515+
test_pred = test_curves.eval()
516+
517+
rmse = np.sqrt(np.mean((test_pred - yobs)**2))
518+
print(f"RMSE with near-true parameters: {rmse:.6f}")
519+
print(f"Data scale: {np.mean(yobs):.3f}")
520+
print(f"Model scale: {np.mean(test_pred):.3f}")
521+
print(f"First few predictions: {test_pred[:3]}")
522+
print(f"First few observations: {yobs[:3]}")
523+
524+
except Exception as e:
525+
print(f"MODEL FAILED: {e}")
526+
527+
print(f"nsp_tensor: {nsp_tensor.eval()}, nr_tensor: {nr_tensor.eval()}")
528+
print(f"tau_hat: {tau_hat.eval()}, w_hat: {w_hat.eval()}")
529+
print(f"c_hat: {c_hat.eval()}, m_hat: {m_hat.eval()}")
530+
print(f"r_hat: {r_hat.eval()}, K_hat: {K_hat.eval()}")
531+
print(f"theta: {theta.eval()}")
547532

548533
# Initial conditions for the ODE
549534
# initial_conditions = np.concatenate([(yobs[0,:nsp]), np.array([10.0, 10.0])])
550535
# Initial species and resource populations
551-
y0 = np.concatenate([np.ones(nsp), np.ones(nr)])
536+
#y0 = np.concatenate([np.ones(nsp), np.ones(nr)])
537+
#y0 = yobs[0, :]
538+
y0 = np.full(nsp + nr, 10.0)
539+
print(f"Initial conditions (y0): {y0}")
552540
# y0 = np.array([10.0, 10.0, 10.0, 10.0])
553541
# y0 = np.full(n_states, 10.0)
554542

@@ -557,11 +545,9 @@ def run_inference(self) -> None:
557545

558546
# Define the log-normal likelihood with log-transformed observed data
559547
# Y = pm.Lognormal("Y", mu=pm.math.log(crm_curves), sigma=sigma, observed=yobs)
560-
Y = pm.Lognormal(
561-
"Y",
562-
mu=at.log(crm_curves),
563-
sigma=sigma,
564-
observed=yobs)
548+
#Y = pm.Lognormal( "Y",mu=at.log(crm_curves),sigma=sigma, observed=yobs)
549+
#Y = pm.Normal("Y", mu=crm_curves, sigma=sigma, observed=yobs)
550+
Y = pm.Normal("Y", mu=crm_curves[:, :nsp], sigma=sigma, observed=yobs_species_only) # species only
565551

566552
# For debugging:
567553
# print if `debug` is set to 'high' or 'low'
@@ -581,12 +567,7 @@ def run_inference(self) -> None:
581567
print("Shape of crm_curves:", crm_curves.shape.eval())
582568

583569
# Sample the posterior
584-
idata = pm.sample(
585-
draws=draws,
586-
tune=tune,
587-
chains=chains,
588-
cores=cores,
589-
progressbar=True)
570+
idata = pm.sample(draws=draws,tune=tune,chains=chains,cores=cores,progressbar=True)
590571

591572
return idata
592573

mimic/model_simulate/sim_CRM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,6 @@ def CRM(sy, t, nsp, nr, tau, w, c, m, r, K) -> numpy.ndarray:
183183
dN = (N / tau) * (c @ (w * R) - m)
184184

185185
# dR_a/dt = 1/(r_a * K_a) * (K_a - R_a) * R_a - Sum_i(N_i * c_ia * R_a)
186-
dR = (1 / r * K) * (K - R) * R - (N @ c * R)
186+
dR = (1 / (r * K)) * (K - R) * R - (N @ c * R)
187187

188188
return numpy.hstack((dN, dR))

0 commit comments

Comments
 (0)