1010from typing import Callable , Optional , Sequence
1111
1212from ..core .workspace import Workspace
13- from ..process .runner import run_in_process
13+ from ..process .runner import _default_runner , Runner
1414from ..exceptions import ConflictError
1515from .result import SpeculationResult , SpeculationOutcome
1616
@@ -51,12 +51,14 @@ def __init__(
5151 evaluate : Callable [[Path ], float ] | None = None ,
5252 timeout : float | None = None ,
5353 commit : bool = True ,
54+ runner : Runner | None = None ,
5455 ):
5556 self ._candidates = list (candidates )
5657 self ._scores = list (scores ) if scores is not None else None
5758 self ._evaluate = evaluate
5859 self ._timeout = timeout
5960 self ._commit = commit
61+ self ._runner = runner or _default_runner
6062
6163 def _score (self , ret , path , index ):
6264 """Parse candidate return and apply optional evaluator."""
@@ -87,9 +89,8 @@ def _run_candidate(index: int) -> None:
8789 ) as b :
8890 result .branch_path = b .path
8991 try :
90- ret = run_in_process (
92+ ret = self . _runner (
9193 self ._candidates [index ], (b .path ,),
92- workspace = b .path ,
9394 )
9495 success , score = self ._score (ret , b .path , index )
9596 result .success = success
@@ -182,17 +183,21 @@ def __init__(
182183 max_retries : int = 3 ,
183184 * ,
184185 critique : Optional [Callable [[Path ], str ]] = None ,
186+ runner : Runner | None = None ,
185187 ):
186188 """
187189 Args:
188190 task: Callable(path, attempt, feedback) -> success.
189191 feedback is None on first attempt, critique output thereafter.
190192 max_retries: Maximum number of attempts.
191193 critique: Optional callable(path) -> feedback_string.
194+ runner: Execution strategy for candidates. Default forks via
195+ BranchContext. Pass a sandlock runner for confinement.
192196 """
193197 self ._task = task
194198 self ._max_retries = max_retries
195199 self ._critique = critique
200+ self ._runner = runner or _default_runner
196201
197202 def __call__ (self , workspace : Workspace ) -> SpeculationOutcome :
198203 return self ._run (workspace )
@@ -211,9 +216,8 @@ def _run(self, workspace: Workspace) -> SpeculationOutcome:
211216 branch_name , on_success = None , on_error = "abort"
212217 ) as b :
213218 result .branch_path = b .path
214- success = run_in_process (
219+ success = self . _runner (
215220 self ._task , (b .path , attempt , feedback ),
216- workspace = b .path ,
217221 )
218222 result .success = bool (success )
219223 result .return_value = success
@@ -287,6 +291,7 @@ def __init__(
287291 ] | None = None ,
288292 max_depth : int = 1 ,
289293 timeout : float | None = None ,
294+ runner : Runner | None = None ,
290295 ):
291296 """
292297 Args:
@@ -299,12 +304,14 @@ def __init__(
299304 Only used when max_depth > 1.
300305 max_depth: Maximum exploration depth (1 = single level).
301306 timeout: Per-level timeout in seconds.
307+ runner: Execution strategy for candidates.
302308 """
303309 self ._strategies = list (strategies )
304310 self ._evaluate = evaluate
305311 self ._expand = expand
306312 self ._max_depth = max_depth
307313 self ._timeout = timeout
314+ self ._runner = runner or _default_runner
308315
309316 def __call__ (self , workspace : Workspace ) -> SpeculationOutcome :
310317 if self ._expand is None or self ._max_depth <= 1 :
@@ -338,9 +345,8 @@ def _run(index: int) -> None:
338345 ) as b :
339346 result .branch_path = b .path
340347 try :
341- ret = run_in_process (
348+ ret = self . _runner (
342349 strategies [index ], (b .path ,),
343- workspace = b .path ,
344350 )
345351 if isinstance (ret , (tuple , list )):
346352 success , score = ret
@@ -506,13 +512,15 @@ def __init__(
506512 beam_width : int = 3 ,
507513 max_depth : int = 2 ,
508514 timeout : float | None = None ,
515+ runner : Runner | None = None ,
509516 ):
510517 self ._strategies = list (strategies )
511518 self ._expand = expand
512519 self ._evaluate = evaluate
513520 self ._beam_width = beam_width
514521 self ._max_depth = max_depth
515522 self ._timeout = timeout
523+ self ._runner = runner or _default_runner
516524
517525 def _score (self , ret , path ):
518526 """Parse strategy return and apply optional evaluator."""
@@ -563,9 +571,8 @@ def _beam_worker(index: int) -> None:
563571 result .branch_path = b .path
564572 beam_branches [index ] = b
565573 try :
566- ret = run_in_process (
574+ ret = self . _runner (
567575 self ._strategies [index ], (b .path ,),
568- workspace = b .path ,
569576 )
570577 result .success , result .score = self ._score (
571578 ret , b .path
@@ -659,9 +666,8 @@ def _sub_worker(idx: int, _d: int = _depth) -> None:
659666 ) as sb :
660667 result .branch_path = sb .path
661668 try :
662- ret = run_in_process (
669+ ret = self . _runner (
663670 strategy , (sb .path ,),
664- workspace = sb .path ,
665671 )
666672 result .success , result .score = self ._score (
667673 ret , sb .path
@@ -775,6 +781,7 @@ def __init__(
775781 * ,
776782 judge : Callable [[Path , Path ], int ],
777783 timeout : float | None = None ,
784+ runner : Runner | None = None ,
778785 ):
779786 """
780787 Args:
@@ -784,10 +791,12 @@ def __init__(
784791 judge: Callable(path_a, path_b) -> 0 (a wins) or 1 (b wins).
785792 Compares two candidates' branches during elimination.
786793 timeout: Overall timeout in seconds.
794+ runner: Execution strategy for candidates.
787795 """
788796 self ._candidates = list (candidates )
789797 self ._judge = judge
790798 self ._timeout = timeout
799+ self ._runner = runner or _default_runner
791800
792801 @staticmethod
793802 def _run_bracket (
@@ -830,9 +839,8 @@ def _run_candidate(index: int) -> None:
830839 result .branch_path = b .path
831840 branch_paths [index ] = b .path
832841 try :
833- success = run_in_process (
842+ success = self . _runner (
834843 self ._candidates [index ], (b .path ,),
835- workspace = b .path ,
836844 )
837845 result .success = bool (success )
838846 result .return_value = success
@@ -940,6 +948,7 @@ def __init__(
940948 fan_out : Sequence [int ] = (1 , 2 , 4 ),
941949 timeout : float | None = None ,
942950 wave_timeout : float | None = None ,
951+ runner : Runner | None = None ,
943952 ):
944953 """
945954 Args:
@@ -955,11 +964,13 @@ def __init__(
955964 ``len(fan_out)``.
956965 timeout: Overall timeout in seconds across all waves.
957966 wave_timeout: Per-wave timeout in seconds.
967+ runner: Execution strategy for candidates.
958968 """
959969 self ._task = task
960970 self ._fan_out = list (fan_out )
961971 self ._timeout = timeout
962972 self ._wave_timeout = wave_timeout
973+ self ._runner = runner or _default_runner
963974
964975 def __call__ (self , workspace : Workspace ) -> SpeculationOutcome :
965976 return self ._run (workspace )
@@ -1123,14 +1134,13 @@ def _run_task(
11231134 feedback : list [str ],
11241135 timeout : Optional [float ] = None ,
11251136 ) -> tuple [bool , str ]:
1126- """Run the task in a forked child ."""
1137+ """Run the task via the configured runner ."""
11271138 from ..exceptions import ProcessBranchError
11281139
11291140 try :
1130- ret = run_in_process (
1141+ ret = self . _runner (
11311142 self ._task ,
11321143 (path , feedback ),
1133- workspace = path ,
11341144 timeout = timeout ,
11351145 )
11361146 if isinstance (ret , (tuple , list )):
0 commit comments