Skip to content

Commit f24f5a8

Browse files
committed
Add scores and evaluate params to BestOfN for logprobs integration
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent eab1327 commit f24f5a8

3 files changed

Lines changed: 158 additions & 9 deletions

File tree

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ Pairs naturally with the ``n=`` parameter in OpenAI's Chat Completions API
118118
to generate N variations in a single call, then test each in an isolated
119119
branch.
120120

121+
Candidates can return ``bool`` or ``(bool, float)``. Scoring is flexible:
122+
pass pre-computed ``scores`` (e.g. from logprobs), provide an ``evaluate``
123+
callback for post-execution scoring, or let candidates score themselves.
124+
121125
```python
122126
from branching import BestOfN
123127

@@ -132,6 +136,32 @@ candidates = [make_candidate(c) for c in generate_solutions(n=5)]
132136
outcome = BestOfN(candidates)(ws)
133137
```
134138

139+
**Logprobs workflow** — score candidates externally using model confidence,
140+
then let BestOfN pick the highest-scoring one that passes:
141+
142+
```python
143+
from branching import BestOfN
144+
import openai
145+
146+
client = openai.OpenAI()
147+
resp = client.chat.completions.create(
148+
model="gpt-4o", n=5, logprobs=True, top_logprobs=1,
149+
messages=[{"role": "user", "content": prompt}],
150+
)
151+
152+
# Pre-computed confidence scores from logprobs
153+
logprob_scores = [
154+
sum(t.logprob for t in c.logprobs.content) / len(c.logprobs.content)
155+
for c in resp.choices
156+
]
157+
158+
# Candidates just apply code and test — return bare bool
159+
candidates = [make_test(c.message.content) for c in resp.choices]
160+
161+
# BestOfN picks the highest-logprob passing candidate
162+
outcome = BestOfN(candidates, scores=logprob_scores)(ws)
163+
```
164+
135165
### Reflexion (retry with feedback)
136166

137167
Run a task, and if it fails, generate a critique and feed it back into the

src/branching/agent/patterns.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,15 @@ class BestOfN:
2525
finishing. The main thread picks the winner based on score, then
2626
signals each thread to commit (winner) or abort (losers).
2727
28-
Each candidate callable receives (path,) and returns
29-
(success: bool, score: float).
28+
Each candidate callable receives (path,) and returns ``bool`` or
29+
``(success: bool, score: float)``. A bare ``bool`` defaults to
30+
score 1.0/0.0 unless overridden by *scores* or *evaluate*.
31+
32+
Score resolution per successful candidate (highest priority first):
33+
1. evaluate(path) — external scorer callback
34+
2. candidate (bool, f) — candidate's own score
35+
3. scores[i] — pre-computed score (e.g. logprobs)
36+
4. 1.0 — default
3037
3138
Example:
3239
candidates = [lambda p: (run_tests(p), score(p)) for _ in range(5)]
@@ -36,17 +43,32 @@ class BestOfN:
3643

3744
def __init__(
3845
self,
39-
candidates: Sequence[Callable[[Path], tuple[bool, float]]],
46+
candidates: Sequence[Callable[[Path], bool | tuple[bool, float]]],
4047
*,
48+
scores: Sequence[float] | None = None,
49+
evaluate: Callable[[Path], float] | None = None,
4150
timeout: float | None = None,
4251
resource_limits: ResourceLimits | None = None,
4352
group_limits: ResourceLimits | None = None,
4453
):
4554
self._candidates = list(candidates)
55+
self._scores = list(scores) if scores is not None else None
56+
self._evaluate = evaluate
4657
self._timeout = timeout
4758
self._resource_limits = resource_limits
4859
self._group_limits = group_limits
4960

61+
def _score(self, ret, path, index):
62+
"""Parse candidate return and apply optional evaluator."""
63+
if isinstance(ret, (tuple, list)):
64+
success, score = ret
65+
else:
66+
success = bool(ret)
67+
score = self._scores[index] if self._scores else (1.0 if success else 0.0)
68+
if self._evaluate and success:
69+
score = self._evaluate(path)
70+
return bool(success), score
71+
5072
def __call__(self, workspace: Workspace) -> SpeculationOutcome:
5173
import os as _os
5274

@@ -101,10 +123,10 @@ def _on_scope(sp: Path, _i: int = index) -> None:
101123
parent_cgroup=root_cgroup,
102124
scope_callback=_on_scope if self._resource_limits else None,
103125
)
104-
success, score = ret
105-
result.success = bool(success)
126+
success, score = self._score(ret, b.path, index)
127+
result.success = success
106128
result.score = score
107-
result.return_value = (success, score)
129+
result.return_value = ret
108130
except Exception as e:
109131
result.exception = e
110132

tests/test_speculate.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ def test_picks_highest_score(self):
121121
assert outcome.winner.score == 0.9
122122
assert len(outcome.all_results) == 3
123123

