@@ -1661,6 +1661,7 @@ def plot_errors_over_time(
16611661 elif mode == "iterative" :
16621662 # Single backslash inside raw string to render the LaTeX Delta properly
16631663 plt .ylabel (r"Log-MAE ($\Delta dex$)" )
1664+ plt .ylim (bottom = 0 , top = min (np .max (list (mean_errors .values ())) * 1.1 , 5 ))
16641665 fname = "iterative_delta_dex_time.png"
16651666 title = "Comparison of Δdex Errors Over Time for Iterative Predictions"
16661667 # Add subtle dashed vertical lines at every n-th timestep if provided and valid
@@ -2035,8 +2036,6 @@ def inference_time_bar_plot(
20352036 # Calculate the upper y-limit to provide space for text
20362037 max_bar = max (means [i ] + stds [i ] for i in range (len (means )))
20372038 # min_bar = min(means[i] - stds[i] for i in range(len(means)))
2038- # Temp!
2039- # ax.set_ylim(min_bar * 0.3, max_bar * 2) # Set limits with some padding
20402039 ax .set_ylim (0 , max_bar * 1.2 ) # Set limits with some padding
20412040
20422041 # Add inference time as text to the bars using the format_time function
@@ -2053,8 +2052,6 @@ def inference_time_bar_plot(
20532052
20542053 ax .set_xlabel ("Surrogate Model" )
20552054 ax .set_ylabel ("Mean Inference Time per Run" )
2056- # Temp!
2057- # ax.set_yscale("log")
20582055 if show_title :
20592056 ax .set_title ("Surrogate Mean Inference Time Comparison" )
20602057
@@ -2507,11 +2504,16 @@ def plot_catastrophic_detection_curves(
25072504
25082505 xs , ys = [], []
25092506 for f in flag_fractions :
2510- unc_thr = np .percentile (u , 100.0 * (1.0 - float (f )))
2511- flagged = u >= unc_thr
2512- recall = (flagged & is_cat ).sum () / n_cat if n_cat > 0 else 0.0
2513- xs .append (100.0 * flagged .mean ())
2514- ys .append (100.0 * recall )
2507+ if f <= 0.0 :
2508+ xs .append (0.0 )
2509+ ys .append (0.0 )
2510+ recall = 0.0
2511+ else :
2512+ unc_thr = np .percentile (u , 100.0 * (1.0 - float (f )))
2513+ flagged = u >= unc_thr
2514+ recall = (flagged & is_cat ).sum () / n_cat if n_cat > 0 else 0.0
2515+ xs .append (100.0 * flagged .mean ())
2516+ ys .append (100.0 * recall )
25152517
25162518 ax .plot (
25172519 xs ,
@@ -2988,8 +2990,6 @@ def rel_errors_and_uq(
29882990
29892991 ax1 .set_xlabel ("Time" )
29902992 ax1 .set_xlim (timesteps [0 ], timesteps [- 1 ])
2991- # Temp!
2992- # ax1.set_ylim(3e-4, 1)
29932993 ax1 .set_ylabel ("Relative Error" )
29942994 ax1 .set_yscale ("log" )
29952995 ax1 .set_title ("Comparison of Relative Errors Over Time" )
@@ -3020,8 +3020,6 @@ def rel_errors_and_uq(
30203020
30213021 ax2 .set_xlabel ("Time" )
30223022 ax2 .set_xlim (timesteps [0 ], timesteps [- 1 ])
3023- # Temp!
3024- # ax2.set_ylim(0, 0.04)
30253023 ax2 .set_ylabel ("Uncertainty/Absolute Error" )
30263024 if show_title :
30273025 ax2 .set_title ("Comparison of Predictive Uncertainty Over Time" )
0 commit comments