Skip to content

Commit 88da657

Browse files
committed
added default encodable instance for Callable
1 parent 6bb2b13 commit 88da657

3 files changed

Lines changed: 272 additions & 1 deletion

File tree

effectful/handlers/llm/encoding.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import ast
12
import base64
3+
import inspect
24
import io
5+
import textwrap
36
import typing
47
from abc import ABC, abstractmethod
5-
from collections.abc import Callable, Mapping, Sequence
8+
from collections.abc import Callable, Mapping, MutableMapping, Sequence
69
from dataclasses import dataclass
10+
from types import CodeType
711
from typing import Any
812

913
import pydantic
@@ -13,6 +17,7 @@
1317
)
1418
from PIL import Image
1519

20+
import effectful.handlers.llm.evaluation as evaluation
1621
from effectful.ops.semantics import _simple_type
1722
from effectful.ops.syntax import _CustomSingleDispatchCallable
1823
from effectful.ops.types import Operation, Term
@@ -253,6 +258,60 @@ def deserialize(self, serialized_value: str) -> typing.Any:
253258
return typing.cast(typing.Any, adapter.validate_json(serialized_value))
254259

255260

261+
@dataclass
262+
class CallableEncodable(Encodable[Callable, str]):
263+
base: type[Callable]
264+
enc: type[str]
265+
ctx: Mapping[str, Any]
266+
267+
def encode(self, t: Callable) -> typing.Any:
268+
# (https://github.com/python/mypy/issues/14928)
269+
if not isinstance(t, Callable): # type: ignore
270+
raise TypeError(f"Expected callable, got {type(t)}")
271+
try:
272+
source = inspect.getsource(t)
273+
except (OSError, TypeError):
274+
source = None
275+
276+
if not source:
277+
# create source stub using signature and docstring of callable (useful for builtins etc.)
278+
pass
279+
280+
assert source, "Could not retrieve source code or docstring for function"
281+
282+
return textwrap.dedent(source)
283+
284+
def decode(self, encoded_value: str) -> Callable:
285+
filename = f"<synthesis:{id(self)}>"
286+
287+
# https://docs.python.org/3/library/functions.html#exec
288+
g: MutableMapping[str, Any] = {}
289+
g.update(self.ctx or {})
290+
291+
before_keys = set(g.keys())
292+
293+
module: ast.AST = evaluation.parse(encoded_value, filename)
294+
bytecode: CodeType = evaluation.compile(module, filename)
295+
evaluation.exec(bytecode, g)
296+
297+
# Otherwise: find newly-created callables (in insertion order).
298+
new_callables = [
299+
v for k, v in g.items() if k not in before_keys and callable(v)
300+
]
301+
if not new_callables or len(new_callables) > 1:
302+
raise ValueError(
303+
"decode() required source code to define exactly one callable."
304+
)
305+
306+
return new_callables[0]
307+
308+
@Operation.define
309+
@classmethod
310+
def encoding_instructions(cls) -> str | None:
311+
"""Instructions to be prefixed onto synthesis prompts to tune the encoding of the result."""
312+
return None
313+
314+
256315
@Encodable.define.register(object)
257316
def _encodable_object[T, U](
258317
ty: type[T], ctx: Mapping[str, Any] | None
@@ -355,3 +414,8 @@ def _encodable_list[T, U](
355414
return typing.cast(
356415
Encodable[T, U], ListEncodable(ty, encoded_ty, ctx, has_image, element_encoder)
357416
)
417+
418+
419+
@Encodable.define.register(Callable)
420+
def _encodable_callable(*args, **kwargs):
421+
raise NotImplementedError
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import ast
2+
import builtins
3+
import linecache
4+
import typing
5+
from collections.abc import MutableMapping
6+
from types import CodeType
7+
from typing import Any
8+
9+
from effectful.ops.syntax import ObjectInterpretation, defop, implements
10+
11+
12+
@defop
13+
def parse(source: str, filename: str) -> ast.AST:
14+
"""
15+
Parse source text into an AST.
16+
17+
source: The Python source code to parse.
18+
filename: The filename recorded in the resulting AST for tracebacks and tooling.
19+
20+
Returns the parsed AST.
21+
"""
22+
raise TypeError("An eval provider must be installed in order to parse code.")
23+
24+
25+
@defop
26+
def compile(module: ast.AST, filename: str) -> CodeType:
27+
"""
28+
Compile an AST into a Python code object.
29+
30+
module: The AST to compile (typically produced by parse()).
31+
filename: The filename recorded in the resulting code object (CodeType.co_filename), used in tracebacks and by inspect.getsource().
32+
33+
Returns the compiled code object.
34+
"""
35+
raise TypeError("An eval provider must be installed in order to compile code.")
36+
37+
38+
@defop
39+
def exec(
40+
bytecode: CodeType,
41+
env: MutableMapping[str, Any],
42+
) -> None:
43+
"""
44+
Execute a compiled code object.
45+
46+
bytecode: A code object to execute (typically produced by compile()).
47+
env: The namespace mapping used during execution.
48+
"""
49+
raise TypeError("An eval provider must be installed in order to execute code.")
50+
51+
52+
class UnsafeEvalProvider(ObjectInterpretation):
53+
"""UNSAFE provider that handles parse, comple and exec operations
54+
by shelling out to python *without* any further checks. Only use for testing."""
55+
56+
@implements(parse)
57+
def parse(self, source: str, filename: str) -> ast.AST:
58+
# Cache source under `filename` so inspect.getsource() can retrieve it later.
59+
# inspect uses f.__code__.co_filename -> linecache.getlines(filename)
60+
linecache.cache[filename] = (
61+
len(source),
62+
None,
63+
source.splitlines(True),
64+
filename,
65+
)
66+
67+
return ast.parse(source, filename=filename, mode="exec")
68+
69+
@implements(compile)
70+
def compile(self, module: ast.AST, filename: str) -> CodeType:
71+
return builtins.compile(typing.cast(typing.Any, module), filename, "exec")
72+
73+
@implements(exec)
74+
def exec(
75+
self,
76+
bytecode: CodeType,
77+
env: dict[str, Any],
78+
) -> None:
79+
# Ensure builtins exist in the execution environment.
80+
env.setdefault("__builtins__", __builtins__)
81+
82+
# Execute module-style so top-level defs land in `env`.
83+
builtins.exec(bytecode, env, env)

tests/test_handlers_llm_encoding.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from PIL import Image
77

88
from effectful.handlers.llm.encoding import Encodable
9+
from effectful.handlers.llm.evaluation import UnsafeEvalProvider
10+
from effectful.ops.semantics import handler
911
from effectful.ops.types import Operation, Term
1012

1113

@@ -718,3 +720,125 @@ class Person(pydantic.BaseModel):
718720
assert decoded_from_model == person
719721
assert isinstance(decoded_from_model, Person)
720722
assert isinstance(decoded_from_model.address, Address)
723+
724+
725+
class TestCallableEncodable:
726+
"""Tests for CallableEncodable - encoding/decoding callables as source code."""
727+
728+
def test_encode_decode_function(self):
729+
from collections.abc import Callable
730+
731+
def add(a: int, b: int) -> int:
732+
return a + b
733+
734+
encodable = Encodable.define(Callable, {})
735+
encoded = encodable.encode(add)
736+
assert isinstance(encoded, str)
737+
assert "def add" in encoded
738+
assert "return a + b" in encoded
739+
740+
with handler(UnsafeEvalProvider()):
741+
decoded = encodable.decode(encoded)
742+
assert callable(decoded)
743+
assert decoded(2, 3) == 5
744+
assert decoded.__name__ == "add"
745+
746+
def test_decode_lambda(self):
747+
from collections.abc import Callable
748+
749+
# Lambdas should work if defined in a way that inspect.getsource can find them
750+
# Note: lambdas defined inline may not always have retrievable source
751+
encodable = Encodable.define(Callable, {})
752+
753+
# Test decoding a lambda from source string
754+
lambda_source = "f = lambda x: x * 2"
755+
with handler(UnsafeEvalProvider()):
756+
decoded = encodable.decode(lambda_source)
757+
assert callable(decoded)
758+
assert decoded(5) == 10
759+
760+
def test_decode_with_env(self):
761+
from collections.abc import Callable
762+
763+
# Test decoding a function that uses env variables
764+
encodable = Encodable.define(Callable, {"factor": 3})
765+
source = """def multiply(x):
766+
return x * factor"""
767+
768+
with handler(UnsafeEvalProvider()):
769+
decoded = encodable.decode(source)
770+
assert callable(decoded)
771+
assert decoded(4) == 12
772+
773+
def test_encode_non_callable_raises(self):
774+
from collections.abc import Callable
775+
776+
encodable = Encodable.define(Callable, {})
777+
with pytest.raises(TypeError, match="Expected callable"):
778+
encodable.encode("not a callable", {})
779+
780+
def test_encode_builtin_raises(self):
781+
from collections.abc import Callable
782+
783+
encodable = Encodable.define(Callable, {})
784+
# Built-in functions don't have source code
785+
with pytest.raises(RuntimeError, match="Source code of callable .* not found"):
786+
with handler(UnsafeEvalProvider()):
787+
encodable.encode(len)
788+
789+
def test_decode_no_callable_raises(self):
790+
from collections.abc import Callable
791+
792+
encodable = Encodable.define(Callable, {})
793+
# Source code that defines no callable
794+
source = "x = 42"
795+
with pytest.raises(ValueError, match="exactly one callable"):
796+
with handler(UnsafeEvalProvider()):
797+
encodable.decode(source)
798+
799+
def test_decode_multiple_callables_raises(self):
800+
from collections.abc import Callable
801+
802+
encodable = Encodable.define(Callable, {})
803+
# Source code that defines multiple callables
804+
source = """def foo():
805+
return 1
806+
807+
def bar():
808+
return 2"""
809+
with pytest.raises(ValueError, match="exactly one callable"):
810+
with handler(UnsafeEvalProvider()):
811+
encodable.decode(source)
812+
813+
def test_decode_class(self):
814+
from collections.abc import Callable
815+
816+
encodable = Encodable.define(Callable, {})
817+
# Classes are callable, decode should work with class definitions
818+
source = """class Greeter:
819+
def __init__(self, name):
820+
self.name = name
821+
822+
def greet(self):
823+
return f"Hello, {self.name}!\""""
824+
825+
with handler(UnsafeEvalProvider()):
826+
decoded = encodable.decode(source)
827+
assert callable(decoded)
828+
instance = decoded("World")
829+
assert instance.greet() == "Hello, World!"
830+
831+
def test_roundtrip(self):
832+
from collections.abc import Callable
833+
834+
def greet(name: str) -> str:
835+
return f"Hello, {name}!"
836+
837+
encodable = Encodable.define(Callable, {})
838+
with handler(UnsafeEvalProvider()):
839+
encoded = encodable.encode(greet)
840+
decoded = encodable.decode(encoded)
841+
842+
assert callable(decoded)
843+
assert decoded("Alice") == "Hello, Alice!"
844+
assert decoded.__name__ == "greet"

0 commit comments

Comments
 (0)