Skip to content

Commit 5913eae

Browse files
committed
Add commit=False to BestOfN for RL training rollouts
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 19f77f0 commit 5913eae

3 files changed

Lines changed: 47 additions & 2 deletions

File tree

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,27 @@ candidates = [make_test(c.message.content) for c in resp.choices]
153153
outcome = BestOfN(candidates, scores=logprob_scores)(ws)
154154
```
155155

156+
#### RL training rollouts
157+
158+
Pass ``commit=False`` to collect scores from all candidates without
159+
modifying the workspace. Every branch runs to completion and aborts --
160+
the base stays pristine for the next batch. This gives you cheap,
161+
isolated rollout environments for policy gradient methods like GRPO.
162+
163+
```python
164+
from branching import Workspace, BestOfN
165+
166+
ws = Workspace("/mnt/workspace")
167+
168+
for prompt in training_batch:
169+
candidates = [make_candidate(prompt) for _ in range(N)]
170+
outcome = BestOfN(candidates, commit=False)(ws)
171+
172+
# All N results available -- extract (success, score) for training
173+
rewards = [(r.success, r.score) for r in outcome.all_results]
174+
trainer.step(prompt, rewards)
175+
```
176+
156177
### Reflexion (retry with feedback)
157178

158179
Run a task, and if it fails, generate a critique and feed it back into the

src/branching/agent/patterns.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ class BestOfN:
2525
finishing. The main thread picks the winner based on score, then
2626
signals each thread to commit (winner) or abort (losers).
2727
28+
Pass ``commit=False`` to abort all branches and return results
29+
without modifying the workspace. Useful for RL training rollouts
30+
where you need scores from every candidate but don't want to
31+
commit any of them.
32+
2833
Each candidate callable receives (path,) and returns ``bool`` or
2934
``(success: bool, score: float)``. A bare ``bool`` defaults to
3035
score 1.0/0.0 unless overridden by *scores* or *evaluate*.
@@ -50,13 +55,15 @@ def __init__(
5055
timeout: float | None = None,
5156
resource_limits: ResourceLimits | None = None,
5257
group_limits: ResourceLimits | None = None,
58+
commit: bool = True,
5359
):
5460
self._candidates = list(candidates)
5561
self._scores = list(scores) if scores is not None else None
5662
self._evaluate = evaluate
5763
self._timeout = timeout
5864
self._resource_limits = resource_limits
5965
self._group_limits = group_limits
66+
self._commit = commit
6067

6168
def _score(self, ret, path, index):
6269
"""Parse candidate return and apply optional evaluator."""
@@ -173,7 +180,7 @@ def _on_scope(sp: Path, _i: int = index) -> None:
173180
best_score = r.score
174181
best_idx = i
175182

176-
if best_idx is not None:
183+
if best_idx is not None and self._commit:
177184
decisions[best_idx] = "commit"
178185

179186
# Kill still-running tasks (only useful when timeout left
@@ -189,7 +196,7 @@ def _on_scope(sp: Path, _i: int = index) -> None:
189196
for f in futures:
190197
f.result()
191198

192-
committed = best_idx is not None
199+
committed = best_idx is not None and self._commit
193200
winner = results[best_idx] if best_idx is not None else None
194201
all_results = [
195202
r if r is not None else SpeculationResult(branch_index=i, success=False)

tests/test_speculate.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,23 @@ def test_commits_exactly_one(self):
185185
assert len(MockFSBackend._commits) == 1
186186
assert len(MockFSBackend._aborts) == 2
187187

188+
def test_commit_false_aborts_all(self):
189+
"""commit=False aborts all branches and returns results."""
190+
ws = _make_workspace()
191+
192+
candidates = [
193+
lambda p, s=s: (True, float(s)) for s in range(3)
194+
]
195+
196+
outcome = BestOfN(candidates, commit=False)(ws)
197+
assert not outcome.committed
198+
assert outcome.winner.branch_index == 2 # best still identified
199+
assert outcome.winner.score == 2.0
200+
assert len(outcome.all_results) == 3
201+
# All aborted, none committed
202+
assert len(MockFSBackend._commits) == 0
203+
assert len(MockFSBackend._aborts) == 3
204+
188205
def test_runs_in_parallel(self):
189206
"""Verify candidates actually run concurrently."""
190207
import time

0 commit comments

Comments
 (0)