124+
def test_bool_candidates(self):
125+
"""Candidates returning bare bool get score 1.0/0.0."""
126+
ws = _make_workspace()
127+
128+
candidates = [
129+
lambda p: False,
130+
lambda p: True,
131+
]
132+
133+
outcome = BestOfN(candidates)(ws)
134+
assert outcome.committed
135+
assert outcome.winner.branch_index == 1
136+
assert outcome.winner.score == 1.0
137+
124138
def test_skips_failures(self):
125139
ws = _make_workspace()
126140

@@ -136,7 +150,7 @@ def test_skips_failures(self):
136150
def test_all_fail(self):
137151
ws = _make_workspace()
138152

139-
candidates = [lambda p: (False, 0.0) for _ in range(3)]
153+
candidates = [lambda p: False for _ in range(3)]
140154

141155
outcome = BestOfN(candidates)(ws)
142156
assert not outcome.committed
@@ -177,16 +191,99 @@ def test_runs_in_parallel(self):
177191
ws = _make_workspace()
178192
start = time.monotonic()
179193

180-
def slow(path: Path) -> tuple[bool, float]:
194+
def slow(path: Path) -> bool:
181195
time.sleep(0.2)
182-
return True, 1.0
196+
return True
183197

184198
outcome = BestOfN([slow, slow, slow])(ws)
185199
elapsed = time.monotonic() - start
186200
assert outcome.committed
187201
# 3 tasks @ 0.2s each; parallel should be ~0.2s, sequential ~0.6s
188202
assert elapsed < 0.5
189203

204+
def test_scores_param(self):
205+
"""Pre-computed scores override default 1.0/0.0 for bool candidates."""
206+
ws = _make_workspace()
207+
208+
candidates = [lambda p: True, lambda p: True, lambda p: True]
209+
logprob_scores = [-2.5, -0.1, -1.3]
210+
211+
outcome = BestOfN(candidates, scores=logprob_scores)(ws)
212+
assert outcome.committed
213+
assert outcome.winner.branch_index == 1 # highest logprob
214+
assert outcome.winner.score == -0.1
215+
216+
def test_scores_ignored_for_tuple_return(self):
217+
"""When candidate returns (bool, float), scores param is ignored."""
218+
ws = _make_workspace()
219+
220+
candidates = [
221+
lambda p: (True, 5.0), # candidate provides own score
222+
lambda p: (True, 10.0), # candidate provides own score
223+
]
224+
225+
outcome = BestOfN(candidates, scores=[99.0, 1.0])(ws)
226+
assert outcome.committed
227+
assert outcome.winner.branch_index == 1 # tuple score 10.0 wins
228+
assert outcome.winner.score == 10.0
229+
230+
def test_scores_skipped_for_failures(self):
231+
"""Failed bool candidates don't use pre-computed scores."""
232+
ws = _make_workspace()
233+
234+
candidates = [lambda p: False, lambda p: True]
235+
outcome = BestOfN(candidates, scores=[99.0, 0.5])(ws)
236+
assert outcome.committed
237+
assert outcome.winner.branch_index == 1
238+
assert outcome.winner.score == 0.5
239+
240+
def test_evaluate_callback(self):
241+
"""External evaluate callback overrides all other scores."""
242+
ws = _make_workspace()
243+
244+
candidates = [
245+
lambda p: (True, 10.0), # candidate says 10
246+
lambda p: (True, 1.0), # candidate says 1
247+
]
248+
249+
calls = []
250+
def evaluate(path):
251+
calls.append(path)
252+
return float(len(calls)) # 1.0 for first, 2.0 for second
253+
254+
outcome = BestOfN(candidates, evaluate=evaluate)(ws)
255+
assert outcome.committed
256+
assert len(calls) == 2
257+
258+
def test_evaluate_not_called_on_failure(self):
259+
"""evaluate is only called for successful candidates."""
260+
ws = _make_workspace()
261+
262+
eval_calls = []
263+
def evaluate(path):
264+
eval_calls.append(path)
265+
return 1.0
266+
267+
candidates = [lambda p: False, lambda p: True]
268+
outcome = BestOfN(candidates, evaluate=evaluate)(ws)
269+
assert outcome.committed
270+
assert len(eval_calls) == 1
271+
272+
def test_evaluate_overrides_scores_param(self):
273+
"""evaluate takes priority over both tuple scores and scores param."""
274+
ws = _make_workspace()
275+
276+
candidates = [lambda p: True, lambda p: True]
277+
278+
outcome = BestOfN(
279+
candidates,
280+
scores=[100.0, 1.0],
281+
evaluate=lambda p: 42.0,
282+
)(ws)
283+
assert outcome.committed
284+
# Both get evaluate score 42.0; either could win (both equal)
285+
assert outcome.winner.score == 42.0
286+
190287

191288
class TestReflexion:
192289
def test_succeeds_first_try(self):

0 commit comments

Comments
 (0)