Skip to content

Commit d23f80c

Browse files
committed
Remove aug_periods from public-facing functions.
1 parent 9a7e401 commit d23f80c

9 files changed

Lines changed: 116 additions & 61 deletions

CLAUDE.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ The codebase uses:
174174
**Simulation:**
175175

176176
- `simulate_dataset(model_spec, params, n_obs=None, data=None, policies=None, seed=None)`
177-
— returns dict with `"unanchored_states"`, `"anchored_states"`,
178-
`"aug_unanchored_states"`, `"aug_measurements"`
177+
— returns dict with `"unanchored_states"`, `"anchored_states"`
179178
- `simulate_policy_effect(model_spec, params, data, policies, seed=None)` — returns
180179
DataFrame of factor mean differences between policy and baseline
181180

@@ -184,7 +183,7 @@ The codebase uses:
184183
- `plot_likelihood_contributions(model_spec, data, params, period=None)`
185184
- `plot_residual_boxplots(model_spec, data, params, period=None)`
186185
- `decompose_measurement_variance(model_spec, params, data)` — returns DataFrame indexed
187-
by `(aug_period, measurement, factor)` with signal/noise columns
186+
by `(period, measurement, factor)` with signal/noise columns
188187
- `summarize_measurement_reliability(variance_decomposition)`
189188
- `create_state_ranges(filtered_states, factors, quantile_cutoff=None)`
190189

@@ -239,18 +238,18 @@ These are not in `__all__` but are imported directly by application projects:
239238
### Period vs Aug_period
240239

241240
Models with endogenous factors split each calendar period into multiple **augmented
242-
periods** (`aug_period`). The public API should use `period` (user-facing); `aug_period`
243-
is an internal concept. Current status:
241+
periods** (`aug_period`). The public API uses `period` (user-facing); `aug_period` is
242+
strictly internal. All public functions now return `period`:
244243

245244
- `ModelSpec` — clean, no `aug_period` exposure.
246245
- `get_transition_plots()` — clean, accepts `period`/`periods`.
247-
- `get_filtered_states()`**leaks `aug_period`**: returned states DataFrames have an
248-
`aug_period` column, not `period`.
249-
- `simulate_dataset()` — mixed: `"anchored_states"` / `"unanchored_states"` use
250-
`period`, but `"aug_unanchored_states"` / `"aug_measurements"` use `aug_period`.
251-
- `plot_residual_boxplots()` / `plot_likelihood_contributions()` — accept `period` but
252-
return figures keyed by `aug_period`.
253-
- `decompose_measurement_variance()` — returns DataFrame indexed by `aug_period`.
246+
- `get_filtered_states()`clean, returns `period` column.
247+
- `simulate_dataset()` — clean, returns `period` in states DataFrames.
248+
- `plot_residual_boxplots()` / `plot_likelihood_contributions()` — clean, accept and
249+
return `period`.
250+
- `decompose_measurement_variance()` — clean, indexed by
251+
`(period, measurement, factor)`.
252+
- `simulate_policy_effect()` / `simulate_dataset()` policies — accept `"period"` key.
254253
- `ProcessedModel.labels` — exposes `aug_periods_to_periods` mapping (acceptable for
255254
internal/advanced use).
256255

src/skillmodels/diagnostic_plots.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,25 @@ def plot_residual_boxplots(
6161
how="left",
6262
)
6363

64+
# Map aug_period → period for the public API
65+
ap_to_p = processed_model.labels.aug_periods_to_periods
66+
6467
available_periods = sorted(residuals_df[period_col].unique())
6568

6669
if period is not None:
70+
# Find aug_period(s) matching the requested period
71+
aug_periods_for_period = [ap for ap, p in ap_to_p.items() if p == period]
72+
aug_period = aug_periods_for_period[0] if aug_periods_for_period else period
6773
return _create_residual_boxplot_for_period(
6874
residuals_df=residuals_df,
69-
period=period,
75+
period=aug_period,
7076
period_col=period_col,
7177
show_reference_line=show_reference_line,
7278
layout_kwargs=layout_kwargs,
7379
)
7480

