11import pickle as pkl
22from functools import partial
3+ from grp import struct_group
34from typing import Callable , Dict
45
56import jax
1516 get_n_state_choice_period ,
1617 validate_stochastic_transition ,
1718)
19+ from dcegm .interfaces .jit_large_arrays import (
20+ merg_non_jit_batch_info_and_jit_batch_info ,
21+ merge_non_jit_and_jit_model_structure ,
22+ split_structure_and_batch_info ,
23+ )
1824from dcegm .interfaces .sol_interface import model_solved
1925from dcegm .law_of_motion import calc_cont_grids_next_period
2026from dcegm .likelihood import create_individual_likelihood_function
@@ -124,51 +130,6 @@ def __init__(
124130 else :
125131 self .alternative_sim_funcs = None
126132
127- def set_alternative_sim_funcs (
128- self , alternative_sim_specifications , alternative_specs = None
129- ):
130- if alternative_specs is None :
131- self .alternative_sim_specs = self .model_specs
132- alternative_specs_without_jax = self .specs_without_jax
133- else :
134- self .alternative_sim_specs = jax .tree_util .tree_map (
135- try_jax_array , alternative_specs
136- )
137- alternative_specs_without_jax = alternative_specs
138-
139- alternative_sim_funcs = generate_alternative_sim_functions (
140- model_specs = alternative_specs_without_jax ,
141- model_specs_jax = self .alternative_sim_specs ,
142- ** alternative_sim_specifications ,
143- )
144- self .alternative_sim_funcs = alternative_sim_funcs
145-
146- def backward_induction_inner_jit (self , params ):
147- return backward_induction (
148- params = params ,
149- income_shock_draws_unscaled = self .income_shock_draws_unscaled ,
150- income_shock_weights = self .income_shock_weights ,
151- model_config = self .model_config ,
152- batch_info = self .batch_info ,
153- model_funcs = self .model_funcs ,
154- model_structure = self .model_structure ,
155- )
156-
157- def get_fast_solve_func (self ):
158- backward_jit = jax .jit (
159- partial (
160- backward_induction ,
161- income_shock_draws_unscaled = self .income_shock_draws_unscaled ,
162- income_shock_weights = self .income_shock_weights ,
163- model_config = self .model_config ,
164- batch_info = self .batch_info ,
165- model_funcs = self .model_funcs ,
166- model_structure = self .model_structure ,
167- )
168- )
169-
170- return backward_jit
171-
172133 def solve (self , params , load_sol_path = None , save_sol_path = None ):
173134 """Solve a discrete-continuous life-cycle model using the DC-EGM algorithm.
174135
@@ -198,8 +159,14 @@ def solve(self, params, load_sol_path=None, save_sol_path=None):
198159 if load_sol_path is not None :
199160 sol_dict = pkl .load (open (load_sol_path , "rb" ))
200161 else :
201- value , policy , endog_grid = self .backward_induction_inner_jit (
202- params_processed
162+ value , policy , endog_grid = backward_induction (
163+ params = params_processed ,
164+ income_shock_draws_unscaled = self .income_shock_draws_unscaled ,
165+ income_shock_weights = self .income_shock_weights ,
166+ model_config = self .model_config ,
167+ model_funcs = self .model_funcs ,
168+ model_structure = self .model_structure ,
169+ batch_info = self .batch_info ,
203170 )
204171 sol_dict = {
205172 "value" : value ,
@@ -245,8 +212,14 @@ def solve_and_simulate(
245212 if load_sol_path is not None :
246213 sol_dict = pkl .load (open (load_sol_path , "rb" ))
247214 else :
248- value , policy , endog_grid = self .backward_induction_inner_jit (
249- params_processed
215+ value , policy , endog_grid = backward_induction (
216+ params = params_processed ,
217+ income_shock_draws_unscaled = self .income_shock_draws_unscaled ,
218+ income_shock_weights = self .income_shock_weights ,
219+ model_config = self .model_config ,
220+ model_funcs = self .model_funcs ,
221+ model_structure = self .model_structure ,
222+ batch_info = self .batch_info ,
250223 )
251224
252225 sol_dict = {
@@ -274,14 +247,69 @@ def solve_and_simulate(
274247 sim_df = create_simulation_df (sim_dict )
275248 return sim_df
276249
250+ def get_solve_func (self ):
251+ """Create a fast function for solving that is jit compiled in the first call."""
252+
253+ (
254+ model_structure_for_jit ,
255+ batch_info_for_jit ,
256+ model_structure_non_jit ,
257+ batch_info_non_jit ,
258+ ) = split_structure_and_batch_info (self .model_structure , self .batch_info )
259+
260+ def solve_function_to_jit (params , model_structure_jit , batch_info_jit ):
261+ params_processed = process_params (params , self .params_check_info )
262+
263+ # Merge back parts together. The non_jit objects are fixed in the closure.
264+ model_structure = merge_non_jit_and_jit_model_structure (
265+ model_structure_jit , model_structure_non_jit
266+ )
267+ batch_info = merg_non_jit_batch_info_and_jit_batch_info (
268+ batch_info_jit , batch_info_non_jit
269+ )
270+
271+ # Solve the model.
272+ value , policy , endog_grid = backward_induction (
273+ params = params_processed ,
274+ model_structure = model_structure ,
275+ batch_info = batch_info ,
276+ income_shock_draws_unscaled = self .income_shock_draws_unscaled ,
277+ income_shock_weights = self .income_shock_weights ,
278+ model_config = self .model_config ,
279+ model_funcs = self .model_funcs ,
280+ )
281+
282+ return value , policy , endog_grid
283+
284+ solve_func = jax .jit (solve_function_to_jit )
285+
286+ # Generate the function. The user only needs to provide params, but we call with the objects for jit.
287+ def solve_function (params ):
288+ """Solve the model for given params."""
289+ value , policy , endog_grid = solve_func (
290+ params , model_structure_for_jit , batch_info_for_jit
291+ )
292+ model_solved_class = model_solved (
293+ model = self ,
294+ params = params ,
295+ value = value ,
296+ policy = policy ,
297+ endog_grid = endog_grid ,
298+ )
299+ return model_solved_class
300+
301+ return solve_function
302+
277303 def get_solve_and_simulate_func (
278304 self ,
279305 states_initial ,
280306 seed ,
281- slow_version = False ,
282307 ):
308+ """Create a fast function for solving and simulation that is jit compiled in the
309+ first call."""
283310
284- sim_func = lambda params , value , policy , endog_gid : simulate_all_periods (
311+ # Fix everything except params, solution of the model and model_structure which contains large arrays.
312+ sim_func = lambda params , value , policy , endog_gid , model_structure : simulate_all_periods (
285313 states_initial = states_initial ,
286314 n_periods = self .model_config ["n_periods" ],
287315 params = params ,
@@ -290,34 +318,59 @@ def get_solve_and_simulate_func(
290318 policy_solved = policy ,
291319 value_solved = value ,
292320 model_config = self .model_config ,
293- model_structure = self . model_structure ,
321+ model_structure = model_structure ,
294322 model_funcs = self .model_funcs ,
295323 alt_model_funcs_sim = self .alternative_sim_funcs ,
296324 )
297325
298- def solve_and_simulate_function_to_jit (params ):
326+ (
327+ model_structure_for_jit ,
328+ batch_info_for_jit ,
329+ model_structure_non_jit ,
330+ batch_info_non_jit ,
331+ ) = split_structure_and_batch_info (self .model_structure , self .batch_info )
332+
333+ def solve_and_simulate_function_to_jit (
334+ params , model_structure_jit , batch_info_jit
335+ ):
299336 params_processed = process_params (params , self .params_check_info )
300- # Solve the model
301- value , policy , endog_grid = self .backward_induction_inner_jit (
302- params_processed
337+
338+ # Merge back parts together. The non_jit objects are fixed in the closure.
339+ model_structure = merge_non_jit_and_jit_model_structure (
340+ model_structure_jit , model_structure_non_jit
341+ )
342+ batch_info = merg_non_jit_batch_info_and_jit_batch_info (
343+ batch_info_jit , batch_info_non_jit
344+ )
345+
346+ # Solve the model.
347+ value , policy , endog_grid = backward_induction (
348+ params = params_processed ,
349+ model_structure = model_structure ,
350+ batch_info = batch_info ,
351+ income_shock_draws_unscaled = self .income_shock_draws_unscaled ,
352+ income_shock_weights = self .income_shock_weights ,
353+ model_config = self .model_config ,
354+ model_funcs = self .model_funcs ,
303355 )
304356
305357 sim_dict = sim_func (
306358 params = params_processed ,
307359 value = value ,
308360 policy = policy ,
309361 endog_gid = endog_grid ,
362+ model_structure = model_structure ,
310363 )
311364
312365 return sim_dict
313366
314- if slow_version :
315- solve_simulate_func = solve_and_simulate_function_to_jit
316- else :
317- solve_simulate_func = jax .jit (solve_and_simulate_function_to_jit )
367+ solve_simulate_func = jax .jit (solve_and_simulate_function_to_jit )
318368
369+ # Generate the function. The user only needs to provide params, but we call with the objects for jit.
319370 def solve_and_simulate_function (params ):
320- sim_dict = solve_simulate_func (params )
371+ sim_dict = solve_simulate_func (
372+ params , model_structure_for_jit , batch_info_for_jit
373+ )
321374 df = create_simulation_df (sim_dict )
322375 return df
323376
@@ -335,11 +388,13 @@ def create_experimental_ll_func(
335388 ):
336389
337390 return create_individual_likelihood_function (
391+ income_shock_draws_unscaled = self .income_shock_draws_unscaled ,
392+ income_shock_weights = self .income_shock_weights ,
393+ batch_info = self .batch_info ,
338394 model_structure = self .model_structure ,
339395 model_config = self .model_config ,
340396 model_funcs = self .model_funcs ,
341397 model_specs = self .model_specs ,
342- backwards_induction_inner_jit = self .backward_induction_inner_jit ,
343398 observed_states = observed_states ,
344399 observed_choices = observed_choices ,
345400 params_all = params_all ,
@@ -475,3 +530,22 @@ def solve_partially(self, params, n_periods, return_candidates=False):
475530 n_periods = n_periods ,
476531 return_candidates = return_candidates ,
477532 )
533+
534+ def set_alternative_sim_funcs (
535+ self , alternative_sim_specifications , alternative_specs = None
536+ ):
537+ if alternative_specs is None :
538+ self .alternative_sim_specs = self .model_specs
539+ alternative_specs_without_jax = self .specs_without_jax
540+ else :
541+ self .alternative_sim_specs = jax .tree_util .tree_map (
542+ try_jax_array , alternative_specs
543+ )
544+ alternative_specs_without_jax = alternative_specs
545+
546+ alternative_sim_funcs = generate_alternative_sim_functions (
547+ model_specs = alternative_specs_without_jax ,
548+ model_specs_jax = self .alternative_sim_specs ,
549+ ** alternative_sim_specifications ,
550+ )
551+ self .alternative_sim_funcs = alternative_sim_funcs
0 commit comments