11"""Interface for the DC-EGM algorithm."""
22
3- from functools import partial
43from typing import Any , Callable , Dict , Tuple
54
5+ import jax
66import jax .lax
77import jax .numpy as jnp
8- import numpy as np
98
109from dcegm .final_periods import solve_last_two_periods
1110from dcegm .law_of_motion import calc_cont_grids_next_period
11+ from dcegm .pre_processing .sol_container import create_solution_container
1212from dcegm .solve_single_period import solve_single_period
1313
1414
@@ -25,117 +25,90 @@ def backward_induction(
2525
2626 Args:
2727 params (dict): Dictionary containing the model parameters.
28- options (dict): Dictionary containing the model options.
29- period_specific_state_objects (np.ndarray): Dictionary containing
30- period-specific state and state-choice objects, with the following keys:
31- - "state_choice_mat" (jnp.ndarray)
32- - "idx_state_of_state_choice" (jnp.ndarray)
33- - "reshape_state_choice_vec_to_mat" (callable)
34- - "transform_between_state_and_state_choice_vec" (callable)
35- exog_savings_grid (np.ndarray): 1d array of shape (n_grid_wealth,)
36- containing the exogenous savings grid.
37- has_second_continuous_state (bool): Boolean indicating whether the model
38- features a second continuous state variable. If False, the only
39- continuous state variable is consumption/savings.
40- state_space (np.ndarray): 2d array of shape (n_states, n_state_variables + 1)
41- which serves as a collection of all possible states. By convention,
42- the first column must contain the period and the last column the
43- exogenous processes. Any other state variables are in between.
44- E.g. if the two state variables are period and lagged choice and all choices
45- are admissible in each period, the shape of the state space array is
46- (n_periods * n_choices, 3).
47- state_choice_space (np.ndarray): 2d array of shape
48- (n_feasible_states, n_state_and_exog_variables + 1) containing all
49- feasible state-choice combinations. By convention, the second to last
50- column contains the exogenous process. The last column always contains the
51- choice to be made (which is not a state variable).
5228 income_shock_draws_unscaled (np.ndarray): 1d array of shape (n_quad_points,)
5329 containing the Hermite quadrature points unscaled.
5430 income_shock_weights (np.ndarrray): 1d array of shape
5531 (n_stochastic_quad_points) with weights for each stoachstic shock draw.
56- n_periods (int): Number of periods.
57- model_funcs (dict): Dictionary containing following model functions:
58- - compute_marginal_utility (callable): User-defined function to compute the
59- agent's marginal utility. The input ```params``` is already partialled
60- in.
61- - compute_inverse_marginal_utility (Callable): Function for calculating the
62- inverse marginal utiFality, which takes the marginal utility as only
63- input.
64- - compute_next_period_wealth (callable): User-defined function to compute
65- the agent's wealth of the next period (t + 1). The inputs
66- ```saving```, ```shock```, ```params``` and ```options```
67- are already partialled in.
68- - transition_vector_by_state (Callable): Partialled transition function
69- return transition vector for each state.
70- - final_period_partial (Callable): Partialled function for calculating the
71- consumption as well as value function and marginal utility in the final
72- period.
73- compute_upper_envelope (Callable): Function for calculating the upper
74- envelope of the policy and value function. If the number of discrete
75- choices is 1, this function is a dummy function that returns the policy
76- and value function as is, without performing a fast upper envelope
77- scan.
32+ model_config (dict): Dictionary containing the model configuration.
33+ model_funcs (dict): Dictionary containing model functions.
34+ model_structure (dict): Dictionary containing model structure.
35+ batch_info (dict): Dictionary containing batch information.
7836
7937 Returns:
80- dict: Dictionary containing the period-specific endog_grid, policy, and value
38+ Tuple: Tuple containing the period-specific endog_grid, policy, and value
8139 from the backward induction.
8240
8341 """
8442 continuous_states_info = model_config ["continuous_states_info" ]
8543
86- cont_grids_next_period = calc_cont_grids_next_period (
87- model_structure = model_structure ,
88- model_config = model_config ,
89- income_shock_draws_unscaled = income_shock_draws_unscaled ,
90- params = params ,
91- model_funcs = model_funcs ,
44+ #
45+ calc_grids_jit = jax .jit (
46+ lambda income_shock_draws , params_inner : calc_cont_grids_next_period (
47+ model_structure = model_structure ,
48+ model_config = model_config ,
49+ income_shock_draws_unscaled = income_shock_draws ,
50+ params = params_inner ,
51+ model_funcs = model_funcs ,
52+ )
9253 )
9354
94- # Create solution containers. The 20 percent extra in wealth grid needs to go
95- # into tuning parameters
96- n_total_wealth_grid = model_config ["tuning_params" ]["n_total_wealth_grid" ]
55+ cont_grids_next_period = calc_grids_jit (income_shock_draws_unscaled , params )
56+
9757 (
9858 value_solved ,
9959 policy_solved ,
10060 endog_grid_solved ,
10161 ) = create_solution_container (
102- model_config = model_config ,
103- model_structure = model_structure ,
62+ continuous_states_info = model_config ["continuous_states_info" ],
63+ # Read out grid size
64+ n_total_wealth_grid = model_config ["tuning_params" ]["n_total_wealth_grid" ],
65+ n_state_choices = model_structure ["state_choice_space" ].shape [0 ],
66+ )
67+
68+ # Solve the last two periods using lambda to capture static arguments
69+ solve_last_two_period_jit = jax .jit (
70+ lambda params_inner , cont_grids , weights , val_solved , pol_solved , endog_solved : solve_last_two_periods (
71+ params = params_inner ,
72+ continuous_states_info = continuous_states_info ,
73+ cont_grids_next_period = cont_grids ,
74+ income_shock_weights = weights ,
75+ model_funcs = model_funcs ,
76+ last_two_period_batch_info = batch_info ["last_two_period_info" ],
77+ value_solved = val_solved ,
78+ policy_solved = pol_solved ,
79+ endog_grid_solved = endog_solved ,
80+ debug_info = None ,
81+ )
10482 )
10583
106- # Solve the last two periods. We do this separately as the marginal utility of
107- # the child states in the last period is calculated from the marginal utility
108- # function of the bequest function, which might differ.
10984 (
11085 value_solved ,
11186 policy_solved ,
11287 endog_grid_solved ,
113- ) = solve_last_two_periods (
114- params = params ,
115- continuous_states_info = continuous_states_info ,
116- cont_grids_next_period = cont_grids_next_period ,
117- income_shock_weights = income_shock_weights ,
118- model_funcs = model_funcs ,
119- last_two_period_batch_info = batch_info ["last_two_period_info" ],
120- value_solved = value_solved ,
121- policy_solved = policy_solved ,
122- endog_grid_solved = endog_grid_solved ,
88+ ) = solve_last_two_period_jit (
89+ params ,
90+ cont_grids_next_period ,
91+ income_shock_weights ,
92+ value_solved ,
93+ policy_solved ,
94+ endog_grid_solved ,
12395 )
12496
12597 # If it is a two period model we are done.
12698 if batch_info ["two_period_model" ]:
12799 return value_solved , policy_solved , endog_grid_solved
128100
129- def partial_single_period (carry , xs ):
130- return solve_single_period (
131- carry = carry ,
132- xs = xs ,
133- params = params ,
134- continuous_grids_info = continuous_states_info ,
135- cont_grids_next_period = cont_grids_next_period ,
136- model_funcs = model_funcs ,
137- income_shock_weights = income_shock_weights ,
138- )
101+ # Create JIT-compiled single period solver using lambda
102+ partial_single_period = lambda carry , xs : solve_single_period (
103+ carry = carry ,
104+ xs = xs ,
105+ params = params ,
106+ continuous_grids_info = continuous_states_info ,
107+ cont_grids_next_period = cont_grids_next_period ,
108+ model_funcs = model_funcs ,
109+ income_shock_weights = income_shock_weights ,
110+ debug_info = None ,
111+ )
139112
140113 for id_segment in range (batch_info ["n_segments" ]):
141114 segment_info = batch_info [f"batches_info_segment_{ id_segment } " ]
@@ -192,53 +165,3 @@ def partial_single_period(carry, xs):
192165 policy_solved ,
193166 endog_grid_solved ,
194167 )
195-
196-
197- def create_solution_container (
198- model_config : Dict [str , Any ],
199- model_structure : Dict [str , Any ],
200- ):
201- """Create solution containers for value, policy, and endog_grid."""
202-
203- # Read out grid size
204- n_total_wealth_grid = model_config ["tuning_params" ]["n_total_wealth_grid" ]
205- n_state_choices = model_structure ["state_choice_space" ].shape [0 ]
206-
207- # Check if second continuous state exists and read out array size
208- continuous_states_info = model_config ["continuous_states_info" ]
209- if continuous_states_info ["second_continuous_exists" ]:
210- n_second_continuous_grid = continuous_states_info ["n_second_continuous_grid" ]
211-
212- value_solved = jnp .full (
213- (n_state_choices , n_second_continuous_grid , n_total_wealth_grid ),
214- dtype = jnp .float64 ,
215- fill_value = jnp .nan ,
216- )
217- policy_solved = jnp .full (
218- (n_state_choices , n_second_continuous_grid , n_total_wealth_grid ),
219- dtype = jnp .float64 ,
220- fill_value = jnp .nan ,
221- )
222- endog_grid_solved = jnp .full (
223- (n_state_choices , n_second_continuous_grid , n_total_wealth_grid ),
224- dtype = jnp .float64 ,
225- fill_value = jnp .nan ,
226- )
227- else :
228- value_solved = jnp .full (
229- (n_state_choices , n_total_wealth_grid ),
230- dtype = jnp .float64 ,
231- fill_value = jnp .nan ,
232- )
233- policy_solved = jnp .full (
234- (n_state_choices , n_total_wealth_grid ),
235- dtype = jnp .float64 ,
236- fill_value = jnp .nan ,
237- )
238- endog_grid_solved = jnp .full (
239- (n_state_choices , n_total_wealth_grid ),
240- dtype = jnp .float64 ,
241- fill_value = jnp .nan ,
242- )
243-
244- return value_solved , policy_solved , endog_grid_solved
0 commit comments