Skip to content

Commit 350c3b9

Browse files
committed
Merge branch 'output-formatting' into om.Constraints
2 parents c1256f9 + 447719a commit 350c3b9

3 files changed

Lines changed: 42 additions & 8 deletions

File tree

src/skillmodels/simulate_data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,13 @@ def simulate_dataset(
180180
factors=processed_model.labels.latent_factors,
181181
),
182182
},
183+
"aug_unanchored_states": {
184+
"states": aug_latent_data,
185+
"state_ranges": create_state_ranges(
186+
filtered_states=aug_latent_data,
187+
factors=processed_model.labels.latent_factors,
188+
),
189+
},
183190
}
184191

185192

src/skillmodels/visualize_transition_equations.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Functions to visualize transition equations and production functions."""
22

33
import itertools
4-
from collections.abc import Callable, Sequence
4+
from collections.abc import Callable, Mapping, Sequence
55
from copy import deepcopy
66
from typing import Any, Literal
77

@@ -259,7 +259,10 @@ def get_transition_plots( # noqa: C901, PLR0912
259259
"anchored_states"
260260
]["states"]
261261

262-
states = _normalize_states_columns(states)
262+
states = _normalize_states_columns(
263+
states,
264+
aug_periods_to_periods=processed_model.labels.aug_periods_to_periods,
265+
)
263266

264267
return _get_dictionary_with_plots(
265268
model=processed_model,
@@ -644,14 +647,25 @@ def _get_states_data(
644647
return states_data
645648

646649

647-
def _normalize_states_columns(states: pd.DataFrame) -> pd.DataFrame:
650+
def _normalize_states_columns(
651+
states: pd.DataFrame,
652+
aug_periods_to_periods: Mapping[int, int] | None = None,
653+
) -> pd.DataFrame:
648654
"""Ensure `aug_period` and `id` are columns, not index levels.
649655
650656
Pre-computed states DataFrames may carry period information as `period`
651657
(in the index or a column) instead of `aug_period`. Downstream code
652658
uniformly expects `aug_period` as a column, so this helper promotes
653-
index levels to columns and renames `period` → `aug_period` when the
654-
latter is absent.
659+
index levels to columns and, when a mapping is provided, expands each
660+
period row into one row per corresponding aug_period.
661+
662+
Args:
663+
states: DataFrame with latent factor columns and either `period` or
664+
`aug_period` identifying the time dimension.
665+
aug_periods_to_periods: Mapping from aug_period to period. When
666+
provided and the DataFrame has `period` but not `aug_period`,
667+
rows are expanded so that each period produces one row per
668+
aug_period that maps to it.
655669
"""
656670
# Promote relevant index levels to columns.
657671
names_to_reset = [
@@ -660,8 +674,21 @@ def _normalize_states_columns(states: pd.DataFrame) -> pd.DataFrame:
660674
if names_to_reset:
661675
states = states.reset_index(level=names_to_reset)
662676

663-
# Rename period → aug_period when aug_period is missing.
664-
if "aug_period" not in states.columns and "period" in states.columns:
677+
if "aug_period" in states.columns:
678+
return states
679+
680+
if "period" not in states.columns:
681+
return states
682+
683+
# Expand period rows into aug_period rows using the mapping.
684+
if aug_periods_to_periods is not None:
685+
mapping_df = pd.DataFrame(
686+
list(aug_periods_to_periods.items()),
687+
columns=["aug_period", "period"],
688+
)
689+
states = states.merge(mapping_df, on="period", how="left")
690+
states = states.drop(columns=["period"])
691+
else:
665692
states = states.rename(columns={"period": "aug_period"})
666693

667694
return states

tests/test_visualize_factor_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_visualize_factor_distributions_runs_with_simulated_states() -> None:
6868
params = params.loc[max_inputs["params_template"].index]
6969

7070
latent_data = simulate_dataset(model, params, data=data, policies=None)[
71-
"unanchored_states"
71+
"aug_unanchored_states"
7272
]["states"]
7373

7474
kde = univariate_densities(

0 commit comments

Comments
 (0)