7581
return {
76-
p: _create_residual_boxplot_for_period(
82+
ap_to_p.get(p, p): _create_residual_boxplot_for_period(
7783
residuals_df=residuals_df,
7884
period=p,
7985
period_col=period_col,
@@ -173,18 +179,24 @@ def plot_likelihood_contributions(
173179
how="left",
174180
)
175181

182+
# Map aug_period → period for the public API
183+
ap_to_p = processed_model.labels.aug_periods_to_periods
184+
176185
available_periods = sorted(contributions_df[period_col].unique())
177186

178187
if period is not None:
188+
# Find aug_period(s) matching the requested period
189+
aug_periods_for_period = [ap for ap, p in ap_to_p.items() if p == period]
190+
aug_period = aug_periods_for_period[0] if aug_periods_for_period else period
179191
return _create_likelihood_boxplot_for_period(
180192
contributions_df=contributions_df,
181-
period=period,
193+
period=aug_period,
182194
period_col=period_col,
183195
layout_kwargs=layout_kwargs,
184196
)
185197

186198
return {
187-
p: _create_likelihood_boxplot_for_period(
199+
ap_to_p.get(p, p): _create_likelihood_boxplot_for_period(
188200
contributions_df=contributions_df,
189201
period=p,
190202
period_col=period_col,

src/skillmodels/filtered_states.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,20 @@ def get_filtered_states(
3535
use_aug_period=True,
3636
)
3737

38+
# Map aug_period → period for the public API
39+
ap_to_p = processed_model.labels.aug_periods_to_periods
40+
for df in (anchored_states_df, unanchored_states_df):
41+
df["period"] = df["aug_period"].map(ap_to_p)
42+
df.drop(columns="aug_period", inplace=True) # noqa: PD002
43+
3844
anchored_ranges = create_state_ranges(
3945
filtered_states=anchored_states_df,
4046
factors=processed_model.labels.latent_factors,
4147
)
48+
unanchored_ranges = create_state_ranges(
49+
filtered_states=unanchored_states_df,
50+
factors=processed_model.labels.latent_factors,
51+
)
4252

4353
return {
4454
"anchored_states": {

src/skillmodels/process_debug_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def create_state_ranges(
182182
"""Compute minimum and maximum state values for each factor by period.
183183
184184
Args:
185-
filtered_states: DataFrame with filtered states. Must have a "period" or
186-
"aug_period" column.
185+
filtered_states: DataFrame with filtered states. Must have a "period"
186+
column.
187187
factors: List of factor names to compute ranges for.
188188
quantile_cutoff: If provided, use quantiles instead of min/max. The cutoff
189189
is applied symmetrically: the minimum is the `quantile_cutoff` quantile

src/skillmodels/simulate_data.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def simulate_dataset(
4444
data: Dataset in the same format as for estimation, containing
4545
information about observed factors and control variables.
4646
policies: Each dictionary specifies a stochastic shock to a latent factor
47-
AT THE END of "period" for "factor" with mean "effect_size" and
48-
"standard deviation".
47+
AT THE END of ``"period"`` for ``"factor"`` with mean
48+
``"effect_size"`` and ``"standard_deviation"``.
4949
seed: Random seed for reproducibility. If None, uses numpy's default random
5050
state.
5151
@@ -125,7 +125,14 @@ def simulate_dataset(
125125
n_obs=n_obs,
126126
)
127127

128-
aug_measurements, aug_latent_data = _simulate_dataset(
128+
# Convert "period" keys in policies to "aug_period" for internal use
129+
if policies is not None:
130+
policies = _convert_policy_periods(
131+
policies=policies,
132+
endogenous_factors_info=processed_model.endogenous_factors_info,
133+
)
134+
135+
_aug_measurements, aug_latent_data = _simulate_dataset(
129136
latent_states=states,
130137
covs=covs,
131138
log_weights=log_weights,
@@ -173,14 +180,6 @@ def simulate_dataset(
173180
factors=processed_model.labels.latent_factors,
174181
),
175182
},
176-
"aug_unanchored_states": {
177-
"states": aug_latent_data,
178-
"state_ranges": create_state_ranges(
179-
filtered_states=aug_latent_data,
180-
factors=processed_model.labels.latent_factors,
181-
),
182-
},
183-
"aug_measurements": aug_measurements,
184183
}
185184

186185

@@ -396,6 +395,31 @@ def _collapse_aug_periods_to_periods(
396395
)
397396

398397

398+
def _convert_policy_periods(
399+
policies: list[dict],
400+
endogenous_factors_info: EndogenousFactorsInfo,
401+
) -> list[dict]:
402+
"""Convert ``"period"`` keys in policy dicts to ``"aug_period"``.
403+
404+
Policies may specify either ``"period"`` (public API) or ``"aug_period"``
405+
(legacy/internal). This normalises to ``"aug_period"`` for the simulation loop.
406+
"""
407+
converted = []
408+
for policy in policies:
409+
if "aug_period" in policy:
410+
converted.append(policy)
411+
elif "period" in policy:
412+
p = dict(policy)
413+
period = p.pop("period")
414+
aug_periods = endogenous_factors_info.aug_periods_from_period(period)
415+
# Use the first aug_period for the given period
416+
p["aug_period"] = aug_periods[0]
417+
converted.append(p)
418+
else:
419+
raise ValueError("Each policy dict must contain a 'period' key.")
420+
return converted
421+
422+
399423
def _get_shock(
400424
rng: np.random.Generator,
401425
mean: float,
@@ -514,7 +538,7 @@ def simulate_policy_effect(
514538
data: Dataset with observed factors and control variables.
515539
policies: List of policy dictionaries. Each dictionary specifies a
516540
stochastic shock to a latent factor with keys:
517-
- "period" or "aug_period": When to apply the shock
541+
- "period": When to apply the shock
518542
- "factor": Which factor to shock
519543
- "effect_size": Mean of the shock
520544
- "standard_deviation": Standard deviation of the shock (use 0 for
@@ -564,8 +588,6 @@ def simulate_policy_effect(
564588
policy_means = policy_states.groupby("period").mean()
565589

566590
# Drop non-factor columns
567-
factor_cols = [
568-
c for c in baseline_means.columns if c not in ("id", "aug_period", "period")
569-
]
591+
factor_cols = [c for c in baseline_means.columns if c not in ("id", "period")]
570592

571593
return policy_means[factor_cols] - baseline_means[factor_cols]

src/skillmodels/variance_decomposition.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
Section 4.2.2.
66
"""
77

8+
from collections.abc import Mapping
9+
810
import pandas as pd
911

1012
from skillmodels.filtered_states import get_filtered_states
1113
from skillmodels.model_spec import ModelSpec
14+
from skillmodels.process_model import process_model
1215

1316

1417
def decompose_measurement_variance(
@@ -34,7 +37,7 @@ def decompose_measurement_variance(
3437
data: Empirical dataset used to estimate the model.
3538
3639
Returns:
37-
DataFrame indexed by (aug_period, measurement, factor) with columns:
40+
DataFrame indexed by (period, measurement, factor) with columns:
3841
- loading: The factor loading (L)
3942
- factor_variance: Var(F) for that period
4043
- meas_sd: The measurement error standard deviation
@@ -54,32 +57,50 @@ def decompose_measurement_variance(
5457
)
5558
filtered_states = filtered_result["anchored_states"]["states"]
5659

60+
processed_model = process_model(model_spec)
5761
return _compute_variance_decomposition(
5862
filtered_states=filtered_states,
5963
params=params,
64+
aug_periods_to_periods=processed_model.labels.aug_periods_to_periods,
6065
)
6166

6267

6368
def _compute_variance_decomposition(
6469
filtered_states: pd.DataFrame,
6570
params: pd.DataFrame,
71+
aug_periods_to_periods: Mapping[int, int],
6672
) -> pd.DataFrame:
6773
"""Compute variance decomposition from filtered states and parameters.
6874
6975
Args:
7076
filtered_states: DataFrame with filtered states, must have columns for
71-
each factor plus "aug_period" and "id".
77+
each factor plus "period" and "id".
7278
params: DataFrame with model parameters indexed by
7379
(category, aug_period, name1, name2).
80+
aug_periods_to_periods: Mapping from aug_period to period.
7481
7582
Returns:
7683
DataFrame with variance decomposition results.
7784
7885
"""
86+
# Build reverse mapping: period → aug_period (pick first aug_period per period)
87+
periods_to_aug_periods = {}
88+
for ap, p in aug_periods_to_periods.items():
89+
if p not in periods_to_aug_periods:
90+
periods_to_aug_periods[p] = ap
91+
92+
# Add aug_period column for internal merges with params
93+
filtered_states = filtered_states.copy()
94+
filtered_states["aug_period"] = filtered_states["period"].map(
95+
periods_to_aug_periods
96+
)
97+
7998
# Compute factor variances by period
8099
periods = filtered_states["aug_period"].unique()
81100
factor_cols = [
82-
c for c in filtered_states.columns if c not in ("aug_period", "id", "mixture")
101+
c
102+
for c in filtered_states.columns
103+
if c not in ("aug_period", "period", "id", "mixture")
83104
]
84105

85106
factor_variances = {}
@@ -133,8 +154,11 @@ def _compute_variance_decomposition(
133154
merged["fraction_noise"] = noise_var / total_var
134155
merged["signal_to_noise_ratio"] = signal_var / noise_var
135156

157+
# Map aug_period → period for the public API
158+
merged["period"] = merged["aug_period"].map(aug_periods_to_periods)
159+
136160
# Set index and select columns
137-
return merged.set_index(["aug_period", "measurement", "factor"])[
161+
return merged.set_index(["period", "measurement", "factor"])[
138162
[
139163
"loading",
140164
"factor_variance",

src/skillmodels/visualize_transition_equations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def get_transition_plots( # noqa: C901, PLR0912
196196
layout_kwargs: Dictionary of key word arguments used to
197197
update layout of plotly image object. If None, the default kwargs
198198
defined in the function will be used.
199-
states: Pre-computed filtered states DataFrame (with an `aug_period`
199+
states: Pre-computed filtered states DataFrame (with a `period`
200200
column). If provided, skip the internal `get_filtered_states` call.
201201
include_correction_factors: Whether to include correction factors in the
202202
plots. Default False.

tests/test_variance_decomposition.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def setup_variance_decomposition():
2222
"fac1": [0.1, 0.1, 0.1, 0.2],
2323
"fac2": [0.1, 0.1, 0.1, 0.1],
2424
"fac3": [0.2, 0.2, 0.2, 0.4],
25-
"aug_period": [0, 0, 0, 0],
25+
"period": [0, 0, 0, 0],
2626
"id": [0, 1, 2, 3],
2727
}
2828
)
@@ -60,7 +60,11 @@ def setup_variance_decomposition():
6060

6161
params = pd.concat([loadings_df, meas_sds_df], keys=["loadings", "meas_sds"])
6262

63-
return {"filtered_states": filtered_states, "params": params}
63+
return {
64+
"filtered_states": filtered_states,
65+
"params": params,
66+
"aug_periods_to_periods": {0: 0},
67+
}
6468

6569

6670
@pytest.fixture
@@ -74,7 +78,7 @@ def expected_variance_decomposition():
7478
"""
7579
index = pd.MultiIndex.from_tuples(
7680
[(0, "y1", "fac1"), (0, "y2", "fac2"), (0, "y3", "fac3")],
77-
names=["aug_period", "measurement", "factor"],
81+
names=["period", "measurement", "factor"],
7882
)
7983
return pd.DataFrame(
8084
{

0 commit comments

Comments
 (0)