@@ -140,10 +140,10 @@ def combine_transition_plots(
140140 return fig
141141
142142
143- def get_transition_plots (
143+ def get_transition_plots ( # noqa: C901, PLR0912
144144 model_spec : ModelSpec ,
145145 params : pd .DataFrame ,
146- data : pd .DataFrame ,
146+ data : pd .DataFrame | None = None ,
147147 period : int | None = None ,
148148 periods : Sequence [int ] | None = None ,
149149 state_ranges : dict [str , pd .DataFrame ] | None = None ,
@@ -159,14 +159,16 @@ def get_transition_plots(
159159 state_range_quantile_cutoff : float | None = None ,
160160 layout_kwargs : dict [str , Any ] | None = None ,
161161 * ,
162+ states : pd .DataFrame | None = None ,
162163 include_correction_factors : bool = False ,
163164) -> dict [tuple [str , str ], go .Figure ]:
164165 """Get dictionary with individual plots of transition equations for each factor.
165166
166167 Args:
167168 model_spec: The model specification. See: :ref:`model_specs`
168169 params: Model parameters.
169- data: Empirical dataset used to estimate the model.
170+ data: Empirical dataset used to estimate the model. Required when `states`
171+ is not provided or when the model has observed factors.
170172 period: The start period of the transition equations that are plotted.
171173 Deprecated in favor of `periods`. If both are provided, `periods` is used.
172174 periods: List of periods to overlay on each plot. Each period gets a different
@@ -194,6 +196,8 @@ def get_transition_plots(
194196 layout_kwargs: Dictionary of key word arguments used to
195197 update layout of plotly image object. If None, the default kwargs
196198 defined in the function will be used.
199+ states: Pre-computed filtered states DataFrame (with an `aug_period`
200+ column). If provided, skip the internal `get_filtered_states` call.
197201 include_correction_factors: Whether to include correction factors in the
198202 plots. Default False.
199203
@@ -247,9 +251,13 @@ def get_transition_plots(
247251 if not processed_model .endogenous_factors_info .factor_info [lf ].is_correction
248252 ]
249253 all_factors = processed_model .labels .all_factors
250- states = get_filtered_states (model_spec = model_spec , data = data , params = params )[
251- "anchored_states"
252- ]["states" ]
254+ if states is None :
255+ if data is None :
256+ msg = "Either 'data' or 'states' must be provided."
257+ raise TypeError (msg )
258+ states = get_filtered_states (model_spec = model_spec , data = data , params = params )[
259+ "anchored_states"
260+ ]["states" ]
253261 return _get_dictionary_with_plots (
254262 model = processed_model ,
255263 data = data ,
@@ -270,7 +278,7 @@ def get_transition_plots(
270278
271279def _get_dictionary_with_plots (
272280 model : ProcessedModel ,
273- data : pd .DataFrame ,
281+ data : pd .DataFrame | None ,
274282 params : pd .DataFrame ,
275283 states : pd .DataFrame ,
276284 state_ranges : dict [str , pd .DataFrame ] | None ,
@@ -582,17 +590,17 @@ def _set_index_params(
582590def _get_states_data (
583591 model : ProcessedModel ,
584592 period : int ,
585- data : pd .DataFrame ,
593+ data : pd .DataFrame | None ,
586594 states : pd .DataFrame ,
587595 observed_factors : tuple [str , ...],
588596) -> pd .DataFrame :
589- if observed_factors and data is None :
590- raise ValueError (
591- "The model has observed factors. You must pass the empirical data to "
592- "'visualize_transition_equations' via the keyword *data*." ,
593- )
594-
595597 if observed_factors :
598+ if data is None :
599+ msg = (
600+ "The model has observed factors. You must pass the empirical data to "
601+ "'get_transition_plots' via the keyword 'data'."
602+ )
603+ raise TypeError (msg )
596604 _observed_arr = process_data (
597605 df = data ,
598606 has_endogenous_factors = model .endogenous_factors_info .has_endogenous_factors ,
0 commit comments