Skip to content

Commit fc7304a

Browse files
Merge pull request #49 from robin-janssen/improve-surrogate-evaluation
Improve surrogate evaluation
2 parents c8a3836 + 3c2efef commit fc7304a

22 files changed

Lines changed: 1237 additions & 408 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# CODES Benchmark
22

3-
[![codecov](https://codecov.io/github/robin-janssen/CODES-Benchmark/graph/badge.svg?token=TNF9ISCAJK)](https://codecov.io/github/robin-janssen/CODES-Benchmark)
3+
[![codecov](https://codecov.io/github/robin-janssen/CODES-Benchmark/branch/main/graph/badge.svg?token=TNF9ISCAJK)](https://codecov.io/github/robin-janssen/CODES-Benchmark)
44
![Static Badge](https://img.shields.io/badge/license-GPLv3-blue)
55
![Static Badge](https://img.shields.io/badge/NeurIPS-2024-green)
66

codes/surrogates/AbstractSurrogate/abstract_surrogate.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.utils.data import DataLoader
1414
from tqdm import tqdm
1515

16-
from codes.utils import create_model_dir
16+
from codes.utils import create_model_dir, parse_hyperparameters
1717

1818

1919
class AbstractSurrogateModel(ABC, nn.Module):
@@ -334,6 +334,9 @@ def save(
334334
)
335335
hyperparameters["date"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
336336

337+
# Recursively parse hyperparameters to make them yaml-serializable
338+
hyperparameters = parse_hyperparameters(hyperparameters)
339+
337340
# Reduce the precision of the losses and accuracy
338341
for attribute in ["train_loss", "test_loss", "MAE"]:
339342
value = getattr(self, attribute)
@@ -518,10 +521,6 @@ def time_pruning(self, current_epoch: int, total_epochs: int) -> None:
518521
# Define warmup period based on 10% of total epochs.
519522
warmup_epochs = max(10, int(total_epochs * 0.02))
520523
if current_epoch < warmup_epochs:
521-
# Do not attempt to prune before the warmup period is complete.
522-
# print(
523-
# f"[time_pruning] Warmup period: {current_epoch}/{warmup_epochs} epochs completed. Skipping pruning check."
524-
# )
525524
return
526525

527526
elapsed = time.time() - self._trial_start_time
@@ -541,7 +540,7 @@ def time_pruning(self, current_epoch: int, total_epochs: int) -> None:
541540
if projected_total_time > threshold:
542541
if self.optuna_trial is not None:
543542
tqdm.write(
544-
f"[time_pruning] Projected total time {projected_total_time:.1f}s exceeds threshold {threshold:.1f}s. Pruning trial {self.optuna_trial.number}."
543+
f"[Trial {self.optuna_trial.number}] Projected total time {projected_total_time:.1f}s exceeds threshold {threshold:.1f}s. Pruning trial."
545544
)
546545
self.optuna_trial.set_user_attr(
547546
"prune_reason",

codes/surrogates/LatentNeuralODE/latent_neural_ode.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,6 @@ def fit_profile(
329329
loss.backward()
330330
optimizer.step()
331331

332-
# # renormalize once after 10 epochs
333-
# if epoch == 10 and i == 0:
334-
# with torch.no_grad():
335-
# self.model.renormalize_loss_weights(
336-
# x_true, x_pred, params, criterion
337-
# )
338-
339332
if not (profiled and epoch == 0):
340333
# Only step here if you didn't already step inside profiled block
341334
scheduler.step()

codes/surrogates/LatentPolynomial/latent_poly.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,6 @@ def fit(
218218
loss.backward()
219219
optimizer.step()
220220

221-
if epoch == 10 and i == 0:
222-
with torch.no_grad():
223-
self.model.renormalize_loss_weights(
224-
x_true, x_pred, params, criterion
225-
)
226-
227221
scheduler.step()
228222

229223
self.validate(
@@ -461,6 +455,7 @@ def identity_loss(self, x_true: Tensor, params: Tensor = None):
461455
# only reconstruct the initial state
462456
x0 = x_true[:, 0, :]
463457
if not self.config.coeff_network and params is not None:
458+
params = params.to(self.device)
464459
enc_input = torch.cat([x0, params], dim=1)
465460
else:
466461
enc_input = x0

codes/tune/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,23 @@
55
plot_test_losses,
66
)
77
from .optuna_fcts import (
8+
MaxValidTrialsCallback,
9+
_count_valid_trials,
10+
build_fine_optuna_params,
811
create_objective,
912
load_yaml_config,
1013
make_optuna_params,
1114
maybe_set_runtime_threshold,
1215
training_run,
1316
)
1417
from .postgres_fcts import (
15-
_make_db_url,
16-
initialize_optuna_database,
1718
_check_postgres_running_local,
18-
_start_postgres_server_local,
1919
_check_remote_reachable,
2020
_initialize_postgres_local,
2121
_initialize_postgres_remote,
22+
_make_db_url,
23+
_start_postgres_server_local,
24+
initialize_optuna_database,
2225
)
2326
from .tune_utils import (
2427
build_study_names,
@@ -30,6 +33,8 @@
3033

3134
__all__ = [
3235
"create_objective",
36+
"MaxValidTrialsCallback",
37+
"build_fine_optuna_params",
3338
"load_yaml_config",
3439
"make_optuna_params",
3540
"maybe_set_runtime_threshold",
@@ -50,4 +55,5 @@
5055
"_check_remote_reachable",
5156
"_initialize_postgres_local",
5257
"_initialize_postgres_remote",
58+
"_count_valid_trials",
5359
]

codes/tune/evaluate_tuning.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,58 @@
2323
from codes.utils import nice_print
2424

2525

26+
def pareto_front(points: np.ndarray) -> np.ndarray:
27+
# lower-is-better for both objectives
28+
is_efficient = np.ones(points.shape[0], dtype=bool)
29+
for i, p in enumerate(points):
30+
if not is_efficient[i]:
31+
continue
32+
# any other point strictly better in both dims dominates p
33+
better = np.all(points <= p, axis=1) & np.any(points < p, axis=1)
34+
dominated = better & (np.arange(points.shape[0]) != i)
35+
if np.any(dominated):
36+
is_efficient[i] = False
37+
return points[is_efficient]
38+
39+
40+
def hypervolume_2d(pareto_points: np.ndarray, reference: np.ndarray) -> float:
41+
# assumes minimize-minimize; reference worse than all pareto_points
42+
if pareto_points.size == 0:
43+
return 0.0
44+
pts = pareto_points[np.argsort(pareto_points[:, 0])] # sort by first objective
45+
hv = 0.0
46+
prev_f2 = reference[1]
47+
for f1, f2 in pts:
48+
width = reference[0] - f1
49+
height = prev_f2 - f2
50+
if width > 0 and height > 0:
51+
hv += width * height
52+
prev_f2 = f2
53+
return hv
54+
55+
56+
def compute_hypervolume_over_time(study: optuna.Study, ref_slack=1.1):
57+
from optuna.trial import TrialState
58+
59+
completed = [t for t in study.trials if t.state == TrialState.COMPLETE]
60+
if not completed:
61+
return [], None
62+
63+
# Order by completion time
64+
completed.sort(key=lambda t: t.datetime_complete or t.datetime_start)
65+
all_vals = np.array([t.values for t in completed]) # shape (N, 2)
66+
reference = all_vals.max(axis=0) * ref_slack # slightly worse than worst seen
67+
68+
hypervolumes = []
69+
for k in range(1, len(completed) + 1):
70+
subset = completed[:k]
71+
pts = np.array([t.values for t in subset])
72+
pareto = pareto_front(pts)
73+
hv = hypervolume_2d(pareto, reference)
74+
hypervolumes.append(hv)
75+
return hypervolumes, reference
76+
77+
2678
def load_loss_history(model_path: str) -> tuple[np.ndarray, np.ndarray, int]:
2779
"""
2880
Load loss histories from a saved model file (.pth).
@@ -218,6 +270,49 @@ def evaluate_tuning(
218270
print(f"Could not load study '{full_name}'")
219271
continue
220272

273+
# Compute hypervolume over time
274+
if len(study.directions) == 2:
275+
hvs, reference = compute_hypervolume_over_time(study)
276+
if hvs:
277+
# Normalize to final hypervolume for relative curve
278+
final_hv = hvs[-1]
279+
rel_hvs = [hv / final_hv if final_hv > 0 else 0 for hv in hvs]
280+
281+
# Plot absolute and relative hypervolume
282+
plt.figure(figsize=(6, 4))
283+
plt.plot(np.arange(1, len(hvs) + 1), hvs, label="Hypervolume")
284+
plt.xlabel("Completed Trials")
285+
plt.ylabel("Hypervolume")
286+
plt.title(f"{suffix} Hypervolume over trials")
287+
plt.grid(True)
288+
plt.tight_layout()
289+
plt.savefig(
290+
os.path.join(save_dir, f"hypervolume_{suffix}.png"), dpi=300
291+
)
292+
plt.close()
293+
294+
plt.figure(figsize=(6, 4))
295+
plt.plot(
296+
np.arange(1, len(rel_hvs) + 1),
297+
rel_hvs,
298+
label="Relative Hypervolume",
299+
)
300+
plt.xlabel("Completed Trials")
301+
plt.ylabel("Fraction of Final HV")
302+
plt.title(f"{suffix} Relative Hypervolume")
303+
plt.grid(True)
304+
plt.tight_layout()
305+
plt.savefig(
306+
os.path.join(save_dir, f"hypervolume_relative_{suffix}.png"),
307+
dpi=300,
308+
)
309+
plt.close()
310+
print(f"Saved hypervolume plots for {suffix} (final HV={final_hv:.3e})")
311+
else:
312+
print("No hypervolume computed (no complete trials).")
313+
else:
314+
print("Skipping hypervolume: study is not two-objective.")
315+
221316
best = get_best_trials(study, top_n)
222317
if not best:
223318
print(f"No completed trials in {full_name}")
@@ -285,19 +380,19 @@ def parse_args():
285380
p.add_argument(
286381
"--study_name",
287382
type=str,
288-
default="cloud_tuning_rough",
383+
default="cloud_tuning_fine",
289384
help="Main study prefix (e.g. lvparams5)",
290385
)
291386
p.add_argument(
292387
"--storage_name",
293388
type=str,
294-
default="optuna_cloud",
389+
default="optuna_cloud_2",
295390
help="Main study prefix (e.g. lvparams5)",
296391
)
297392
p.add_argument(
298393
"--top_n",
299394
type=int,
300-
default=10,
395+
default=20,
301396
help="Number of top trials to plot per surrogate",
302397
)
303398
return p.parse_args()

0 commit comments

Comments
 (0)