Skip to content

Commit d1e81f3

Browse files
authored
Introduce separation of large arrays for jit compiling (#192)
1 parent cdad768 commit d1e81f3

4 files changed

Lines changed: 304 additions & 114 deletions

File tree

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
def split_structure_and_batch_info(model_structure, batch_info):
2+
"""Splits the model structure and batch info into static parts, which we can not jit
3+
compile and (large) arrays that we want to include in the function call for
4+
jitting."""
5+
6+
struct_keys_not_for_jit = [
7+
"discrete_states_names",
8+
"state_names_without_stochastic",
9+
"stochastic_states_names",
10+
]
11+
model_structure_non_jit = {
12+
key: model_structure[key] for key in struct_keys_not_for_jit
13+
}
14+
model_structure_jit = model_structure.copy()
15+
# Remove non-jittable items
16+
for key in struct_keys_not_for_jit:
17+
model_structure_jit.pop(key, None)
18+
19+
# Remove non-jittable items from batch_info
20+
batch_info_jit = batch_info.copy()
21+
batch_info_non_jit = {
22+
"two_period_model": batch_info["two_period_model"],
23+
}
24+
batch_info_jit.pop("two_period_model", None)
25+
# If it is not a two period model, there is more
26+
if not batch_info["two_period_model"]:
27+
batch_info_non_jit["n_segments"] = batch_info["n_segments"]
28+
batch_info_jit.pop("n_segments", None)
29+
for batch_id in range(batch_info_non_jit["n_segments"]):
30+
batch_key = f"batches_info_segment_{batch_id}"
31+
batch_info_non_jit[batch_key] = {}
32+
batch_info_non_jit[batch_key]["batches_cover_all"] = batch_info[batch_key][
33+
"batches_cover_all"
34+
]
35+
batch_info_jit[batch_key].pop("batches_cover_all", None)
36+
37+
return (
38+
model_structure_jit,
39+
batch_info_jit,
40+
model_structure_non_jit,
41+
batch_info_non_jit,
42+
)
43+
44+
45+
def merge_non_jit_and_jit_model_structure(model_structure_jit, model_structure_non_jit):
46+
"""Generate one model_structure to handle inside the package functions."""
47+
model_structure = {
48+
**model_structure_jit,
49+
**model_structure_non_jit,
50+
}
51+
return model_structure
52+
53+
54+
def merg_non_jit_batch_info_and_jit_batch_info(batch_info_jit, batch_info_non_jit):
55+
batch_info = {
56+
**batch_info_jit,
57+
"two_period_model": batch_info_non_jit["two_period_model"],
58+
}
59+
if not batch_info_non_jit["two_period_model"]:
60+
batch_info["n_segments"] = batch_info_non_jit["n_segments"]
61+
for batch_id in range(batch_info_non_jit["n_segments"]):
62+
batch_key = f"batches_info_segment_{batch_id}"
63+
batch_info[batch_key]["batches_cover_all"] = batch_info_non_jit[batch_key][
64+
"batches_cover_all"
65+
]
66+
return batch_info

src/dcegm/interfaces/model_class.py

Lines changed: 136 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pickle as pkl
22
from functools import partial
3+
from grp import struct_group
34
from typing import Callable, Dict
45

56
import jax
@@ -15,6 +16,11 @@
1516
get_n_state_choice_period,
1617
validate_stochastic_transition,
1718
)
19+
from dcegm.interfaces.jit_large_arrays import (
20+
merg_non_jit_batch_info_and_jit_batch_info,
21+
merge_non_jit_and_jit_model_structure,
22+
split_structure_and_batch_info,
23+
)
1824
from dcegm.interfaces.sol_interface import model_solved
1925
from dcegm.law_of_motion import calc_cont_grids_next_period
2026
from dcegm.likelihood import create_individual_likelihood_function
@@ -124,51 +130,6 @@ def __init__(
124130
else:
125131
self.alternative_sim_funcs = None
126132

127-
def set_alternative_sim_funcs(
128-
self, alternative_sim_specifications, alternative_specs=None
129-
):
130-
if alternative_specs is None:
131-
self.alternative_sim_specs = self.model_specs
132-
alternative_specs_without_jax = self.specs_without_jax
133-
else:
134-
self.alternative_sim_specs = jax.tree_util.tree_map(
135-
try_jax_array, alternative_specs
136-
)
137-
alternative_specs_without_jax = alternative_specs
138-
139-
alternative_sim_funcs = generate_alternative_sim_functions(
140-
model_specs=alternative_specs_without_jax,
141-
model_specs_jax=self.alternative_sim_specs,
142-
**alternative_sim_specifications,
143-
)
144-
self.alternative_sim_funcs = alternative_sim_funcs
145-
146-
def backward_induction_inner_jit(self, params):
147-
return backward_induction(
148-
params=params,
149-
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
150-
income_shock_weights=self.income_shock_weights,
151-
model_config=self.model_config,
152-
batch_info=self.batch_info,
153-
model_funcs=self.model_funcs,
154-
model_structure=self.model_structure,
155-
)
156-
157-
def get_fast_solve_func(self):
158-
backward_jit = jax.jit(
159-
partial(
160-
backward_induction,
161-
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
162-
income_shock_weights=self.income_shock_weights,
163-
model_config=self.model_config,
164-
batch_info=self.batch_info,
165-
model_funcs=self.model_funcs,
166-
model_structure=self.model_structure,
167-
)
168-
)
169-
170-
return backward_jit
171-
172133
def solve(self, params, load_sol_path=None, save_sol_path=None):
173134
"""Solve a discrete-continuous life-cycle model using the DC-EGM algorithm.
174135
@@ -198,8 +159,14 @@ def solve(self, params, load_sol_path=None, save_sol_path=None):
198159
if load_sol_path is not None:
199160
sol_dict = pkl.load(open(load_sol_path, "rb"))
200161
else:
201-
value, policy, endog_grid = self.backward_induction_inner_jit(
202-
params_processed
162+
value, policy, endog_grid = backward_induction(
163+
params=params_processed,
164+
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
165+
income_shock_weights=self.income_shock_weights,
166+
model_config=self.model_config,
167+
model_funcs=self.model_funcs,
168+
model_structure=self.model_structure,
169+
batch_info=self.batch_info,
203170
)
204171
sol_dict = {
205172
"value": value,
@@ -245,8 +212,14 @@ def solve_and_simulate(
245212
if load_sol_path is not None:
246213
sol_dict = pkl.load(open(load_sol_path, "rb"))
247214
else:
248-
value, policy, endog_grid = self.backward_induction_inner_jit(
249-
params_processed
215+
value, policy, endog_grid = backward_induction(
216+
params=params_processed,
217+
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
218+
income_shock_weights=self.income_shock_weights,
219+
model_config=self.model_config,
220+
model_funcs=self.model_funcs,
221+
model_structure=self.model_structure,
222+
batch_info=self.batch_info,
250223
)
251224

252225
sol_dict = {
@@ -274,14 +247,69 @@ def solve_and_simulate(
274247
sim_df = create_simulation_df(sim_dict)
275248
return sim_df
276249

250+
def get_solve_func(self):
251+
"""Create a fast function for solving that is jit compiled in the first call."""
252+
253+
(
254+
model_structure_for_jit,
255+
batch_info_for_jit,
256+
model_structure_non_jit,
257+
batch_info_non_jit,
258+
) = split_structure_and_batch_info(self.model_structure, self.batch_info)
259+
260+
def solve_function_to_jit(params, model_structure_jit, batch_info_jit):
261+
params_processed = process_params(params, self.params_check_info)
262+
263+
# Merge back parts together. The non_jit objects are fixed in the closure.
264+
model_structure = merge_non_jit_and_jit_model_structure(
265+
model_structure_jit, model_structure_non_jit
266+
)
267+
batch_info = merg_non_jit_batch_info_and_jit_batch_info(
268+
batch_info_jit, batch_info_non_jit
269+
)
270+
271+
# Solve the model.
272+
value, policy, endog_grid = backward_induction(
273+
params=params_processed,
274+
model_structure=model_structure,
275+
batch_info=batch_info,
276+
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
277+
income_shock_weights=self.income_shock_weights,
278+
model_config=self.model_config,
279+
model_funcs=self.model_funcs,
280+
)
281+
282+
return value, policy, endog_grid
283+
284+
solve_func = jax.jit(solve_function_to_jit)
285+
286+
# Generate the function. The user only needs to provide params, but we call with the objects for jit.
287+
def solve_function(params):
288+
"""Solve the model for given params."""
289+
value, policy, endog_grid = solve_func(
290+
params, model_structure_for_jit, batch_info_for_jit
291+
)
292+
model_solved_class = model_solved(
293+
model=self,
294+
params=params,
295+
value=value,
296+
policy=policy,
297+
endog_grid=endog_grid,
298+
)
299+
return model_solved_class
300+
301+
return solve_function
302+
277303
def get_solve_and_simulate_func(
278304
self,
279305
states_initial,
280306
seed,
281-
slow_version=False,
282307
):
308+
"""Create a fast function for solving and simulation that is jit compiled in the
309+
first call."""
283310

284-
sim_func = lambda params, value, policy, endog_gid: simulate_all_periods(
311+
# Fix everything except params, solution of the model and model_structure which contains large arrays.
312+
sim_func = lambda params, value, policy, endog_gid, model_structure: simulate_all_periods(
285313
states_initial=states_initial,
286314
n_periods=self.model_config["n_periods"],
287315
params=params,
@@ -290,34 +318,59 @@ def get_solve_and_simulate_func(
290318
policy_solved=policy,
291319
value_solved=value,
292320
model_config=self.model_config,
293-
model_structure=self.model_structure,
321+
model_structure=model_structure,
294322
model_funcs=self.model_funcs,
295323
alt_model_funcs_sim=self.alternative_sim_funcs,
296324
)
297325

298-
def solve_and_simulate_function_to_jit(params):
326+
(
327+
model_structure_for_jit,
328+
batch_info_for_jit,
329+
model_structure_non_jit,
330+
batch_info_non_jit,
331+
) = split_structure_and_batch_info(self.model_structure, self.batch_info)
332+
333+
def solve_and_simulate_function_to_jit(
334+
params, model_structure_jit, batch_info_jit
335+
):
299336
params_processed = process_params(params, self.params_check_info)
300-
# Solve the model
301-
value, policy, endog_grid = self.backward_induction_inner_jit(
302-
params_processed
337+
338+
# Merge back parts together. The non_jit objects are fixed in the closure.
339+
model_structure = merge_non_jit_and_jit_model_structure(
340+
model_structure_jit, model_structure_non_jit
341+
)
342+
batch_info = merg_non_jit_batch_info_and_jit_batch_info(
343+
batch_info_jit, batch_info_non_jit
344+
)
345+
346+
# Solve the model.
347+
value, policy, endog_grid = backward_induction(
348+
params=params_processed,
349+
model_structure=model_structure,
350+
batch_info=batch_info,
351+
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
352+
income_shock_weights=self.income_shock_weights,
353+
model_config=self.model_config,
354+
model_funcs=self.model_funcs,
303355
)
304356

305357
sim_dict = sim_func(
306358
params=params_processed,
307359
value=value,
308360
policy=policy,
309361
endog_gid=endog_grid,
362+
model_structure=model_structure,
310363
)
311364

312365
return sim_dict
313366

314-
if slow_version:
315-
solve_simulate_func = solve_and_simulate_function_to_jit
316-
else:
317-
solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit)
367+
solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit)
318368

369+
# Generate the function. The user only needs to provide params, but we call with the objects for jit.
319370
def solve_and_simulate_function(params):
320-
sim_dict = solve_simulate_func(params)
371+
sim_dict = solve_simulate_func(
372+
params, model_structure_for_jit, batch_info_for_jit
373+
)
321374
df = create_simulation_df(sim_dict)
322375
return df
323376

@@ -335,11 +388,13 @@ def create_experimental_ll_func(
335388
):
336389

337390
return create_individual_likelihood_function(
391+
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
392+
income_shock_weights=self.income_shock_weights,
393+
batch_info=self.batch_info,
338394
model_structure=self.model_structure,
339395
model_config=self.model_config,
340396
model_funcs=self.model_funcs,
341397
model_specs=self.model_specs,
342-
backwards_induction_inner_jit=self.backward_induction_inner_jit,
343398
observed_states=observed_states,
344399
observed_choices=observed_choices,
345400
params_all=params_all,
@@ -475,3 +530,22 @@ def solve_partially(self, params, n_periods, return_candidates=False):
475530
n_periods=n_periods,
476531
return_candidates=return_candidates,
477532
)
533+
534+
def set_alternative_sim_funcs(
535+
self, alternative_sim_specifications, alternative_specs=None
536+
):
537+
if alternative_specs is None:
538+
self.alternative_sim_specs = self.model_specs
539+
alternative_specs_without_jax = self.specs_without_jax
540+
else:
541+
self.alternative_sim_specs = jax.tree_util.tree_map(
542+
try_jax_array, alternative_specs
543+
)
544+
alternative_specs_without_jax = alternative_specs
545+
546+
alternative_sim_funcs = generate_alternative_sim_functions(
547+
model_specs=alternative_specs_without_jax,
548+
model_specs_jax=self.alternative_sim_specs,
549+
**alternative_sim_specifications,
550+
)
551+
self.alternative_sim_funcs = alternative_sim_funcs

0 commit comments

Comments
 (0)