|
3 | 3 | import inspect |
4 | 4 | import io |
5 | 5 | import textwrap |
| 6 | +import types |
6 | 7 | import typing |
7 | 8 | from abc import ABC, abstractmethod |
8 | 9 | from collections.abc import Callable, Mapping, MutableMapping, Sequence |
@@ -593,48 +594,28 @@ def _encodable_callable( |
593 | 594 | ) -> Encodable[Callable, SynthesizedFunction]: |
594 | 595 | ctx = ctx or {} |
595 | 596 |
|
596 | | - # Extract type args - Callable requires a type signature |
597 | 597 | type_args = typing.get_args(ty) |
598 | 598 |
|
599 | | - # Handle bare Callable without type args - allow encoding but disable decode |
600 | | - # this occurs when encoding Tools which return callable (need to Encodable.define(return_type) for return type) |
| 599 | + # Bare Callable without type args - allow encoding but disable decode |
| 600 | + # this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type) |
601 | 601 | if not type_args: |
| 602 | + assert ty is types.FunctionType, f"Callable must have type signatures {ty}" |
602 | 603 | typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type] |
603 | | - return CallableEncodable( |
604 | | - ty, typed_enc, ctx, expected_params=None, expected_return=None |
605 | | - ) |
| 604 | + return CallableEncodable(ty, typed_enc, ctx) |
606 | 605 |
|
607 | 606 | if len(type_args) < 2: |
608 | 607 | raise TypeError( |
609 | 608 | f"Callable type signature incomplete: {ty}. " |
610 | 609 | "Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]." |
611 | 610 | ) |
612 | 611 |
|
613 | | - # Extract param and return types for validation |
614 | | - param_types = type_args[0] |
615 | | - expected_return: type | None = type_args[-1] |
616 | | - |
617 | | - # Handle Any as return type - allow encoding but disable decode |
618 | | - # Any doesn't provide useful information for synthesis (expected_return=None) |
619 | | - if expected_return is typing.Any: |
620 | | - typed_enc = _create_typed_synthesized_function(ty) |
621 | | - return CallableEncodable( |
622 | | - ty, typed_enc, ctx, expected_params=None, expected_return=None |
623 | | - ) |
| 612 | + param_types, expected_return = type_args[0], type_args[-1] |
624 | 613 |
|
625 | | - # Create a typed SynthesizedFunction model with the type signature in the description |
626 | 614 | typed_enc = _create_typed_synthesized_function(ty) |
627 | 615 |
|
628 | | - # Handle Callable[..., ReturnType] - ellipsis means any params, skip param validation |
| 616 | + # Ellipsis means any params, skip param validation |
629 | 617 | expected_params: list[type] | None = None |
630 | | - if param_types is not ...: |
631 | | - if isinstance(param_types, (list, tuple)): |
632 | | - expected_params = list(param_types) |
633 | | - |
634 | | - return CallableEncodable( |
635 | | - ty, |
636 | | - typed_enc, |
637 | | - ctx, |
638 | | - expected_params=expected_params, |
639 | | - expected_return=expected_return, |
640 | | - ) |
| 618 | + if param_types is not ... and isinstance(param_types, (list, tuple)): |
| 619 | + expected_params = list(param_types) |
| 620 | + |
| 621 | + return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return) |
0 commit comments