Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c192c78
Fresh diff
kiranandcode Jan 30, 2026
ea3d25c
remove instructionhandler
eb8680 Jan 16, 2026
f83d312
updated internal interface to make all tests pass
kiranandcode Jan 28, 2026
d8d52e7
fixed caching tests
kiranandcode Jan 28, 2026
a322a35
updated llm.ipynb
kiranandcode Jan 28, 2026
41beb78
removed unnecessarily defensive validation
kiranandcode Jan 29, 2026
1400d19
updated tool call decoding to use concrete type of tool result instea…
kiranandcode Jan 29, 2026
a06296d
updated completions to fix basic type errors
kiranandcode Jan 30, 2026
c47abd6
updated call assistant to handle decoding tool calls
kiranandcode Jan 30, 2026
43e5b78
dropped stale comments
kiranandcode Jan 30, 2026
6bb2b13
moved model and param model back to internals of `completions`
kiranandcode Jan 30, 2026
88da657
added default encodable instance for Callable
kiranandcode Jan 30, 2026
88c65ee
fixed type errors
kiranandcode Jan 30, 2026
18da11b
update to use more structured type for synthesis
kiranandcode Jan 31, 2026
52df3f6
updated callable encoding tests
kiranandcode Jan 31, 2026
4b10ac3
s/TypeError/NotImplementedError
kiranandcode Jan 31, 2026
2a8af94
Merge branch 'master' into kg-encodable-default
kiranandcode Jan 31, 2026
2b4449a
simplified smart constructor
kiranandcode Jan 31, 2026
553450f
bare callables not allowed
kiranandcode Jan 31, 2026
aac0eb9
droped synthesis and removed encoding_instructions
kiranandcode Jan 31, 2026
5b3a559
fixed imports
kiranandcode Jan 31, 2026
711b27d
fixed imports and tests
kiranandcode Jan 31, 2026
3f1aa65
added restricted python again
kiranandcode Jan 31, 2026
fd4041e
added test for custom policies for restricted python
kiranandcode Jan 31, 2026
fbe478f
more specific arguments to RestrictedEvalProvider and made exec more …
kiranandcode Jan 31, 2026
6aa4d1b
reverted flags for customizing rglobals, using same rglobals for loca…
kiranandcode Jan 31, 2026
4f0c963
fixed failing tests (env was not being mutated)
kiranandcode Jan 31, 2026
86d8b41
Merge branch 'master' into kg-restricted-eval
kiranandcode Jan 31, 2026
1998d38
updated restricted python to be a llm dependency
kiranandcode Jan 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 270 additions & 1 deletion effectful/handlers/llm/encoding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import ast
import base64
import inspect
import io
import textwrap
import types
import typing
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from dataclasses import dataclass
from types import CodeType
from typing import Any

import pydantic
Expand All @@ -13,6 +18,7 @@
)
from PIL import Image

import effectful.handlers.llm.evaluation as evaluation
from effectful.ops.semantics import _simple_type
from effectful.ops.syntax import _CustomSingleDispatchCallable
from effectful.ops.types import Operation, Term
Expand Down Expand Up @@ -253,6 +259,236 @@ def deserialize(self, serialized_value: str) -> typing.Any:
return typing.cast(typing.Any, adapter.validate_json(serialized_value))


def _format_callable_type(callable_type: type[Callable]) -> str:
"""Format a Callable type annotation as a string for LLM instructions."""
args = typing.get_args(callable_type)
if not args:
return "Callable"

# Callable[[arg1, arg2, ...], return_type]
if len(args) >= 2:
param_types = args[0]
return_type = args[-1]

if param_types is ...:
params_str = "..."
elif isinstance(param_types, (list, tuple)):
params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types)
else:
params_str = str(param_types)

return_str = getattr(return_type, "__name__", str(return_type))
return f"Callable[[{params_str}], {return_str}]"

return str(callable_type)


class SynthesizedFunction(pydantic.BaseModel):
"""Structured output for function synthesis.

Pydantic model representing synthesized code with function name and module code.
"""

module_code: str = pydantic.Field(
...,
description="Complete Python module code (no imports needed)",
)


def _create_typed_synthesized_function(
callable_type: type[Callable],
) -> type[SynthesizedFunction]:
"""Create a SynthesizedFunction subclass with type signature in the model description.

Uses pydantic.create_model to ensure the description is included in the JSON schema
sent to the LLM, informing it of the expected function signature.
"""
type_signature = _format_callable_type(callable_type)

description = f"""Given the specification above, generate a Python function satisfying the following specification and type signature.

<signature>{type_signature}</signature>

<instructions>
1. Produce one block of Python code.
2. The function MUST have type annotations for all parameters and the return type.
3. The function definition must be the LAST statement - do not add any code after it.
4. Do not include usage examples or function calls.
</instructions>
"""

# Use pydantic.create_model to create a proper model with the description
# The __doc__ becomes the model's description in the JSON schema
model = pydantic.create_model(
"TypedSynthesizedFunction",
__base__=SynthesizedFunction,
__doc__=description,
)
return model


def _validate_signature_ast(
func_ast: ast.FunctionDef | ast.AsyncFunctionDef,
expected_params: list[type] | None,
) -> None:
"""Validate the function signature from AST before execution."""
if expected_params is not None:
ast_params = func_ast.args.args + func_ast.args.posonlyargs
if len(ast_params) != len(expected_params):
raise ValueError(
f"decode() expected function with {len(expected_params)} parameters, "
f"got {len(ast_params)}"
)


