Skip to content

Commit bd650b7

Browse files
committed
updated tests
1 parent 5af3568 commit bd650b7

1 file changed

Lines changed: 55 additions & 14 deletions

File tree

tests/test_handlers_llm_futures.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
"""
77

88
import time
9+
from collections.abc import Callable
910
from concurrent.futures import Future
1011
from inspect import BoundArguments
1112
from typing import Any, override
1213

13-
from effectful.handlers.futures import ThreadPoolFuturesInterpretation
14+
import effectful.handlers.futures as futures
15+
from effectful.handlers.futures import Executor, ThreadPoolFuturesInterpretation
1416
from effectful.handlers.llm import Template
1517
from effectful.handlers.llm.providers import OpenAIAPIProvider
1618
from effectful.ops.semantics import handler
@@ -28,34 +30,73 @@ def __init__(self, response, delay: float = 0.05, mapping={}):
2830

2931
@override
3032
def _openai_api_call[T](
31-
self, template: Any, args: BoundArguments, retty: type[T]
33+
self, template: Template, args: BoundArguments, retty: type[T]
3234
) -> T:
3335
self.calls.append((template, args.args, retty))
3436
time.sleep(self.delay)
35-
36-
return self.mapping.get((template, tuple(args.args)), self.response)
37+
return self.mapping.get(template, {}).get(tuple(args.args), self.response)
3738

3839

3940
@Template.define
40-
def hiaku(topic: str) -> Future[str]:
41+
def hiaku(topic: str) -> str:
4142
"""Return a hiaku about {topic}."""
4243
raise NotHandled
4344

4445

45-
# synchronous template for comparison
46-
@Template.define
47-
def hiaku_s(topic: str) -> str:
48-
"""Return a hiaku about {topic}."""
49-
raise NotImplementedError
50-
51-
5246
def test_future_return_type_decodes_inner_type():
53-
"""Test that Future[int] templates correctly decode to int."""
47+
"""Test that llm templates correctly decode to int, even wrapped in a future."""
5448
ref_hiaku = "apples to oranges, oranges to pears, I don't know what a hiaku is"
5549
mock_provider = SlowMockLLMProvider(ref_hiaku, delay=0.001)
5650

5751
with handler(ThreadPoolFuturesInterpretation()), handler(mock_provider):
58-
future = hiaku("apples")
52+
future = Executor.submit(hiaku, "apples")
5953
assert isinstance(future, Future)
6054
result = future.result()
6155
assert result == ref_hiaku
56+
57+
58+
@Template.define
59+
def generate_program(task: str) -> Callable[[int], int]:
60+
"""Generate a Python program that {task}."""
61+
raise NotHandled
62+
63+
64+
def test_concurrent_program_generation():
65+
"""Simulate concurrent LLM calls to generate Python programs and pick the best one."""
66+
# Mock responses for different approaches to the same task
67+
responses = {
68+
generate_program: {
69+
("implement fibonacci algorithm 0",): "def fib(n: int) -> int: return n",
70+
(
71+
"implement fibonacci algorithm 1",
72+
): "def fib(n: int) -> int: return n * fib(n - 1)",
73+
(
74+
"implement fibonacci algorithm 2",
75+
): "def fib(n: int) -> int: return fib(n - 2) + fib(n - 1) if n > 1 else 0",
76+
}
77+
}
78+
79+
mock_provider = SlowMockLLMProvider(
80+
response="print('Default')", delay=0.01, mapping=responses
81+
)
82+
83+
user_request: str = "implement fibonacci algorithm"
84+
85+
with handler(ThreadPoolFuturesInterpretation()), handler(mock_provider):
86+
# Launch multiple LLM calls concurrently
87+
tasks = [
88+
Executor.submit(generate_program, (user_request + f" {i}"))
89+
for i in range(3)
90+
]
91+
92+
# Collect all results as they finish
93+
results_as_completed = (f.result() for f in futures.as_completed(tasks))
94+
95+
valid_results = [(result, len(result)) for result in results_as_completed]
96+
97+
# Pick the "best" result (here: the shortest program, as a naive heuristic)
98+
best_program = max(valid_results, key=lambda pair: pair[1])[0]
99+
100+
# Assertions
101+
assert len(valid_results) == 3
102+
assert best_program in set(responses[generate_program].values())

0 commit comments

Comments
 (0)