|
5 | 5 |
|
6 | 6 | import threading |
7 | 7 | import time |
8 | | -from concurrent.futures import ThreadPoolExecutor |
| 8 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
9 | 9 | from pathlib import Path |
10 | 10 | from typing import Callable, Optional, Sequence, TYPE_CHECKING |
11 | 11 |
|
12 | 12 | from ..core.workspace import Workspace |
13 | 13 | from ..process.runner import run_in_process |
| 14 | +from ..exceptions import ConflictError |
14 | 15 | from .result import SpeculationResult, SpeculationOutcome |
15 | 16 |
|
16 | 17 | if TYPE_CHECKING: |
@@ -1101,3 +1102,280 @@ def _on_scope(sp: Path, _i: int = index) -> None: |
1101 | 1102 | all_results=all_results, |
1102 | 1103 | committed=committed, |
1103 | 1104 | ) |
| 1105 | + |
| 1106 | + |
| 1107 | +class Cascaded: |
| 1108 | + """Cascaded speculation: start narrow, widen on failure with feedback. |
| 1109 | +
|
| 1110 | + Runs successive waves of parallel candidates with increasing fan-out. |
| 1111 | + Wave 0 starts with few candidates (typically one). If all fail, the |
| 1112 | + next wave runs more candidates, each informed by error context from |
| 1113 | + prior failures. Within each wave, first-winner-commit semantics |
| 1114 | + apply: the first successful candidate commits and siblings are aborted. |
| 1115 | +
|
| 1116 | + The task callable receives ``(path, feedback)`` and returns |
| 1117 | + ``(success, error_context)``. *feedback* is always a list of error |
| 1118 | + strings from prior failed attempts (empty on wave 0). |
| 1119 | +
|
| 1120 | + Example:: |
| 1121 | +
|
| 1122 | + def solve(path: Path, feedback: list[str]) -> tuple[bool, str]: |
| 1123 | + result = run_agent(path, prior_errors=feedback) |
| 1124 | + if result.tests_pass: |
| 1125 | + return True, "" |
| 1126 | + return False, result.error_output |
| 1127 | +
|
| 1128 | + outcome = Cascaded(solve, fan_out=(1, 2, 4))(ws) |
| 1129 | + """ |
| 1130 | + |
| 1131 | + def __init__( |
| 1132 | + self, |
| 1133 | + task: Callable[[Path, list[str]], tuple[bool, str]], |
| 1134 | + *, |
| 1135 | + fan_out: Sequence[int] = (1, 2, 4), |
| 1136 | + timeout: float | None = None, |
| 1137 | + wave_timeout: float | None = None, |
| 1138 | + resource_limits: ResourceLimits | None = None, |
| 1139 | + group_limits: ResourceLimits | None = None, |
| 1140 | + ): |
| 1141 | + """ |
| 1142 | + Args: |
| 1143 | + task: Callable(path, feedback) -> (success, error_context). |
| 1144 | + *feedback* is a list of error strings from all prior |
| 1145 | + failed attempts (empty list on the first wave). |
| 1146 | + Returns a tuple of (success_bool, error_string). The |
| 1147 | + error string is collected as feedback for subsequent |
| 1148 | + waves when the attempt fails; ignored on success. |
| 1149 | + fan_out: Number of parallel candidates per wave. |
| 1150 | + Default ``(1, 2, 4)`` runs 1 candidate in wave 0, |
| 1151 | + 2 in wave 1, 4 in wave 2. The number of waves equals |
| 1152 | + ``len(fan_out)``. |
| 1153 | + timeout: Overall timeout in seconds across all waves. |
| 1154 | + wave_timeout: Per-wave timeout in seconds. |
| 1155 | + resource_limits: Optional per-branch resource limits. |
| 1156 | + group_limits: Optional resource limits for the root cgroup. |
| 1157 | + """ |
| 1158 | + self._task = task |
| 1159 | + self._fan_out = list(fan_out) |
| 1160 | + self._timeout = timeout |
| 1161 | + self._wave_timeout = wave_timeout |
| 1162 | + self._resource_limits = resource_limits |
| 1163 | + self._group_limits = group_limits |
| 1164 | + |
| 1165 | + def __call__(self, workspace: Workspace) -> SpeculationOutcome: |
| 1166 | + import os as _os |
| 1167 | + |
| 1168 | + root_cgroup: Optional[Path] = None |
| 1169 | + if self._resource_limits is not None and self._group_limits is not None: |
| 1170 | + try: |
| 1171 | + from ..process._cgroup import create_group |
| 1172 | + root_cgroup = create_group( |
| 1173 | + f"cascaded-{_os.getpid()}", |
| 1174 | + limits=self._group_limits, |
| 1175 | + ) |
| 1176 | + except OSError: |
| 1177 | + root_cgroup = None |
| 1178 | + |
| 1179 | + try: |
| 1180 | + return self._run(workspace, root_cgroup) |
| 1181 | + finally: |
| 1182 | + if root_cgroup is not None: |
| 1183 | + from ..process._cgroup import kill_scope |
| 1184 | + kill_scope(root_cgroup) |
| 1185 | + |
| 1186 | + def _run( |
| 1187 | + self, workspace: Workspace, root_cgroup: Optional[Path], |
| 1188 | + ) -> SpeculationOutcome: |
| 1189 | + all_results: list[SpeculationResult] = [] |
| 1190 | + feedback: list[str] = [] |
| 1191 | + base_index = 0 |
| 1192 | + |
| 1193 | + deadline = ( |
| 1194 | + time.monotonic() + self._timeout |
| 1195 | + if self._timeout is not None |
| 1196 | + else None |
| 1197 | + ) |
| 1198 | + |
| 1199 | + for wave, width in enumerate(self._fan_out): |
| 1200 | + if deadline is not None and deadline - time.monotonic() <= 0: |
| 1201 | + break |
| 1202 | + |
| 1203 | + # Snapshot so concurrent/later mutations don't leak in. |
| 1204 | + feedback_snapshot = list(feedback) |
| 1205 | + wave_errors: list[str] = [] |
| 1206 | + |
| 1207 | + outcome = self._run_wave( |
| 1208 | + workspace, wave, width, feedback_snapshot, base_index, |
| 1209 | + root_cgroup, deadline, wave_errors, |
| 1210 | + ) |
| 1211 | + all_results.extend(outcome.all_results) |
| 1212 | + |
| 1213 | + if outcome.committed: |
| 1214 | + return SpeculationOutcome( |
| 1215 | + winner=outcome.winner, |
| 1216 | + all_results=all_results, |
| 1217 | + committed=True, |
| 1218 | + ) |
| 1219 | + |
| 1220 | + feedback.extend(wave_errors) |
| 1221 | + base_index += width |
| 1222 | + |
| 1223 | + return SpeculationOutcome(all_results=all_results, committed=False) |
| 1224 | + |
| 1225 | + def _run_wave( |
| 1226 | + self, |
| 1227 | + workspace: Workspace, |
| 1228 | + wave: int, |
| 1229 | + width: int, |
| 1230 | + feedback: list[str], |
| 1231 | + base_index: int, |
| 1232 | + root_cgroup: Optional[Path], |
| 1233 | + deadline: Optional[float], |
| 1234 | + wave_errors: list[str], |
| 1235 | + ) -> SpeculationOutcome: |
| 1236 | + """Run a single wave of parallel candidates with first-wins.""" |
| 1237 | + results: list[Optional[SpeculationResult]] = [None] * width |
| 1238 | + winner: Optional[SpeculationResult] = None |
| 1239 | + committed = False |
| 1240 | + cancel_event = threading.Event() |
| 1241 | + |
| 1242 | + branch_scopes: dict[int, Path] = {} |
| 1243 | + |
| 1244 | + def _kill_scopes(exclude: int = -1) -> None: |
| 1245 | + from ..process._cgroup import kill_scope |
| 1246 | + for idx, scope in list(branch_scopes.items()): |
| 1247 | + if idx != exclude: |
| 1248 | + kill_scope(scope) |
| 1249 | + |
| 1250 | + # Compute effective wave timeout. |
| 1251 | + wt = self._wave_timeout |
| 1252 | + if deadline is not None: |
| 1253 | + remaining = deadline - time.monotonic() |
| 1254 | + wt = min(wt, remaining) if wt is not None else remaining |
| 1255 | + |
| 1256 | + def _run_candidate(index: int) -> SpeculationResult: |
| 1257 | + global_index = base_index + index |
| 1258 | + branch_name = f"cascaded_w{wave}_{index}" |
| 1259 | + result = SpeculationResult( |
| 1260 | + branch_index=global_index, success=False, |
| 1261 | + ) |
| 1262 | + |
| 1263 | + if cancel_event.is_set(): |
| 1264 | + return result |
| 1265 | + |
| 1266 | + try: |
| 1267 | + with workspace.branch( |
| 1268 | + branch_name, on_success=None, on_error="abort" |
| 1269 | + ) as b: |
| 1270 | + result.branch_path = b.path |
| 1271 | + |
| 1272 | + if cancel_event.is_set(): |
| 1273 | + b.abort() |
| 1274 | + return result |
| 1275 | + |
| 1276 | + def _on_scope(scope_path: Path, _i: int = index) -> None: |
| 1277 | + branch_scopes[_i] = scope_path |
| 1278 | + |
| 1279 | + success, error_ctx = self._run_task( |
| 1280 | + b.path, feedback, root_cgroup, |
| 1281 | + scope_callback=_on_scope if self._resource_limits else None, |
| 1282 | + timeout=wt, |
| 1283 | + ) |
| 1284 | + |
| 1285 | + result.success = bool(success) |
| 1286 | + result.return_value = (success, error_ctx) |
| 1287 | + |
| 1288 | + if result.success and not cancel_event.is_set(): |
| 1289 | + cancel_event.set() |
| 1290 | + _kill_scopes(index) |
| 1291 | + try: |
| 1292 | + b.commit() |
| 1293 | + except ConflictError: |
| 1294 | + result.success = False |
| 1295 | + b.abort() |
| 1296 | + return result |
| 1297 | + else: |
| 1298 | + if ( |
| 1299 | + not cancel_event.is_set() |
| 1300 | + and error_ctx |
| 1301 | + ): |
| 1302 | + wave_errors.append(error_ctx) |
| 1303 | + b.abort() |
| 1304 | + return result |
| 1305 | + |
| 1306 | + except Exception as e: |
| 1307 | + result.exception = e |
| 1308 | + return result |
| 1309 | + |
| 1310 | + with ThreadPoolExecutor(max_workers=width) as pool: |
| 1311 | + futures: dict = {} |
| 1312 | + for i in range(width): |
| 1313 | + f = pool.submit(_run_candidate, i) |
| 1314 | + futures[f] = i |
| 1315 | + |
| 1316 | + try: |
| 1317 | + for f in as_completed(futures, timeout=wt): |
| 1318 | + idx = futures[f] |
| 1319 | + try: |
| 1320 | + result = f.result() |
| 1321 | + except Exception as e: |
| 1322 | + result = SpeculationResult( |
| 1323 | + branch_index=base_index + idx, |
| 1324 | + success=False, |
| 1325 | + exception=e, |
| 1326 | + ) |
| 1327 | + results[idx] = result |
| 1328 | + |
| 1329 | + if result.success and winner is None: |
| 1330 | + winner = result |
| 1331 | + committed = True |
| 1332 | + except TimeoutError: |
| 1333 | + cancel_event.set() |
| 1334 | + _kill_scopes() |
| 1335 | + |
| 1336 | + for f in futures: |
| 1337 | + if not f.done(): |
| 1338 | + try: |
| 1339 | + f.result(timeout=5.0) |
| 1340 | + except Exception: |
| 1341 | + pass |
| 1342 | + |
| 1343 | + for i, r in enumerate(results): |
| 1344 | + if r is None: |
| 1345 | + results[i] = SpeculationResult( |
| 1346 | + branch_index=base_index + i, success=False, |
| 1347 | + ) |
| 1348 | + |
| 1349 | + return SpeculationOutcome( |
| 1350 | + winner=winner, |
| 1351 | + all_results=results, |
| 1352 | + committed=committed, |
| 1353 | + ) |
| 1354 | + |
| 1355 | + def _run_task( |
| 1356 | + self, |
| 1357 | + path: Path, |
| 1358 | + feedback: list[str], |
| 1359 | + parent_cgroup: Optional[Path], |
| 1360 | + scope_callback=None, |
| 1361 | + timeout: Optional[float] = None, |
| 1362 | + ) -> tuple[bool, str]: |
| 1363 | + """Run the task in a forked child with resource limits.""" |
| 1364 | + from ..exceptions import ProcessBranchError |
| 1365 | + |
| 1366 | + try: |
| 1367 | + ret = run_in_process( |
| 1368 | + self._task, |
| 1369 | + (path, feedback), |
| 1370 | + workspace=path, |
| 1371 | + limits=self._resource_limits, |
| 1372 | + timeout=timeout, |
| 1373 | + parent_cgroup=parent_cgroup, |
| 1374 | + scope_callback=scope_callback, |
| 1375 | + ) |
| 1376 | + if isinstance(ret, (tuple, list)): |
| 1377 | + success, error_ctx = ret |
| 1378 | + return bool(success), str(error_ctx) if error_ctx else "" |
| 1379 | + return bool(ret), "" |
| 1380 | + except ProcessBranchError: |
| 1381 | + return False, "" |
0 commit comments