def _validate_signature_callable(
func: Callable,
expected_params: list[type] | None,
expected_return: type,
) -> None:
"""Validate the function signature from runtime callable after execution.

The synthesized function must have type annotations for parameters and return type.
"""
sig = inspect.signature(func)

if expected_params is not None:
actual_params = list(sig.parameters.values())
if len(actual_params) != len(expected_params):
raise ValueError(
f"decode() expected function with {len(expected_params)} parameters, "
f"got {len(actual_params)}"
)

actual_return = sig.return_annotation
if actual_return is inspect.Parameter.empty:
raise ValueError(
"decode() requires synthesized function to have a return type annotation"
)

expected_name = getattr(expected_return, "__name__", str(expected_return))
actual_name = getattr(actual_return, "__name__", str(actual_return))
if expected_name != actual_name:
raise ValueError(
f"decode() expected function with return type {expected_name}, "
f"got {actual_name}"
)


@dataclass
class CallableEncodable(Encodable[Callable, SynthesizedFunction]):
base: type[Callable]
enc: type[SynthesizedFunction]
ctx: Mapping[str, Any]
expected_params: list[type] | None = None
expected_return: type | None = None # None means decode is disabled

def encode(self, t: Callable) -> SynthesizedFunction:
# (https://github.com/python/mypy/issues/14928)
if not isinstance(t, Callable): # type: ignore
raise TypeError(f"Expected callable, got {type(t)}")

try:
source = inspect.getsource(t)
except (OSError, TypeError):
source = None

if source:
return self.enc(module_code=textwrap.dedent(source))

# Source not available - create stub from name, signature, and docstring
# This is useful for builtins and C extensions
name = getattr(t, "__name__", None)
if not name:
raise RuntimeError(
f"Cannot encode callable {t}: no source code and no __name__"
)

try:
sig = inspect.signature(t)
sig_str = str(sig)
except (ValueError, TypeError):
# Some builtins don't have inspectable signatures
sig_str = "(...)"

docstring = inspect.getdoc(t)
if not docstring:
raise RuntimeError(
f"Cannot encode callable {t}: no source code and no docstring"
)

# Format as a stub function with docstring
stub_code = f'''def {name}{sig_str}:
"""{docstring}"""
...
'''
return self.enc(module_code=stub_code)

def decode(self, encoded_value: SynthesizedFunction) -> Callable:
# Decode requires a concrete return type for synthesis
if self.expected_return is None:
raise TypeError(
"Cannot decode/synthesize callable without a concrete type signature. "
"Use Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType] "
"with a concrete return type (not Any)."
)

filename = f"<synthesis:{id(self)}>"

module_code = encoded_value.module_code

# Parse and validate AST before execution
module: ast.AST = evaluation.parse(module_code, filename)

if not isinstance(module, ast.Module) or not module.body:
raise ValueError(
"decode() requires module code with at least one statement."
)

last_stmt = module.body[-1]
if not isinstance(last_stmt, ast.FunctionDef):
raise ValueError(
f"decode() requires the last statement to be a function definition, "
f"got {type(last_stmt).__name__}"
)

# Validate signature from AST before execution
_validate_signature_ast(last_stmt, self.expected_params)

# Compile and execute
# https://docs.python.org/3/library/functions.html#exec
g: MutableMapping[str, Any] = {}
g.update(self.ctx or {})

bytecode: CodeType = evaluation.compile(module, filename)
evaluation.exec(bytecode, g)

func_name = last_stmt.name
if func_name not in g:
raise ValueError(
f"decode() expected function '{func_name}' to be defined in globals"
)

result = g[func_name]
if not callable(result):
raise ValueError(
f"decode() expected '{func_name}' to be callable, got {type(result)}"
)

# Validate signature from runtime callable after execution
_validate_signature_callable(result, self.expected_params, self.expected_return)

return result

def serialize(
self, encoded_value: SynthesizedFunction
) -> Sequence[OpenAIMessageContentListBlock]:
return [{"type": "text", "text": encoded_value.model_dump_json()}]

def deserialize(self, serialized_value: str) -> SynthesizedFunction:
return SynthesizedFunction.model_validate_json(serialized_value)


@Encodable.define.register(object)
def _encodable_object[T, U](
ty: type[T], ctx: Mapping[str, Any] | None
Expand Down Expand Up @@ -355,3 +591,36 @@ def _encodable_list[T, U](
return typing.cast(
Encodable[T, U], ListEncodable(ty, encoded_ty, ctx, has_image, element_encoder)
)


@Encodable.define.register(Callable)
def _encodable_callable(
ty: type[Callable], ctx: Mapping[str, Any] | None
) -> Encodable[Callable, SynthesizedFunction]:
ctx = ctx or {}

type_args = typing.get_args(ty)

# Bare Callable without type args - allow encoding but disable decode
# this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type)
if not type_args:
assert ty is types.FunctionType, f"Callable must have type signatures {ty}"
typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type]
return CallableEncodable(ty, typed_enc, ctx)

if len(type_args) < 2:
raise TypeError(
f"Callable type signature incomplete: {ty}. "
"Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]."
)

param_types, expected_return = type_args[0], type_args[-1]

typed_enc = _create_typed_synthesized_function(ty)

# Ellipsis means any params, skip param validation
expected_params: list[type] | None = None
if param_types is not ... and isinstance(param_types, (list, tuple)):
expected_params = list(param_types)

return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return)
Loading
Loading