Skip to content

Commit 84d3127

Browse files
committed
Add iterative comparative evaluation and corresponding plots
1 parent 7c360d3 commit 84d3127

4 files changed

Lines changed: 99 additions & 14 deletions

File tree

codes/benchmark/bench_fcts.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,8 @@ def evaluate_iterative_predictions(
483483
"mean_squared_error": mse,
484484
"mean_absolute_error": mae,
485485
"absolute_errors": abs_errors,
486+
"absolute_errors_log": abs_errors_log,
487+
"iteration_interval": iter_interval,
486488
}
487489

488490

@@ -1145,6 +1147,9 @@ def compare_models(metrics: dict, config: dict):
11451147
if config["losses"]:
11461148
compare_main_losses(metrics, config)
11471149

1150+
if config["iterative"]:
1151+
compare_iterative(metrics, config)
1152+
11481153
if config["gradients"]:
11491154
compare_gradients(metrics, config)
11501155

@@ -1274,6 +1279,44 @@ def compare_errors(metrics: dict[str, dict], config: dict) -> None:
12741279
plot_error_distribution_comparative(log_errors, config, mode="deltadex")
12751280

12761281

1282+
def compare_iterative(metrics: dict[str, dict], config: dict) -> None:
1283+
"""
1284+
Compare the iterative prediction errors of different surrogate models.
1285+
1286+
Args:
1287+
metrics (dict[str, dict]): dictionary containing the benchmark metrics for each surrogate model.
1288+
config (dict): Configuration dictionary.
1289+
1290+
Returns:
1291+
None
1292+
"""
1293+
iterative_errors = {}
1294+
mean_iterative_errors = {}
1295+
median_iterative_errors = {}
1296+
1297+
for surrogate, surrogate_metrics in metrics.items():
1298+
if "iterative" in surrogate_metrics:
1299+
iterative_errors[surrogate] = surrogate_metrics["iterative"][
1300+
"absolute_errors_log"
1301+
]
1302+
mean_iterative_errors[surrogate] = np.mean(
1303+
iterative_errors[surrogate], axis=(0, 2)
1304+
)
1305+
median_iterative_errors[surrogate] = np.median(
1306+
iterative_errors[surrogate], axis=(0, 2)
1307+
)
1308+
1309+
plot_errors_over_time(
1310+
mean_iterative_errors,
1311+
median_iterative_errors,
1312+
surrogate_metrics["timesteps"],
1313+
config,
1314+
mode="iterative",
1315+
iter_interval=surrogate_metrics["iterative"]["iteration_interval"],
1316+
)
1317+
plot_error_distribution_comparative(iterative_errors, config, mode="iterative")
1318+
1319+
12771320
def compare_inference_time(
12781321
metrics: dict[str, dict], config: dict, save: bool = True
12791322
) -> None:

codes/benchmark/bench_plots.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,45 +1593,55 @@ def plot_errors_over_time(
15931593
config: dict,
15941594
save: bool = True,
15951595
show_title: bool = True,
1596-
mode: str = "relative", # "relative" or "deltadex"
1596+
mode: str = "relative", # "relative", "deltadex", or "iterative"
1597+
iter_interval: int | None = None,
15971598
) -> None:
15981599
"""
1599-
Plot errors over time for different surrogate models (relative or Δdex).
1600+
Plot errors over time for different surrogate models (relative, Δdex, or iterative Δdex).
16001601
16011602
Args:
16021603
mean_errors (dict): Mean errors for each surrogate.
16031604
median_errors (dict): Median errors for each surrogate.
16041605
timesteps (np.ndarray): Array of timesteps.
16051606
config (dict): Configuration dictionary.
1606-
save (bool): Whether to save the figure.
1607-
show_title (bool): Whether to add a title.
1608-
mode (str): "relative" (percentage errors) or "deltadex" (log-space abs. errors).
1607+
save (bool): Whether to save the figure.
1608+
show_title (bool): Whether to add a title.
1609+
mode (str):
1610+
- "relative": percentage errors (y-axis log scale)
1611+
- "deltadex": log-space absolute errors (Δdex)
1612+
- "iterative": like "deltadex" but also draws dashed vertical lines at every
1613+
n-th timestep to indicate the iterative retrigger interval.
1614+
iter_interval (int | None): Interval for vertical guide lines when mode == "iterative".
16091615
"""
16101616
plt.figure(figsize=(6, 4))
16111617
colors = plt.cm.viridis(np.linspace(0, 0.95, len(mean_errors)))
16121618
linestyles = ["-", "--"]
16131619

