Skip to content

Commit 46f65ef

Browse files
committed
implement hyperparameter finetuning mode
1 parent 184a566 commit 46f65ef

3 files changed

Lines changed: 208 additions & 28 deletions

File tree

codes/tune/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
)
77
from .optuna_fcts import (
88
MaxValidTrialsCallback,
9+
_count_valid_trials,
10+
build_fine_optuna_params,
911
create_objective,
1012
load_yaml_config,
1113
make_optuna_params,
@@ -31,6 +33,8 @@
3133

3234
__all__ = [
3335
"create_objective",
36+
"MaxValidTrialsCallback",
37+
"build_fine_optuna_params",
3438
"load_yaml_config",
3539
"make_optuna_params",
3640
"maybe_set_runtime_threshold",
@@ -51,4 +55,5 @@
5155
"_check_remote_reachable",
5256
"_initialize_postgres_local",
5357
"_initialize_postgres_remote",
58+
"_count_valid_trials",
5459
]

codes/tune/optuna_fcts.py

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,60 @@ def objective(trial):
223223
return objective
224224

225225

226+
def create_objective(
227+
config: dict, study_name: str, device_queue: queue.Queue
228+
) -> callable:
229+
"""
230+
Create the objective function for Optuna.
231+
232+
Args:
233+
config (dict): Configuration dictionary.
234+
study_name (str): Name of the study.
235+
device_queue (queue.Queue): Queue of available devices.
236+
237+
Returns:
238+
function: Objective function for Optuna.
239+
"""
240+
241+
def objective(trial):
242+
device, slot_id = device_queue.get()
243+
try:
244+
try:
245+
return training_run(trial, device, slot_id, config, study_name)
246+
except torch.cuda.OutOfMemoryError as e:
247+
torch.cuda.empty_cache()
248+
msg = repr(e).strip()
249+
if not msg:
250+
msg = "CUDA Out of Memory (no details provided)."
251+
trial.set_user_attr("exception", msg)
252+
tqdm.write(f"[Trial {trial.number}] resulted in an OOM error.")
253+
# raise optuna.TrialPruned(f"OOM error in trial {trial.number}")
254+
if config.get("multi_objective", False):
255+
# In multi-objective mode, we return a tuple
256+
return float(config.get("loss_cap", 20)), float(10)
257+
else:
258+
# In single objective mode, we return a single value
259+
return float(config.get("loss_cap", 20))
260+
except optuna.TrialPruned as e:
261+
msg = repr(e).strip()
262+
trial.set_user_attr("exception", msg)
263+
raise
264+
except Exception as e:
265+
torch.cuda.empty_cache()
266+
msg = repr(e).strip()
267+
if not msg:
268+
msg = "Unknown error occurred."
269+
tqdm.write(
270+
f"Trial {trial.number} failed due to an unexpected error: {msg}"
271+
)
272+
trial.set_user_attr("exception", msg)
273+
raise optuna.TrialPruned(f"Error in trial {trial.number}: {msg}")
274+
finally:
275+
device_queue.put((device, slot_id))
276+
277+
return objective
278+
279+
226280
def training_run(
227281
trial: optuna.Trial, device: str, slot_id: int, config: dict, study_name: str
228282
) -> float | tuple[float, float]:
@@ -244,7 +298,6 @@ def training_run(
244298

245299
download_data(config["dataset"]["name"], verbose=False)
246300

247-
# Load full data and parameters
248301
(
249302
(train_data, test_data, _),
250303
(train_params, test_params, _),
@@ -263,21 +316,29 @@ def training_run(
263316
)
264317

265318
subset_factor = config["dataset"].get("subset_factor", 1)
266-
# Get the appropriate subset of the training data
267-
# We nevertheless use the full test data to measure performance.
268319
train_data = train_data[::subset_factor]
269320
train_params = train_params[::subset_factor] if train_params is not None else None
270321

271322
set_random_seeds(config["seed"], device=device)
272323
surr_name = config["surrogate"]["name"]
273-
suggested_params = make_optuna_params(trial, config["optuna_params"])
274-
n_params = train_params.shape[1] if train_params is not None else 0
275324

325+
# Load base (best) config from disk as you already do
326+
model_config = get_model_config(surr_name, config)
327+
328+
# Decide search space
329+
if config.get("fine", False):
330+
fine_space = config.get("fine_space")
331+
suggested_params = make_optuna_params(trial, fine_space)
332+
else:
333+
suggested_params = make_optuna_params(trial, config["optuna_params"])
334+
335+
n_params = train_params.shape[1] if train_params is not None else 0
276336
n_timesteps = train_data.shape[1]
277337
n_quantities = train_data.shape[2]
278338
surrogate_class = get_surrogate(surr_name)
279-
model_config = get_model_config(surr_name, config)
339+
280340
model_config.update(suggested_params)
341+
281342
model = surrogate_class(
282343
device=device,
283344
n_quantities=n_quantities,
@@ -312,30 +373,21 @@ def training_run(
312373
multi_objective=config["multi_objective"],
313374
)
314375

315-
# criterion = torch.nn.MSELoss()
316376
preds, targets = model.predict(test_loader, leave_log=True)
317377
p99_dex = torch.quantile(
318378
(preds - targets).abs().flatten(), float(config["target_percentile"])
319379
).item()
320-
# cap the loss to prevent exploding values
321380
p99_dex = min(p99_dex, config.get("loss_cap", 20))
322381

323-
# Extract the study name without the timestamp/suffix part
324382
parts = study_name.split("_")
325383
sname = "_".join(parts[:-1]) if len(parts) > 1 else study_name
326384

327385
savepath = os.path.join("tuned", sname, "models")
328386
os.makedirs(savepath, exist_ok=True)
329387
model_name = f"{surr_name.lower()}_{trial.number}"
330-
model.save(
331-
model_name=model_name,
332-
base_dir="",
333-
training_id=savepath,
334-
)
388+
model.save(model_name=model_name, base_dir="", training_id=savepath)
335389

336-
# Check if we're running multi-objective optimisation
337390
if config["multi_objective"]:
338-
# Measure inference time
339391
with _inference_time_lock:
340392
inference_times = measure_inference_time(model, test_loader)
341393
return p99_dex, np.mean(inference_times)
@@ -405,3 +457,54 @@ def is_bad(tr):
405457
f"\n[Study] Warmup complete. Runtime threshold set to {threshold:.1f}s "
406458
f"(mean = {mean_:.1f}s, std = {std_:.1f}s) over trials {used_trial_numbers}."
407459
)
460+
461+
462+
def _bounds_around(
463+
v: float, factor: float = 10.0, lo: float | None = None, hi: float | None = None
464+
) -> tuple[float, float]:
465+
low, high = float(v) / factor, float(v) * factor
466+
if lo is not None:
467+
low = max(low, lo)
468+
if hi is not None:
469+
high = min(high, hi)
470+
# avoid degenerate ranges
471+
if high <= low:
472+
eps = max(abs(v) * 1e-3, 1e-12)
473+
low, high = float(v) - eps, float(v) + eps
474+
return low, high
475+
476+
477+
def build_fine_optuna_params(model_config: dict) -> dict:
478+
keys = (
479+
"learning_rate",
480+
"beta",
481+
"poly_power",
482+
"eta_min",
483+
"regularization_factor",
484+
"momentum",
485+
)
486+
space: dict[str, dict] = {}
487+
for k in keys:
488+
if k not in model_config:
489+
continue
490+
val = model_config[k]
491+
if not isinstance(val, (int, float)) or val == 0:
492+
continue
493+
lo, hi = _bounds_around(
494+
val,
495+
factor=10.0,
496+
lo=1e-12 if k != "momentum" else 0.0,
497+
hi=0.999 if k == "momentum" else None,
498+
)
499+
space[k] = {"type": "float", "low": lo, "high": hi, "log": True}
500+
return space
501+
502+
503+
def _is_valid_trial(t: optuna.trial.FrozenTrial) -> bool:
504+
return (t.state in (TrialState.COMPLETE, TrialState.PRUNED)) and (
505+
"exception" not in t.user_attrs
506+
)
507+
508+
509+
def _count_valid_trials(study: optuna.Study) -> int:
510+
return sum(1 for t in study.get_trials(deepcopy=False) if _is_valid_trial(t))

run_tuning.py

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import argparse
2+
import os
23
import queue
34
import sys
45
import time
56
from pathlib import Path
67

78
import optuna
9+
import yaml
810
from optuna.trial import TrialState
911
from tqdm import tqdm
1012

13+
from codes.benchmark import get_model_config
1114
from codes.tune import (
1215
MaxValidTrialsCallback,
16+
_count_valid_trials,
17+
build_fine_optuna_params,
1318
create_objective,
1419
initialize_optuna_database,
1520
load_yaml_config,
@@ -23,6 +28,16 @@ def run_single_study(config: dict, study_name: str, db_url: str):
2328
if not config.get("optuna_logging", False):
2429
optuna.logging.set_verbosity(optuna.logging.WARNING)
2530

31+
if config.get("fine", False):
32+
try:
33+
base_cfg = get_model_config(config["surrogate"]["name"], config)
34+
finetune_space = build_fine_optuna_params(base_cfg)
35+
n_fine = len(finetune_space)
36+
except Exception:
37+
n_fine = 0 # conservative fallback
38+
39+
config["n_trials"] = max(5 * n_fine, 5) # disregard YAML trials
40+
2641
if config["multi_objective"]:
2742
sampler = optuna.samplers.NSGAIISampler(
2843
seed=config["seed"], population_size=config["population_size"]
@@ -56,6 +71,13 @@ def run_single_study(config: dict, study_name: str, db_url: str):
5671
load_if_exists=True,
5772
)
5873

74+
have = _count_valid_trials(study)
75+
if have >= config["n_trials"]:
76+
print(
77+
f"[skip] {study_name}: already has {have} valid trials (target {config['n_trials']}). Skipping optimize()."
78+
)
79+
return
80+
5981
device_queue = queue.Queue()
6082
for slot_id, dev in enumerate(config["devices"]):
6183
device_queue.put((dev, slot_id))
@@ -110,7 +132,11 @@ def trial_complete_callback(study_: optuna.Study, trial_: optuna.trial.FrozenTri
110132

111133
def run_all_studies(config: dict, main_study_name: str, db_url: str):
112134
surrogates = config["surrogates"]
113-
global_params = config.get("global_optuna_params", {})
135+
global_params = (
136+
{} if config.get("fine", False) else config.get("global_optuna_params", {})
137+
)
138+
139+
fine_report: dict[str, dict] = {}
114140

115141
total_sub_studies = len(surrogates)
116142
with tqdm(
@@ -122,15 +148,47 @@ def run_all_studies(config: dict, main_study_name: str, db_url: str):
122148
)
123149

124150
for surr in surrogates:
125-
local = surr.get("optuna_params", {})
126-
for name, opts in global_params.items():
127-
if name in local:
128-
print(
129-
f"⚠️ Hyperparameter '{name}' defined globally and locally for {surr['name']}; using local."
130-
)
131-
else:
132-
local[name] = opts
133-
surr["optuna_params"] = local
151+
arch_name = surr["name"]
152+
if config.get("fine", False):
153+
# ignore manual search spaces
154+
surr["optuna_params"] = {}
155+
156+
# derive fine space from previously best config
157+
base_cfg = get_model_config(arch_name, config)
158+
fine_space = build_fine_optuna_params(base_cfg)
159+
n_fine = len(fine_space)
160+
n_trials_override = max(5 * n_fine, 5)
161+
162+
# CLI confirmation
163+
print(
164+
f"[fine] {arch_name}: found fine-tunable parameters: {list(fine_space.keys()) or 'none'}"
165+
)
166+
for k, spec in fine_space.items():
167+
print(f" - {k}: [{spec['low']:.3g}, {spec['high']:.3g}] (log)")
168+
print(f" -> running for {n_trials_override} trials\n")
169+
170+
# stash for YAML and pass along to run_single_study
171+
fine_report[arch_name] = {
172+
"trials": int(n_trials_override),
173+
"params": {
174+
k: {
175+
"low": float(v["low"]),
176+
"high": float(v["high"]),
177+
"log": bool(v.get("log", False)),
178+
}
179+
for k, v in fine_space.items()
180+
},
181+
}
182+
else:
183+
local = surr.get("optuna_params", {})
184+
for name, opts in global_params.items():
185+
if name in local:
186+
print(
187+
f"⚠️ Hyperparameter '{name}' defined globally and locally for {surr['name']}; using local."
188+
)
189+
else:
190+
local[name] = opts
191+
surr["optuna_params"] = local
134192

135193
arch_name = surr["name"]
136194
study_name = f"{main_study_name}_{arch_name.lower()}"
@@ -142,22 +200,36 @@ def run_all_studies(config: dict, main_study_name: str, db_url: str):
142200
"dataset": config["dataset"],
143201
"devices": config["devices"],
144202
"epochs": surr["epochs"],
145-
"n_trials": trials,
203+
"n_trials": trials if not n_trials_override else n_trials_override,
146204
"seed": config["seed"],
147205
"surrogate": {"name": arch_name},
148-
"optuna_params": surr["optuna_params"],
206+
"optuna_params": surr.get("optuna_params", {}),
149207
"prune": config.get("prune", True),
150208
"optuna_logging": config.get("optuna_logging", False),
151209
"use_optimal_params": config.get("use_optimal_params", False),
152210
"multi_objective": config.get("multi_objective", False),
153211
"population_size": config.get("population_size", 50),
154212
"target_percentile": config.get("target_percentile", 0.95),
213+
"fine": config.get("fine", False), # pass through
214+
"loss_cap": config.get("loss_cap", 20),
155215
}
156216

217+
if config.get("fine", False):
218+
sub_config["fine_space"] = fine_space
219+
157220
run_single_study(sub_config, study_name, db_url)
158221
arch_pbar.update(1)
159222
arch_pbar.set_postfix({"done": study_name})
160223

224+
# Write YAML summary once per main study (only in fine mode)
225+
if config.get("fine", False):
226+
out_dir = os.path.join("tuned", main_study_name)
227+
os.makedirs(out_dir, exist_ok=True)
228+
out_path = os.path.join(out_dir, "fine_summary.yaml")
229+
with open(out_path, "w", encoding="utf-8") as f:
230+
yaml.safe_dump(fine_report, f, sort_keys=True, default_flow_style=False)
231+
print(f"[fine] Wrote summary: {out_path}")
232+
161233

162234
def parse_arguments():
163235
parser = argparse.ArgumentParser(description="Run Optuna tuning studies.")

0 commit comments

Comments
 (0)