Skip to content

Commit e933fc0

Browse files
committed
Merge branch 'improve-surrogate-evaluation' of https://github.com/robin-janssen/CODES-Benchmark into improve-surrogate-evaluation
2 parents 363fb3e + 9e0ca37 commit e933fc0

7 files changed

Lines changed: 350 additions & 94 deletions

File tree

codes/surrogates/LatentPolynomial/latent_poly.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def total_loss(
399399
traj_loss = criterion(x_pred, x_true)
400400

401401
# identity loss (reconstruct x0)
402-
identity = self.identity_loss(x_true, params.to(self.device))
402+
identity = self.identity_loss(x_true, params)
403403

404404
# derivative losses: compute once
405405
d_pred = self.first_derivative(x_pred)
@@ -455,6 +455,7 @@ def identity_loss(self, x_true: Tensor, params: Tensor = None):
455455
# only reconstruct the initial state
456456
x0 = x_true[:, 0, :]
457457
if not self.config.coeff_network and params is not None:
458+
params = params.to(self.device)
458459
enc_input = torch.cat([x0, params], dim=1)
459460
else:
460461
enc_input = x0

codes/tune/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
)
77
from .optuna_fcts import (
88
MaxValidTrialsCallback,
9+
_count_valid_trials,
10+
build_fine_optuna_params,
911
create_objective,
1012
load_yaml_config,
1113
make_optuna_params,
@@ -31,6 +33,8 @@
3133

3234
__all__ = [
3335
"create_objective",
36+
"MaxValidTrialsCallback",
37+
"build_fine_optuna_params",
3438
"load_yaml_config",
3539
"make_optuna_params",
3640
"maybe_set_runtime_threshold",
@@ -51,4 +55,5 @@
5155
"_check_remote_reachable",
5256
"_initialize_postgres_local",
5357
"_initialize_postgres_remote",
58+
"_count_valid_trials",
5459
]

codes/tune/optuna_fcts.py

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,60 @@ def objective(trial):
223223
return objective
224224

225225

226+
def create_objective(
227+
config: dict, study_name: str, device_queue: queue.Queue
228+
) -> callable:
229+
"""
230+
Create the objective function for Optuna.
231+
232+
Args:
233+
config (dict): Configuration dictionary.
234+
study_name (str): Name of the study.
235+
device_queue (queue.Queue): Queue of available devices.
236+
237+
Returns:
238+
function: Objective function for Optuna.
239+
"""
240+
241+
def objective(trial):
242+
device, slot_id = device_queue.get()
243+
try:
244+
try:
245+
return training_run(trial, device, slot_id, config, study_name)
246+
except torch.cuda.OutOfMemoryError as e:
247+
torch.cuda.empty_cache()
248+
msg = repr(e).strip()
249+
if not msg:
250+
msg = "CUDA Out of Memory (no details provided)."
251+
trial.set_user_attr("exception", msg)
252+
tqdm.write(f"[Trial {trial.number}] resulted in an OOM error.")
253+
# raise optuna.TrialPruned(f"OOM error in trial {trial.number}")
254+
if config.get("multi_objective", False):
255+
# In multi-objective mode, we return a tuple
256+
return float(config.get("loss_cap", 20)), float(10)
257+
else:
258+
# In single objective mode, we return a single value
259+
return float(config.get("loss_cap", 20))
260+
except optuna.TrialPruned as e:
261+
msg = repr(e).strip()
262+
trial.set_user_attr("exception", msg)
263+
raise
264+
except Exception as e:
265+
torch.cuda.empty_cache()
266+
msg = repr(e).strip()
267+
if not msg:
268+
msg = "Unknown error occurred."
269+
tqdm.write(
270+
f"Trial {trial.number} failed due to an unexpected error: {msg}"
271+
)
272+
trial.set_user_attr("exception", msg)
273+
raise optuna.TrialPruned(f"Error in trial {trial.number}: {msg}")
274+
finally:
275+
device_queue.put((device, slot_id))
276+
277+
return objective
278+
279+
226280
def training_run(
227281
trial: optuna.Trial, device: str, slot_id: int, config: dict, study_name: str
228282
) -> float | tuple[float, float]:
@@ -244,7 +298,6 @@ def training_run(
244298

245299
download_data(config["dataset"]["name"], verbose=False)
246300

247-
# Load full data and parameters
248301
(
249302
(train_data, test_data, _),
250303
(train_params, test_params, _),
@@ -263,21 +316,29 @@ def training_run(
263316
)
264317

265318
subset_factor = config["dataset"].get("subset_factor", 1)
266-
# Get the appropriate subset of the training data
267-
# We nevertheless use the full test data to measure performance.
268319
train_data = train_data[::subset_factor]
269320
train_params = train_params[::subset_factor] if train_params is not None else None
270321

271322
set_random_seeds(config["seed"], device=device)
272323
surr_name = config["surrogate"]["name"]
273-
suggested_params = make_optuna_params(trial, config["optuna_params"])
274-
n_params = train_params.shape[1] if train_params is not None else 0
275324

325+
# Load base (best) config from disk as you already do
326+
model_config = get_model_config(surr_name, config)
327+
328+
# Decide search space
329+
if config.get("fine", False):
330+
fine_space = config.get("fine_space")
331+
suggested_params = make_optuna_params(trial, fine_space)
332+
else:
333+
suggested_params = make_optuna_params(trial, config["optuna_params"])
334+
335+
n_params = train_params.shape[1] if train_params is not None else 0
276336
n_timesteps = train_data.shape[1]
277337
n_quantities = train_data.shape[2]
278338
surrogate_class = get_surrogate(surr_name)
279-
model_config = get_model_config(surr_name, config)
339+
280340
model_config.update(suggested_params)
341+
281342
model = surrogate_class(
282343
device=device,
283344
n_quantities=n_quantities,
@@ -312,30 +373,21 @@ def training_run(
312373
multi_objective=config["multi_objective"],
313374
)
314375

315-
# criterion = torch.nn.MSELoss()
316376
preds, targets = model.predict(test_loader, leave_log=True)
317377
p99_dex = torch.quantile(
318378
(preds - targets).abs().flatten(), float(config["target_percentile"])
319379
).item()
320-
# cap the loss to prevent exploding values
321380
p99_dex = min(p99_dex, config.get("loss_cap", 20))
322381

323-
# Extract the study name without the timestamp/suffix part
324382
parts = study_name.split("_")
325383
sname = "_".join(parts[:-1]) if len(parts) > 1 else study_name
326384

327385
savepath = os.path.join("tuned", sname, "models")
328386
os.makedirs(savepath, exist_ok=True)
329387
model_name = f"{surr_name.lower()}_{trial.number}"
330-
model.save(
331-
model_name=model_name,
332-
base_dir="",
333-
training_id=savepath,
334-
)
388+
model.save(model_name=model_name, base_dir="", training_id=savepath)
335389

336-
# Check if we're running multi-objective optimisation
337390
if config["multi_objective"]:
338-
# Measure inference time
339391
with _inference_time_lock:
340392
inference_times = measure_inference_time(model, test_loader)
341393
return p99_dex, np.mean(inference_times)
@@ -405,3 +457,54 @@ def is_bad(tr):
405457
f"\n[Study] Warmup complete. Runtime threshold set to {threshold:.1f}s "
406458
f"(mean = {mean_:.1f}s, std = {std_:.1f}s) over trials {used_trial_numbers}."
407459
)
460+
461+
462+
def _bounds_around(
463+
v: float, factor: float = 10.0, lo: float | None = None, hi: float | None = None
464+
) -> tuple[float, float]:
465+
low, high = float(v) / factor, float(v) * factor
466+
if lo is not None:
467+
low = max(low, lo)
468+
if hi is not None:
469+
high = min(high, hi)
470+
# avoid degenerate ranges
471+
if high <= low:
472+
eps = max(abs(v) * 1e-3, 1e-12)
473+
low, high = float(v) - eps, float(v) + eps
474+
return low, high
475+
476+
477+
def build_fine_optuna_params(model_config: dict) -> dict:
478+
keys = (
479+
"learning_rate",
480+
"beta",
481+
"poly_power",
482+
"eta_min",
483+
"regularization_factor",
484+
"momentum",
485+
)
486+
space: dict[str, dict] = {}
487+
for k in keys:
488+
if k not in model_config:
489+
continue
490+
val = model_config[k]
491+
if not isinstance(val, (int, float)) or val == 0:
492+
continue
493+
lo, hi = _bounds_around(
494+
val,
495+
factor=10.0,
496+
lo=1e-12 if k != "momentum" else 0.0,
497+
hi=0.999 if k == "momentum" else None,
498+
)
499+
space[k] = {"type": "float", "low": lo, "high": hi, "log": True}
500+
return space
501+
502+
503+
def _is_valid_trial(t: optuna.trial.FrozenTrial) -> bool:
504+
return (t.state in (TrialState.COMPLETE, TrialState.PRUNED)) and (
505+
"exception" not in t.user_attrs
506+
)
507+
508+
509+
def _count_valid_trials(study: optuna.Study) -> int:
510+
return sum(1 for t in study.get_trials(deepcopy=False) if _is_valid_trial(t))

codes/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def create_model_dir(
6363

6464
# Check if the directory exists, and create it if it doesn't
6565
if not os.path.exists(full_path):
66-
os.makedirs(full_path)
66+
os.makedirs(full_path, exist_ok=True)
6767

6868
return full_path
6969

0 commit comments

Comments
 (0)