66"""
77
88import time
9+ from collections .abc import Callable
910from concurrent .futures import Future
1011from inspect import BoundArguments
1112from typing import Any , override
1213
13- from effectful .handlers .futures import ThreadPoolFuturesInterpretation
14+ import effectful .handlers .futures as futures
15+ from effectful .handlers .futures import Executor , ThreadPoolFuturesInterpretation
1416from effectful .handlers .llm import Template
1517from effectful .handlers .llm .providers import OpenAIAPIProvider
1618from effectful .ops .semantics import handler
@@ -28,34 +30,73 @@ def __init__(self, response, delay: float = 0.05, mapping={}):
2830
2931 @override
3032 def _openai_api_call [T ](
31- self , template : Any , args : BoundArguments , retty : type [T ]
33+ self , template : Template , args : BoundArguments , retty : type [T ]
3234 ) -> T :
3335 self .calls .append ((template , args .args , retty ))
3436 time .sleep (self .delay )
35-
36- return self .mapping .get ((template , tuple (args .args )), self .response )
37+ return self .mapping .get (template , {}).get (tuple (args .args ), self .response )
3738
3839
3940@Template .define
40- def hiaku (topic : str ) -> Future [ str ] :
41+ def hiaku (topic : str ) -> str :
4142 """Return a hiaku about {topic}."""
4243 raise NotHandled
4344
4445
45- # synchronous template for comparison
46- @Template .define
47- def hiaku_s (topic : str ) -> str :
48- """Return a hiaku about {topic}."""
49- raise NotImplementedError
50-
51-
5246def test_future_return_type_decodes_inner_type ():
53- """Test that Future[int] templates correctly decode to int."""
47+ """Test that llm templates correctly decode to int, even wrapped in a future ."""
5448 ref_hiaku = "apples to oranges, oranges to pears, I don't know what a hiaku is"
5549 mock_provider = SlowMockLLMProvider (ref_hiaku , delay = 0.001 )
5650
5751 with handler (ThreadPoolFuturesInterpretation ()), handler (mock_provider ):
58- future = hiaku ( "apples" )
52+ future = Executor . submit ( hiaku , "apples" )
5953 assert isinstance (future , Future )
6054 result = future .result ()
6155 assert result == ref_hiaku
56+
57+
58+ @Template .define
59+ def generate_program (task : str ) -> Callable [[int ], int ]:
60+ """Generate a Python program that {task}."""
61+ raise NotHandled
62+
63+
64+ def test_concurrent_program_generation ():
65+ """Simulate concurrent LLM calls to generate Python programs and pick the best one."""
66+ # Mock responses for different approaches to the same task
67+ responses = {
68+ generate_program : {
69+ ("implement fibonacci algorithm 0" ,): "def fib(n: int) -> int: return n" ,
70+ (
71+ "implement fibonacci algorithm 1" ,
72+ ): "def fib(n: int) -> int: return n * fib(n - 1)" ,
73+ (
74+ "implement fibonacci algorithm 2" ,
75+ ): "def fib(n: int) -> int: return fib(n - 2) + fib(n - 1) if n > 1 else 0" ,
76+ }
77+ }
78+
79+ mock_provider = SlowMockLLMProvider (
80+ response = "print('Default')" , delay = 0.01 , mapping = responses
81+ )
82+
83+ user_request : str = "implement fibonacci algorithm"
84+
85+ with handler (ThreadPoolFuturesInterpretation ()), handler (mock_provider ):
86+ # Launch multiple LLM calls concurrently
87+ tasks = [
88+ Executor .submit (generate_program , (user_request + f" { i } " ))
89+ for i in range (3 )
90+ ]
91+
92+ # Collect all results as they finish
93+ results_as_completed = (f .result () for f in futures .as_completed (tasks ))
94+
95+ valid_results = [(result , len (result )) for result in results_as_completed ]
96+
97+ # Pick the "best" result (here: the shortest program, as a naive heuristic)
98+ best_program = max (valid_results , key = lambda pair : pair [1 ])[0 ]
99+
100+ # Assertions
101+ assert len (valid_results ) == 3
102+ assert best_program in set (responses [generate_program ].values ())
0 commit comments