Skip to content

Commit b077856

Browse files
committed
Add runner parameter to all agent patterns for pluggable execution
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent ffae1e7 commit b077856

6 files changed

Lines changed: 51 additions & 25 deletions

File tree

src/branching/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"Branch",
4545
# Process
4646
"BranchContext",
47+
"Runner",
4748
# Agent patterns
4849
"Speculate",
4950
"BestOfN",
@@ -78,6 +79,7 @@
7879
"Branch": ".core.branch",
7980
# Process layer
8081
"BranchContext": ".process.context",
82+
"Runner": ".process.runner",
8183
# Agent layer
8284
"Speculate": ".agent.speculate",
8385
"BestOfN": ".agent.patterns",

src/branching/agent/patterns.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Callable, Optional, Sequence
1111

1212
from ..core.workspace import Workspace
13-
from ..process.runner import run_in_process
13+
from ..process.runner import _default_runner, Runner
1414
from ..exceptions import ConflictError
1515
from .result import SpeculationResult, SpeculationOutcome
1616

@@ -51,12 +51,14 @@ def __init__(
5151
evaluate: Callable[[Path], float] | None = None,
5252
timeout: float | None = None,
5353
commit: bool = True,
54+
runner: Runner | None = None,
5455
):
5556
self._candidates = list(candidates)
5657
self._scores = list(scores) if scores is not None else None
5758
self._evaluate = evaluate
5859
self._timeout = timeout
5960
self._commit = commit
61+
self._runner = runner or _default_runner
6062

6163
def _score(self, ret, path, index):
6264
"""Parse candidate return and apply optional evaluator."""
@@ -87,9 +89,8 @@ def _run_candidate(index: int) -> None:
8789
) as b:
8890
result.branch_path = b.path
8991
try:
90-
ret = run_in_process(
92+
ret = self._runner(
9193
self._candidates[index], (b.path,),
92-
workspace=b.path,
9394
)
9495
success, score = self._score(ret, b.path, index)
9596
result.success = success
@@ -182,17 +183,21 @@ def __init__(
182183
max_retries: int = 3,
183184
*,
184185
critique: Optional[Callable[[Path], str]] = None,
186+
runner: Runner | None = None,
185187
):
186188
"""
187189
Args:
188190
task: Callable(path, attempt, feedback) -> success.
189191
feedback is None on first attempt, critique output thereafter.
190192
max_retries: Maximum number of attempts.
191193
critique: Optional callable(path) -> feedback_string.
194+
runner: Execution strategy for candidates. Default forks via
195+
BranchContext. Pass a sandlock runner for confinement.
192196
"""
193197
self._task = task
194198
self._max_retries = max_retries
195199
self._critique = critique
200+
self._runner = runner or _default_runner
196201

197202
def __call__(self, workspace: Workspace) -> SpeculationOutcome:
198203
return self._run(workspace)
@@ -211,9 +216,8 @@ def _run(self, workspace: Workspace) -> SpeculationOutcome:
211216
branch_name, on_success=None, on_error="abort"
212217
) as b:
213218
result.branch_path = b.path
214-
success = run_in_process(
219+
success = self._runner(
215220
self._task, (b.path, attempt, feedback),
216-
workspace=b.path,
217221
)
218222
result.success = bool(success)
219223
result.return_value = success
@@ -287,6 +291,7 @@ def __init__(
287291
] | None = None,
288292
max_depth: int = 1,
289293
timeout: float | None = None,
294+
runner: Runner | None = None,
290295
):
291296
"""
292297
Args:
@@ -299,12 +304,14 @@ def __init__(
299304
Only used when max_depth > 1.
300305
max_depth: Maximum exploration depth (1 = single level).
301306
timeout: Per-level timeout in seconds.
307+
runner: Execution strategy for candidates.
302308
"""
303309
self._strategies = list(strategies)
304310
self._evaluate = evaluate
305311
self._expand = expand
306312
self._max_depth = max_depth
307313
self._timeout = timeout
314+
self._runner = runner or _default_runner
308315

