Skip to content

Commit f1df561

Browse files
committed
Remove temporary code snippets
1 parent 8dde262 commit f1df561

2 files changed

Lines changed: 0 additions & 37 deletions

File tree

codes/benchmark/bench_fcts.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
9696
tolerance=conf["dataset"]["tolerance"],
9797
per_species=conf["dataset"].get("normalise_per_species", False),
9898
)
99-
# TEMP
100-
print(conf["dataset"]["name"], train_data.shape, val_data.shape, test_data.shape)
10199

102100
model_config = get_model_config(surr_name, conf)
103101
n_timesteps = train_data.shape[1]
@@ -1282,10 +1280,6 @@ def compare_errors(metrics: dict[str, dict], config: dict) -> None:
12821280
if log_errors:
12831281
plot_errors_over_time(mean_log, median_log, timesteps, config, mode="deltadex")
12841282
plot_error_distribution_comparative(log_errors, config, mode="deltadex")
1285-
# TEMP
1286-
dataset = config["dataset"]["name"]
1287-
os.makedirs(f"scripts/pp/{dataset}", exist_ok=True)
1288-
np.savez(f"scripts/pp/{dataset}/all_log_errors.npz", log_errors)
12891283

12901284

12911285
def compare_iterative(metrics: dict[str, dict], config: dict) -> None:
@@ -1315,11 +1309,6 @@ def compare_iterative(metrics: dict[str, dict], config: dict) -> None:
13151309
iterative_errors[surrogate], axis=(0, 2)
13161310
)
13171311

1318-
# TEMP
1319-
dataset = config["dataset"]["name"]
1320-
os.makedirs(f"scripts/pp/{dataset}", exist_ok=True)
1321-
np.savez(f"scripts/pp/{dataset}/all_iterative_errors.npz", iterative_errors)
1322-
13231312
plot_errors_over_time(
13241313
mean_iterative_errors,
13251314
median_iterative_errors,
@@ -1602,12 +1591,6 @@ def compare_UQ(all_metrics: dict, config: dict) -> None:
16021591
show_title=True,
16031592
)
16041593

1605-
# TEMP
1606-
dataset = config["dataset"]["name"]
1607-
os.makedirs(f"scripts/pp/{dataset}", exist_ok=True)
1608-
np.savez(f"scripts/pp/{dataset}/all_uq_errors.npz", ensemble_errors)
1609-
np.savez(f"scripts/pp/{dataset}/all_uq_std.npz", ensemble_std)
1610-
16111594

16121595
def tabular_comparison(all_metrics: dict, config: dict) -> None:
16131596
"""

codes/benchmark/bench_plots.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2036,8 +2036,6 @@ def inference_time_bar_plot(
20362036
# Calculate the upper y-limit to provide space for text
20372037
max_bar = max(means[i] + stds[i] for i in range(len(means)))
20382038
# min_bar = min(means[i] - stds[i] for i in range(len(means)))
2039-
# Temp!
2040-
# ax.set_ylim(min_bar * 0.3, max_bar * 2) # Set limits with some padding
20412039
ax.set_ylim(0, max_bar * 1.2) # Set limits with some padding
20422040

20432041
# Add inference time as text to the bars using the format_time function
@@ -2054,8 +2052,6 @@ def inference_time_bar_plot(
20542052

20552053
ax.set_xlabel("Surrogate Model")
20562054
ax.set_ylabel("Mean Inference Time per Run")
2057-
# Temp!
2058-
# ax.set_yscale("log")
20592055
if show_title:
20602056
ax.set_title("Surrogate Mean Inference Time Comparison")
20612057

@@ -2490,10 +2486,6 @@ def plot_catastrophic_detection_curves(
24902486
names = list(errors_log.keys())
24912487
colors = plt.cm.viridis(np.linspace(0, 0.95, len(names)))
24922488
summary: dict[str, dict[float, dict[str, float]]] = {}
2493-
# TEMP
2494-
recall_99 = np.zeros((len(names), len(flag_fractions)))
2495-
recall_90 = np.zeros((len(names), len(flag_fractions)))
2496-
dataset_name = conf["dataset"]["name"]
24972489

24982490
# --- Recall vs fraction flagged (per catastrophic percentile) ---
24992491
for ax, perc in zip(axes[:-1], percentiles):
@@ -2522,11 +2514,6 @@ def plot_catastrophic_detection_curves(
25222514
recall = (flagged & is_cat).sum() / n_cat if n_cat > 0 else 0.0
25232515
xs.append(100.0 * flagged.mean())
25242516
ys.append(100.0 * recall)
2525-
# TEMP
2526-
if perc == 99.0:
2527-
recall_99[i, flag_fractions.index(f)] = recall
2528-
if perc == 90.0:
2529-
recall_90[i, flag_fractions.index(f)] = recall
25302517

25312518
ax.plot(
25322519
xs,
@@ -2553,9 +2540,6 @@ def plot_catastrophic_detection_curves(
25532540
f"Detection @ {perc}th percentile (Top {100 - perc:.0f}% Δdex)"
25542541
)
25552542

2556-
np.savez(f"scripts/pp/{dataset_name}/catastrophic_recall_99.npz", recall_99)
2557-
np.savez(f"scripts/pp/{dataset_name}/catastrophic_recall_90.npz", recall_90)
2558-
25592543
# MAE improvement plot
25602544
ax_mae = axes[-1]
25612545
for i, name in enumerate(names):
@@ -3006,8 +2990,6 @@ def rel_errors_and_uq(
30062990

30072991
ax1.set_xlabel("Time")
30082992
ax1.set_xlim(timesteps[0], timesteps[-1])
3009-
# Temp!
3010-
# ax1.set_ylim(3e-4, 1)
30112993
ax1.set_ylabel("Relative Error")
30122994
ax1.set_yscale("log")
30132995
ax1.set_title("Comparison of Relative Errors Over Time")
@@ -3038,8 +3020,6 @@ def rel_errors_and_uq(
30383020

30393021
ax2.set_xlabel("Time")
30403022
ax2.set_xlim(timesteps[0], timesteps[-1])
3041-
# Temp!
3042-
# ax2.set_ylim(0, 0.04)
30433023
ax2.set_ylabel("Uncertainty/Absolute Error")
30443024
if show_title:
30453025
ax2.set_title("Comparison of Predictive Uncertainty Over Time")

0 commit comments

Comments
 (0)