Skip to content

Commit dba2a76

Browse files
committed
added default encodable instance for Callable
1 parent e33873d commit dba2a76

3 files changed

Lines changed: 268 additions & 3 deletions

File tree

effectful/handlers/llm/encoding.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import ast
12
import base64
3+
import inspect
24
import io
5+
import textwrap
6+
import threading
7+
import time
38
import typing
49
from abc import ABC, abstractmethod
5-
from collections.abc import Callable, Mapping
10+
from collections.abc import Callable, Mapping, MutableMapping
11+
from types import CodeType
612
from typing import Any
713

814
import pydantic
@@ -12,7 +18,10 @@
1218
)
1319
from PIL import Image
1420

15-
from effectful.ops.syntax import _CustomSingleDispatchCallable
21+
import effectful.handlers.llm.evaluation as evaluation
22+
from effectful.ops.syntax import (
23+
_CustomSingleDispatchCallable,
24+
)
1625
from effectful.ops.types import Operation, Term
1726

1827

@@ -44,7 +53,7 @@ def decode(cls, vl: U, env: Mapping[str, Any] | None = None) -> T:
4453

4554
@classmethod
4655
def encoding_instructions(cls) -> str | None:
47-
"""Optional instructions to be prefixed onto synthesis prompts to tune the encoding of the result."""
56+
"""Optional instructions to be prefixed onto prompts to tune the encoding of the result."""
4857
return None
4958

5059
@classmethod
@@ -262,3 +271,53 @@ def serialize(cls, value: typing.Any) -> list[OpenAIMessageContentListBlock]:
262271
return super().serialize(value)
263272

264273
return typing.cast(Encodable[T], ListEncodable())
274+
275+
276+
@type_to_encodable_type.register(Callable)
277+
class CallableEncodable(EncodableAs[Callable, str]):
278+
t: type[typing.Any] = str
279+
280+
@classmethod
281+
def encode[T](cls, t: T, env: Mapping[str, Any] | None = None) -> typing.Any:
282+
if not callable(t):
283+
raise TypeError(f"Expected callable, got {type(t)}")
284+
try:
285+
source = inspect.getsource(t)
286+
except (OSError, TypeError):
287+
source = None
288+
289+
if not source:
290+
raise RuntimeError(f"Source code of callable {t} not found")
291+
292+
return textwrap.dedent(source)
293+
294+
@classmethod
295+
def decode(cls, t: str, env: Mapping[str, Any] | None = None) -> Callable:
296+
filename = f"<{cls.__name__}.decode:{int(time.time() * 1_000_000)}:{threading.get_ident()}>"
297+
298+
# https://docs.python.org/3/library/functions.html#exec
299+
g: MutableMapping[str, Any] = {}
300+
g.update(env or {})
301+
302+
before_keys = set(g.keys())
303+
304+
module: ast.AST = evaluation.parse(t, filename)
305+
bytecode: CodeType = evaluation.compile(module, filename)
306+
evaluation.exec(bytecode, g)
307+
308+
# Otherwise: find newly-created callables (in insertion order).
309+
new_callables = [
310+
v for k, v in g.items() if k not in before_keys and callable(v)
311+
]
312+
if not new_callables or len(new_callables) > 1:
313+
raise ValueError(
314+
"decode() required source code to define exactly one callable."
315+
)
316+
317+
return new_callables[0]
318+
319+
@Operation.define
320+
@classmethod
321+
def encoding_instructions(cls) -> str | None:
322+
"""Instructions to be prefixed onto synthesis prompts to tune the encoding of the result."""
323+
return None
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import ast
2+
import builtins
3+
import linecache
4+
from collections.abc import MutableMapping
5+
from types import CodeType
6+
from typing import Any
7+
8+
from effectful.ops.syntax import ObjectInterpretation, defop, implements
9+
10+
11+
@defop
12+
def parse(source: str, filename: str) -> ast.AST:
13+
"""
14+
Parse source text into an AST.
15+
16+
source: The Python source code to parse.
17+
filename: The filename recorded in the resulting AST for tracebacks and tooling.
18+
19+
Returns the parsed AST.
20+
"""
21+
raise TypeError("An eval provider must be installed in order to parse code.")
22+
23+
24+
@defop
25+
def compile(module: ast.AST, filename: str) -> CodeType:
26+
"""
27+
Compile an AST into a Python code object.
28+
29+
module: The AST to compile (typically produced by parse()).
30+
filename: The filename recorded in the resulting code object (CodeType.co_filename), used in tracebacks and by inspect.getsource().
31+
32+
Returns the compiled code object.
33+
"""
34+
raise TypeError("An eval provider must be installed in order to compile code.")
35+
36+
37+
@defop
38+
def exec(
39+
bytecode: CodeType,
40+
env: MutableMapping[str, Any],
41+
) -> None:
42+
"""
43+
Execute a compiled code object.
44+
45+
bytecode: A code object to execute (typically produced by compile()).
46+
env: The namespace mapping used during execution.
47+
"""
48+
raise TypeError("An eval provider must be installed in order to execute code.")
49+
50+
51+
class UnsafeEvalProvider(ObjectInterpretation):
52+
"""UNSAFE provider that handles parse, comple and exec operations
53+
by shelling out to python *without* any further checks. Only use for testing."""
54+
55+
@implements(parse)
56+
def parse(self, source: str, filename: str) -> ast.AST:
57+
# Cache source under `filename` so inspect.getsource() can retrieve it later.
58+
# inspect uses f.__code__.co_filename -> linecache.getlines(filename)
59+
linecache.cache[filename] = (
60+
len(source),
61+
None,
62+
source.splitlines(True),
63+
filename,
64+
)
65+
66+
return ast.parse(source, filename=filename, mode="exec")
67+
68+
@implements(compile)
69+
def compile(self, module: ast.AST, filename: str, *rags) -> CodeType:
70+
return builtins.compile(module, filename, "exec")
71+
72+
@implements(exec)
73+
def exec(
74+
self,
75+
bytecode: CodeType,
76+
env: MutableMapping[str, Any],
77+
) -> None:
78+
# Ensure builtins exist in the execution environment.
79+
env.setdefault("__builtins__", __builtins__)
80+
81+
# Execute module-style so top-level defs land in `env`.
82+
builtins.exec(bytecode, env, env)

