|
| 1 | +import ast |
1 | 2 | import base64 |
| 3 | +import inspect |
2 | 4 | import io |
| 5 | +import textwrap |
| 6 | +import types |
3 | 7 | import typing |
4 | 8 | from abc import ABC, abstractmethod |
5 | | -from collections.abc import Callable, Mapping, Sequence |
| 9 | +from collections.abc import Callable, Mapping, MutableMapping, Sequence |
6 | 10 | from dataclasses import dataclass |
| 11 | +from types import CodeType |
7 | 12 | from typing import Any |
8 | 13 |
|
9 | 14 | import pydantic |
|
13 | 18 | ) |
14 | 19 | from PIL import Image |
15 | 20 |
|
| 21 | +import effectful.handlers.llm.evaluation as evaluation |
16 | 22 | from effectful.ops.semantics import _simple_type |
17 | 23 | from effectful.ops.syntax import _CustomSingleDispatchCallable |
18 | 24 | from effectful.ops.types import Operation, Term |
@@ -253,6 +259,236 @@ def deserialize(self, serialized_value: str) -> typing.Any: |
253 | 259 | return typing.cast(typing.Any, adapter.validate_json(serialized_value)) |
254 | 260 |
|
255 | 261 |
|
| 262 | +def _format_callable_type(callable_type: type[Callable]) -> str: |
| 263 | + """Format a Callable type annotation as a string for LLM instructions.""" |
| 264 | + args = typing.get_args(callable_type) |
| 265 | + if not args: |
| 266 | + return "Callable" |
| 267 | + |
| 268 | + # Callable[[arg1, arg2, ...], return_type] |
| 269 | + if len(args) >= 2: |
| 270 | + param_types = args[0] |
| 271 | + return_type = args[-1] |
| 272 | + |
| 273 | + if param_types is ...: |
| 274 | + params_str = "..." |
| 275 | + elif isinstance(param_types, (list, tuple)): |
| 276 | + params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types) |
| 277 | + else: |
| 278 | + params_str = str(param_types) |
| 279 | + |
| 280 | + return_str = getattr(return_type, "__name__", str(return_type)) |
| 281 | + return f"Callable[[{params_str}], {return_str}]" |
| 282 | + |
| 283 | + return str(callable_type) |
| 284 | + |
| 285 | + |
| 286 | +class SynthesizedFunction(pydantic.BaseModel): |
| 287 | + """Structured output for function synthesis. |
| 288 | +
|
| 289 | + Pydantic model representing synthesized code with function name and module code. |
| 290 | + """ |
| 291 | + |
| 292 | + module_code: str = pydantic.Field( |
| 293 | + ..., |
| 294 | + description="Complete Python module code (no imports needed)", |
| 295 | + ) |
| 296 | + |
| 297 | + |
| 298 | +def _create_typed_synthesized_function( |
| 299 | + callable_type: type[Callable], |
| 300 | +) -> type[SynthesizedFunction]: |
| 301 | + """Create a SynthesizedFunction subclass with type signature in the model description. |
| 302 | +
|
| 303 | + Uses pydantic.create_model to ensure the description is included in the JSON schema |
| 304 | + sent to the LLM, informing it of the expected function signature. |
| 305 | + """ |
| 306 | + type_signature = _format_callable_type(callable_type) |
| 307 | + |
| 308 | + description = f"""Given the specification above, generate a Python function satisfying the following specification and type signature. |
| 309 | +
|
| 310 | +<signature>{type_signature}</signature> |
| 311 | +
|
| 312 | +<instructions> |
| 313 | +1. Produce one block of Python code. |
| 314 | +2. The function MUST have type annotations for all parameters and the return type. |
| 315 | +3. The function definition must be the LAST statement - do not add any code after it. |
| 316 | +4. Do not include usage examples or function calls. |
| 317 | +</instructions> |
| 318 | +""" |
| 319 | + |
| 320 | + # Use pydantic.create_model to create a proper model with the description |
| 321 | + # The __doc__ becomes the model's description in the JSON schema |
| 322 | + model = pydantic.create_model( |
| 323 | + "TypedSynthesizedFunction", |
| 324 | + __base__=SynthesizedFunction, |
| 325 | + __doc__=description, |
| 326 | + ) |
| 327 | + return model |
| 328 | + |
| 329 | + |
| 330 | +def _validate_signature_ast( |
| 331 | + func_ast: ast.FunctionDef | ast.AsyncFunctionDef, |
| 332 | + expected_params: list[type] | None, |
| 333 | +) -> None: |
| 334 | + """Validate the function signature from AST before execution.""" |
| 335 | + if expected_params is not None: |
| 336 | + ast_params = func_ast.args.args + func_ast.args.posonlyargs |
| 337 | + if len(ast_params) != len(expected_params): |
| 338 | + raise ValueError( |
| 339 | + f"decode() expected function with {len(expected_params)} parameters, " |
| 340 | + f"got {len(ast_params)}" |
| 341 | + ) |
| 342 | + |
| 343 | + |
| 344 | +def _validate_signature_callable( |
| 345 | + func: Callable, |
| 346 | + expected_params: list[type] | None, |
| 347 | + expected_return: type, |
| 348 | +) -> None: |
| 349 | + """Validate the function signature from runtime callable after execution. |
| 350 | +
|
| 351 | + The synthesized function must have type annotations for parameters and return type. |
| 352 | + """ |
| 353 | + sig = inspect.signature(func) |
| 354 | + |
| 355 | + if expected_params is not None: |
| 356 | + actual_params = list(sig.parameters.values()) |
| 357 | + if len(actual_params) != len(expected_params): |
| 358 | + raise ValueError( |
| 359 | + f"decode() expected function with {len(expected_params)} parameters, " |
| 360 | + f"got {len(actual_params)}" |
| 361 | + ) |
| 362 | + |
| 363 | + actual_return = sig.return_annotation |
| 364 | + if actual_return is inspect.Parameter.empty: |
| 365 | + raise ValueError( |
| 366 | + "decode() requires synthesized function to have a return type annotation" |
| 367 | + ) |
| 368 | + |
| 369 | + expected_name = getattr(expected_return, "__name__", str(expected_return)) |
| 370 | + actual_name = getattr(actual_return, "__name__", str(actual_return)) |
| 371 | + if expected_name != actual_name: |
| 372 | + raise ValueError( |
| 373 | + f"decode() expected function with return type {expected_name}, " |
| 374 | + f"got {actual_name}" |
| 375 | + ) |
| 376 | + |
| 377 | + |
| 378 | +@dataclass |
| 379 | +class CallableEncodable(Encodable[Callable, SynthesizedFunction]): |
| 380 | + base: type[Callable] |
| 381 | + enc: type[SynthesizedFunction] |
| 382 | + ctx: Mapping[str, Any] |
| 383 | + expected_params: list[type] | None = None |
| 384 | + expected_return: type | None = None # None means decode is disabled |
| 385 | + |
| 386 | + def encode(self, t: Callable) -> SynthesizedFunction: |
| 387 | + # (https://github.com/python/mypy/issues/14928) |
| 388 | + if not isinstance(t, Callable): # type: ignore |
| 389 | + raise TypeError(f"Expected callable, got {type(t)}") |
| 390 | + |
| 391 | + try: |
| 392 | + source = inspect.getsource(t) |
| 393 | + except (OSError, TypeError): |
| 394 | + source = None |
| 395 | + |
| 396 | + if source: |
| 397 | + return self.enc(module_code=textwrap.dedent(source)) |
| 398 | + |
| 399 | + # Source not available - create stub from name, signature, and docstring |
| 400 | + # This is useful for builtins and C extensions |
| 401 | + name = getattr(t, "__name__", None) |
| 402 | + if not name: |
| 403 | + raise RuntimeError( |
| 404 | + f"Cannot encode callable {t}: no source code and no __name__" |
| 405 | + ) |
| 406 | + |
| 407 | + try: |
| 408 | + sig = inspect.signature(t) |
| 409 | + sig_str = str(sig) |
| 410 | + except (ValueError, TypeError): |
| 411 | + # Some builtins don't have inspectable signatures |
| 412 | + sig_str = "(...)" |
| 413 | + |
| 414 | + docstring = inspect.getdoc(t) |
| 415 | + if not docstring: |
| 416 | + raise RuntimeError( |
| 417 | + f"Cannot encode callable {t}: no source code and no docstring" |
| 418 | + ) |
| 419 | + |
| 420 | + # Format as a stub function with docstring |
| 421 | + stub_code = f'''def {name}{sig_str}: |
| 422 | + """{docstring}""" |
| 423 | + ... |
| 424 | +''' |
| 425 | + return self.enc(module_code=stub_code) |
| 426 | + |
| 427 | + def decode(self, encoded_value: SynthesizedFunction) -> Callable: |
| 428 | + # Decode requires a concrete return type for synthesis |
| 429 | + if self.expected_return is None: |
| 430 | + raise TypeError( |
| 431 | + "Cannot decode/synthesize callable without a concrete type signature. " |
| 432 | + "Use Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType] " |
| 433 | + "with a concrete return type (not Any)." |
| 434 | + ) |
| 435 | + |
| 436 | + filename = f"<synthesis:{id(self)}>" |
| 437 | + |
| 438 | + module_code = encoded_value.module_code |
| 439 | + |
| 440 | + # Parse and validate AST before execution |
| 441 | + module: ast.AST = evaluation.parse(module_code, filename) |
| 442 | + |
| 443 | + if not isinstance(module, ast.Module) or not module.body: |
| 444 | + raise ValueError( |
| 445 | + "decode() requires module code with at least one statement." |
| 446 | + ) |
| 447 | + |
| 448 | + last_stmt = module.body[-1] |
| 449 | + if not isinstance(last_stmt, ast.FunctionDef): |
| 450 | + raise ValueError( |
| 451 | + f"decode() requires the last statement to be a function definition, " |
| 452 | + f"got {type(last_stmt).__name__}" |
| 453 | + ) |
| 454 | + |
| 455 | + # Validate signature from AST before execution |
| 456 | + _validate_signature_ast(last_stmt, self.expected_params) |
| 457 | + |
| 458 | + # Compile and execute |
| 459 | + # https://docs.python.org/3/library/functions.html#exec |
| 460 | + g: MutableMapping[str, Any] = {} |
| 461 | + g.update(self.ctx or {}) |
| 462 | + |
| 463 | + bytecode: CodeType = evaluation.compile(module, filename) |
| 464 | + evaluation.exec(bytecode, g) |
| 465 | + |
| 466 | + func_name = last_stmt.name |
| 467 | + if func_name not in g: |
| 468 | + raise ValueError( |
| 469 | + f"decode() expected function '{func_name}' to be defined in globals" |
| 470 | + ) |
| 471 | + |
| 472 | + result = g[func_name] |
| 473 | + if not callable(result): |
| 474 | + raise ValueError( |
| 475 | + f"decode() expected '{func_name}' to be callable, got {type(result)}" |
| 476 | + ) |
| 477 | + |
| 478 | + # Validate signature from runtime callable after execution |
| 479 | + _validate_signature_callable(result, self.expected_params, self.expected_return) |
| 480 | + |
| 481 | + return result |
| 482 | + |
| 483 | + def serialize( |
| 484 | + self, encoded_value: SynthesizedFunction |
| 485 | + ) -> Sequence[OpenAIMessageContentListBlock]: |
| 486 | + return [{"type": "text", "text": encoded_value.model_dump_json()}] |
| 487 | + |
| 488 | + def deserialize(self, serialized_value: str) -> SynthesizedFunction: |
| 489 | + return SynthesizedFunction.model_validate_json(serialized_value) |
| 490 | + |
| 491 | + |
256 | 492 | @Encodable.define.register(object) |
257 | 493 | def _encodable_object[T, U]( |
258 | 494 | ty: type[T], ctx: Mapping[str, Any] | None |
@@ -355,3 +591,36 @@ def _encodable_list[T, U]( |
355 | 591 | return typing.cast( |
356 | 592 | Encodable[T, U], ListEncodable(ty, encoded_ty, ctx, has_image, element_encoder) |
357 | 593 | ) |
| 594 | + |
| 595 | + |
| 596 | +@Encodable.define.register(Callable) |
| 597 | +def _encodable_callable( |
| 598 | + ty: type[Callable], ctx: Mapping[str, Any] | None |
| 599 | +) -> Encodable[Callable, SynthesizedFunction]: |
| 600 | + ctx = ctx or {} |
| 601 | + |
| 602 | + type_args = typing.get_args(ty) |
| 603 | + |
| 604 | + # Bare Callable without type args - allow encoding but disable decode |
| 605 | + # this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type) |
| 606 | + if not type_args: |
| 607 | + assert ty is types.FunctionType, f"Callable must have type signatures {ty}" |
| 608 | + typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type] |
| 609 | + return CallableEncodable(ty, typed_enc, ctx) |
| 610 | + |
| 611 | + if len(type_args) < 2: |
| 612 | + raise TypeError( |
| 613 | + f"Callable type signature incomplete: {ty}. " |
| 614 | + "Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]." |
| 615 | + ) |
| 616 | + |
| 617 | + param_types, expected_return = type_args[0], type_args[-1] |
| 618 | + |
| 619 | + typed_enc = _create_typed_synthesized_function(ty) |
| 620 | + |
| 621 | + # Ellipsis means any params, skip param validation |
| 622 | + expected_params: list[type] | None = None |
| 623 | + if param_types is not ... and isinstance(param_types, (list, tuple)): |
| 624 | + expected_params = list(param_types) |
| 625 | + |
| 626 | + return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return) |
0 commit comments