Skip to content

Commit cdad768

Browse files
authored
Add more options to debug models (#179)
1 parent d8860f4 commit cdad768

48 files changed

Lines changed: 2286 additions & 1260 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
hooks:
1111
- id: yamllint
1212
- repo: https://github.com/lyz-code/yamlfix
13-
rev: 1.19.0
13+
rev: 1.19.1
1414
hooks:
1515
- id: yamlfix
1616
- repo: https://github.com/pre-commit/pre-commit-hooks
@@ -59,10 +59,10 @@ repos:
5959
# hooks:
6060
# - id: setup-cfg-fmt
6161
- repo: https://github.com/psf/black-pre-commit-mirror
62-
rev: 25.11.0
62+
rev: 25.12.0
6363
hooks:
6464
- id: black
65-
language_version: python3.12
65+
language_version: python3.13
6666
# - repo: https://github.com/charliermarsh/ruff-pre-commit
6767
# rev: v0.0.282
6868
# hooks:

docs/source/background/two_period_model_tutorial.ipynb

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -749,15 +749,10 @@
749749
]
750750
},
751751
{
752+
"metadata": {},
752753
"cell_type": "code",
753-
"execution_count": 17,
754-
"metadata": {
755-
"ExecuteTime": {
756-
"end_time": "2025-06-30T09:22:14.977528Z",
757-
"start_time": "2025-06-30T09:22:14.323938Z"
758-
}
759-
},
760754
"outputs": [],
755+
"execution_count": null,
761756
"source": [
762757
"state_dict = {\n",
763758
" \"ltc\": initial_condition[\"health\"],\n",
@@ -767,9 +762,9 @@
767762
"}\n",
768763
"\n",
769764
"\n",
770-
"cons_calc, value = solved_model.value_and_policy_for_state_and_choice(\n",
771-
" state=state_dict,\n",
772-
" choice=choice_in_period_0,\n",
765+
"cons_calc, value = solved_model.policy_and_value_for_states_and_choices(\n",
766+
" states=state_dict,\n",
767+
" choices=choice_in_period_0,\n",
773768
")"
774769
]
775770
},

environment.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ dependencies:
2222
- flake8
2323
- jupyterlab
2424
- matplotlib
25-
- pdbpp
2625
- pre-commit
2726
- setuptools_scm
2827
- toml

src/dcegm/asset_correction.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
from jax import vmap
33

44
from dcegm.law_of_motion import (
5-
calc_assets_beginning_of_period_2cont_vec,
6-
calc_beginning_of_period_assets_1cont_vec,
5+
calc_beginning_of_period_assets_for_single_state,
76
)
87

98

10-
def adjust_observed_assets(observed_states_dict, params, model_class):
9+
def adjust_observed_assets(observed_states_dict, params, model_class, aux_outs=False):
1110
"""Correct observed beginning of period assets data for likelihood estimation.
1211
1312
Assets in empirical survey data is observed without the income of last period's
@@ -37,30 +36,23 @@ def adjust_observed_assets(observed_states_dict, params, model_class):
3736
second_cont_state_vars = observed_states_dict[second_cont_state_name]
3837
observed_states_dict_int.pop(second_cont_state_name)
3938

40-
adjusted_assets = vmap(
41-
calc_assets_beginning_of_period_2cont_vec,
42-
in_axes=(0, 0, 0, None, None, None, None),
43-
)(
44-
observed_states_dict_int,
45-
second_cont_state_vars,
46-
assets_end_last_period,
47-
jnp.array(0.0, dtype=jnp.float64),
48-
params,
49-
model_funcs["compute_assets_begin_of_period"],
50-
False,
51-
)
52-
39+
all_states = {
40+
**observed_states_dict_int,
41+
"continuous_state": second_cont_state_vars,
42+
}
5343
else:
54-
adjusted_assets = vmap(
55-
calc_beginning_of_period_assets_1cont_vec,
56-
in_axes=(0, 0, None, None, None, None),
57-
)(
58-
observed_states_dict,
59-
assets_end_last_period,
60-
jnp.array(0.0, dtype=jnp.float64),
61-
params,
62-
model_funcs["compute_assets_begin_of_period"],
63-
False,
64-
)
44+
all_states = observed_states_dict_int
45+
46+
adjusted_assets = vmap(
47+
calc_beginning_of_period_assets_for_single_state,
48+
in_axes=(0, 0, None, None, None, None),
49+
)(
50+
all_states,
51+
assets_end_last_period,
52+
jnp.array(0.0, dtype=jnp.float64),
53+
params,
54+
model_funcs["compute_assets_begin_of_period"],
55+
aux_outs,
56+
)
6557

6658
return adjusted_assets

src/dcegm/backward_induction.py

Lines changed: 56 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Interface for the DC-EGM algorithm."""
22