tests/test_handlers_llm_encoding.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from PIL import Image
77

88
from effectful.handlers.llm.encoding import type_to_encodable_type
9+
from effectful.handlers.llm.evaluation import UnsafeEvalProvider
10+
from effectful.ops.semantics import handler
911
from effectful.ops.types import Operation, Term
1012

1113

@@ -718,3 +720,125 @@ class Person(pydantic.BaseModel):
718720
assert decoded_from_model == person
719721
assert isinstance(decoded_from_model, Person)
720722
assert isinstance(decoded_from_model.address, Address)
723+
724+
725+
class TestCallableEncodable:
726+
"""Tests for CallableEncodable - encoding/decoding callables as source code."""
727+
728+
def test_encode_decode_function(self):
729+
from collections.abc import Callable
730+
731+
def add(a: int, b: int) -> int:
732+
return a + b
733+
734+
encodable = type_to_encodable_type(Callable)
735+
encoded = encodable.encode(add, {})
736+
assert isinstance(encoded, str)
737+
assert "def add" in encoded
738+
assert "return a + b" in encoded
739+
740+
with handler(UnsafeEvalProvider()):
741+
decoded = encodable.decode(encoded, {})
742+
assert callable(decoded)
743+
assert decoded(2, 3) == 5
744+
assert decoded.__name__ == "add"
745+
746+
def test_decode_lambda(self):
747+
from collections.abc import Callable
748+
749+
# Lambdas should work if defined in a way that inspect.getsource can find them
750+
# Note: lambdas defined inline may not always have retrievable source
751+
encodable = type_to_encodable_type(Callable)
752+
753+
# Test decoding a lambda from source string
754+
lambda_source = "f = lambda x: x * 2"
755+
with handler(UnsafeEvalProvider()):
756+
decoded = encodable.decode(lambda_source, {})
757+
assert callable(decoded)
758+
assert decoded(5) == 10
759+
760+
def test_decode_with_env(self):
761+
from collections.abc import Callable
762+
763+
# Test decoding a function that uses env variables
764+
encodable = type_to_encodable_type(Callable)
765+
source = """def multiply(x):
766+
return x * factor"""
767+
768+
with handler(UnsafeEvalProvider()):
769+
decoded = encodable.decode(source, {"factor": 3})
770+
assert callable(decoded)
771+
assert decoded(4) == 12
772+
773+
def test_encode_non_callable_raises(self):
774+
from collections.abc import Callable
775+
776+
encodable = type_to_encodable_type(Callable)
777+
with pytest.raises(TypeError, match="Expected callable"):
778+
encodable.encode("not a callable", {})
779+
780+
def test_encode_builtin_raises(self):
781+
from collections.abc import Callable
782+
783+
encodable = type_to_encodable_type(Callable)
784+
# Built-in functions don't have source code
785+
with pytest.raises(RuntimeError, match="Source code of callable .* not found"):
786+
with handler(UnsafeEvalProvider()):
787+
encodable.encode(len, {})
788+
789+
def test_decode_no_callable_raises(self):
790+
from collections.abc import Callable
791+
792+
encodable = type_to_encodable_type(Callable)
793+
# Source code that defines no callable
794+
source = "x = 42"
795+
with pytest.raises(ValueError, match="exactly one callable"):
796+
with handler(UnsafeEvalProvider()):
797+
encodable.decode(source, {})
798+
799+
def test_decode_multiple_callables_raises(self):
800+
from collections.abc import Callable
801+
802+
encodable = type_to_encodable_type(Callable)
803+
# Source code that defines multiple callables
804+
source = """def foo():
805+
return 1
806+
807+
def bar():
808+
return 2"""
809+
with pytest.raises(ValueError, match="exactly one callable"):
810+
with handler(UnsafeEvalProvider()):
811+
encodable.decode(source, {})
812+
813+
def test_decode_class(self):
814+
from collections.abc import Callable
815+
816+
encodable = type_to_encodable_type(Callable)
817+
# Classes are callable, decode should work with class definitions
818+
source = """class Greeter:
819+
def __init__(self, name):
820+
self.name = name
821+
822+
def greet(self):
823+
return f"Hello, {self.name}!\""""
824+
825+
with handler(UnsafeEvalProvider()):
826+
decoded = encodable.decode(source, {})
827+
assert callable(decoded)
828+
instance = decoded("World")
829+
assert instance.greet() == "Hello, World!"
830+
831+
def test_roundtrip(self):
832+
from collections.abc import Callable
833+
834+
def greet(name: str) -> str:
835+
return f"Hello, {name}!"
836+
837+
encodable = type_to_encodable_type(Callable)
838+
with handler(UnsafeEvalProvider()):
839+
encoded = encodable.encode(greet, {})
840+
decoded = encodable.decode(encoded, {})
841+
842+
assert callable(decoded)
843+
assert decoded("Alice") == "Hello, Alice!"
844+
assert decoded.__name__ == "greet"

0 commit comments

Comments
 (0)