Skip to content

Commit 0f21d49

Browse files
committed
Add BeamSearch pattern: multi-level beam search with top-K pruning
Inspired by the EnCompass paper, BeamSearch keeps the top-K branches alive at each depth level, interpolating between BestOfN (all parallel, one level) and TreeOfThoughts (one winner per level). Pruning happens globally across all beams' candidates at each level. Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 820baec commit 0f21d49

3 files changed

Lines changed: 496 additions & 1 deletion

File tree

src/branching/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"BestOfN",
5151
"Reflexion",
5252
"TreeOfThoughts",
53+
"BeamSearch",
5354
"Tournament",
5455
# Results
5556
"SpeculationResult",
@@ -83,6 +84,7 @@
8384
"BestOfN": ".agent.patterns",
8485
"Reflexion": ".agent.patterns",
8586
"TreeOfThoughts": ".agent.patterns",
87+
"BeamSearch": ".agent.patterns",
8688
"Tournament": ".agent.patterns",
8789
# Results
8890
"SpeculationResult": ".agent.result",

src/branching/agent/patterns.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,280 @@ def _multi_level(self, workspace: Workspace) -> SpeculationOutcome:
420420
)
421421

422422

423+
class BeamSearch:
424+
"""Multi-level beam search: keep top-K branches alive at each depth.
425+
426+
Interpolates between BestOfN (all parallel, one level) and
427+
TreeOfThoughts multi-level (one winner per level). Multiple beams
428+
survive each level, each accumulating its own state independently.
429+
Pruning happens globally across all beams' candidates.
430+
431+
Strategies return ``bool`` or ``(bool, float)`` — if a bare bool,
432+
the score defaults to 1.0 for success, 0.0 for failure.
433+
434+
Example:
435+
outcome = BeamSearch(
436+
[strat_a, strat_b, strat_c, strat_d],
437+
expand=lambda path, depth: [refine_x, refine_y],
438+
beam_width=2,
439+
max_depth=3,
440+
)(workspace)
441+
"""
442+
443+
def __init__(
444+
self,
445+
strategies: Sequence[Callable[[Path], bool | tuple[bool, float]]],
446+
*,
447+
expand: Callable[
448+
[Path, int],
449+
Sequence[Callable[[Path], bool | tuple[bool, float]]],
450+
],
451+
evaluate: Callable[[Path], float] | None = None,
452+
beam_width: int = 3,
453+
max_depth: int = 2,
454+
timeout: float | None = None,
455+
):
456+
self._strategies = list(strategies)
457+
self._expand = expand
458+
self._evaluate = evaluate
459+
self._beam_width = beam_width
460+
self._max_depth = max_depth
461+
self._timeout = timeout
462+
463+
def _score(self, ret, path):
464+
"""Parse strategy return and apply optional evaluator."""
465+
if isinstance(ret, tuple):
466+
success, score = ret
467+
else:
468+
success = bool(ret)
469+
score = 1.0 if success else 0.0
470+
if self._evaluate and success:
471+
score = self._evaluate(path)
472+
return bool(success), score
473+
474+
def _top_k(self, results, k):
475+
"""Return indices of top-k successful results by score."""
476+
scored = [
477+
(i, r) for i, r in enumerate(results)
478+
if r is not None and r.success
479+
]
480+
scored.sort(key=lambda x: x[1].score, reverse=True)
481+
return [i for i, _ in scored[:k]]
482+
483+
def __call__(self, workspace: Workspace) -> SpeculationOutcome:
484+
n = len(self._strategies)
485+
if n == 0:
486+
return SpeculationOutcome()
487+
488+
K = self._beam_width
489+
all_results: list[SpeculationResult] = []
490+
491+
# -- Level 0: create beam branches from workspace ----------------
492+
beam_branches: list[Optional[object]] = [None] * n
493+
level0_results: list[Optional[SpeculationResult]] = [None] * n
494+
task_done = [threading.Event() for _ in range(n)]
495+
final_decision = [threading.Event() for _ in range(n)]
496+
final_actions = ["abort"] * n
497+
498+
def _beam_worker(index: int) -> None:
499+
result = SpeculationResult(branch_index=index, success=False)
500+
try:
501+
with workspace.branch(
502+
f"beam_{index}", on_success=None, on_error=None
503+
) as b:
504+
result.branch_path = b.path
505+
beam_branches[index] = b
506+
try:
507+
ret = self._strategies[index](b.path)
508+
result.success, result.score = self._score(
509+
ret, b.path
510+
)
511+
result.return_value = ret
512+
except Exception as e:
513+
result.exception = e
514+
515+
level0_results[index] = result
516+
task_done[index].set()
517+
518+
# Hold branch open until final decision
519+
final_decision[index].wait()
520+
if final_actions[index] == "commit":
521+
b.commit()
522+
else:
523+
b.abort()
524+
except Exception as e:
525+
result.exception = e
526+
level0_results[index] = result
527+
task_done[index].set()
528+
529+
with ThreadPoolExecutor(max_workers=n) as pool:
530+
futures = [pool.submit(_beam_worker, i) for i in range(n)]
531+
532+
deadline = (
533+
time.monotonic() + self._timeout
534+
if self._timeout is not None
535+
else None
536+
)
537+
for ev in task_done:
538+
remaining = (
539+
max(0, deadline - time.monotonic())
540+
if deadline is not None
541+
else None
542+
)
543+
ev.wait(timeout=remaining)
544+
545+
# Select top-K beams
546+
survivors = set(self._top_k(level0_results, K))
547+
548+
beam_scores: dict[int, float] = {}
549+
for i in survivors:
550+
beam_scores[i] = level0_results[i].score
551+
552+
all_results.extend(
553+
r if r is not None
554+
else SpeculationResult(branch_index=i, success=False)
555+
for i, r in enumerate(level0_results)
556+
)
557+
558+
# Abort non-survivors immediately
559+
for i in range(n):
560+
if i not in survivors:
561+
final_actions[i] = "abort"
562+
final_decision[i].set()
563+
564+
# -- Deeper levels -------------------------------------------
565+
for depth in range(1, self._max_depth):
566+
if not survivors:
567+
break
568+
569+
sub_tasks: list[tuple[int, int, Callable]] = []
570+
for beam_idx in sorted(survivors):
571+
sub_strats = list(
572+
self._expand(beam_branches[beam_idx].path, depth)
573+
)
574+
for si, strat in enumerate(sub_strats):
575+
sub_tasks.append((beam_idx, si, strat))
576+
577+
if not sub_tasks:
578+
break
579+
580+
m = len(sub_tasks)
581+
sub_results: list[Optional[SpeculationResult]] = [None] * m
582+
sub_done = [threading.Event() for _ in range(m)]
583+
sub_decision_ready = [threading.Event() for _ in range(m)]
584+
sub_decisions = ["abort"] * m
585+
_depth = depth # capture value for closure
586+
587+
def _sub_worker(idx: int, _d: int = _depth) -> None:
588+
beam_idx, strat_idx, strategy = sub_tasks[idx]
589+
result = SpeculationResult(
590+
branch_index=idx, success=False
591+
)
592+
try:
593+
parent = beam_branches[beam_idx]
594+
with parent.branch(
595+
f"beam_{beam_idx}_d{_d}_{strat_idx}",
596+
on_success=None,
597+
on_error=None,
598+
) as sb:
599+
result.branch_path = sb.path
600+
try:
601+
ret = strategy(sb.path)
602+
result.success, result.score = self._score(
603+
ret, sb.path
604+
)
605+
result.return_value = ret
606+
except Exception as e:
607+
result.exception = e
608+
609+
sub_results[idx] = result
610+
sub_done[idx].set()
611+
sub_decision_ready[idx].wait()
612+
613+
if sub_decisions[idx] == "commit":
614+
sb.commit()
615+
else:
616+
sb.abort()
617+
except Exception as e:
618+
result.exception = e
619+
sub_results[idx] = result
620+
sub_done[idx].set()
621+
622+
with ThreadPoolExecutor(max_workers=m) as sub_pool:
623+
sub_futures = [
624+
sub_pool.submit(_sub_worker, i) for i in range(m)
625+
]
626+
627+
for ev in sub_done:
628+
remaining = (
629+
max(0, deadline - time.monotonic())
630+
if deadline is not None
631+
else None
632+
)
633+
ev.wait(timeout=remaining)
634+
635+
top_k_indices = set(self._top_k(sub_results, K))
636+
637+
all_results.extend(
638+
r if r is not None
639+
else SpeculationResult(branch_index=i, success=False)
640+
for i, r in enumerate(sub_results)
641+
)
642+
643+
for i in top_k_indices:
644+
sub_decisions[i] = "commit"
645+
for ev in sub_decision_ready:
646+
ev.set()
647+
for f in sub_futures:
648+
f.result()
649+
650+
# Update surviving beams
651+
beams_alive: dict[int, float] = {}
652+
for i in top_k_indices:
653+
beam_idx = sub_tasks[i][0]
654+
score = sub_results[i].score
655+
if (
656+
beam_idx not in beams_alive
657+
or score > beams_alive[beam_idx]
658+
):
659+
beams_alive[beam_idx] = score
660+
661+
for beam_idx in survivors - set(beams_alive):
662+
final_actions[beam_idx] = "abort"
663+
final_decision[beam_idx].set()
664+
665+
survivors = set(beams_alive)
666+
beam_scores.update(beams_alive)
667+
668+
# -- Final: pick best surviving beam -------------------------
669+
winner = None
670+
if survivors:
671+
best = max(survivors, key=lambda i: beam_scores[i])
672+
final_actions[best] = "commit"
673+
winner = SpeculationResult(
674+
branch_index=best,
675+
success=True,
676+
score=beam_scores[best],
677+
branch_path=(
678+
level0_results[best].branch_path
679+
if level0_results[best] is not None
680+
else None
681+
),
682+
)
683+
684+
# Release all remaining beam threads
685+
for i in range(n):
686+
final_decision[i].set()
687+
for f in futures:
688+
f.result()
689+
690+
return SpeculationOutcome(
691+
winner=winner,
692+
all_results=all_results,
693+
committed=winner is not None,
694+
)
695+
696+
423697
class Tournament:
424698
"""Pairwise elimination bracket: generate N candidates, compare
425699
pairwise via a judge function, commit the final winner.

0 commit comments

Comments
 (0)