3-
from functools import partial
43
from typing import Any, Callable, Dict, Tuple
54

5+
import jax
66
import jax.lax
77
import jax.numpy as jnp
8-
import numpy as np
98

109
from dcegm.final_periods import solve_last_two_periods
1110
from dcegm.law_of_motion import calc_cont_grids_next_period
11+
from dcegm.pre_processing.sol_container import create_solution_container
1212
from dcegm.solve_single_period import solve_single_period
1313

1414

@@ -25,117 +25,90 @@ def backward_induction(
2525
2626
Args:
2727
params (dict): Dictionary containing the model parameters.
28-
options (dict): Dictionary containing the model options.
29-
period_specific_state_objects (np.ndarray): Dictionary containing
30-
period-specific state and state-choice objects, with the following keys:
31-
- "state_choice_mat" (jnp.ndarray)
32-
- "idx_state_of_state_choice" (jnp.ndarray)
33-
- "reshape_state_choice_vec_to_mat" (callable)
34-
- "transform_between_state_and_state_choice_vec" (callable)
35-
exog_savings_grid (np.ndarray): 1d array of shape (n_grid_wealth,)
36-
containing the exogenous savings grid.
37-
has_second_continuous_state (bool): Boolean indicating whether the model
38-
features a second continuous state variable. If False, the only
39-
continuous state variable is consumption/savings.
40-
state_space (np.ndarray): 2d array of shape (n_states, n_state_variables + 1)
41-
which serves as a collection of all possible states. By convention,
42-
the first column must contain the period and the last column the
43-
exogenous processes. Any other state variables are in between.
44-
E.g. if the two state variables are period and lagged choice and all choices
45-
are admissible in each period, the shape of the state space array is
46-
(n_periods * n_choices, 3).
47-
state_choice_space (np.ndarray): 2d array of shape
48-
(n_feasible_states, n_state_and_exog_variables + 1) containing all
49-
feasible state-choice combinations. By convention, the second to last
50-
column contains the exogenous process. The last column always contains the
51-
choice to be made (which is not a state variable).
5228
income_shock_draws_unscaled (np.ndarray): 1d array of shape (n_quad_points,)
5329
containing the Hermite quadrature points unscaled.
5430
income_shock_weights (np.ndarrray): 1d array of shape
5531
(n_stochastic_quad_points) with weights for each stoachstic shock draw.
56-
n_periods (int): Number of periods.
57-
model_funcs (dict): Dictionary containing following model functions:
58-
- compute_marginal_utility (callable): User-defined function to compute the
59-
agent's marginal utility. The input ```params``` is already partialled
60-
in.
61-
- compute_inverse_marginal_utility (Callable): Function for calculating the
62-
inverse marginal utiFality, which takes the marginal utility as only
63-
input.
64-
- compute_next_period_wealth (callable): User-defined function to compute
65-
the agent's wealth of the next period (t + 1). The inputs
66-
```saving```, ```shock```, ```params``` and ```options```
67-
are already partialled in.
68-
- transition_vector_by_state (Callable): Partialled transition function
69-
return transition vector for each state.
70-
- final_period_partial (Callable): Partialled function for calculating the
71-
consumption as well as value function and marginal utility in the final
72-
period.
73-
compute_upper_envelope (Callable): Function for calculating the upper
74-
envelope of the policy and value function. If the number of discrete
75-
choices is 1, this function is a dummy function that returns the policy
76-
and value function as is, without performing a fast upper envelope
77-
scan.
32+
model_config (dict): Dictionary containing the model configuration.
33+
model_funcs (dict): Dictionary containing model functions.
34+
model_structure (dict): Dictionary containing model structure.
35+
batch_info (dict): Dictionary containing batch information.
7836
7937
Returns:
80-
dict: Dictionary containing the period-specific endog_grid, policy, and value
38+
Tuple: Tuple containing the period-specific endog_grid, policy, and value
8139
from the backward induction.
8240
8341
"""
8442
continuous_states_info = model_config["continuous_states_info"]
8543

86-
cont_grids_next_period = calc_cont_grids_next_period(
87-
model_structure=model_structure,
88-
model_config=model_config,
89-
income_shock_draws_unscaled=income_shock_draws_unscaled,
90-
params=params,
91-
model_funcs=model_funcs,
44+
#
45+
calc_grids_jit = jax.jit(
46+
lambda income_shock_draws, params_inner: calc_cont_grids_next_period(
47+
model_structure=model_structure,
48+
model_config=model_config,
49+
income_shock_draws_unscaled=income_shock_draws,
50+
params=params_inner,
51+
model_funcs=model_funcs,
52+
)
9253
)
9354

94-
# Create solution containers. The 20 percent extra in wealth grid needs to go
95-
# into tuning parameters
96-
n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"]
55+
cont_grids_next_period = calc_grids_jit(income_shock_draws_unscaled, params)
56+
9757
(
9858
value_solved,
9959
policy_solved,
10060
endog_grid_solved,
10161
) = create_solution_container(
102-
model_config=model_config,
103-
model_structure=model_structure,
62+
continuous_states_info=model_config["continuous_states_info"],
63+
# Read out grid size
64+
n_total_wealth_grid=model_config["tuning_params"]["n_total_wealth_grid"],
65+
n_state_choices=model_structure["state_choice_space"].shape[0],
66+
)
67+
68+
# Solve the last two periods using lambda to capture static arguments
69+
solve_last_two_period_jit = jax.jit(
70+
lambda params_inner, cont_grids, weights, val_solved, pol_solved, endog_solved: solve_last_two_periods(
71+
params=params_inner,
72+
continuous_states_info=continuous_states_info,
73+
cont_grids_next_period=cont_grids,
74+
income_shock_weights=weights,
75+
model_funcs=model_funcs,
76+
last_two_period_batch_info=batch_info["last_two_period_info"],
77+
value_solved=val_solved,
78+
policy_solved=pol_solved,
79+
endog_grid_solved=endog_solved,
80+
debug_info=None,
81+
)
10482
)
10583

106-
# Solve the last two periods. We do this separately as the marginal utility of
107-
# the child states in the last period is calculated from the marginal utility
108-
# function of the bequest function, which might differ.
10984
(
11085
value_solved,
11186
policy_solved,
11287
endog_grid_solved,
113-
) = solve_last_two_periods(
114-
params=params,
115-
continuous_states_info=continuous_states_info,
116-
cont_grids_next_period=cont_grids_next_period,
117-
income_shock_weights=income_shock_weights,
118-
model_funcs=model_funcs,
119-
last_two_period_batch_info=batch_info["last_two_period_info"],
120-
value_solved=value_solved,
121-
policy_solved=policy_solved,
122-
endog_grid_solved=endog_grid_solved,
88+
) = solve_last_two_period_jit(
89+
params,
90+
cont_grids_next_period,
91+
income_shock_weights,
92+
value_solved,
93+
policy_solved,
94+
endog_grid_solved,
12395
)
12496

12597
# If it is a two period model we are done.
12698
if batch_info["two_period_model"]:
12799
return value_solved, policy_solved, endog_grid_solved
128100

129-
def partial_single_period(carry, xs):
130-
return solve_single_period(
131-
carry=carry,
132-
xs=xs,
133-
params=params,
134-
continuous_grids_info=continuous_states_info,
135-
cont_grids_next_period=cont_grids_next_period,
136-
model_funcs=model_funcs,
137-
income_shock_weights=income_shock_weights,
138-
)
101+
# Create JIT-compiled single period solver using lambda
102+
partial_single_period = lambda carry, xs: solve_single_period(
103+
carry=carry,
104+
xs=xs,
105+
params=params,
106+
continuous_grids_info=continuous_states_info,
107+
cont_grids_next_period=cont_grids_next_period,
108+
model_funcs=model_funcs,
109+
income_shock_weights=income_shock_weights,
110+
debug_info=None,
111+
)
139112

140113
for id_segment in range(batch_info["n_segments"]):
141114
segment_info = batch_info[f"batches_info_segment_{id_segment}"]
@@ -192,53 +165,3 @@ def partial_single_period(carry, xs):
192165
policy_solved,
193166
endog_grid_solved,
194167
)
195-
196-
197-
def create_solution_container(
198-
model_config: Dict[str, Any],
199-
model_structure: Dict[str, Any],
200-
):
201-
"""Create solution containers for value, policy, and endog_grid."""
202-
203-
# Read out grid size
204-
n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"]
205-
n_state_choices = model_structure["state_choice_space"].shape[0]
206-
207-
# Check if second continuous state exists and read out array size
208-
continuous_states_info = model_config["continuous_states_info"]
209-
if continuous_states_info["second_continuous_exists"]:
210-
n_second_continuous_grid = continuous_states_info["n_second_continuous_grid"]
211-
212-
value_solved = jnp.full(
213-
(n_state_choices, n_second_continuous_grid, n_total_wealth_grid),
214-
dtype=jnp.float64,
215-
fill_value=jnp.nan,
216-
)
217-
policy_solved = jnp.full(
218-
(n_state_choices, n_second_continuous_grid, n_total_wealth_grid),
219-
dtype=jnp.float64,
220-
fill_value=jnp.nan,
221-
)
222-
endog_grid_solved = jnp.full(
223-
(n_state_choices, n_second_continuous_grid, n_total_wealth_grid),
224-
dtype=jnp.float64,
225-
fill_value=jnp.nan,
226-
)
227-
else:
228-
value_solved = jnp.full(
229-
(n_state_choices, n_total_wealth_grid),
230-
dtype=jnp.float64,
231-
fill_value=jnp.nan,
232-
)
233-
policy_solved = jnp.full(
234-
(n_state_choices, n_total_wealth_grid),
235-
dtype=jnp.float64,
236-
fill_value=jnp.nan,
237-
)
238-
endog_grid_solved = jnp.full(
239-
(n_state_choices, n_total_wealth_grid),
240-
dtype=jnp.float64,
241-
fill_value=jnp.nan,
242-
)
243-
244-
return value_solved, policy_solved, endog_grid_solved

src/dcegm/egm/aggregate_marginal_utility.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from typing import Tuple
22

33
import jax.numpy as jnp
4-
import numpy as np
54

65

76
def aggregate_marg_utils_and_exp_values(
87
value_state_choice_specific: jnp.ndarray,
98
marg_util_state_choice_specific: jnp.ndarray,
10-
reshape_state_choice_vec_to_mat: np.ndarray,
9+
reshape_state_choice_vec_to_mat: jnp.ndarray,
1110
taste_shock_scale,
1211
taste_shock_scale_is_scalar,
1312
income_shock_weights: jnp.ndarray,
@@ -47,11 +46,12 @@ def aggregate_marg_utils_and_exp_values(
4746
mode="fill",
4847
fill_value=jnp.nan,
4948
)
49+
5050
# If taste shock is not scalar, we select from the array,
5151
# where we have for each choice a taste shock scale one. They are by construction
5252
# the same for all choices in a state
5353
if not taste_shock_scale_is_scalar:
54-
one_choice_per_state = np.min(reshape_state_choice_vec_to_mat, axis=1)
54+
one_choice_per_state = jnp.min(reshape_state_choice_vec_to_mat, axis=1)
5555
taste_shock_scale = jnp.take(
5656
taste_shock_scale,
5757
one_choice_per_state,

0 commit comments

Comments
 (0)