Skip to content

Commit 67d7b44

Browse files
committed
Allow passing states directly to get_transition_plots (health-cognition use case).
1 parent 4a836ae commit 67d7b44

1 file changed

Lines changed: 22 additions & 14 deletions

File tree

src/skillmodels/visualize_transition_equations.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,10 @@ def combine_transition_plots(
140140
return fig
141141

142142

143-
def get_transition_plots(
143+
def get_transition_plots( # noqa: C901, PLR0912
144144
model_spec: ModelSpec,
145145
params: pd.DataFrame,
146-
data: pd.DataFrame,
146+
data: pd.DataFrame | None = None,
147147
period: int | None = None,
148148
periods: Sequence[int] | None = None,
149149
state_ranges: dict[str, pd.DataFrame] | None = None,
@@ -159,14 +159,16 @@ def get_transition_plots(
159159
state_range_quantile_cutoff: float | None = None,
160160
layout_kwargs: dict[str, Any] | None = None,
161161
*,
162+
states: pd.DataFrame | None = None,
162163
include_correction_factors: bool = False,
163164
) -> dict[tuple[str, str], go.Figure]:
164165
"""Get dictionary with individual plots of transition equations for each factor.
165166
166167
Args:
167168
model_spec: The model specification. See: :ref:`model_specs`
168169
params: Model parameters.
169-
data: Empirical dataset used to estimate the model.
170+
data: Empirical dataset used to estimate the model. Required when `states`
171+
is not provided or when the model has observed factors.
170172
period: The start period of the transition equations that are plotted.
171173
Deprecated in favor of `periods`. If both are provided, `periods` is used.
172174
periods: List of periods to overlay on each plot. Each period gets a different
@@ -194,6 +196,8 @@ def get_transition_plots(
194196
layout_kwargs: Dictionary of key word arguments used to
195197
update layout of plotly image object. If None, the default kwargs
196198
defined in the function will be used.
199+
states: Pre-computed filtered states DataFrame (with an `aug_period`
200+
column). If provided, skip the internal `get_filtered_states` call.
197201
include_correction_factors: Whether to include correction factors in the
198202
plots. Default False.
199203
@@ -247,9 +251,13 @@ def get_transition_plots(
247251
if not processed_model.endogenous_factors_info.factor_info[lf].is_correction
248252
]
249253
all_factors = processed_model.labels.all_factors
250-
states = get_filtered_states(model_spec=model_spec, data=data, params=params)[
251-
"anchored_states"
252-
]["states"]
254+
if states is None:
255+
if data is None:
256+
msg = "Either 'data' or 'states' must be provided."
257+
raise TypeError(msg)
258+
states = get_filtered_states(model_spec=model_spec, data=data, params=params)[
259+
"anchored_states"
260+
]["states"]
253261
return _get_dictionary_with_plots(
254262
model=processed_model,
255263
data=data,
@@ -270,7 +278,7 @@ def get_transition_plots(
270278

271279
def _get_dictionary_with_plots(
272280
model: ProcessedModel,
273-
data: pd.DataFrame,
281+
data: pd.DataFrame | None,
274282
params: pd.DataFrame,
275283
states: pd.DataFrame,
276284
state_ranges: dict[str, pd.DataFrame] | None,
@@ -582,17 +590,17 @@ def _set_index_params(
582590
def _get_states_data(
583591
model: ProcessedModel,
584592
period: int,
585-
data: pd.DataFrame,
593+
data: pd.DataFrame | None,
586594
states: pd.DataFrame,
587595
observed_factors: tuple[str, ...],
588596
) -> pd.DataFrame:
589-
if observed_factors and data is None:
590-
raise ValueError(
591-
"The model has observed factors. You must pass the empirical data to "
592-
"'visualize_transition_equations' via the keyword *data*.",
593-
)
594-
595597
if observed_factors:
598+
if data is None:
599+
msg = (
600+
"The model has observed factors. You must pass the empirical data to "
601+
"'get_transition_plots' via the keyword 'data'."
602+
)
603+
raise TypeError(msg)
596604
_observed_arr = process_data(
597605
df=data,
598606
has_endogenous_factors=model.endogenous_factors_info.has_endogenous_factors,

0 commit comments

Comments
 (0)