Skip to content

Commit c1256f9

Browse files
committed
Merge branch 'output-formatting' into om.Constraints
2 parents a32a177 + d23f80c commit c1256f9

9 files changed

Lines changed: 155 additions & 57 deletions

CLAUDE.md

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ The codebase uses:
174174
**Simulation:**
175175

176176
- `simulate_dataset(model_spec, params, n_obs=None, data=None, policies=None, seed=None)`
177-
— returns dict with `"unanchored_states"`, `"anchored_states"`,
178-
`"aug_unanchored_states"`, `"aug_measurements"`
177+
— returns dict with `"unanchored_states"`, `"anchored_states"`
179178
- `simulate_policy_effect(model_spec, params, data, policies, seed=None)` — returns
180179
DataFrame of factor mean differences between policy and baseline
181180

@@ -184,7 +183,7 @@ The codebase uses:
184183
- `plot_likelihood_contributions(model_spec, data, params, period=None)`
185184
- `plot_residual_boxplots(model_spec, data, params, period=None)`
186185
- `decompose_measurement_variance(model_spec, params, data)` — returns DataFrame indexed
187-
by `(aug_period, measurement, factor)` with signal/noise columns
186+
by `(period, measurement, factor)` with signal/noise columns
188187
- `summarize_measurement_reliability(variance_decomposition)`
189188
- `create_state_ranges(filtered_states, factors, quantile_cutoff=None)`
190189

@@ -237,6 +236,27 @@ These are not in `__all__` but are imported directly by application projects:
237236
- `ensure_containers_are_immutable()` recursively converts dict→MappingProxyType,
238237
list→tuple, set→frozenset
239238

239+
### Period vs Aug_period
240+
241+
Models with endogenous factors split each calendar period into multiple **augmented
242+
periods** (`aug_period`). The public API uses `period` (user-facing); `aug_period` is
243+
strictly internal. All public functions now return `period`:
244+
245+
- `ModelSpec` — clean, no `aug_period` exposure.
246+
- `get_transition_plots()` — clean, accepts `period`/`periods`.
247+
- `get_filtered_states()` — clean, returns `period` column.
248+
- `simulate_dataset()` — clean, returns `period` in states DataFrames.
249+
- `plot_residual_boxplots()` / `plot_likelihood_contributions()` — clean, accept and
250+
return `period`.
251+
- `decompose_measurement_variance()` — clean, indexed by
252+
`(period, measurement, factor)`.
253+
- `simulate_policy_effect()` / `simulate_dataset()` policies — accept `"period"` key.
254+
- `ProcessedModel.labels` — exposes `aug_periods_to_periods` mapping (acceptable for
255+
internal/advanced use).
256+
257+
When writing new public-facing code, always accept and return `period`. Convert to
258+
`aug_period` internally using `ProcessedModel.labels.aug_periods_to_periods`.
259+
240260
## Testing
241261

242262
- pytest with markers: `wip`, `unit`, `integration`, `end_to_end`

src/skillmodels/diagnostic_plots.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,25 @@ def plot_residual_boxplots(
6161
how="left",
6262
)
6363

64+
# Map aug_period → period for the public API
65+
ap_to_p = processed_model.labels.aug_periods_to_periods
66+
6467
available_periods = sorted(residuals_df[period_col].unique())
6568

6669
if period is not None:
70+
# Find aug_period(s) matching the requested period
71+
aug_periods_for_period = [ap for ap, p in ap_to_p.items() if p == period]
72+
aug_period = aug_periods_for_period[0] if aug_periods_for_period else period
6773
return _create_residual_boxplot_for_period(
6874
residuals_df=residuals_df,
69-
period=period,
75+
period=aug_period,
7076
period_col=period_col,
7177
show_reference_line=show_reference_line,
7278
layout_kwargs=layout_kwargs,
7379
)
7480

