Skip to content

Commit 7679e47

Browse files
kiranandcodeeb8680
andauthored
Implement internal synthesis API (#519)
* Fresh diff * remove instructionhandler * updated internal interface to make all tests pass * fixed caching tests * updated llm.ipynb * removed unnecessarily defensive validation * updated tool call decoding to use concrete type of tool result instead of annotations * updated completions to fix basic type errors * updated call assistant to handle decoding tool calls * dropped stale comments * moved model and param model back to internals of `completions` * added default encodable instance for Callable * fixed type errors * update to use more structured type for synthesis * updated callable encoding tests * s/TypeError/NotImplementedError * simplified smart constructor * bare callables not allowed * droped synthesis and removed encoding_instructions * fixed imports * fixed imports and tests --------- Co-authored-by: Eli <eli@basis.ai>
1 parent a4a5086 commit 7679e47

14 files changed

Lines changed: 1285 additions & 60 deletions

effectful/handlers/llm/encoding.py

Lines changed: 270 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import ast
12
import base64
3+
import inspect
24
import io
5+
import textwrap
6+
import types
37
import typing
48
from abc import ABC, abstractmethod
5-
from collections.abc import Callable, Mapping, Sequence
9+
from collections.abc import Callable, Mapping, MutableMapping, Sequence
610
from dataclasses import dataclass
11+
from types import CodeType
712
from typing import Any
813

914
import pydantic
@@ -13,6 +18,7 @@
1318
)
1419
from PIL import Image
1520

21+
import effectful.handlers.llm.evaluation as evaluation
1622
from effectful.ops.semantics import _simple_type
1723
from effectful.ops.syntax import _CustomSingleDispatchCallable
1824
from effectful.ops.types import Operation, Term
@@ -253,6 +259,236 @@ def deserialize(self, serialized_value: str) -> typing.Any:
253259
return typing.cast(typing.Any, adapter.validate_json(serialized_value))
254260

255261

