@@ -420,6 +420,280 @@ def _multi_level(self, workspace: Workspace) -> SpeculationOutcome:
420420 )
421421
422422
423+ class BeamSearch :
424+ """Multi-level beam search: keep top-K branches alive at each depth.
425+
426+ Interpolates between BestOfN (all parallel, one level) and
427+ TreeOfThoughts multi-level (one winner per level). Multiple beams
428+ survive each level, each accumulating its own state independently.
429+ Pruning happens globally across all beams' candidates.
430+
431+ Strategies return ``bool`` or ``(bool, float)`` — if a bare bool,
432+ the score defaults to 1.0 for success, 0.0 for failure.
433+
434+ Example:
435+ outcome = BeamSearch(
436+ [strat_a, strat_b, strat_c, strat_d],
437+ expand=lambda path, depth: [refine_x, refine_y],
438+ beam_width=2,
439+ max_depth=3,
440+ )(workspace)
441+ """
442+
443+ def __init__ (
444+ self ,
445+ strategies : Sequence [Callable [[Path ], bool | tuple [bool , float ]]],
446+ * ,
447+ expand : Callable [
448+ [Path , int ],
449+ Sequence [Callable [[Path ], bool | tuple [bool , float ]]],
450+ ],
451+ evaluate : Callable [[Path ], float ] | None = None ,
452+ beam_width : int = 3 ,
453+ max_depth : int = 2 ,
454+ timeout : float | None = None ,
455+ ):
456+ self ._strategies = list (strategies )
457+ self ._expand = expand
458+ self ._evaluate = evaluate
459+ self ._beam_width = beam_width
460+ self ._max_depth = max_depth
461+ self ._timeout = timeout
462+
463+ def _score (self , ret , path ):
464+ """Parse strategy return and apply optional evaluator."""
465+ if isinstance (ret , tuple ):
466+ success , score = ret
467+ else :
468+ success = bool (ret )
469+ score = 1.0 if success else 0.0
470+ if self ._evaluate and success :
471+ score = self ._evaluate (path )
472+ return bool (success ), score
473+
474+ def _top_k (self , results , k ):
475+ """Return indices of top-k successful results by score."""
476+ scored = [
477+ (i , r ) for i , r in enumerate (results )
478+ if r is not None and r .success
479+ ]
480+ scored .sort (key = lambda x : x [1 ].score , reverse = True )
481+ return [i for i , _ in scored [:k ]]
482+
483+ def __call__ (self , workspace : Workspace ) -> SpeculationOutcome :
484+ n = len (self ._strategies )
485+ if n == 0 :
486+ return SpeculationOutcome ()
487+
488+ K = self ._beam_width
489+ all_results : list [SpeculationResult ] = []
490+
491+ # -- Level 0: create beam branches from workspace ----------------
492+ beam_branches : list [Optional [object ]] = [None ] * n
493+ level0_results : list [Optional [SpeculationResult ]] = [None ] * n
494+ task_done = [threading .Event () for _ in range (n )]
495+ final_decision = [threading .Event () for _ in range (n )]
496+ final_actions = ["abort" ] * n
497+
498+ def _beam_worker (index : int ) -> None :
499+ result = SpeculationResult (branch_index = index , success = False )
500+ try :
501+ with workspace .branch (
502+ f"beam_{ index } " , on_success = None , on_error = None
503+ ) as b :
504+ result .branch_path = b .path
505+ beam_branches [index ] = b
506+ try :
507+ ret = self ._strategies [index ](b .path )
508+ result .success , result .score = self ._score (
509+ ret , b .path
510+ )
511+ result .return_value = ret
512+ except Exception as e :
513+ result .exception = e
514+
515+ level0_results [index ] = result
516+ task_done [index ].set ()
517+
518+ # Hold branch open until final decision
519+ final_decision [index ].wait ()
520+ if final_actions [index ] == "commit" :
521+ b .commit ()
522+ else :
523+ b .abort ()
524+ except Exception as e :
525+ result .exception = e
526+ level0_results [index ] = result
527+ task_done [index ].set ()
528+
529+ with ThreadPoolExecutor (max_workers = n ) as pool :
530+ futures = [pool .submit (_beam_worker , i ) for i in range (n )]
531+
532+ deadline = (
533+ time .monotonic () + self ._timeout
534+ if self ._timeout is not None
535+ else None
536+ )
537+ for ev in task_done :
538+ remaining = (
539+ max (0 , deadline - time .monotonic ())
540+ if deadline is not None
541+ else None
542+ )
543+ ev .wait (timeout = remaining )
544+
545+ # Select top-K beams
546+ survivors = set (self ._top_k (level0_results , K ))
547+
548+ beam_scores : dict [int , float ] = {}
549+ for i in survivors :
550+ beam_scores [i ] = level0_results [i ].score
551+
552+ all_results .extend (
553+ r if r is not None
554+ else SpeculationResult (branch_index = i , success = False )
555+ for i , r in enumerate (level0_results )
556+ )
557+
558+ # Abort non-survivors immediately
559+ for i in range (n ):
560+ if i not in survivors :
561+ final_actions [i ] = "abort"
562+ final_decision [i ].set ()
563+
564+ # -- Deeper levels -------------------------------------------
565+ for depth in range (1 , self ._max_depth ):
566+ if not survivors :
567+ break
568+
569+ sub_tasks : list [tuple [int , int , Callable ]] = []
570+ for beam_idx in sorted (survivors ):
571+ sub_strats = list (
572+ self ._expand (beam_branches [beam_idx ].path , depth )
573+ )
574+ for si , strat in enumerate (sub_strats ):
575+ sub_tasks .append ((beam_idx , si , strat ))
576+
577+ if not sub_tasks :
578+ break
579+
580+ m = len (sub_tasks )
581+ sub_results : list [Optional [SpeculationResult ]] = [None ] * m
582+ sub_done = [threading .Event () for _ in range (m )]
583+ sub_decision_ready = [threading .Event () for _ in range (m )]
584+ sub_decisions = ["abort" ] * m
585+ _depth = depth # capture value for closure
586+
587+ def _sub_worker (idx : int , _d : int = _depth ) -> None :
588+ beam_idx , strat_idx , strategy = sub_tasks [idx ]
589+ result = SpeculationResult (
590+ branch_index = idx , success = False
591+ )
592+ try :
593+ parent = beam_branches [beam_idx ]
594+ with parent .branch (
595+ f"beam_{ beam_idx } _d{ _d } _{ strat_idx } " ,
596+ on_success = None ,
597+ on_error = None ,
598+ ) as sb :
599+ result .branch_path = sb .path
600+ try :
601+ ret = strategy (sb .path )
602+ result .success , result .score = self ._score (
603+ ret , sb .path
604+ )
605+ result .return_value = ret
606+ except Exception as e :
607+ result .exception = e
608+
609+ sub_results [idx ] = result
610+ sub_done [idx ].set ()
611+ sub_decision_ready [idx ].wait ()
612+
613+ if sub_decisions [idx ] == "commit" :
614+ sb .commit ()
615+ else :
616+ sb .abort ()
617+ except Exception as e :
618+ result .exception = e
619+ sub_results [idx ] = result
620+ sub_done [idx ].set ()
621+
622+ with ThreadPoolExecutor (max_workers = m ) as sub_pool :
623+ sub_futures = [
624+ sub_pool .submit (_sub_worker , i ) for i in range (m )
625+ ]
626+
627+ for ev in sub_done :
628+ remaining = (
629+ max (0 , deadline - time .monotonic ())
630+ if deadline is not None
631+ else None
632+ )
633+ ev .wait (timeout = remaining )
634+
635+ top_k_indices = set (self ._top_k (sub_results , K ))
636+
637+ all_results .extend (
638+ r if r is not None
639+ else SpeculationResult (branch_index = i , success = False )
640+ for i , r in enumerate (sub_results )
641+ )
642+
643+ for i in top_k_indices :
644+ sub_decisions [i ] = "commit"
645+ for ev in sub_decision_ready :
646+ ev .set ()
647+ for f in sub_futures :
648+ f .result ()
649+
650+ # Update surviving beams
651+ beams_alive : dict [int , float ] = {}
652+ for i in top_k_indices :
653+ beam_idx = sub_tasks [i ][0 ]
654+ score = sub_results [i ].score
655+ if (
656+ beam_idx not in beams_alive
657+ or score > beams_alive [beam_idx ]
658+ ):
659+ beams_alive [beam_idx ] = score
660+
661+ for beam_idx in survivors - set (beams_alive ):
662+ final_actions [beam_idx ] = "abort"
663+ final_decision [beam_idx ].set ()
664+
665+ survivors = set (beams_alive )
666+ beam_scores .update (beams_alive )
667+
668+ # -- Final: pick best surviving beam -------------------------
669+ winner = None
670+ if survivors :
671+ best = max (survivors , key = lambda i : beam_scores [i ])
672+ final_actions [best ] = "commit"
673+ winner = SpeculationResult (
674+ branch_index = best ,
675+ success = True ,
676+ score = beam_scores [best ],
677+ branch_path = (
678+ level0_results [best ].branch_path
679+ if level0_results [best ] is not None
680+ else None
681+ ),
682+ )
683+
684+ # Release all remaining beam threads
685+ for i in range (n ):
686+ final_decision [i ].set ()
687+ for f in futures :
688+ f .result ()
689+
690+ return SpeculationOutcome (
691+ winner = winner ,
692+ all_results = all_results ,
693+ committed = winner is not None ,
694+ )
695+
696+
423697class Tournament :
424698 """Pairwise elimination bracket: generate N candidates, compare
425699 pairwise via a judge function, commit the final winner.
0 commit comments