Skip to content

Commit 0111ad4

Browse files
committed
Add Cascaded speculation pattern with adaptive fan-out
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 3458029 commit 0111ad4

4 files changed

Lines changed: 549 additions & 4 deletions

File tree

README.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ raises, everything is rolled back - the workspace is untouched.
7676

7777
## Agent patterns
7878

79-
BranchContext ships with six high-level patterns that cover the most common
79+
BranchContext ships with seven high-level patterns that cover the most common
8080
agent workflows. Each is a callable class: instantiate with config, call with
8181
a workspace.
8282

@@ -239,6 +239,39 @@ def judge(path_a: Path, path_b: Path) -> int:
239239
outcome = Tournament(generate_patch, n=8, judge=judge)(ws)
240240
```
241241

242+
### Cascaded speculation (adaptive fan-out)
243+
244+
Start with one attempt. If it fails, widen to more parallel candidates,
245+
each informed by error context from prior failures. Repeat with increasing
246+
fan-out until one succeeds or all waves are exhausted.
247+
248+
Inspired by [Cascade Speculative Drafting](https://arxiv.org/abs/2312.11462),
249+
which applies the same start-cheap-escalate-on-failure principle to LLM
250+
token generation.
251+
252+
Use when most tasks succeed on the first try and you want to minimize
253+
wasted compute: coding agents where one LLM call usually works but
254+
occasionally needs retries with error feedback, test-fix loops where the
255+
error log from a failed attempt is the best guide for the next one, or
256+
any workload with variable difficulty where paying for N parallel branches
257+
upfront is wasteful.
258+
259+
```python
260+
from branching import Cascaded
261+
262+
def solve(path: Path, feedback: list[str]) -> tuple[bool, str]:
263+
result = run_agent(path, prior_errors=feedback)
264+
if result.tests_pass:
265+
return True, ""
266+
return False, result.error_output
267+
268+
outcome = Cascaded(solve, fan_out=(1, 2, 4), timeout=120)(ws)
269+
```
270+
271+
The task returns `(success, error_context)`. On failure, the error string
272+
is collected and passed as feedback to subsequent waves. On success, it is
273+
ignored. Empty error strings are silently dropped.
274+
242275
## Lower-level usage
243276

244277
The patterns above are built on two lower-level primitives you can use
@@ -320,7 +353,7 @@ outcome = BestOfN(scored_task, n=5, resource_limits=limits)(ws)
320353
```
321354

322355
All patterns accept `resource_limits`: `Speculate`, `BestOfN`, `Reflexion`,
323-
`TreeOfThoughts`, `BeamSearch`, and `Tournament`. Fields default to `None`
356+
`TreeOfThoughts`, `BeamSearch`, `Tournament`, and `Cascaded`. Fields default to `None`
324357
(unlimited). A `ResourceLimits()` with all `None` fields triggers process
325358
isolation without applying any limits.
326359

src/branching/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"TreeOfThoughts",
5454
"BeamSearch",
5555
"Tournament",
56+
"Cascaded",
5657
# Results
5758
"SpeculationResult",
5859
"SpeculationOutcome",
@@ -88,6 +89,7 @@
8889
"TreeOfThoughts": ".agent.patterns",
8990
"BeamSearch": ".agent.patterns",
9091
"Tournament": ".agent.patterns",
92+
"Cascaded": ".agent.patterns",
9193
# Results
9294
"SpeculationResult": ".agent.result",
9395
"SpeculationOutcome": ".agent.result",

src/branching/agent/patterns.py

Lines changed: 279 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
import threading
77
import time
8-
from concurrent.futures import ThreadPoolExecutor
8+
from concurrent.futures import ThreadPoolExecutor, as_completed
99
from pathlib import Path
1010
from typing import Callable, Optional, Sequence, TYPE_CHECKING
1111

1212
from ..core.workspace import Workspace
1313
from ..process.runner import run_in_process
14+
from ..exceptions import ConflictError
1415
from .result import SpeculationResult, SpeculationOutcome
1516

