Skip to content

Commit 7b37f42

Browse files
committed
second round of PR feedback
1 parent 1f1b54e commit 7b37f42

9 files changed

Lines changed: 51 additions & 20 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ on:
1313
- opened
1414
- synchronize
1515
- reopened
16-
- closed
1716

1817
permissions:
1918
contents: write
@@ -172,4 +171,4 @@ jobs:
172171
publish_branch: gh-pages
173172
user_name: "GitHub Actions"
174173
user_email: "actions@github.com"
175-
force_orphan: false
174+
force_orphan: false

codes/benchmark/bench_fcts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def run_benchmark(surr_name: str, surrogate_class, conf: dict) -> dict[str, Any]
103103
log=dataset_conf.get("log10_transform", True),
104104
log_params=dataset_conf.get("log10_transform_params", True),
105105
normalisation_mode=dataset_conf.get("normalise", "minmax"),
106-
tolerance=dataset_conf.get("tolerance", 1e-25),
106+
tolerance=dataset_conf.get("tolerance", None),
107107
per_species=dataset_conf.get("normalise_per_species", False),
108108
)
109109

codes/tune/optuna_fcts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def training_run(
258258
log=dataset_cfg.get("log10_transform", True),
259259
log_params=dataset_cfg.get("log10_transform_params", True),
260260
normalisation_mode=dataset_cfg.get("normalise", "minmax"),
261-
tolerance=dataset_cfg.get("tolerance", 1e-25),
261+
tolerance=dataset_cfg.get("tolerance", None),
262262
per_species=dataset_cfg.get("normalise_per_species", False),
263263
)
264264

codes/tune/tune_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def apply_tuning_defaults(config: dict) -> dict:
9696
dataset.setdefault("log10_transform_params", True)
9797
dataset.setdefault("normalise", "minmax")
9898
dataset.setdefault("normalise_per_species", False)
99-
dataset.setdefault("tolerance", 1e-25)
99+
dataset.setdefault("tolerance", None)
100100
dataset.setdefault("subset_factor", 1)
101101
dataset.setdefault("log_timesteps", False)
102102
cfg["dataset"] = dataset

docs/source/reference/configuration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ dataset:
9191
log10_transform_params: true
9292
normalise: "minmax"
9393
normalise_per_species: false
94-
tolerance: 1e-25
94+
tolerance: null
9595
subset_factor: 1
9696
log_timesteps: false
9797
use_optimal_params: true

docs/source/tutorials/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Interactive tutorials live in this folder so they can be rendered directly on th
1111
To execute a notebook locally:
1212

1313
```bash
14-
pip install -r requirements.txt # ensures docs dependencies
14+
uv pip install --group dev # installs docs + notebook deps
1515
jupyter lab docs/source/tutorials/benchmark_quickstart.ipynb
1616
```
1717

run_training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ def main(args):
3535
torch.use_deterministic_algorithms(True)
3636
config = load_and_save_config(config_path=args.config, save=False)
3737

38-
if torch.cuda.is_available():
39-
print(torch.cuda.device_count(), torch.cuda.current_device())
38+
if torch.cuda.is_available() and config.get("verbose", False):
39+
nice_print(
40+
f"CUDA devices detected: count={torch.cuda.device_count()}, current={torch.cuda.current_device()}"
41+
)
4042

4143
download_data(config["dataset"]["name"], verbose=config.get("verbose", False))
4244
task_list_filepath, copy_config = check_training_status(config)

run_tuning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def run_single_study(config: dict, study_name: str, db_url: str, sqlite_backend:
5656
if not config.get("optuna_logging", False):
5757
optuna.logging.set_verbosity(optuna.logging.WARNING)
5858

59-
if config.get("multi_objective", False):
59+
multi_objective = config.get("multi_objective", False)
60+
61+
if multi_objective:
6062
sampler = optuna.samplers.NSGAIISampler(
6163
seed=config["seed"], population_size=config["population_size"]
6264
)
@@ -151,7 +153,7 @@ def trial_complete_callback(study_: optuna.Study, trial_: optuna.trial.FrozenTri
151153
try:
152154
study.optimize(
153155
objective_fn,
154-
n_trials=n_trials * 2 if config["multi_objective"] else n_trials,
156+
n_trials=n_trials,
155157
n_jobs=n_jobs,
156158
callbacks=[
157159
MaxValidTrialsCallback(n_trials),

test/test_bench_main.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,47 @@ def test_time_inference(monkeypatch):
147147
assert result["mean_inference_time_per_prediction"] == pytest.approx(3.0 / 2)
148148

149149

150-
def test_evaluate_compute(monkeypatch):
151-
if not torch.cuda.is_available():
152-
pytest.skip("CUDA unavailable, skipping measure_memory_footprint test.")
150+
def test_evaluate_compute_cpu_path(monkeypatch):
151+
# Force non-CUDA execution
152+
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
153153

154-
# patch memory footprint and parameter count
154+
# measure_memory_footprint should not be called
155+
called = {"mem": False}
156+
157+
def fake_measure(model, inputs, device):
158+
called["mem"] = True
159+
return {}, model
160+
161+
monkeypatch.setattr(bf, "measure_memory_footprint", fake_measure)
162+
monkeypatch.setattr(bf, "count_trainable_parameters", lambda m: 12345)
163+
164+
class DummyLoader:
165+
def __iter__(self):
166+
yield ("inp",)
167+
168+
loader = DummyLoader()
169+
model = FakeModel()
170+
surr = "SurrB"
171+
conf = {"training_id": "TID"}
172+
out = bf.evaluate_compute(model, surr, test_loader=loader, conf=conf)
173+
assert model.load_calls == [("TID", surr, f"{surr.lower()}_main")]
174+
assert out["num_trainable_parameters"] == 12345
175+
assert out["memory_footprint"] == {}
176+
assert not called["mem"]
177+
178+
179+
@pytest.mark.skipif(
180+
not torch.cuda.is_available(), reason="CUDA unavailable for memory profiling test."
181+
)
182+
def test_evaluate_compute_cuda_path(monkeypatch):
155183
fake_mem = {"model_memory": 100, "forward_memory_nograd": 50}
156-
monkeypatch.setattr(
157-
bf, "measure_memory_footprint", lambda m, inp, device: (fake_mem, m)
158-
)
184+
185+
def fake_measure(model, inputs, device):
186+
return fake_mem, model
187+
188+
monkeypatch.setattr(bf, "measure_memory_footprint", fake_measure)
159189
monkeypatch.setattr(bf, "count_trainable_parameters", lambda m: 12345)
160190

161-
# test_loader yields one tuple of inputs
162191
class DummyLoader:
163192
def __iter__(self):
164193
yield ("inp",)
@@ -168,7 +197,6 @@ def __iter__(self):
168197
surr = "SurrB"
169198
conf = {"training_id": "TID"}
170199
out = bf.evaluate_compute(model, surr, test_loader=loader, conf=conf)
171-
# load main was invoked
172200
assert model.load_calls == [("TID", surr, f"{surr.lower()}_main")]
173201
assert out["num_trainable_parameters"] == 12345
174202
assert out["memory_footprint"] is fake_mem

0 commit comments

Comments
 (0)