Skip to content

Commit eab1327

Browse files
committed
Switch BestOfN and Tournament from task+n to candidates list
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 604ca9e commit eab1327

5 files changed

Lines changed: 108 additions & 100 deletions

File tree

README.md

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,22 +109,27 @@ if outcome.committed:
109109

110110
### Best-of-N with scoring
111111

112-
Run the same task N times (e.g. with different random seeds or temperatures)
113-
and commit the highest-scoring success.
112+
Run N candidates in parallel and commit the highest-scoring success.
114113

115114
Use when quality matters more than speed: code generation where you want
116115
the cleanest output across multiple temperatures, translation with a BLEU
117-
scorer picking the best variant, or any task with a reliable quality metric
118-
where the same prompt can produce varying results.
116+
scorer picking the best variant, or any task with a reliable quality metric.
117+
Pairs naturally with the ``n=`` parameter in OpenAI's Chat Completions API
118+
to generate N variations in a single call, then test each in an isolated
119+
branch.
119120

120121
```python
121122
from branching import BestOfN
122123

123-
def scored_task(path: Path, attempt: int) -> tuple[bool, float]:
124-
result = run_agent(workdir=path, seed=attempt)
125-
return result.passed, result.quality_score
124+
def make_candidate(code: str):
125+
def candidate(path: Path) -> tuple[bool, float]:
126+
(path / "solution.py").write_text(code)
127+
passed = run_tests(path)
128+
return passed, evaluate_quality(path) if passed else 0.0
129+
return candidate
126130

127-
outcome = BestOfN(scored_task, n=5)(ws)
131+
candidates = [make_candidate(c) for c in generate_solutions(n=5)]
132+
outcome = BestOfN(candidates)(ws)
128133
```
129134