1617
if TYPE_CHECKING:
@@ -1101,3 +1102,280 @@ def _on_scope(sp: Path, _i: int = index) -> None:
11011102
all_results=all_results,
11021103
committed=committed,
11031104
)
1105+
1106+
1107+
class Cascaded:
1108+
"""Cascaded speculation: start narrow, widen on failure with feedback.
1109+
1110+
Runs successive waves of parallel candidates with increasing fan-out.
1111+
Wave 0 starts with few candidates (typically one). If all fail, the
1112+
next wave runs more candidates, each informed by error context from
1113+
prior failures. Within each wave, first-winner-commit semantics
1114+
apply: the first successful candidate commits and siblings are aborted.
1115+
1116+
The task callable receives ``(path, feedback)`` and returns
1117+
``(success, error_context)``. *feedback* is always a list of error
1118+
strings from prior failed attempts (empty on wave 0).
1119+
1120+
Example::
1121+
1122+
def solve(path: Path, feedback: list[str]) -> tuple[bool, str]:
1123+
result = run_agent(path, prior_errors=feedback)
1124+
if result.tests_pass:
1125+
return True, ""
1126+
return False, result.error_output
1127+
1128+
outcome = Cascaded(solve, fan_out=(1, 2, 4))(ws)
1129+
"""
1130+
1131+
def __init__(
1132+
self,
1133+
task: Callable[[Path, list[str]], tuple[bool, str]],
1134+
*,
1135+
fan_out: Sequence[int] = (1, 2, 4),
1136+
timeout: float | None = None,
1137+
wave_timeout: float | None = None,
1138+
resource_limits: ResourceLimits | None = None,
1139+
group_limits: ResourceLimits | None = None,
1140+
):
1141+
"""
1142+
Args:
1143+
task: Callable(path, feedback) -> (success, error_context).
1144+
*feedback* is a list of error strings from all prior
1145+
failed attempts (empty list on the first wave).
1146+
Returns a tuple of (success_bool, error_string). The
1147+
error string is collected as feedback for subsequent
1148+
waves when the attempt fails; ignored on success.
1149+
fan_out: Number of parallel candidates per wave.
1150+
Default ``(1, 2, 4)`` runs 1 candidate in wave 0,
1151+
2 in wave 1, 4 in wave 2. The number of waves equals
1152+
``len(fan_out)``.
1153+
timeout: Overall timeout in seconds across all waves.
1154+
wave_timeout: Per-wave timeout in seconds.
1155+
resource_limits: Optional per-branch resource limits.
1156+
group_limits: Optional resource limits for the root cgroup.
1157+
"""
1158+
self._task = task
1159+
self._fan_out = list(fan_out)
1160+
self._timeout = timeout
1161+
self._wave_timeout = wave_timeout
1162+
self._resource_limits = resource_limits
1163+
self._group_limits = group_limits
1164+
1165+
def __call__(self, workspace: Workspace) -> SpeculationOutcome:
1166+
import os as _os
1167+
1168+
root_cgroup: Optional[Path] = None
1169+
if self._resource_limits is not None and self._group_limits is not None:
1170+
try:
1171+
from ..process._cgroup import create_group
1172+
root_cgroup = create_group(
1173+
f"cascaded-{_os.getpid()}",
1174+
limits=self._group_limits,
1175+
)
1176+
except OSError:
1177+
root_cgroup = None
1178+
1179+
try:
1180+
return self._run(workspace, root_cgroup)
1181+
finally:
1182+
if root_cgroup is not None:
1183+
from ..process._cgroup import kill_scope
1184+
kill_scope(root_cgroup)
1185+
1186+
def _run(
1187+
self, workspace: Workspace, root_cgroup: Optional[Path],
1188+
) -> SpeculationOutcome:
1189+
all_results: list[SpeculationResult] = []
1190+
feedback: list[str] = []
1191+
base_index = 0
1192+
1193+
deadline = (
1194+
time.monotonic() + self._timeout
1195+
if self._timeout is not None
1196+
else None
1197+
)
1198+
1199+
for wave, width in enumerate(self._fan_out):
1200+
if deadline is not None and deadline - time.monotonic() <= 0:
1201+
break
1202+
1203+
# Snapshot so concurrent/later mutations don't leak in.
1204+
feedback_snapshot = list(feedback)
1205+
wave_errors: list[str] = []
1206+
1207+
outcome = self._run_wave(
1208+
workspace, wave, width, feedback_snapshot, base_index,
1209+
root_cgroup, deadline, wave_errors,
1210+
)
1211+
all_results.extend(outcome.all_results)
1212+
1213+
if outcome.committed:
1214+
return SpeculationOutcome(
1215+
winner=outcome.winner,
1216+
all_results=all_results,
1217+
committed=True,
1218+
)
1219+
1220+
feedback.extend(wave_errors)
1221+
base_index += width
1222+
1223+
return SpeculationOutcome(all_results=all_results, committed=False)
1224+
1225+
def _run_wave(
1226+
self,
1227+
workspace: Workspace,
1228+
wave: int,
1229+
width: int,
1230+
feedback: list[str],
1231+
base_index: int,
1232+
root_cgroup: Optional[Path],
1233+
deadline: Optional[float],
1234+
wave_errors: list[str],
1235+
) -> SpeculationOutcome:
1236+
"""Run a single wave of parallel candidates with first-wins."""
1237+
results: list[Optional[SpeculationResult]] = [None] * width
1238+
winner: Optional[SpeculationResult] = None
1239+
committed = False
1240+
cancel_event = threading.Event()
1241+
1242+
branch_scopes: dict[int, Path] = {}
1243+
1244+
def _kill_scopes(exclude: int = -1) -> None:
1245+
from ..process._cgroup import kill_scope
1246+
for idx, scope in list(branch_scopes.items()):
1247+
if idx != exclude:
1248+
kill_scope(scope)
1249+
1250+
# Compute effective wave timeout.
1251+
wt = self._wave_timeout
1252+
if deadline is not None:
1253+
remaining = deadline - time.monotonic()
1254+
wt = min(wt, remaining) if wt is not None else remaining
1255+
1256+
def _run_candidate(index: int) -> SpeculationResult:
1257+
global_index = base_index + index
1258+
branch_name = f"cascaded_w{wave}_{index}"
1259+
result = SpeculationResult(
1260+
branch_index=global_index, success=False,
1261+
)
1262+
1263+
if cancel_event.is_set():
1264+
return result
1265+
1266+
try:
1267+
with workspace.branch(
1268+
branch_name, on_success=None, on_error="abort"
1269+
) as b:
1270+
result.branch_path = b.path
1271+
1272+
if cancel_event.is_set():
1273+
b.abort()
1274+
return result
1275+
1276+
def _on_scope(scope_path: Path, _i: int = index) -> None:
1277+
branch_scopes[_i] = scope_path
1278+
1279+
success, error_ctx = self._run_task(
1280+
b.path, feedback, root_cgroup,
1281+
scope_callback=_on_scope if self._resource_limits else None,
1282+
timeout=wt,
1283+
)
1284+
1285+
result.success = bool(success)
1286+
result.return_value = (success, error_ctx)
1287+
1288+
if result.success and not cancel_event.is_set():
1289+
cancel_event.set()
1290+
_kill_scopes(index)
1291+
try:
1292+
b.commit()
1293+
except ConflictError:
1294+
result.success = False
1295+
b.abort()
1296+
return result
1297+
else:
1298+
if (
1299+
not cancel_event.is_set()
1300+
and error_ctx
1301+
):
1302+
wave_errors.append(error_ctx)
1303+
b.abort()
1304+
return result
1305+
1306+
except Exception as e:
1307+
result.exception = e
1308+
return result
1309+
1310+
with ThreadPoolExecutor(max_workers=width) as pool:
1311+
futures: dict = {}
1312+
for i in range(width):
1313+
f = pool.submit(_run_candidate, i)
1314+
futures[f] = i
1315+
1316+
try:
1317+
for f in as_completed(futures, timeout=wt):
1318+
idx = futures[f]
1319+
try:
1320+
result = f.result()
1321+
except Exception as e:
1322+
result = SpeculationResult(
1323+
branch_index=base_index + idx,
1324+
success=False,
1325+
exception=e,
1326+
)
1327+
results[idx] = result
1328+
1329+
if result.success and winner is None:
1330+
winner = result
1331+
committed = True
1332+
except TimeoutError:
1333+
cancel_event.set()
1334+
_kill_scopes()
1335+
1336+
for f in futures:
1337+
if not f.done():
1338+
try:
1339+
f.result(timeout=5.0)
1340+
except Exception:
1341+
pass
1342+
1343+
for i, r in enumerate(results):
1344+
if r is None:
1345+
results[i] = SpeculationResult(
1346+
branch_index=base_index + i, success=False,
1347+
)
1348+
1349+
return SpeculationOutcome(
1350+
winner=winner,
1351+
all_results=results,
1352+
committed=committed,
1353+
)
1354+
1355+
def _run_task(
1356+
self,
1357+
path: Path,
1358+
feedback: list[str],
1359+
parent_cgroup: Optional[Path],
1360+
scope_callback=None,
1361+
timeout: Optional[float] = None,
1362+
) -> tuple[bool, str]:
1363+
"""Run the task in a forked child with resource limits."""
1364+
from ..exceptions import ProcessBranchError
1365+
1366+
try:
1367+
ret = run_in_process(
1368+
self._task,
1369+
(path, feedback),
1370+
workspace=path,
1371+
limits=self._resource_limits,
1372+
timeout=timeout,
1373+
parent_cgroup=parent_cgroup,
1374+
scope_callback=scope_callback,
1375+
)
1376+
if isinstance(ret, (tuple, list)):
1377+
success, error_ctx = ret
1378+
return bool(success), str(error_ctx) if error_ctx else ""
1379+
return bool(ret), ""
1380+
except ProcessBranchError:
1381+
return False, ""

0 commit comments

Comments
 (0)