7581
return {
76-
p: _create_residual_boxplot_for_period(
82+
ap_to_p.get(p, p): _create_residual_boxplot_for_period(
7783
residuals_df=residuals_df,
7884
period=p,
7985
period_col=period_col,
@@ -173,18 +179,24 @@ def plot_likelihood_contributions(
173179
how="left",
174180
)
175181

182+
# Map aug_period → period for the public API
183+
ap_to_p = processed_model.labels.aug_periods_to_periods
184+
176185
available_periods = sorted(contributions_df[period_col].unique())
177186

178187
if period is not None:
188+
# Find aug_period(s) matching the requested period
189+
aug_periods_for_period = [ap for ap, p in ap_to_p.items() if p == period]
190+
aug_period = aug_periods_for_period[0] if aug_periods_for_period else period
179191
return _create_likelihood_boxplot_for_period(
180192
contributions_df=contributions_df,
181-
period=period,
193+
period=aug_period,
182194
period_col=period_col,
183195
layout_kwargs=layout_kwargs,
184196
)
185197

186198
return {
187-
p: _create_likelihood_boxplot_for_period(
199+
ap_to_p.get(p, p): _create_likelihood_boxplot_for_period(
188200
contributions_df=contributions_df,
189201
period=p,
190202
period_col=period_col,

src/skillmodels/filtered_states.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,20 @@ def get_filtered_states(
3535
use_aug_period=True,
3636
)
3737

38+
# Map aug_period → period for the public API
39+
ap_to_p = processed_model.labels.aug_periods_to_periods
40+
for df in (anchored_states_df, unanchored_states_df):
41+
df["period"] = df["aug_period"].map(ap_to_p)
42+
df.drop(columns="aug_period", inplace=True) # noqa: PD002
43+
3844
anchored_ranges = create_state_ranges(
3945
filtered_states=anchored_states_df,
4046
factors=processed_model.labels.latent_factors,
4147
)
48+
unanchored_ranges = create_state_ranges(
49+
filtered_states=unanchored_states_df,
50+
factors=processed_model.labels.latent_factors,
51+
)
4252

4353
return {
4454
"anchored_states": {

src/skillmodels/process_debug_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def create_state_ranges(
182182
"""Compute minimum and maximum state values for each factor by period.
183183
184184
Args:
185-
filtered_states: DataFrame with filtered states. Must have a "period" or
186-
"aug_period" column.
185+
filtered_states: DataFrame with filtered states. Must have a "period"
186+
column.
187187
factors: List of factor names to compute ranges for.
188188
quantile_cutoff: If provided, use quantiles instead of min/max. The cutoff
189189
is applied symmetrically: the minimum is the `quantile_cutoff` quantile

src/skillmodels/simulate_data.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def simulate_dataset(
4444
data: Dataset in the same format as for estimation, containing
4545
information about observed factors and control variables.
4646
policies: Each dictionary specifies a stochastic shock to a latent factor
47-
AT THE END of "period" for "factor" with mean "effect_size" and
48-
"standard deviation".
47+
AT THE END of ``"period"`` for ``"factor"`` with mean
48+
``"effect_size"`` and ``"standard_deviation"``.
4949
seed: Random seed for reproducibility. If None, uses numpy's default random
5050
state.
5151
@@ -125,7 +125,14 @@ def simulate_dataset(
125125
n_obs=n_obs,
126126
)
127127

128-
aug_measurements, aug_latent_data = _simulate_dataset(
128+
# Convert "period" keys in policies to "aug_period" for internal use
129+
if policies is not None:
130+
policies = _convert_policy_periods(
131+
policies=policies,
132+
endogenous_factors_info=processed_model.endogenous_factors_info,
133+
)
134+
135+
_aug_measurements, aug_latent_data = _simulate_dataset(
129136
latent_states=states,
130137
covs=covs,
131138
log_weights=log_weights,
@@ -173,14 +180,6 @@ def simulate_dataset(
173180
factors=processed_model.labels.latent_factors,
174181
),
175182
},
176-
"aug_unanchored_states": {
177-
"states": aug_latent_data,
178-
"state_ranges": create_state_ranges(
179-
filtered_states=aug_latent_data,
180-
factors=processed_model.labels.latent_factors,
181-
),
182-
},
183-
"aug_measurements": aug_measurements,
184183
}
185184

186185

@@ -396,6 +395,31 @@ def _collapse_aug_periods_to_periods(
396395
)
397396

398397

398+
def _convert_policy_periods(
399+
policies: list[dict],
400+
endogenous_factors_info: EndogenousFactorsInfo,
401+
) -> list[dict]:
402+
"""Convert ``"period"`` keys in policy dicts to ``"aug_period"``.
403+
404+
Policies may specify either ``"period"`` (public API) or ``"aug_period"``
405+
(legacy/internal). This normalises to ``"aug_period"`` for the simulation loop.
406+
"""
407+
converted = []
408+
for policy in policies:
409+
if "aug_period" in policy:
410+
converted.append(policy)
411+
elif "period" in policy:
412+
p = dict(policy)
413+
period = p.pop("period")
414+
aug_periods = endogenous_factors_info.aug_periods_from_period(period)
415+
# Use the first aug_period for the given period
416+
p["aug_period"] = aug_periods[0]
417+
converted.append(p)
418+
else:
419+
raise ValueError("Each policy dict must contain a 'period' key.")
420+
return converted
421+
422+
399423
def _get_shock(
400424
rng: np.random.Generator,
401425
mean: float,
@@ -514,7 +538,7 @@ def simulate_policy_effect(
514538
data: Dataset with observed factors and control variables.
515539
policies: List of policy dictionaries. Each dictionary specifies a
516540
stochastic shock to a latent factor with keys:
517-
- "period" or "aug_period": When to apply the shock
541+
- "period": When to apply the shock
518542
- "factor": Which factor to shock
519543
- "effect_size": Mean of the shock
520544
- "standard_deviation": Standard deviation of the shock (use 0 for
@@ -564,8 +588,6 @@ def simulate_policy_effect(
564588
policy_means = policy_states.groupby("period").mean()
565589

566590
# Drop non-factor columns
567-
factor_cols = [
568-
c for c in baseline_means.columns if c not in ("id", "aug_period", "period")
569-
]
591+
factor_cols = [c for c in baseline_means.columns if c not in ("id", "period")]
570592

571593
return policy_means[factor_cols] - baseline_means[factor_cols]

src/skillmodels/variance_decomposition.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
Section 4.2.2.
66
"""
77

8+
from collections.abc import Mapping
9+
810
import pandas as pd
911

1012
from skillmodels.filtered_states import get_filtered_states
1113
from skillmodels.model_spec import ModelSpec
14+
from skillmodels.process_model import process_model
1215

1316

1417
def decompose_measurement_variance(
@@ -34,7 +37,7 @@ def decompose_measurement_variance(
3437
data: Empirical dataset used to estimate the model.
3538
3639
Returns:
37-
DataFrame indexed by (aug_period, measurement, factor) with columns:
40+
DataFrame indexed by (period, measurement, factor) with columns:
3841
- loading: The factor loading (L)
3942
- factor_variance: Var(F) for that period
4043
- meas_sd: The measurement error standard deviation
@@ -54,32 +57,50 @@ def decompose_measurement_variance(
5457
)
5558
filtered_states = filtered_result["anchored_states"]["states"]
5659

60+
processed_model = process_model(model_spec)
5761
return _compute_variance_decomposition(
5862
filtered_states=filtered_states,
5963
params=params,
64+
aug_periods_to_periods=processed_model.labels.aug_periods_to_periods,
6065
)
6166

6267

6368
def _compute_variance_decomposition(
6469
filtered_states: pd.DataFrame,
6570
params: pd.DataFrame,
71+
aug_periods_to_periods: Mapping[int, int],
6672
) -> pd.DataFrame:
6773
"""Compute variance decomposition from filtered states and parameters.
6874
6975
Args:
7076
filtered_states: DataFrame with filtered states, must have columns for
71-
each factor plus "aug_period" and "id".
77+
each factor plus "period" and "id".
7278
params: DataFrame with model parameters indexed by
7379
(category, aug_period, name1, name2).
80+
aug_periods_to_periods: Mapping from aug_period to period.
7481
7582
Returns:
7683
DataFrame with variance decomposition results.
7784
7885
"""
86+
# Build reverse mapping: period → aug_period (pick first aug_period per period)
87+
periods_to_aug_periods = {}
88+
for ap, p in aug_periods_to_periods.items():
89+
if p not in periods_to_aug_periods:
90+
periods_to_aug_periods[p] = ap
91+
92+
# Add aug_period column for internal merges with params
93+
filtered_states = filtered_states.copy()
94+
filtered_states["aug_period"] = filtered_states["period"].map(
95+
periods_to_aug_periods
96+
)
97+
7998
# Compute factor variances by period
8099
periods = filtered_states["aug_period"].unique()
81100
factor_cols = [
82-
c for c in filtered_states.columns if c not in ("aug_period", "id", "mixture")
101+
c
102+
for c in filtered_states.columns
103+
if c not in ("aug_period", "period", "id", "mixture")
83104
]
84105

85106
factor_variances = {}
@@ -133,8 +154,11 @@ def _compute_variance_decomposition(
133154
merged["fraction_noise"] = noise_var / total_var
134155
merged["signal_to_noise_ratio"] = signal_var / noise_var
135156

157+
# Map aug_period → period for the public API
158+
merged["period"] = merged["aug_period"].map(aug_periods_to_periods)
159+
136160
# Set index and select columns
137-
return merged.set_index(["aug_period", "measurement", "factor"])[
161+
return merged.set_index(["period", "measurement", "factor"])[
138162
[
139163
"loading",
140164
"factor_variance",

src/skillmodels/visualize_transition_equations.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def get_transition_plots( # noqa: C901, PLR0912
196196
layout_kwargs: Dictionary of key word arguments used to
197197
update layout of plotly image object. If None, the default kwargs
198198
defined in the function will be used.
199-
states: Pre-computed filtered states DataFrame (with an `aug_period`
199+
states: Pre-computed filtered states DataFrame (with a `period`
200200
column). If provided, skip the internal `get_filtered_states` call.
201201
include_correction_factors: Whether to include correction factors in the
202202
plots. Default False.
@@ -258,6 +258,9 @@ def get_transition_plots( # noqa: C901, PLR0912
258258
states = get_filtered_states(model_spec=model_spec, data=data, params=params)[
259259
"anchored_states"
260260
]["states"]
261+
262+
states = _normalize_states_columns(states)
263+
261264
return _get_dictionary_with_plots(
262265
model=processed_model,
263266
data=data,
@@ -371,8 +374,6 @@ def _get_dictionary_with_plots(
371374
else:
372375
colors = colorscale
373376

374-
period_col = "aug_period" if "aug_period" in states_data.columns else "period"
375-
376377
plots_dict = {}
377378
for output_factor, input_factor in itertools.product(latent_factors, all_factors):
378379
combined_data = _prepare_plot_data_for_factor_pair(
@@ -381,7 +382,6 @@ def _get_dictionary_with_plots(
381382
state_ranges=state_ranges,
382383
parsed_params=parsed_params,
383384
periods=periods,
384-
period_col=period_col,
385385
input_factor=input_factor,
386386
output_factor=output_factor,
387387
all_factors=all_factors,
@@ -414,7 +414,6 @@ def _prepare_plot_data_for_factor_pair(
414414
state_ranges: dict[str, pd.DataFrame],
415415
parsed_params: ParsedParams,
416416
periods: list[int],
417-
period_col: str,
418417
input_factor: str,
419418
output_factor: str,
420419
all_factors: tuple[str, ...],
@@ -438,7 +437,7 @@ def _prepare_plot_data_for_factor_pair(
438437
transition_params = {
439438
output_factor: parsed_params.transition[output_factor][aug_period]
440439
}
441-
period_states = states_data[states_data[period_col] == aug_period]
440+
period_states = states_data[states_data["aug_period"] == aug_period]
442441

443442
plot_data = _prepare_single_period_plot_data(
444443
states_data=period_states,
@@ -645,6 +644,29 @@ def _get_states_data(
645644
return states_data
646645

647646

647+
def _normalize_states_columns(states: pd.DataFrame) -> pd.DataFrame:
648+
"""Ensure `aug_period` and `id` are columns, not index levels.
649+
650+
Pre-computed states DataFrames may carry period information as `period`
651+
(in the index or a column) instead of `aug_period`. Downstream code
652+
uniformly expects `aug_period` as a column, so this helper promotes
653+
index levels to columns and renames `period` → `aug_period` when the
654+
latter is absent.
655+
"""
656+
# Promote relevant index levels to columns.
657+
names_to_reset = [
658+
n for n in states.index.names if n in ("period", "aug_period", "id")
659+
]
660+
if names_to_reset:
661+
states = states.reset_index(level=names_to_reset)
662+
663+
# Rename period → aug_period when aug_period is missing.
664+
if "aug_period" not in states.columns and "period" in states.columns:
665+
states = states.rename(columns={"period": "aug_period"})
666+
667+
return states
668+
669+
648670
def _prepare_data_for_one_plot_fixed_quantile_2d(
649671
states_data: pd.DataFrame,
650672
state_ranges: dict[str, pd.DataFrame],

0 commit comments

Comments
 (0)