@@ -44,8 +44,8 @@ def simulate_dataset(
4444 data: Dataset in the same format as for estimation, containing
4545 information about observed factors and control variables.
4646 policies: Each dictionary specifies a stochastic shock to a latent factor
47- AT THE END of "period" for "factor" with mean "effect_size" and
48- "standard deviation" .
47+ AT THE END of `` "period"`` for `` "factor"`` with mean
48+ ``"effect_size"`` and ``"standard_deviation"`` .
4949 seed: Random seed for reproducibility. If None, uses numpy's default random
5050 state.
5151
@@ -125,7 +125,14 @@ def simulate_dataset(
125125 n_obs = n_obs ,
126126 )
127127
128- aug_measurements , aug_latent_data = _simulate_dataset (
128+ # Convert "period" keys in policies to "aug_period" for internal use
129+ if policies is not None :
130+ policies = _convert_policy_periods (
131+ policies = policies ,
132+ endogenous_factors_info = processed_model .endogenous_factors_info ,
133+ )
134+
135+ _aug_measurements , aug_latent_data = _simulate_dataset (
129136 latent_states = states ,
130137 covs = covs ,
131138 log_weights = log_weights ,
@@ -173,14 +180,6 @@ def simulate_dataset(
173180 factors = processed_model .labels .latent_factors ,
174181 ),
175182 },
176- "aug_unanchored_states" : {
177- "states" : aug_latent_data ,
178- "state_ranges" : create_state_ranges (
179- filtered_states = aug_latent_data ,
180- factors = processed_model .labels .latent_factors ,
181- ),
182- },
183- "aug_measurements" : aug_measurements ,
184183 }
185184
186185
@@ -396,6 +395,31 @@ def _collapse_aug_periods_to_periods(
396395 )
397396
398397
398+ def _convert_policy_periods (
399+ policies : list [dict ],
400+ endogenous_factors_info : EndogenousFactorsInfo ,
401+ ) -> list [dict ]:
402+ """Convert ``"period"`` keys in policy dicts to ``"aug_period"``.
403+
404+ Policies may specify either ``"period"`` (public API) or ``"aug_period"``
405+ (legacy/internal). This normalises to ``"aug_period"`` for the simulation loop.
406+ """
407+ converted = []
408+ for policy in policies :
409+ if "aug_period" in policy :
410+ converted .append (policy )
411+ elif "period" in policy :
412+ p = dict (policy )
413+ period = p .pop ("period" )
414+ aug_periods = endogenous_factors_info .aug_periods_from_period (period )
415+ # Use the first aug_period for the given period
416+ p ["aug_period" ] = aug_periods [0 ]
417+ converted .append (p )
418+ else :
419+ raise ValueError ("Each policy dict must contain a 'period' key." )
420+ return converted
421+
422+
399423def _get_shock (
400424 rng : np .random .Generator ,
401425 mean : float ,
@@ -514,7 +538,7 @@ def simulate_policy_effect(
514538 data: Dataset with observed factors and control variables.
515539 policies: List of policy dictionaries. Each dictionary specifies a
516540 stochastic shock to a latent factor with keys:
517- - "period" or "aug_period" : When to apply the shock
541+ - "period": When to apply the shock
518542 - "factor": Which factor to shock
519543 - "effect_size": Mean of the shock
520544 - "standard_deviation": Standard deviation of the shock (use 0 for
@@ -564,8 +588,6 @@ def simulate_policy_effect(
564588 policy_means = policy_states .groupby ("period" ).mean ()
565589
566590 # Drop non-factor columns
567- factor_cols = [
568- c for c in baseline_means .columns if c not in ("id" , "aug_period" , "period" )
569- ]
591+ factor_cols = [c for c in baseline_means .columns if c not in ("id" , "period" )]
570592
571593 return policy_means [factor_cols ] - baseline_means [factor_cols ]
0 commit comments