262+
def _format_callable_type(callable_type: type[Callable]) -> str:
263+
"""Format a Callable type annotation as a string for LLM instructions."""
264+
args = typing.get_args(callable_type)
265+
if not args:
266+
return "Callable"
267+
268+
# Callable[[arg1, arg2, ...], return_type]
269+
if len(args) >= 2:
270+
param_types = args[0]
271+
return_type = args[-1]
272+
273+
if param_types is ...:
274+
params_str = "..."
275+
elif isinstance(param_types, (list, tuple)):
276+
params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types)
277+
else:
278+
params_str = str(param_types)
279+
280+
return_str = getattr(return_type, "__name__", str(return_type))
281+
return f"Callable[[{params_str}], {return_str}]"
282+
283+
return str(callable_type)
284+
285+
286+
class SynthesizedFunction(pydantic.BaseModel):
287+
"""Structured output for function synthesis.
288+
289+
Pydantic model representing synthesized code with function name and module code.
290+
"""
291+
292+
module_code: str = pydantic.Field(
293+
...,
294+
description="Complete Python module code (no imports needed)",
295+
)
296+
297+
298+
def _create_typed_synthesized_function(
299+
callable_type: type[Callable],
300+
) -> type[SynthesizedFunction]:
301+
"""Create a SynthesizedFunction subclass with type signature in the model description.
302+
303+
Uses pydantic.create_model to ensure the description is included in the JSON schema
304+
sent to the LLM, informing it of the expected function signature.
305+
"""
306+
type_signature = _format_callable_type(callable_type)
307+
308+
description = f"""Given the specification above, generate a Python function satisfying the following specification and type signature.
309+
310+
<signature>{type_signature}</signature>
311+
312+
<instructions>
313+
1. Produce one block of Python code.
314+
2. The function MUST have type annotations for all parameters and the return type.
315+
3. The function definition must be the LAST statement - do not add any code after it.
316+
4. Do not include usage examples or function calls.
317+
</instructions>
318+
"""
319+
320+
# Use pydantic.create_model to create a proper model with the description
321+
# The __doc__ becomes the model's description in the JSON schema
322+
model = pydantic.create_model(
323+
"TypedSynthesizedFunction",
324+
__base__=SynthesizedFunction,
325+
__doc__=description,
326+
)
327+
return model
328+
329+
330+
def _validate_signature_ast(
331+
func_ast: ast.FunctionDef | ast.AsyncFunctionDef,
332+
expected_params: list[type] | None,
333+
) -> None:
334+
"""Validate the function signature from AST before execution."""
335+
if expected_params is not None:
336+
ast_params = func_ast.args.args + func_ast.args.posonlyargs
337+
if len(ast_params) != len(expected_params):
338+
raise ValueError(
339+
f"decode() expected function with {len(expected_params)} parameters, "
340+
f"got {len(ast_params)}"
341+
)
342+
343+
344+
def _validate_signature_callable(
345+
func: Callable,
346+
expected_params: list[type] | None,
347+
expected_return: type,
348+
) -> None:
349+
"""Validate the function signature from runtime callable after execution.
350+
351+
The synthesized function must have type annotations for parameters and return type.
352+
"""
353+
sig = inspect.signature(func)
354+
355+
if expected_params is not None:
356+
actual_params = list(sig.parameters.values())
357+
if len(actual_params) != len(expected_params):
358+
raise ValueError(
359+
f"decode() expected function with {len(expected_params)} parameters, "
360+
f"got {len(actual_params)}"
361+
)
362+
363+
actual_return = sig.return_annotation
364+
if actual_return is inspect.Parameter.empty:
365+
raise ValueError(
366+
"decode() requires synthesized function to have a return type annotation"
367+
)
368+
369+
expected_name = getattr(expected_return, "__name__", str(expected_return))
370+
actual_name = getattr(actual_return, "__name__", str(actual_return))
371+
if expected_name != actual_name:
372+
raise ValueError(
373+
f"decode() expected function with return type {expected_name}, "
374+
f"got {actual_name}"
375+
)
376+
377+
378+
@dataclass
379+
class CallableEncodable(Encodable[Callable, SynthesizedFunction]):
380+
base: type[Callable]
381+
enc: type[SynthesizedFunction]
382+
ctx: Mapping[str, Any]
383+
expected_params: list[type] | None = None
384+
expected_return: type | None = None # None means decode is disabled
385+
386+
def encode(self, t: Callable) -> SynthesizedFunction:
387+
# (https://github.com/python/mypy/issues/14928)
388+
if not isinstance(t, Callable): # type: ignore
389+
raise TypeError(f"Expected callable, got {type(t)}")
390+
391+
try:
392+
source = inspect.getsource(t)
393+
except (OSError, TypeError):
394+
source = None
395+
396+
if source:
397+
return self.enc(module_code=textwrap.dedent(source))
398+
399+
# Source not available - create stub from name, signature, and docstring
400+
# This is useful for builtins and C extensions
401+
name = getattr(t, "__name__", None)
402+
if not name:
403+
raise RuntimeError(
404+
f"Cannot encode callable {t}: no source code and no __name__"
405+
)
406+
407+
try:
408+
sig = inspect.signature(t)
409+
sig_str = str(sig)
410+
except (ValueError, TypeError):
411+
# Some builtins don't have inspectable signatures
412+
sig_str = "(...)"
413+
414+
docstring = inspect.getdoc(t)
415+
if not docstring:
416+
raise RuntimeError(
417+
f"Cannot encode callable {t}: no source code and no docstring"
418+
)
419+
420+
# Format as a stub function with docstring
421+
stub_code = f'''def {name}{sig_str}:
422+
"""{docstring}"""
423+
...
424+
'''
425+
return self.enc(module_code=stub_code)
426+
427+
def decode(self, encoded_value: SynthesizedFunction) -> Callable:
428+
# Decode requires a concrete return type for synthesis
429+
if self.expected_return is None:
430+
raise TypeError(
431+
"Cannot decode/synthesize callable without a concrete type signature. "
432+
"Use Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType] "
433+
"with a concrete return type (not Any)."
434+
)
435+
436+
filename = f"<synthesis:{id(self)}>"
437+
438+
module_code = encoded_value.module_code
439+
440+
# Parse and validate AST before execution
441+
module: ast.AST = evaluation.parse(module_code, filename)
442+
443+
if not isinstance(module, ast.Module) or not module.body:
444+
raise ValueError(
445+
"decode() requires module code with at least one statement."
446+
)
447+
448+
last_stmt = module.body[-1]
449+
if not isinstance(last_stmt, ast.FunctionDef):
450+
raise ValueError(
451+
f"decode() requires the last statement to be a function definition, "
452+
f"got {type(last_stmt).__name__}"
453+
)
454+
455+
# Validate signature from AST before execution
456+
_validate_signature_ast(last_stmt, self.expected_params)
457+
458+
# Compile and execute
459+
# https://docs.python.org/3/library/functions.html#exec
460+
g: MutableMapping[str, Any] = {}
461+
g.update(self.ctx or {})
462+
463+
bytecode: CodeType = evaluation.compile(module, filename)
464+
evaluation.exec(bytecode, g)
465+
466+
func_name = last_stmt.name
467+
if func_name not in g:
468+
raise ValueError(
469+
f"decode() expected function '{func_name}' to be defined in globals"
470+
)
471+
472+
result = g[func_name]
473+
if not callable(result):
474+
raise ValueError(
475+
f"decode() expected '{func_name}' to be callable, got {type(result)}"
476+
)
477+
478+
# Validate signature from runtime callable after execution
479+
_validate_signature_callable(result, self.expected_params, self.expected_return)
480+
481+
return result
482+
483+
def serialize(
484+
self, encoded_value: SynthesizedFunction
485+
) -> Sequence[OpenAIMessageContentListBlock]:
486+
return [{"type": "text", "text": encoded_value.model_dump_json()}]
487+
488+
def deserialize(self, serialized_value: str) -> SynthesizedFunction:
489+
return SynthesizedFunction.model_validate_json(serialized_value)
490+
491+
256492
@Encodable.define.register(object)
257493
def _encodable_object[T, U](
258494
ty: type[T], ctx: Mapping[str, Any] | None
@@ -355,3 +591,36 @@ def _encodable_list[T, U](
355591
return typing.cast(
356592
Encodable[T, U], ListEncodable(ty, encoded_ty, ctx, has_image, element_encoder)
357593
)
594+
595+
596+
@Encodable.define.register(Callable)
597+
def _encodable_callable(
598+
ty: type[Callable], ctx: Mapping[str, Any] | None
599+
) -> Encodable[Callable, SynthesizedFunction]:
600+
ctx = ctx or {}
601+
602+
type_args = typing.get_args(ty)
603+
604+
# Bare Callable without type args - allow encoding but disable decode
605+
# this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type)
606+
if not type_args:
607+
assert ty is types.FunctionType, f"Callable must have type signatures {ty}"
608+
typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type]
609+
return CallableEncodable(ty, typed_enc, ctx)
610+
611+
if len(type_args) < 2:
612+
raise TypeError(
613+
f"Callable type signature incomplete: {ty}. "
614+
"Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]."
615+
)
616+
617+
param_types, expected_return = type_args[0], type_args[-1]
618+
619+
typed_enc = _create_typed_synthesized_function(ty)
620+
621+
# Ellipsis means any params, skip param validation
622+
expected_params: list[type] | None = None
623+
if param_types is not ... and isinstance(param_types, (list, tuple)):
624+
expected_params = list(param_types)
625+
626+
return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import ast
2+
import builtins
3+
import linecache
4+
import typing
5+
from types import CodeType
6+
from typing import Any
7+
8+
from effectful.ops.syntax import ObjectInterpretation, defop, implements
9+
10+
11+
@defop
12+
def parse(source: str, filename: str) -> ast.Module:
13+
"""
14+
Parse source text into an AST.
15+
16+
source: The Python source code to parse.
17+
filename: The filename recorded in the resulting AST for tracebacks and tooling.
18+
19+
Returns the parsed AST.
20+
"""
21+
raise NotImplementedError(
22+
"An eval provider must be installed in order to parse code."
23+
)
24+
25+
26+
@defop
27+
def compile(module: ast.Module, filename: str) -> CodeType:
28+
"""
29+
Compile an AST into a Python code object.
30+
31+
module: The AST to compile (typically produced by parse()).
32+
filename: The filename recorded in the resulting code object (CodeType.co_filename), used in tracebacks and by inspect.getsource().
33+
34+
Returns the compiled code object.
35+
"""
36+
raise NotImplementedError(
37+
"An eval provider must be installed in order to compile code."
38+
)
39+
40+
41+
@defop
42+
def exec(
43+
bytecode: CodeType,
44+
env: dict[str, Any],
45+
) -> None:
46+
"""
47+
Execute a compiled code object.
48+
49+
bytecode: A code object to execute (typically produced by compile()).
50+
env: The namespace mapping used during execution.
51+
"""
52+
raise NotImplementedError(
53+
"An eval provider must be installed in order to execute code."
54+
)
55+
56+
57+
class UnsafeEvalProvider(ObjectInterpretation):
58+
"""UNSAFE provider that handles parse, comple and exec operations
59+
by shelling out to python *without* any further checks. Only use for testing."""
60+
61+
@implements(parse)
62+
def parse(self, source: str, filename: str) -> ast.Module:
63+
# Cache source under `filename` so inspect.getsource() can retrieve it later.
64+
# inspect uses f.__code__.co_filename -> linecache.getlines(filename)
65+
linecache.cache[filename] = (
66+
len(source),
67+
None,
68+
source.splitlines(True),
69+
filename,
70+
)
71+
72+
return ast.parse(source, filename=filename, mode="exec")
73+
74+
@implements(compile)
75+
def compile(self, module: ast.AST, filename: str) -> CodeType:
76+
return builtins.compile(typing.cast(typing.Any, module), filename, "exec")
77+
78+
@implements(exec)
79+
def exec(
80+
self,
81+
bytecode: CodeType,
82+
env: dict[str, Any],
83+
) -> None:
84+
# Ensure builtins exist in the execution environment.
85+
env.setdefault("__builtins__", __builtins__)
86+
87+
# Execute module-style so top-level defs land in `env`.
88+
builtins.exec(bytecode, env, env)

0 commit comments

Comments
 (0)