@@ -1593,45 +1593,55 @@ def plot_errors_over_time(
15931593 config : dict ,
15941594 save : bool = True ,
15951595 show_title : bool = True ,
1596- mode : str = "relative" , # "relative" or "deltadex"
1596+ mode : str = "relative" , # "relative", "deltadex", or "iterative"
1597+ iter_interval : int | None = None ,
15971598) -> None :
15981599 """
1599- Plot errors over time for different surrogate models (relative or Δdex).
1600+ Plot errors over time for different surrogate models (relative, Δdex, or iterative Δdex).
16001601
16011602 Args:
16021603 mean_errors (dict): Mean errors for each surrogate.
16031604 median_errors (dict): Median errors for each surrogate.
16041605 timesteps (np.ndarray): Array of timesteps.
16051606 config (dict): Configuration dictionary.
1606- save (bool): Whether to save the figure.
1607- show_title (bool): Whether to add a title.
1608- mode (str): "relative" (percentage errors) or "deltadex" (log-space abs. errors).
1607+ save (bool): Whether to save the figure.
1608+ show_title (bool): Whether to add a title.
1609+ mode (str):
1610+ - "relative": percentage errors (y-axis log scale)
1611+ - "deltadex": log-space absolute errors (Δdex)
1612+ - "iterative": like "deltadex" but also draws dashed vertical lines at every
1613+ n-th timestep to indicate the iterative retrigger interval.
1614+ iter_interval (int | None): Interval for vertical guide lines when mode == "iterative".
16091615 """
16101616 plt .figure (figsize = (6 , 4 ))
16111617 colors = plt .cm .viridis (np .linspace (0 , 0.95 , len (mean_errors )))
16121618 linestyles = ["-" , "--" ]
16131619
1620+ # Support both dict inputs and array-like median_errors
16141621 for i , surrogate in enumerate (mean_errors .keys ()):
1615- mean_val = np .mean (mean_errors [surrogate ])
1616- median_val = np .mean (median_errors [surrogate ])
1622+ mean_series = mean_errors [surrogate ]
1623+ median_series = median_errors [surrogate ]
1624+
1625+ mean_val = float (np .mean (mean_series ))
1626+ median_val = float (np .mean (median_series ))
16171627
16181628 if mode == "relative" :
16191629 mean_label = f"{ surrogate } \n Mean = { mean_val * 100 :.2f} %"
16201630 median_label = f"{ surrogate } \n Median = { median_val * 100 :.2f} %"
1621- else : # deltadex
1631+ else : # deltadex or iterative
16221632 mean_label = f"{ surrogate } \n Mean = { mean_val :.3f} dex"
16231633 median_label = f"{ surrogate } \n Median = { median_val :.3f} dex"
16241634
16251635 plt .plot (
16261636 timesteps ,
1627- mean_errors [ surrogate ] ,
1637+ mean_series ,
16281638 label = mean_label ,
16291639 color = colors [i ],
16301640 linestyle = linestyles [0 ],
16311641 )
16321642 plt .plot (
16331643 timesteps ,
1634- median_errors [ surrogate ] ,
1644+ median_series ,
16351645 label = median_label ,
16361646 color = colors [i ],
16371647 linestyle = linestyles [1 ],
@@ -1644,10 +1654,23 @@ def plot_errors_over_time(
16441654 plt .yscale ("log" )
16451655 fname = "accuracy_rel_errors_time_models.png"
16461656 title = "Comparison of Relative Errors Over Time"
1647- else :
1657+ elif mode == "deltadex" :
16481658 plt .ylabel (r"Log-MAE ($\Delta dex$)" )
16491659 fname = "accuracy_delta_dex_time.png"
16501660 title = "Comparison of Δdex Errors Over Time"
1661+ elif mode == "iterative" :
1662+ # Single backslash inside raw string to render the LaTeX Delta properly
1663+ plt .ylabel (r"Log-MAE ($\Delta dex$)" )
1664+ fname = "iterative_delta_dex_time.png"
1665+ title = "Comparison of Δdex Errors Over Time for Iterative Predictions"
1666+ # Add subtle dashed vertical lines at every n-th timestep if provided and valid
1667+ if isinstance (iter_interval , int ) and iter_interval > 0 :
1668+ # start at iter_interval to avoid drawing a line at the very first x-limit
1669+ for idx in range (iter_interval , len (timesteps ), iter_interval ):
1670+ x = timesteps [idx ]
1671+ plt .axvline (x = x , linestyle = "--" , color = "gray" , alpha = 0.3 , linewidth = 0.8 )
1672+ else :
1673+ raise ValueError (f"Unknown mode: { mode } " )
16511674
16521675 if config ["dataset" ]["log_timesteps" ]:
16531676 plt .xscale ("log" )
@@ -2252,7 +2275,7 @@ def plot_error_distribution_comparative(
22522275 conf : dict ,
22532276 save : bool = True ,
22542277 show_title : bool = True ,
2255- mode : str = "relative" , # "relative" or "deltadex "
2278+ mode : str = "relative" , # "relative", "deltadex", or "iterative "
22562279) -> None :
22572280 """
22582281 Plot comparative error distributions for each surrogate model.
@@ -2262,7 +2285,8 @@ def plot_error_distribution_comparative(
22622285 conf (dict): Configuration dictionary.
22632286 save (bool): Whether to save the figure.
22642287 show_title (bool): Whether to add a title.
2265- mode (str): "relative" (unitless %) or "deltadex" (log-space abs. errors).
2288+ mode (str): "relative" (unitless %), "deltadex" (log-space abs. errors), or
2289+ "iterative" (same as deltadex plotting, different title/filename for iterative context).
22662290 """
22672291 model_names = list (errors .keys ())
22682292 num_models = len (model_names )
@@ -2321,10 +2345,17 @@ def plot_error_distribution_comparative(
23212345 xlabel = "Relative Error Magnitude"
23222346 title = "Distribution of Surrogate Relative Errors"
23232347 fname = "accuracy_error_dist_relative.png"
2324- else :
2348+ elif mode == "deltadex" :
23252349 xlabel = r"Log-MAE ($\Delta dex$)"
23262350 title = "Distribution of Surrogate Δdex Errors"
23272351 fname = "accuracy_error_dist_deltadex.png"
2352+ elif mode == "iterative" :
2353+ # Plot identical to deltadex but labeled for iterative evaluation context
2354+ xlabel = r"Log-MAE ($\Delta dex$)"
2355+ title = "Distribution of Surrogate Relative Errors for Iterative Prediction"
2356+ fname = "iterative_error_dist_deltadex.png"
2357+ else :
2358+ raise ValueError (f"Unknown mode: { mode } " )
23282359
23292360 plt .xlabel (xlabel )
23302361 plt .ylabel ("Smoothed Histogram Count" )
0 commit comments