Skip to content

Commit 9a7e401

Browse files
hmgaudeckerclaude
andcommitted
Fix aug_period/period bug in transition equation visualization
The viz code assumed states DataFrames always have `aug_period` as a column, but pre-computed states (e.g. from health-cognition) may carry `period` in the index instead. Add `_normalize_states_columns` to promote index levels and rename `period` → `aug_period` when needed. Also document the period vs aug_period convention in CLAUDE.md. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 67d7b44 commit 9a7e401

2 files changed

Lines changed: 48 additions & 5 deletions

File tree

CLAUDE.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,27 @@ These are not in `__all__` but are imported directly by application projects:
236236
- `ensure_containers_are_immutable()` recursively converts dict→MappingProxyType,
237237
list→tuple, set→frozenset
238238

239+
### Period vs Aug_period
240+
241+
Models with endogenous factors split each calendar period into multiple **augmented
242+
periods** (`aug_period`). The public API should use `period` (user-facing); `aug_period`
243+
is an internal concept. Current status:
244+
245+
- `ModelSpec` — clean, no `aug_period` exposure.
246+
- `get_transition_plots()` — clean, accepts `period`/`periods`.
247+
- `get_filtered_states()`**leaks `aug_period`**: returned states DataFrames have an
248+
`aug_period` column, not `period`.
249+
- `simulate_dataset()` — mixed: `"anchored_states"` / `"unanchored_states"` use
250+
`period`, but `"aug_unanchored_states"` / `"aug_measurements"` use `aug_period`.
251+
- `plot_residual_boxplots()` / `plot_likelihood_contributions()` — accept `period` but
252+
return figures keyed by `aug_period`.
253+
- `decompose_measurement_variance()` — returns DataFrame indexed by `aug_period`.
254+
- `ProcessedModel.labels` — exposes `aug_periods_to_periods` mapping (acceptable for
255+
internal/advanced use).
256+
257+
When writing new public-facing code, always accept and return `period`. Convert to
258+
`aug_period` internally using `ProcessedModel.labels.aug_periods_to_periods`.
259+
239260
## Testing
240261

241262
- pytest with markers: `wip`, `unit`, `integration`, `end_to_end`

src/skillmodels/visualize_transition_equations.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ def get_transition_plots( # noqa: C901, PLR0912
258258
states = get_filtered_states(model_spec=model_spec, data=data, params=params)[
259259
"anchored_states"
260260
]["states"]
261+
262+
states = _normalize_states_columns(states)
263+
261264
return _get_dictionary_with_plots(
262265
model=processed_model,
263266
data=data,
@@ -371,8 +374,6 @@ def _get_dictionary_with_plots(
371374
else:
372375
colors = colorscale
373376

374-
period_col = "aug_period" if "aug_period" in states_data.columns else "period"
375-
376377
plots_dict = {}
377378
for output_factor, input_factor in itertools.product(latent_factors, all_factors):
378379
combined_data = _prepare_plot_data_for_factor_pair(
@@ -381,7 +382,6 @@ def _get_dictionary_with_plots(
381382
state_ranges=state_ranges,
382383
parsed_params=parsed_params,
383384
periods=periods,
384-
period_col=period_col,
385385
input_factor=input_factor,
386386
output_factor=output_factor,
387387
all_factors=all_factors,
@@ -414,7 +414,6 @@ def _prepare_plot_data_for_factor_pair(
414414
state_ranges: dict[str, pd.DataFrame],
415415
parsed_params: ParsedParams,
416416
periods: list[int],
417-
period_col: str,
418417
input_factor: str,
419418
output_factor: str,
420419
all_factors: tuple[str, ...],
@@ -438,7 +437,7 @@ def _prepare_plot_data_for_factor_pair(
438437
transition_params = {
439438
output_factor: parsed_params.transition[output_factor][aug_period]
440439
}
441-
period_states = states_data[states_data[period_col] == aug_period]
440+
period_states = states_data[states_data["aug_period"] == aug_period]
442441

443442
plot_data = _prepare_single_period_plot_data(
444443
states_data=period_states,
@@ -645,6 +644,29 @@ def _get_states_data(
645644
return states_data
646645

647646

647+
def _normalize_states_columns(states: pd.DataFrame) -> pd.DataFrame:
648+
"""Ensure `aug_period` and `id` are columns, not index levels.
649+
650+
Pre-computed states DataFrames may carry period information as `period`
651+
(in the index or a column) instead of `aug_period`. Downstream code
652+
uniformly expects `aug_period` as a column, so this helper promotes
653+
index levels to columns and renames `period` → `aug_period` when the
654+
latter is absent.
655+
"""
656+
# Promote relevant index levels to columns.
657+
names_to_reset = [
658+
n for n in states.index.names if n in ("period", "aug_period", "id")
659+
]
660+
if names_to_reset:
661+
states = states.reset_index(level=names_to_reset)
662+
663+
# Rename period → aug_period when aug_period is missing.
664+
if "aug_period" not in states.columns and "period" in states.columns:
665+
states = states.rename(columns={"period": "aug_period"})
666+
667+
return states
668+
669+
648670
def _prepare_data_for_one_plot_fixed_quantile_2d(
649671
states_data: pd.DataFrame,
650672
state_ranges: dict[str, pd.DataFrame],

0 commit comments

Comments
 (0)