77from matplotlib .gridspec import GridSpec
88from scipy .ndimage import gaussian_filter1d
99
10+ from codes .utils import batch_factor_to_float
11+
1012from .bench_utils import format_time
1113
1214# Utility functions for plotting
@@ -157,6 +159,9 @@ def plot_relative_errors_over_time(
157159 plt .xlabel ("Time" )
158160 plt .ylabel ("Relative Error" )
159161 plt .xlim (timesteps [0 ], timesteps [- 1 ])
162+ plt .ylim (bottom = 1e-8 )
163+ if conf ["dataset" ]["log_timesteps" ]:
164+ plt .xscale ("log" )
160165 if show_title :
161166 plt .title (title )
162167 plt .legend (loc = "center left" , bbox_to_anchor = (1 , 0.5 ))
@@ -332,6 +337,8 @@ def plot_average_errors_over_time(
332337 plt .xlim (timesteps [0 ], timesteps [- 1 ])
333338 plt .ylabel ("Mean Absolute Error" )
334339 plt .yscale ("log" )
340+ if conf ["dataset" ]["log_timesteps" ]:
341+ plt .xscale ("log" )
335342 title = f"Mean Absolute Errors over Time ({ mode .capitalize ()} , { surr_name } )"
336343 filename = f"{ mode } _errors_over_time.png"
337344
@@ -450,6 +457,8 @@ def plot_example_mode_predictions(
450457
451458 # Set the x-axis limits based on the timesteps array
452459 ax .set_xlim (timesteps .min (), timesteps .max ())
460+ if conf ["dataset" ]["log_timesteps" ]:
461+ ax .set_xscale ("log" )
453462
454463 # Add a single x-axis label to the bottom of the figure
455464 fig .text (0.5 , 0.04 , "Time" , ha = "center" , va = "center" , fontsize = 12 )
@@ -471,10 +480,10 @@ def plot_example_mode_predictions(
471480
472481 # Set the overall title with details depending on the mode
473482 if mode == "interpolation" :
474- title = f"DeepEnsemble : Example Predictions (Interpolation, { surr_name } )\n "
483+ title = f"Interpolation : Example Predictions (Interpolation, { surr_name } )\n "
475484 extra_info = f"Sample Index: { example_idx } , Training Interval: { metric } "
476485 elif mode == "extrapolation" :
477- title = f"DeepEnsemble : Example Predictions (Extrapolation, { surr_name } )\n "
486+ title = f"Extrapolation : Example Predictions (Extrapolation, { surr_name } )\n "
478487 extra_info = f"Sample Index: { example_idx } , Cutoff Timestep: { metric } "
479488 else :
480489 raise ValueError (
@@ -589,6 +598,8 @@ def plot_example_predictions_with_uncertainty(
589598
590599 # Set the x limit exactly from the lowest to the highest timestep
591600 ax .set_xlim (timesteps .min (), timesteps .max ())
601+ if conf ["dataset" ]["log_timesteps" ]:
602+ ax .set_xscale ("log" )
592603
593604 # Add a single x-axis label to the bottom plot
594605 fig .text (0.5 , 0.04 , "Time" , ha = "center" , va = "center" , fontsize = 12 )
@@ -656,6 +667,8 @@ def plot_average_uncertainty_over_time(
656667 plt .xlabel ("Time" )
657668 plt .ylabel ("Average Uncertainty / Mean Absolute Error" )
658669 plt .xlim (timesteps [0 ], timesteps [- 1 ])
670+ if conf ["dataset" ]["log_timesteps" ]:
671+ plt .xscale ("log" )
659672 if show_title :
660673 plt .title ("Average Uncertainty and Mean Absolute Error Over Time" )
661674 plt .legend ()
@@ -835,10 +848,16 @@ def load_losses(model_identifier: str):
835848
836849 # Batchsize losses
837850 if conf ["batch_scaling" ]["enabled" ]:
838- batch_sizes = conf ["batch_scaling" ]["sizes" ]
851+ batch_factors = conf ["batch_scaling" ]["sizes" ]
839852 batch_train_losses = []
840853 batch_test_losses = []
841- for batch_size in batch_sizes :
854+ batch_sizes = []
855+ surr_index = conf ["surrogates" ].index (surr_name )
856+ main_model_bs = conf ["batch_size" ][surr_index ]
857+ for batch_factor in batch_factors :
858+ batch_factor = batch_factor_to_float (batch_factor )
859+ batch_size = int (main_model_bs * batch_factor )
860+ batch_sizes .append (batch_size )
842861 train_loss , test_loss , epochs = load_losses (
843862 f"{ surr_name .lower ()} _batchsize_{ batch_size } "
844863 )
@@ -966,7 +985,9 @@ def plot_error_distribution_per_quantity(
966985 fig .align_ylabels ()
967986
968987 plt .xscale ("log" ) # Log scale for error magnitudes
969- plt .xlim (10 ** x_min , 10 ** x_max ) # Set x-axis range based on log-space calculations
988+ plt .xlim (
989+ np .maximum (10 ** x_min , 1e-8 ), 10 ** x_max
990+ ) # Set x-axis range based on log-space calculations
970991 plt .xlabel ("Relative Error" )
971992 if show_title :
972993 if num_plots > 1 :
@@ -1423,6 +1444,8 @@ def plot_relative_errors(
14231444 plt .yscale ("log" )
14241445 if show_title :
14251446 plt .title ("Comparison of Relative Errors Over Time" )
1447+ if config ["dataset" ]["log_timesteps" ]:
1448+ plt .xscale ("log" )
14261449 plt .legend (loc = "center left" , bbox_to_anchor = (1 , 0.5 ))
14271450
14281451 if save and config :
@@ -1483,6 +1506,8 @@ def plot_uncertainty_over_time_comparison(
14831506 plt .xlim (timesteps [0 ], timesteps [- 1 ])
14841507 plt .ylabel ("Uncertainty / MAE" )
14851508 plt .yscale ("log" )
1509+ if config ["dataset" ]["log_timesteps" ]:
1510+ plt .xscale ("log" )
14861511 if show_title :
14871512 plt .title ("Comparison of Predictive Uncertainty and True MAE over Time" )
14881513 plt .legend (loc = "center left" , bbox_to_anchor = (1 , 0.5 ))
@@ -2266,7 +2291,7 @@ def plot_error_distribution_comparative(
22662291 )
22672292
22682293 plt .xscale ("log" ) # Log scale for error magnitudes
2269- plt .xlim (10 ** x_min , 10 ** x_max ) # Set x-axis range based on log-space calculations
2294+ plt .xlim (np . maximum ( 10 ** x_min , 1e-8 ), 10 ** x_max ) # Set x-axis range
22702295
22712296 if mode == "main" :
22722297 title = "Distribution of Surrogate Relative Errors"
0 commit comments