Skip to content

Commit 820baec

Browse files
committed
Add Tournament pattern: pairwise elimination bracket
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 8be1bc4 commit 820baec

4 files changed

Lines changed: 293 additions & 2 deletions

File tree

README.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ raises, everything is rolled back - the workspace is untouched.
3939

4040
## Agent patterns
4141

42-
BranchContext ships with four high-level patterns that cover the most common
42+
BranchContext ships with five high-level patterns that cover the most common
4343
agent workflows. Each is a callable class: instantiate with config, call with
4444
a workspace.
4545

@@ -142,6 +142,30 @@ outcome = TreeOfThoughts(
142142
)(ws)
143143
```
144144

145+
### Tournament (pairwise elimination)
146+
147+
Generate N candidates in parallel, then narrow to one through pairwise
148+
elimination via a judge function. The convergent dual of Tree of Thoughts:
149+
starts wide, narrows to one.
150+
151+
Use when you have a reliable pairwise comparator but no absolute scoring
152+
function: patch selection where an LLM judge picks the better diff,
153+
A/B-style evaluation where candidates are compared head-to-head, or
154+
any setting where relative ranking is easier than absolute scoring.
155+
156+
```python
157+
from branching import Tournament
158+
159+
def generate_patch(path: Path, index: int) -> bool:
160+
return run_agent(workdir=path, seed=index)
161+
162+
def judge(path_a: Path, path_b: Path) -> int:
163+
# 0 = a wins, 1 = b wins
164+
return llm_compare(path_a / "diff.patch", path_b / "diff.patch")
165+
166+
outcome = Tournament(generate_patch, n=8, judge=judge)(ws)
167+
```
168+
145169
## Lower-level usage
146170

147171
The patterns above are built on two lower-level primitives you can use
@@ -289,6 +313,7 @@ All patterns: instantiate with config, call with a `Workspace`, get a
289313
| **`BestOfN`** | `(task, n=3, *, timeout=None)` | Run N copies; commit highest-scoring success |
290314
| **`Reflexion`** | `(task, max_retries=3, *, critique=None)` | Retry with critique feedback loop |
291315
| **`TreeOfThoughts`** | `(strategies, *, evaluate=None, expand=None, max_depth=1, timeout=None)` | Parallel strategy tree with optional depth expansion |
316+
| **`Tournament`** | `(task, n=4, *, judge, timeout=None)` | Generate N candidates; pairwise elimination picks winner |
292317

293318
### Result types
294319

src/branching/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"BestOfN",
5151
"Reflexion",
5252
"TreeOfThoughts",
53+
"Tournament",
5354
# Results
5455
"SpeculationResult",
5556
"SpeculationOutcome",
@@ -82,6 +83,7 @@
8283
"BestOfN": ".agent.patterns",
8384
"Reflexion": ".agent.patterns",
8485
"TreeOfThoughts": ".agent.patterns",
86+
"Tournament": ".agent.patterns",
8587
# Results
8688
"SpeculationResult": ".agent.result",
8789
"SpeculationOutcome": ".agent.result",

src/branching/agent/patterns.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,150 @@ def _multi_level(self, workspace: Workspace) -> SpeculationOutcome:
418418
all_results=all_results,
419419
committed=True,
420420
)
421+
422+
423+
class Tournament:
424+
"""Pairwise elimination bracket: generate N candidates, compare
425+
pairwise via a judge function, commit the final winner.
426+
427+
The convergent dual of TreeOfThoughts: starts wide, narrows to one.
428+
429+
Example:
430+
outcome = Tournament(task, n=4, judge=judge)(ws)
431+
# Commits the bracket winner
432+
"""
433+
434+
def __init__(
435+
self,
436+
task: Callable[[Path, int], bool],
437+
n: int = 4,
438+
*,
439+
judge: Callable[[Path, Path], int],
440+
timeout: float | None = None,
441+
):
442+
"""
443+
Args:
444+
task: Callable(branch_path, candidate_index) → success.
445+
Produces output in the branch directory.
446+
n: Number of candidates to generate.
447+
judge: Callable(path_a, path_b) → 0 (a wins) or 1 (b wins).
448+
Compares two candidates' branches during elimination.
449+
timeout: Overall timeout in seconds.
450+
"""
451+
self._task = task
452+
self._n = n
453+
self._judge = judge
454+
self._timeout = timeout
455+
456+
@staticmethod
457+
def _run_bracket(
458+
survivors: list[int],
459+
branch_paths: list[Path],
460+
judge: Callable[[Path, Path], int],
461+
) -> int:
462+
"""Single-elimination bracket. Returns the winning candidate index."""
463+
while len(survivors) > 1:
464+
next_round: list[int] = []
465+
i = 0
466+
while i < len(survivors) - 1:
467+
a, b = survivors[i], survivors[i + 1]
468+
pick = judge(branch_paths[a], branch_paths[b])
469+
next_round.append(b if pick else a)
470+
i += 2
471+
# Odd candidate gets a bye
472+
if len(survivors) % 2 == 1:
473+
next_round.append(survivors[-1])
474+
survivors = next_round
475+
return survivors[0]
476+
477+
def __call__(self, workspace: Workspace) -> SpeculationOutcome:
478+
n = self._n
479+
results: list[Optional[SpeculationResult]] = [None] * n
480+
branch_paths: list[Optional[Path]] = [None] * n
481+
task_done = [threading.Event() for _ in range(n)]
482+
decision_ready = [threading.Event() for _ in range(n)]
483+
decisions = ["abort"] * n
484+
485+
def _run_candidate(index: int) -> None:
486+
result = SpeculationResult(branch_index=index, success=False)
487+
try:
488+
with workspace.branch(
489+
f"tournament_{index}", on_success=None, on_error=None
490+
) as b:
491+
result.branch_path = b.path
492+
branch_paths[index] = b.path
493+
try:
494+
success = self._task(b.path, index)
495+
result.success = bool(success)
496+
result.return_value = success
497+
except Exception as e:
498+
result.exception = e
499+
500+
results[index] = result
501+
task_done[index].set()
502+
503+
decision_ready[index].wait()
504+
505+
if decisions[index] == "commit":
506+
b.commit()
507+
else:
508+
b.abort()
509+
510+
except Exception as e:
511+
result.exception = e
512+
results[index] = result
513+
task_done[index].set()
514+
515+
with ThreadPoolExecutor(max_workers=n) as pool:
516+
futures = [pool.submit(_run_candidate, i) for i in range(n)]
517+
518+
# Wait for all tasks to finish
519+
deadline = (
520+
time.monotonic() + self._timeout
521+
if self._timeout is not None
522+
else None
523+
)
524+
for ev in task_done:
525+
remaining = (
526+
max(0, deadline - time.monotonic())
527+
if deadline is not None
528+
else None
529+
)
530+
ev.wait(timeout=remaining)
531+
532+
# Filter to successful survivors
533+
survivors = [
534+
i for i, r in enumerate(results)
535+
if r is not None and r.success
536+
]
537+
538+
winner_idx: Optional[int] = None
539+
if len(survivors) == 1:
540+
winner_idx = survivors[0]
541+
elif len(survivors) > 1:
542+
winner_idx = self._run_bracket(
543+
survivors, branch_paths, self._judge
544+
)
545+
546+
if winner_idx is not None:
547+
decisions[winner_idx] = "commit"
548+
549+
# Release all threads
550+
for ev in decision_ready:
551+
ev.set()
552+
553+
for f in futures:
554+
f.result()
555+
556+
committed = winner_idx is not None
557+
winner = results[winner_idx] if winner_idx is not None else None
558+
all_results = [
559+
r if r is not None else SpeculationResult(branch_index=i, success=False)
560+
for i, r in enumerate(results)
561+
]
562+
563+
return SpeculationOutcome(
564+
winner=winner,
565+
all_results=all_results,
566+
committed=committed,
567+
)

tests/test_speculate.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from branching.core.base import FSBackend
99
from branching.core.workspace import Workspace
1010
from branching.agent.speculate import Speculate
11-
from branching.agent.patterns import BestOfN, Reflexion, TreeOfThoughts
11+
from branching.agent.patterns import BestOfN, Reflexion, TreeOfThoughts, Tournament
1212
from branching.agent.result import SpeculationResult, SpeculationOutcome
1313

1414

@@ -381,6 +381,123 @@ def good(path):
381381
assert outcome.all_results[0].exception is not None
382382

383383

384+
class TestTournament:
385+
def test_basic_bracket(self):
386+
"""4 candidates, judge always picks second → candidate 3 wins."""
387+
ws = _make_workspace()
388+
389+
def task(path: Path, index: int) -> bool:
390+
return True
391+
392+
def judge(path_a: Path, path_b: Path) -> int:
393+
return 1 # always pick b
394+
395+
outcome = Tournament(task, n=4, judge=judge)(ws)
396+
assert outcome.committed
397+
assert outcome.winner is not None
398+
# Bracket: (0v1→1), (2v3→3), (1v3→3)
399+
assert outcome.winner.branch_index == 3
400+
assert len(outcome.all_results) == 4
401+
402+
def test_all_fail(self):
403+
"""No survivors means nothing committed."""
404+
ws = _make_workspace()
405+
406+
def task(path: Path, index: int) -> bool:
407+
return False
408+
409+
def judge(path_a, path_b):
410+
raise AssertionError("judge should not be called")
411+
412+
outcome = Tournament(task, n=3, judge=judge)(ws)
413+
assert not outcome.committed
414+
assert outcome.winner is None
415+
assert len(outcome.all_results) == 3
416+
417+
def test_single_survivor(self):
418+
"""Only 1 succeeds → auto-wins without judge call."""
419+
ws = _make_workspace()
420+
judge_calls = []
421+
422+
def task(path: Path, index: int) -> bool:
423+
return index == 2
424+
425+
def judge(path_a, path_b):
426+
judge_calls.append(1)
427+
return 0
428+
429+
outcome = Tournament(task, n=4, judge=judge)(ws)
430+
assert outcome.committed
431+
assert outcome.winner.branch_index == 2
432+
assert len(judge_calls) == 0
433+
434+
def test_odd_candidates(self):
435+
"""3 candidates: one gets a bye in round 1."""
436+
ws = _make_workspace()
437+
438+
def task(path: Path, index: int) -> bool:
439+
return True
440+
441+
def judge(path_a: Path, path_b: Path) -> int:
442+
return 0 # always pick a
443+
444+
outcome = Tournament(task, n=3, judge=judge)(ws)
445+
assert outcome.committed
446+
# Bracket: (0v1→0), bye 2, then (0v2→0)
447+
assert outcome.winner.branch_index == 0
448+
449+
def test_commits_exactly_one(self):
450+
"""Only the winner should be committed; all others aborted."""
451+
ws = _make_workspace()
452+
453+
def task(path: Path, index: int) -> bool:
454+
return True
455+
456+
def judge(path_a, path_b):
457+
return 0
458+
459+
outcome = Tournament(task, n=4, judge=judge)(ws)
460+
assert outcome.committed
461+
assert len(MockFSBackend._commits) == 1
462+
assert len(MockFSBackend._aborts) == 3
463+
464+
def test_runs_in_parallel(self):
465+
"""Verify candidates actually run concurrently."""
466+
import time
467+
ws = _make_workspace()
468+
start = time.monotonic()
469+
470+
def task(path: Path, index: int) -> bool:
471+
time.sleep(0.2)
472+
return True
473+
474+
def judge(path_a, path_b):
475+
return 0
476+
477+
outcome = Tournament(task, n=3, judge=judge)(ws)
478+
elapsed = time.monotonic() - start
479+
assert outcome.committed
480+
# 3 tasks @ 0.2s each; parallel ~0.2s, sequential ~0.6s
481+
assert elapsed < 0.5
482+
483+
def test_exception_in_candidate(self):
484+
"""Exception in one candidate → eliminated, others proceed."""
485+
ws = _make_workspace()
486+
487+
def task(path: Path, index: int) -> bool:
488+
if index == 0:
489+
raise RuntimeError("boom")
490+
return True
491+
492+
def judge(path_a, path_b):
493+
return 0
494+
495+
outcome = Tournament(task, n=3, judge=judge)(ws)
496+
assert outcome.committed
497+
assert outcome.all_results[0].exception is not None
498+
assert outcome.winner.branch_index != 0
499+
500+
384501
class TestSpeculationResult:
385502
def test_dataclass(self):
386503
r = SpeculationResult(branch_index=0, success=True, score=0.95)

0 commit comments

Comments
 (0)