1- import time
21from contextlib import redirect_stdout
32from typing import Any
43
4039 get_model_config ,
4140 get_surrogate ,
4241 make_comparison_csv ,
42+ measure_inference_time ,
4343 measure_memory_footprint ,
4444 save_table_csv ,
4545 write_metrics_to_yaml ,
@@ -349,13 +349,12 @@ def time_inference(
349349 n_runs : int = 5 ,
350350) -> dict [str , Any ]:
351351 """
352- Time the inference of the surrogate model.
352+ Time the inference of the surrogate model (full version with metrics) .
353353
354354 Args:
355355 model: Instance of the surrogate model class.
356356 surr_name (str): The name of the surrogate model.
357357 test_loader (DataLoader): The DataLoader object containing the test data.
358- timesteps (np.ndarray): The timesteps array.
359358 conf (dict): The configuration dictionary.
360359 n_test_samples (int): The number of test samples.
361360 n_runs (int, optional): Number of times to run the inference for timing.
@@ -366,35 +365,19 @@ def time_inference(
366365 training_id = conf ["training_id" ]
367366 model .load (training_id , surr_name , model_identifier = f"{ surr_name .lower ()} _main" )
368367
369- # Run inference multiple times and record the durations
370- inference_times = []
371- for _ in range (n_runs ):
372- # _, _ = model.predict(data_loader=test_loader)
373- total_time = 0
374- with torch .inference_mode ():
375- for inputs in test_loader :
376- start_time = time .perf_counter ()
377- _ , _ = model .forward (inputs )
378- end_time = time .perf_counter ()
379- total_time += end_time - start_time
380- # total_time /= n_test_samples
381- inference_times .append (total_time )
382-
383- # Calculate metrics
368+ inference_times = measure_inference_time (model , test_loader , n_runs = n_runs )
369+
384370 mean_inference_time = np .mean (inference_times )
385371 std_inference_time = np .std (inference_times )
386372
387- # Store metrics
388- timing_metrics = {
373+ return {
389374 "mean_inference_time_per_run" : mean_inference_time ,
390375 "std_inference_time_per_run" : std_inference_time ,
391376 "num_predictions" : n_test_samples ,
392377 "mean_inference_time_per_prediction" : mean_inference_time / n_test_samples ,
393378 "std_inference_time_per_prediction" : std_inference_time / n_test_samples ,
394379 }
395380
396- return timing_metrics
397-
398381
399382def evaluate_compute (
400383 model , surr_name : str , test_loader : DataLoader , conf : dict
0 commit comments