1919 plot_dynamic_correlation_heatmap ,
2020 plot_error_correlation_heatmap ,
2121 plot_error_distribution_comparative ,
22- plot_error_distribution_per_chemical ,
22+ plot_error_distribution_per_quantity ,
2323 plot_example_mode_predictions ,
2424 plot_example_predictions_with_uncertainty ,
2525 plot_generalization_error_comparison ,
@@ -86,9 +86,9 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
8686 )
8787 model_config = get_model_config (surr_name , conf )
8888 n_timesteps = train_data .shape [1 ]
89- n_chemicals = train_data .shape [2 ]
89+ n_quantities = train_data .shape [2 ]
9090 n_test_samples = n_timesteps * val_data .shape [0 ]
91- model = surrogate_class (device , n_chemicals , n_timesteps , model_config )
91+ model = surrogate_class (device , n_quantities , n_timesteps , model_config )
9292
9393 # Placeholder for metrics
9494 metrics = {}
@@ -185,25 +185,25 @@ def evaluate_accuracy(
185185 labels : list | None = None ,
186186) -> dict [str , Any ]:
187187 """
188- Evaluate the accuracy of the surrogate model.
189-
190- Args:
191- model: Instance of the surrogate model class.
192- surr_name (str): The name of the surrogate model.
193- timesteps (np.ndarray): The timesteps array.
194- test_loader (DataLoader): The DataLoader object containing the test data.
195- conf (dict): The configuration dictionary.
196- labels (list, optional): The labels for the chemical species .
197-
198- Returns:
199- dict: A dictionary containing accuracy metrics.
188+ Evaluate the accuracy of the surrogate model.
189+ quantitiesquantities
190+ Args:
191+ model: Instance of the surrogate model class.
192+ surr_name (str): The name of the surrogate model.
193+ timesteps (np.ndarray): The timesteps array.
194+ test_loader (DataLoader): The DataLoader object containing the test data.
195+ conf (dict): The configuration dictionary.
196+ labels (list, optional): The labels for the quantities .
197+
198+ Returns:
199+ dict: A dictionary containing accuracy metrics.
200200 """
201201 training_id = conf ["training_id" ]
202202
203203 # Load the model
204204 model .load (training_id , surr_name , model_identifier = f"{ surr_name .lower ()} _main" )
205205 train_time = model .train_duration
206- num_chemicals = model .n_chemicals
206+ num_quantities = model .n_quantities
207207 model_index = conf ["surrogates" ].index (surr_name )
208208 n_epochs = conf ["epochs" ][model_index ]
209209
@@ -229,12 +229,12 @@ def evaluate_accuracy(
229229 show_title = TITLE ,
230230 )
231231
232- plot_error_distribution_per_chemical (
232+ plot_error_distribution_per_quantity (
233233 surr_name ,
234234 conf ,
235235 relative_errors ,
236- chemical_names = labels ,
237- num_chemicals = num_chemicals ,
236+ quantity_names = labels ,
237+ num_quantities = num_quantities ,
238238 save = True ,
239239 show_title = TITLE ,
240240 )
@@ -431,7 +431,7 @@ def evaluate_interpolation(
431431 test_loader (DataLoader): The DataLoader object containing the test data.
432432 timesteps (np.ndarray): The timesteps array.
433433 conf (dict): The configuration dictionary.
434- labels (list, optional): The labels for the chemical species .
434+ labels (list, optional): The labels for the quantities .
435435
436436 Returns:
437437 dict: A dictionary containing interpolation metrics.
@@ -529,7 +529,7 @@ def evaluate_extrapolation(
529529 test_loader (DataLoader): The DataLoader object containing the test data.
530530 timesteps (np.ndarray): The timesteps array.
531531 conf (dict): The configuration dictionary.
532- labels (list, optional): The labels for the chemical species .
532+ labels (list, optional): The labels for the quantities .
533533
534534 Returns:
535535 dict: A dictionary containing extrapolation metrics.
@@ -793,7 +793,7 @@ def evaluate_UQ(
793793 test_loader (DataLoader): The DataLoader object containing the test data.
794794 timesteps (np.ndarray): The timesteps array.
795795 conf (dict): The configuration dictionary.
796- labels (list, optional): The labels for the chemical species .
796+ labels (list, optional): The labels for the quantities .
797797
798798 Returns:
799799 dict: A dictionary containing UQ metrics.
@@ -952,9 +952,9 @@ def compare_main_losses(metrics: dict, config: dict) -> None:
952952 training_id = config ["training_id" ]
953953 surrogate_class = get_surrogate (surr_name )
954954 n_timesteps = metrics [surr_name ]["timesteps" ].shape [0 ]
955- n_chemicals = metrics [surr_name ]["accuracy" ]["absolute_errors" ].shape [2 ]
955+ n_quantities = metrics [surr_name ]["accuracy" ]["absolute_errors" ].shape [2 ]
956956 model_config = get_model_config (surr_name , config )
957- model = surrogate_class (device , n_chemicals , n_timesteps , model_config )
957+ model = surrogate_class (device , n_quantities , n_timesteps , model_config )
958958
959959 def load_losses (model_identifier : str ):
960960 model .load (training_id , surr_name , model_identifier = model_identifier )
@@ -1001,9 +1001,9 @@ def load_losses(model_identifier: str):
10011001# training_id = config["training_id"]
10021002# surrogate_class = get_surrogate(surr_name)
10031003# n_timesteps = metrics[surr_name]["timesteps"].shape[0]
1004- # n_chemicals = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
1004+ # n_quantities = metrics[surr_name]["accuracy"]["absolute_errors"].shape[2]
10051005# model_config = get_model_config(surr_name, config)
1006- # model = surrogate_class(device, n_chemicals , n_timesteps, model_config)
1006+ # model = surrogate_class(device, n_quantities , n_timesteps, model_config)
10071007# model_identifier = f"{surr_name.lower()}_main"
10081008# model.load(training_id, surr_name, model_identifier=model_identifier)
10091009# MAE.append(model.MAE)
0 commit comments