309316
def __call__(self, workspace: Workspace) -> SpeculationOutcome:
310317
if self._expand is None or self._max_depth <= 1:
@@ -338,9 +345,8 @@ def _run(index: int) -> None:
338345
) as b:
339346
result.branch_path = b.path
340347
try:
341-
ret = run_in_process(
348+
ret = self._runner(
342349
strategies[index], (b.path,),
343-
workspace=b.path,
344350
)
345351
if isinstance(ret, (tuple, list)):
346352
success, score = ret
@@ -506,13 +512,15 @@ def __init__(
506512
beam_width: int = 3,
507513
max_depth: int = 2,
508514
timeout: float | None = None,
515+
runner: Runner | None = None,
509516
):
510517
self._strategies = list(strategies)
511518
self._expand = expand
512519
self._evaluate = evaluate
513520
self._beam_width = beam_width
514521
self._max_depth = max_depth
515522
self._timeout = timeout
523+
self._runner = runner or _default_runner
516524

517525
def _score(self, ret, path):
518526
"""Parse strategy return and apply optional evaluator."""
@@ -563,9 +571,8 @@ def _beam_worker(index: int) -> None:
563571
result.branch_path = b.path
564572
beam_branches[index] = b
565573
try:
566-
ret = run_in_process(
574+
ret = self._runner(
567575
self._strategies[index], (b.path,),
568-
workspace=b.path,
569576
)
570577
result.success, result.score = self._score(
571578
ret, b.path
@@ -659,9 +666,8 @@ def _sub_worker(idx: int, _d: int = _depth) -> None:
659666
) as sb:
660667
result.branch_path = sb.path
661668
try:
662-
ret = run_in_process(
669+
ret = self._runner(
663670
strategy, (sb.path,),
664-
workspace=sb.path,
665671
)
666672
result.success, result.score = self._score(
667673
ret, sb.path
@@ -775,6 +781,7 @@ def __init__(
775781
*,
776782
judge: Callable[[Path, Path], int],
777783
timeout: float | None = None,
784+
runner: Runner | None = None,
778785
):
779786
"""
780787
Args:
@@ -784,10 +791,12 @@ def __init__(
784791
judge: Callable(path_a, path_b) -> 0 (a wins) or 1 (b wins).
785792
Compares two candidates' branches during elimination.
786793
timeout: Overall timeout in seconds.
794+
runner: Execution strategy for candidates.
787795
"""
788796
self._candidates = list(candidates)
789797
self._judge = judge
790798
self._timeout = timeout
799+
self._runner = runner or _default_runner
791800

792801
@staticmethod
793802
def _run_bracket(
@@ -830,9 +839,8 @@ def _run_candidate(index: int) -> None:
830839
result.branch_path = b.path
831840
branch_paths[index] = b.path
832841
try:
833-
success = run_in_process(
842+
success = self._runner(
834843
self._candidates[index], (b.path,),
835-
workspace=b.path,
836844
)
837845
result.success = bool(success)
838846
result.return_value = success
@@ -940,6 +948,7 @@ def __init__(
940948
fan_out: Sequence[int] = (1, 2, 4),
941949
timeout: float | None = None,
942950
wave_timeout: float | None = None,
951+
runner: Runner | None = None,
943952
):
944953
"""
945954
Args:
@@ -955,11 +964,13 @@ def __init__(
955964
``len(fan_out)``.
956965
timeout: Overall timeout in seconds across all waves.
957966
wave_timeout: Per-wave timeout in seconds.
967+
runner: Execution strategy for candidates.
958968
"""
959969
self._task = task
960970
self._fan_out = list(fan_out)
961971
self._timeout = timeout
962972
self._wave_timeout = wave_timeout
973+
self._runner = runner or _default_runner
963974

964975
def __call__(self, workspace: Workspace) -> SpeculationOutcome:
965976
return self._run(workspace)
@@ -1123,14 +1134,13 @@ def _run_task(
11231134
feedback: list[str],
11241135
timeout: Optional[float] = None,
11251136
) -> tuple[bool, str]:
1126-
"""Run the task in a forked child."""
1137+
"""Run the task via the configured runner."""
11271138
from ..exceptions import ProcessBranchError
11281139

11291140
try:
1130-
ret = run_in_process(
1141+
ret = self._runner(
11311142
self._task,
11321143
(path, feedback),
1133-
workspace=path,
11341144
timeout=timeout,
11351145
)
11361146
if isinstance(ret, (tuple, list)):

src/branching/agent/speculate.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import threading
1010

1111
from ..core.workspace import Workspace
12+
from ..process.runner import _default_runner, Runner
1213
from ..exceptions import ConflictError
1314
from .result import SpeculationResult, SpeculationOutcome
1415

@@ -32,6 +33,7 @@ def __init__(
3233
first_wins: bool = True,
3334
max_parallel: int | None = None,
3435
timeout: float | None = None,
36+
runner: Runner | None = None,
3537
):
3638
"""
3739
Args:
@@ -41,11 +43,14 @@ def __init__(
4143
abort siblings. If False, run all and commit the first success.
4244
max_parallel: Maximum parallel workers (default: len(candidates)).
4345
timeout: Overall timeout in seconds for all candidates.
46+
runner: Execution strategy for candidates. Default forks via
47+
BranchContext. Pass a sandlock runner for confinement.
4448
"""
4549
self._candidates = list(candidates)
4650
self._first_wins = first_wins
4751
self._max_parallel = max_parallel or len(self._candidates)
4852
self._timeout = timeout
53+
self._runner = runner or _default_runner
4954

5055
def __call__(self, workspace: Workspace) -> SpeculationOutcome:
5156
return self._run(workspace)
@@ -140,8 +145,7 @@ def _run_candidate(index: int) -> SpeculationResult:
140145
)
141146

142147
def _run_in_branch(self, path: Path, index: int) -> bool:
143-
"""Run a candidate in a forked child."""
144-
from ..process.runner import run_in_process
148+
"""Run a candidate via the configured runner."""
145149
from ..exceptions import ProcessBranchError
146150

147151
per_candidate = (
@@ -151,10 +155,9 @@ def _run_in_branch(self, path: Path, index: int) -> bool:
151155
)
152156

153157
try:
154-
result = run_in_process(
158+
result = self._runner(
155159
self._candidates[index],
156160
(path,),
157-
workspace=path,
158161
timeout=per_candidate,
159162
)
160163
return bool(result)

src/branching/process/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .context import BranchContext
2+
from .runner import Runner, _default_runner
23

3-
__all__ = ["BranchContext"]
4+
__all__ = ["BranchContext", "Runner", "_default_runner"]

src/branching/process/runner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111
from ..exceptions import ProcessBranchError
1212
from .context import BranchContext
1313

14+
# Runner protocol: (fn, args, *, timeout) -> Any
15+
# Used by agent patterns to execute candidates. Swappable to use
16+
# sandlock or run candidates directly without forking.
17+
Runner = Callable[..., Any]
18+
19+
20+
def _default_runner(fn: Callable, args: tuple, *, timeout: float | None = None) -> Any:
21+
"""Default runner: fork via BranchContext, pass result via pipe."""
22+
return run_in_process(fn, args, workspace=args[0], timeout=timeout)
23+
1424

1525
def run_in_process(
1626
fn: Callable,

tests/test_speculate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414

1515
@pytest.fixture(autouse=True)
16-
def mock_run_in_process():
16+
def mock_runner():
1717
"""Run tasks in-process for unit tests (skip fork)."""
18-
def _run_inline(fn, args, workspace, **kwargs):
18+
def _run_inline(fn, args, **kwargs):
1919
return fn(*args)
20-
with patch("branching.agent.patterns.run_in_process", _run_inline), \
21-
patch("branching.process.runner.run_in_process", _run_inline):
20+
with patch("branching.agent.patterns._default_runner", _run_inline), \
21+
patch("branching.agent.speculate._default_runner", _run_inline):
2222
yield
2323

2424

0 commit comments

Comments
 (0)