11"""Functions to visualize transition equations and production functions."""
22
33import itertools
4- from collections .abc import Callable , Sequence
4+ from collections .abc import Callable , Mapping , Sequence
55from copy import deepcopy
66from typing import Any , Literal
77
@@ -259,7 +259,10 @@ def get_transition_plots( # noqa: C901, PLR0912
259259 "anchored_states"
260260 ]["states" ]
261261
262- states = _normalize_states_columns (states )
262+ states = _normalize_states_columns (
263+ states ,
264+ aug_periods_to_periods = processed_model .labels .aug_periods_to_periods ,
265+ )
263266
264267 return _get_dictionary_with_plots (
265268 model = processed_model ,
@@ -644,14 +647,25 @@ def _get_states_data(
644647 return states_data
645648
646649
647- def _normalize_states_columns (states : pd .DataFrame ) -> pd .DataFrame :
650+ def _normalize_states_columns (
651+ states : pd .DataFrame ,
652+ aug_periods_to_periods : Mapping [int , int ] | None = None ,
653+ ) -> pd .DataFrame :
648654 """Ensure `aug_period` and `id` are columns, not index levels.
649655
650656 Pre-computed states DataFrames may carry period information as `period`
651657 (in the index or a column) instead of `aug_period`. Downstream code
652658 uniformly expects `aug_period` as a column, so this helper promotes
653- index levels to columns and renames `period` → `aug_period` when the
654- latter is absent.
659+ index levels to columns and, when a mapping is provided, expands each
660+ period row into one row per corresponding aug_period.
661+
662+ Args:
663+ states: DataFrame with latent factor columns and either `period` or
664+ `aug_period` identifying the time dimension.
665+ aug_periods_to_periods: Mapping from aug_period to period. When
666+ provided and the DataFrame has `period` but not `aug_period`,
667+ rows are expanded so that each period produces one row per
668+ aug_period that maps to it.
655669 """
656670 # Promote relevant index levels to columns.
657671 names_to_reset = [
@@ -660,8 +674,21 @@ def _normalize_states_columns(states: pd.DataFrame) -> pd.DataFrame:
660674 if names_to_reset :
661675 states = states .reset_index (level = names_to_reset )
662676
663- # Rename period → aug_period when aug_period is missing.
664- if "aug_period" not in states .columns and "period" in states .columns :
677+ if "aug_period" in states .columns :
678+ return states
679+
680+ if "period" not in states .columns :
681+ return states
682+
683+ # Expand period rows into aug_period rows using the mapping.
684+ if aug_periods_to_periods is not None :
685+ mapping_df = pd .DataFrame (
686+ list (aug_periods_to_periods .items ()),
687+ columns = ["aug_period" , "period" ],
688+ )
689+ states = states .merge (mapping_df , on = "period" , how = "left" )
690+ states = states .drop (columns = ["period" ])
691+ else :
665692 states = states .rename (columns = {"period" : "aug_period" })
666693
667694 return states
0 commit comments