130135
### Reflexion (retry with feedback)
@@ -217,7 +222,7 @@ outcome = BeamSearch(
217222

218223
### Tournament (pairwise elimination)
219224

220-
Generate N candidates in parallel, then narrow to one through pairwise
225+
Run N candidates in parallel, then narrow to one through pairwise
221226
elimination via a judge function. The convergent dual of Tree of Thoughts:
222227
starts wide, narrows to one.
223228

@@ -229,14 +234,19 @@ any setting where relative ranking is easier than absolute scoring.
229234
```python
230235
from branching import Tournament
231236

232-
def generate_patch(path: Path, index: int) -> bool:
233-
return run_agent(workdir=path, seed=index)
237+
def make_patch(code: str):
238+
def candidate(path: Path) -> bool:
239+
(path / "fix.patch").write_text(code)
240+
return apply_and_test(path)
241+
return candidate
242+
243+
candidates = [make_patch(p) for p in generate_patches(n=8)]
234244

235245
def judge(path_a: Path, path_b: Path) -> int:
236246
# 0 = a wins, 1 = b wins
237247
return llm_compare(path_a / "diff.patch", path_b / "diff.patch")
238248

239-
outcome = Tournament(generate_patch, n=8, judge=judge)(ws)
249+
outcome = Tournament(candidates, judge=judge)(ws)
240250
```
241251

242252
### Cascaded speculation (adaptive fan-out)
@@ -349,7 +359,7 @@ from branching import ResourceLimits, BestOfN
349359

350360
limits = ResourceLimits(memory=512 * 1024 * 1024, cpu=0.5) # 512 MB, 50% CPU
351361

352-
outcome = BestOfN(scored_task, n=5, resource_limits=limits)(ws)
362+
outcome = BestOfN(candidates, resource_limits=limits)(ws)
353363
```
354364

355365
All patterns accept `resource_limits`: `Speculate`, `BestOfN`, `Reflexion`,

src/branching/agent/patterns.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,30 @@
1919

2020

2121
class BestOfN:
22-
"""Run N copies of a task in parallel, commit the highest-scoring one.
22+
"""Run N candidates in parallel, commit the highest-scoring one.
2323
2424
All candidates run concurrently. Each holds its branch open after
2525
finishing. The main thread picks the winner based on score, then
2626
signals each thread to commit (winner) or abort (losers).
2727
28-
The task callable receives (path, attempt_index) and returns
28+
Each candidate callable receives (path,) and returns
2929
(success: bool, score: float).
3030
3131
Example:
32-
outcome = BestOfN(scored_task, n=5)(ws)
33-
# Commits the highest-scoring successful attempt
32+
candidates = [lambda p: (run_tests(p), score(p)) for _ in range(5)]
33+
outcome = BestOfN(candidates)(ws)
34+
# Commits the highest-scoring successful candidate
3435
"""
3536

3637
def __init__(
3738
self,
38-
task: Callable[[Path, int], tuple[bool, float]],
39-
n: int = 3,
39+
candidates: Sequence[Callable[[Path], tuple[bool, float]]],
4040
*,
4141
timeout: float | None = None,
4242
resource_limits: ResourceLimits | None = None,
4343
group_limits: ResourceLimits | None = None,
4444
):
45-
self._task = task
46-
self._n = n
45+
self._candidates = list(candidates)
4746
self._timeout = timeout
4847
self._resource_limits = resource_limits
4948
self._group_limits = group_limits
@@ -70,7 +69,7 @@ def __call__(self, workspace: Workspace) -> SpeculationOutcome:
7069
kill_scope(root_cgroup)
7170

7271
def _run(self, workspace: Workspace, root_cgroup: Optional[Path]) -> SpeculationOutcome:
73-
n = self._n
72+
n = len(self._candidates)
7473
results: list[Optional[SpeculationResult]] = [None] * n
7574
task_done = [threading.Event() for _ in range(n)]
7675
decision_ready = [threading.Event() for _ in range(n)]
@@ -96,7 +95,7 @@ def _on_scope(sp: Path, _i: int = index) -> None:
9695
branch_scopes[_i] = sp
9796

9897
ret = run_in_process(
99-
self._task, (b.path, index),
98+
self._candidates[index], (b.path,),
10099
workspace=b.path,
101100
limits=self._resource_limits,
102101
parent_cgroup=root_cgroup,
@@ -911,20 +910,19 @@ def _on_sub_scope(
911910

912911

913912
class Tournament:
914-
"""Pairwise elimination bracket: generate N candidates, compare
913+
"""Pairwise elimination bracket: run N candidates, compare
915914
pairwise via a judge function, commit the final winner.
916915
917916
The convergent dual of TreeOfThoughts: starts wide, narrows to one.
918917
919918
Example:
920-
outcome = Tournament(task, n=4, judge=judge)(ws)
919+
outcome = Tournament(candidates, judge=judge)(ws)
921920
# Commits the bracket winner
922921
"""
923922

924923
def __init__(
925924
self,
926-
task: Callable[[Path, int], bool],
927-
n: int = 4,
925+
candidates: Sequence[Callable[[Path], bool]],
928926
*,
929927
judge: Callable[[Path, Path], int],
930928
timeout: float | None = None,
@@ -933,17 +931,16 @@ def __init__(
933931
):
934932
"""
935933
Args:
936-
task: Callable(branch_path, candidate_index) → success.
937-
Produces output in the branch directory.
938-
n: Number of candidates to generate.
934+
candidates: Callables that take a Path (branch working dir)
935+
and return True on success. Each produces output in
936+
the branch directory for the judge to compare.
939937
judge: Callable(path_a, path_b) → 0 (a wins) or 1 (b wins).
940938
Compares two candidates' branches during elimination.
941939
timeout: Overall timeout in seconds.
942940
resource_limits: Optional per-branch resource limits.
943941
group_limits: Optional resource limits for the root cgroup.
944942
"""
945-
self._task = task
946-
self._n = n
943+
self._candidates = list(candidates)
947944
self._judge = judge
948945
self._timeout = timeout
949946
self._resource_limits = resource_limits
@@ -992,7 +989,7 @@ def __call__(self, workspace: Workspace) -> SpeculationOutcome:
992989
kill_scope(root_cgroup)
993990

994991
def _run(self, workspace: Workspace, root_cgroup: Optional[Path]) -> SpeculationOutcome:
995-
n = self._n
992+
n = len(self._candidates)
996993
results: list[Optional[SpeculationResult]] = [None] * n
997994
branch_paths: list[Optional[Path]] = [None] * n
998995
task_done = [threading.Event() for _ in range(n)]
@@ -1020,7 +1017,7 @@ def _on_scope(sp: Path, _i: int = index) -> None:
10201017
branch_scopes[_i] = sp
10211018

10221019
success = run_in_process(
1023-
self._task, (b.path, index),
1020+
self._candidates[index], (b.path,),
10241021
workspace=b.path,
10251022
limits=self._resource_limits,
10261023
parent_cgroup=root_cgroup,

src/cli/best_of_n.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from . import _parse_group_limits, _parse_resource_limits, _print_error, _resolve_workspace
1212

1313

14-
def _make_task(cmd: list[str]):
15-
"""Wrap a command into a BestOfN task callable.
14+
def _make_candidate(cmd: list[str], index: int):
15+
"""Wrap a command into a BestOfN candidate callable.
1616
17-
Returns a callable(path, index) -> (success, score).
17+
Returns a callable(path) -> (success, score).
1818
The child process can write a score float to fd 3.
1919
"""
2020

21-
def task(workdir: Path, index: int) -> tuple[bool, float]:
21+
def candidate(workdir: Path) -> tuple[bool, float]:
2222
# Create a pipe for the child to report its score.
2323
# Python 3.4+ creates pipe fds with CLOEXEC, so they are
2424
# automatically closed on exec — only fd 3 (dup2 clears
@@ -57,7 +57,7 @@ def _preexec():
5757

5858
return (success, score)
5959

60-
return task
60+
return candidate
6161

6262

6363
def cmd_best_of_n(args) -> int:
@@ -68,10 +68,10 @@ def cmd_best_of_n(args) -> int:
6868
ws_path = _resolve_workspace(args)
6969
ws = Workspace(ws_path)
7070

71-
task = _make_task(args.cmd)
71+
candidates = [_make_candidate(args.cmd, i) for i in range(args.n)]
7272
limits = _parse_resource_limits(args)
7373
group_limits = _parse_group_limits(args)
74-
best = BestOfN(task, n=args.n, timeout=args.timeout, resource_limits=limits, group_limits=group_limits)
74+
best = BestOfN(candidates, timeout=args.timeout, resource_limits=limits, group_limits=group_limits)
7575
outcome = best(ws)
7676

7777
results_summary = []

tests/test_resource_limits.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def test_speculate_accepts_resource_limits(self):
415415
def test_best_of_n_accepts_resource_limits(self):
416416
from branching.agent.patterns import BestOfN
417417
rl = ResourceLimits(cpu=0.5)
418-
bon = BestOfN(lambda p, i: (True, 1.0), n=2, resource_limits=rl)
418+
bon = BestOfN([lambda p: (True, 1.0)] * 2, resource_limits=rl)
419419
assert bon._resource_limits is rl
420420

421421
def test_reflexion_accepts_resource_limits(self):
@@ -444,7 +444,7 @@ def test_tournament_accepts_resource_limits(self):
444444
from branching.agent.patterns import Tournament
445445
rl = ResourceLimits(memory=4096)
446446
t = Tournament(
447-
lambda p, i: True, n=2,
447+
[lambda p: True] * 2,
448448
judge=lambda a, b: 0,
449449
resource_limits=rl,
450450
)
@@ -468,7 +468,7 @@ def test_speculate_accepts_group_limits(self):
468468
def test_best_of_n_accepts_group_limits(self):
469469
from branching.agent.patterns import BestOfN
470470
gl = ResourceLimits(cpu=2.0)
471-
bon = BestOfN(lambda p, i: (True, 1.0), n=2, group_limits=gl)
471+
bon = BestOfN([lambda p: (True, 1.0)] * 2, group_limits=gl)
472472
assert bon._group_limits is gl
473473

474474
def test_reflexion_accepts_group_limits(self):
@@ -497,7 +497,7 @@ def test_tournament_accepts_group_limits(self):
497497
from branching.agent.patterns import Tournament
498498
gl = ResourceLimits(memory=4096, cpu=2.0)
499499
t = Tournament(
500-
lambda p, i: True, n=2,
500+
[lambda p: True] * 2,
501501
judge=lambda a, b: 0,
502502
group_limits=gl,
503503
)
@@ -923,15 +923,14 @@ def mock_rip(fn, args, *, workspace, limits,
923923
scope = workspace / ".scope"
924924
if scope_callback:
925925
scope_callback(scope)
926-
idx = args[1]
927-
return (True, float(idx))
926+
return fn(*args)
928927

929928
ws = _mock_workspace()
930929
with patch("branching.agent.patterns.run_in_process", side_effect=mock_rip), \
931930
patch("branching.process._cgroup.kill_scope",
932931
side_effect=lambda s: killed.append(s)):
933932
bon = BestOfN(
934-
lambda p, i: (True, float(i)), n=2,
933+
[lambda p, i=i: (True, float(i)) for i in range(2)],
935934
resource_limits=ResourceLimits(memory=1024),
936935
)
937936
bon._run(ws, None)
@@ -949,17 +948,21 @@ def mock_rip(fn, args, *, workspace, limits,
949948
scope = workspace / ".scope"
950949
if scope_callback:
951950
scope_callback(scope)
952-
idx = args[1]
953-
if idx == 1:
954-
time.sleep(2) # simulate stuck task
955-
return (True, float(idx))
951+
return fn(*args)
952+
953+
def fast(p):
954+
return (True, 0.0)
955+
956+
def stuck(p):
957+
time.sleep(2)
958+
return (True, 1.0)
956959

957960
ws = _mock_workspace()
958961
with patch("branching.agent.patterns.run_in_process", side_effect=mock_rip), \
959962
patch("branching.process._cgroup.kill_scope",
960963
side_effect=lambda s: killed.append(s)):
961964
bon = BestOfN(
962-
lambda p, i: (True, float(i)), n=2,
965+
[fast, stuck],
963966
timeout=0.05,
964967
resource_limits=ResourceLimits(memory=1024),
965968
)
@@ -1008,14 +1011,14 @@ def mock_rip(fn, args, *, workspace, limits,
10081011
scope = workspace / ".scope"
10091012
if scope_callback:
10101013
scope_callback(scope)
1011-
return True
1014+
return fn(*args)
10121015

10131016
ws = _mock_workspace()
10141017
with patch("branching.agent.patterns.run_in_process", side_effect=mock_rip), \
10151018
patch("branching.process._cgroup.kill_scope",
10161019
side_effect=lambda s: killed.append(s)):
10171020
t = Tournament(
1018-
lambda p, i: True, n=2,
1021+
[lambda p: True] * 2,
10191022
judge=lambda a, b: 0,
10201023
resource_limits=ResourceLimits(memory=1024),
10211024
)
@@ -1030,7 +1033,7 @@ def test_no_kill_without_resource_limits(self):
10301033
with patch("branching.process._cgroup.kill_scope",
10311034
side_effect=lambda s: killed.append(s)):
10321035
t = Tournament(
1033-
lambda p, i: True, n=2,
1036+
[lambda p: True] * 2,
10341037
judge=lambda a, b: 0,
10351038
)
10361039
t._run(ws, None)

0 commit comments

Comments
 (0)