1620+
# Support both dict inputs and array-like median_errors
16141621
for i, surrogate in enumerate(mean_errors.keys()):
1615-
mean_val = np.mean(mean_errors[surrogate])
1616-
median_val = np.mean(median_errors[surrogate])
1622+
mean_series = mean_errors[surrogate]
1623+
median_series = median_errors[surrogate]
1624+
1625+
mean_val = float(np.mean(mean_series))
1626+
median_val = float(np.mean(median_series))
16171627

16181628
if mode == "relative":
16191629
mean_label = f"{surrogate}\nMean = {mean_val * 100:.2f}%"
16201630
median_label = f"{surrogate}\nMedian = {median_val * 100:.2f}%"
1621-
else: # deltadex
1631+
else: # deltadex or iterative
16221632
mean_label = f"{surrogate}\nMean = {mean_val:.3f} dex"
16231633
median_label = f"{surrogate}\nMedian = {median_val:.3f} dex"
16241634

16251635
plt.plot(
16261636
timesteps,
1627-
mean_errors[surrogate],
1637+
mean_series,
16281638
label=mean_label,
16291639
color=colors[i],
16301640
linestyle=linestyles[0],
16311641
)
16321642
plt.plot(
16331643
timesteps,
1634-
median_errors[surrogate],
1644+
median_series,
16351645
label=median_label,
16361646
color=colors[i],
16371647
linestyle=linestyles[1],
@@ -1644,10 +1654,23 @@ def plot_errors_over_time(
16441654
plt.yscale("log")
16451655
fname = "accuracy_rel_errors_time_models.png"
16461656
title = "Comparison of Relative Errors Over Time"
1647-
else:
1657+
elif mode == "deltadex":
16481658
plt.ylabel(r"Log-MAE ($\Delta dex$)")
16491659
fname = "accuracy_delta_dex_time.png"
16501660
title = "Comparison of Δdex Errors Over Time"
1661+
elif mode == "iterative":
1662+
# Single backslash inside raw string to render the LaTeX Delta properly
1663+
plt.ylabel(r"Log-MAE ($\Delta dex$)")
1664+
fname = "iterative_delta_dex_time.png"
1665+
title = "Comparison of Δdex Errors Over Time for Iterative Predictions"
1666+
# Add subtle dashed vertical lines at every n-th timestep if provided and valid
1667+
if isinstance(iter_interval, int) and iter_interval > 0:
1668+
# start at iter_interval to avoid drawing a line at the very first x-limit
1669+
for idx in range(iter_interval, len(timesteps), iter_interval):
1670+
x = timesteps[idx]
1671+
plt.axvline(x=x, linestyle="--", color="gray", alpha=0.3, linewidth=0.8)
1672+
else:
1673+
raise ValueError(f"Unknown mode: {mode}")
16511674

16521675
if config["dataset"]["log_timesteps"]:
16531676
plt.xscale("log")
@@ -2252,7 +2275,7 @@ def plot_error_distribution_comparative(
22522275
conf: dict,
22532276
save: bool = True,
22542277
show_title: bool = True,
2255-
mode: str = "relative", # "relative" or "deltadex"
2278+
mode: str = "relative", # "relative", "deltadex", or "iterative"
22562279
) -> None:
22572280
"""
22582281
Plot comparative error distributions for each surrogate model.
@@ -2262,7 +2285,8 @@ def plot_error_distribution_comparative(
22622285
conf (dict): Configuration dictionary.
22632286
save (bool): Whether to save the figure.
22642287
show_title (bool): Whether to add a title.
2265-
mode (str): "relative" (unitless %) or "deltadex" (log-space abs. errors).
2288+
mode (str): "relative" (unitless %), "deltadex" (log-space abs. errors), or
2289+
"iterative" (same as deltadex plotting, different title/filename for iterative context).
22662290
"""
22672291
model_names = list(errors.keys())
22682292
num_models = len(model_names)
@@ -2321,10 +2345,17 @@ def plot_error_distribution_comparative(
23212345
xlabel = "Relative Error Magnitude"
23222346
title = "Distribution of Surrogate Relative Errors"
23232347
fname = "accuracy_error_dist_relative.png"
2324-
else:
2348+
elif mode == "deltadex":
23252349
xlabel = r"Log-MAE ($\Delta dex$)"
23262350
title = "Distribution of Surrogate Δdex Errors"
23272351
fname = "accuracy_error_dist_deltadex.png"
2352+
elif mode == "iterative":
2353+
# Plot identical to deltadex but labeled for iterative evaluation context
2354+
xlabel = r"Log-MAE ($\Delta dex$)"
2355+
title = "Distribution of Surrogate Relative Errors for Iterative Prediction"
2356+
fname = "iterative_error_dist_deltadex.png"
2357+
else:
2358+
raise ValueError(f"Unknown mode: {mode}")
23282359

23292360
plt.xlabel(xlabel)
23302361
plt.ylabel("Smoothed Histogram Count")

codes/benchmark/bench_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def clean_metrics(metrics: dict, conf: dict) -> dict:
415415

416416
if conf["iterative"]:
417417
write_metrics["iterative"].pop("absolute_errors", None)
418+
write_metrics["iterative"].pop("absolute_errors_log", None)
418419
if conf["gradients"]:
419420
write_metrics["gradients"].pop("gradients", None)
420421
write_metrics["gradients"].pop("max_counts", None)

test/test_model_comparison.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def record_calls(monkeypatch):
1212
calls = []
1313
names = [
1414
"compare_errors",
15+
"compare_iterative",
1516
"compare_main_losses",
1617
"compare_gradients",
1718
"compare_inference_time",
@@ -68,6 +69,10 @@ def make_dummy_metrics():
6869
"correlation_metrics": None,
6970
"weighted_diff": None,
7071
},
72+
"iterative": {
73+
"iteration_interval": 10,
74+
"absolute_errors_log": None,
75+
},
7176
}
7277
}
7378

@@ -79,6 +84,7 @@ def make_dummy_metrics():
7984
(
8085
{
8186
"losses": True,
87+
"iterative": True,
8288
"gradients": True,
8389
"timing": True,
8490
"interpolation": {"enabled": True},
@@ -90,6 +96,7 @@ def make_dummy_metrics():
9096
[
9197
"compare_errors",
9298
"compare_main_losses",
99+
"compare_iterative",
93100
"compare_gradients",
94101
"compare_inference_time",
95102
"compare_interpolation",
@@ -105,6 +112,7 @@ def make_dummy_metrics():
105112
(
106113
{
107114
"losses": False,
115+
"iterative": False,
108116
"gradients": False,
109117
"timing": False,
110118
"interpolation": {"enabled": False},
@@ -122,6 +130,7 @@ def make_dummy_metrics():
122130
(
123131
{
124132
"losses": True,
133+
"iterative": False,
125134
"gradients": False,
126135
"timing": False,
127136
"interpolation": {"enabled": False},
@@ -143,6 +152,7 @@ def test_compare_models_branching(record_calls, flags, expected_sequence):
143152
"training_id": "test",
144153
"devices": ["cpu"], # for compare_main_losses
145154
"losses": flags["losses"],
155+
"iterative": flags["iterative"],
146156
"gradients": flags["gradients"],
147157
"timing": flags["timing"],
148158
"interpolation": flags["interpolation"],

0 commit comments

Comments
 (0)