Skip to content

Commit 0db8ff8

Browse files
committed
Fix wandb run reuse across models
1 parent 4cf171d commit 0db8ff8

2 files changed

Lines changed: 109 additions & 14 deletions

File tree

src/art/model.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
470470
id=self.name,
471471
config=self._wandb_config or None,
472472
resume="allow",
473+
reinit="create_new",
473474
settings=wandb.Settings(
474475
x_stats_open_metrics_endpoints={
475476
"vllm": "http://localhost:8000/metrics",
@@ -492,18 +493,18 @@ def _get_wandb_run(self) -> Optional["Run"]:
492493

493494
# Define training_step as the x-axis for all metrics.
494495
# This allows out-of-order logging (e.g., async validation for previous steps).
495-
wandb.define_metric("training_step")
496-
wandb.define_metric("time/wall_clock_sec")
497-
wandb.define_metric("reward/*", step_metric="training_step")
498-
wandb.define_metric("loss/*", step_metric="training_step")
499-
wandb.define_metric("throughput/*", step_metric="training_step")
500-
wandb.define_metric("costs/*", step_metric="training_step")
501-
wandb.define_metric("time/*", step_metric="training_step")
502-
wandb.define_metric("data/*", step_metric="training_step")
503-
wandb.define_metric("train/*", step_metric="training_step")
504-
wandb.define_metric("val/*", step_metric="training_step")
505-
wandb.define_metric("test/*", step_metric="training_step")
506-
wandb.define_metric("discarded/*", step_metric="training_step")
496+
run.define_metric("training_step")
497+
run.define_metric("time/wall_clock_sec")
498+
run.define_metric("reward/*", step_metric="training_step")
499+
run.define_metric("loss/*", step_metric="training_step")
500+
run.define_metric("throughput/*", step_metric="training_step")
501+
run.define_metric("costs/*", step_metric="training_step")
502+
run.define_metric("time/*", step_metric="training_step")
503+
run.define_metric("data/*", step_metric="training_step")
504+
run.define_metric("train/*", step_metric="training_step")
505+
run.define_metric("val/*", step_metric="training_step")
506+
run.define_metric("test/*", step_metric="training_step")
507+
run.define_metric("discarded/*", step_metric="training_step")
507508
self._sync_wandb_config(run)
508509
return self._wandb_run
509510

@@ -562,14 +563,16 @@ def _log_metrics(
562563
run.log(prefixed)
563564

564565
def _define_wandb_step_metrics(self, keys: Iterable[str]) -> None:
565-
import wandb
566+
run = self._wandb_run
567+
if run is None or run._is_finished:
568+
return
566569

567570
for key in keys:
568571
if not key.startswith("costs/"):
569572
continue
570573
if key in self._wandb_defined_metrics:
571574
continue
572-
wandb.define_metric(key, step_metric="training_step")
575+
run.define_metric(key, step_metric="training_step")
573576
self._wandb_defined_metrics.add(key)
574577

575578
def _route_metrics_and_collect_non_costs(

tests/unit/test_wandb_multi_run.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
from unittest.mock import patch
5+
6+
from art import Model
7+
8+
9+
def test_wandb_creates_separate_runs_per_model(tmp_path: Path):
10+
class FakeRun:
11+
def __init__(self, name: str):
12+
self.name = name
13+
self.id = name
14+
self._is_finished = False
15+
self.defined_metrics: list[tuple[str, str | None]] = []
16+
17+
def define_metric(self, name: str, *, step_metric: str | None = None) -> None:
18+
self.defined_metrics.append((name, step_metric))
19+
20+
class FakeWandb:
21+
def __init__(self):
22+
self.init_calls: list[dict] = []
23+
self.runs: list[FakeRun] = []
24+
25+
@staticmethod
26+
def Settings(**kwargs):
27+
return kwargs
28+
29+
def init(self, **kwargs):
30+
self.init_calls.append(kwargs)
31+
run = FakeRun(kwargs["name"])
32+
self.runs.append(run)
33+
return run
34+
35+
def define_metric(self, *args, **kwargs) -> None:
36+
raise AssertionError("Model should define metrics on the run object")
37+
38+
fake_wandb = FakeWandb()
39+
model_one = Model(
40+
name="run-one",
41+
project="test-project",
42+
base_path=str(tmp_path),
43+
)
44+
model_two = Model(
45+
name="run-two",
46+
project="test-project",
47+
base_path=str(tmp_path),
48+
)
49+
50+
with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}):
51+
with patch.dict(sys.modules, {"wandb": fake_wandb}):
52+
run_one = model_one._get_wandb_run()
53+
run_two = model_two._get_wandb_run()
54+
model_one._define_wandb_step_metrics(["costs/train/custom"])
55+
56+
assert run_one is not None
57+
assert run_two is not None
58+
assert run_one is not run_two
59+
assert [call["name"] for call in fake_wandb.init_calls] == [
60+
"run-one",
61+
"run-two",
62+
]
63+
assert all(call["reinit"] == "create_new" for call in fake_wandb.init_calls)
64+
assert run_one.defined_metrics == [
65+
("training_step", None),
66+
("time/wall_clock_sec", None),
67+
("reward/*", "training_step"),
68+
("loss/*", "training_step"),
69+
("throughput/*", "training_step"),
70+
("costs/*", "training_step"),
71+
("time/*", "training_step"),
72+
("data/*", "training_step"),
73+
("train/*", "training_step"),
74+
("val/*", "training_step"),
75+
("test/*", "training_step"),
76+
("discarded/*", "training_step"),
77+
("costs/train/custom", "training_step"),
78+
]
79+
assert run_two.defined_metrics == [
80+
("training_step", None),
81+
("time/wall_clock_sec", None),
82+
("reward/*", "training_step"),
83+
("loss/*", "training_step"),
84+
("throughput/*", "training_step"),
85+
("costs/*", "training_step"),
86+
("time/*", "training_step"),
87+
("data/*", "training_step"),
88+
("train/*", "training_step"),
89+
("val/*", "training_step"),
90+
("test/*", "training_step"),
91+
("discarded/*", "training_step"),
92+
]

0 commit comments

Comments
 (0)