Skip to content

Commit 2b4449a

Browse files
committed
simplified smart constructor
1 parent 2a8af94 commit 2b4449a

2 files changed

Lines changed: 11 additions & 44 deletions

File tree

effectful/handlers/llm/encoding.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
import io
55
import textwrap
6+
import types
67
import typing
78
from abc import ABC, abstractmethod
89
from collections.abc import Callable, Mapping, MutableMapping, Sequence
@@ -593,48 +594,28 @@ def _encodable_callable(
593594
) -> Encodable[Callable, SynthesizedFunction]:
594595
ctx = ctx or {}
595596

596-
# Extract type args - Callable requires a type signature
597597
type_args = typing.get_args(ty)
598598

599-
# Handle bare Callable without type args - allow encoding but disable decode
600-
# this occurs when encoding Tools which return callable (need to Encodable.define(return_type) for return type)
599+
# Bare Callable without type args - allow encoding but disable decode
600+
# this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type)
601601
if not type_args:
602+
assert ty is types.FunctionType, f"Callable must have type signatures {ty}"
602603
typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type]
603-
return CallableEncodable(
604-
ty, typed_enc, ctx, expected_params=None, expected_return=None
605-
)
604+
return CallableEncodable(ty, typed_enc, ctx)
606605

607606
if len(type_args) < 2:
608607
raise TypeError(
609608
f"Callable type signature incomplete: {ty}. "
610609
"Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]."
611610
)
612611

613-
# Extract param and return types for validation
614-
param_types = type_args[0]
615-
expected_return: type | None = type_args[-1]
616-
617-
# Handle Any as return type - allow encoding but disable decode
618-
# Any doesn't provide useful information for synthesis (expected_return=None)
619-
if expected_return is typing.Any:
620-
typed_enc = _create_typed_synthesized_function(ty)
621-
return CallableEncodable(
622-
ty, typed_enc, ctx, expected_params=None, expected_return=None
623-
)
612+
param_types, expected_return = type_args[0], type_args[-1]
624613

625-
# Create a typed SynthesizedFunction model with the type signature in the description
626614
typed_enc = _create_typed_synthesized_function(ty)
627615

628-
# Handle Callable[..., ReturnType] - ellipsis means any params, skip param validation
616+
# Ellipsis means any params, skip param validation
629617
expected_params: list[type] | None = None
630-
if param_types is not ...:
631-
if isinstance(param_types, (list, tuple)):
632-
expected_params = list(param_types)
633-
634-
return CallableEncodable(
635-
ty,
636-
typed_enc,
637-
ctx,
638-
expected_params=expected_params,
639-
expected_return=expected_return,
640-
)
618+
if param_types is not ... and isinstance(param_types, (list, tuple)):
619+
expected_params = list(param_types)
620+
621+
return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return)

tests/test_handlers_llm_encoding.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -742,20 +742,6 @@ def add(a: int, b: int) -> int:
742742
with handler(UnsafeEvalProvider()):
743743
encodable.decode(encoded)
744744

745-
def test_callable_with_any_return_allows_encode_but_not_decode(self):
746-
def add(a: int, b: int) -> int:
747-
return a + b
748-
749-
# Callable[..., Any] allows encoding
750-
encodable = Encodable.define(Callable[..., Any], {})
751-
encoded = encodable.encode(add)
752-
assert isinstance(encoded, SynthesizedFunction)
753-
754-
# But decode is disabled
755-
with pytest.raises(TypeError, match="Cannot decode/synthesize callable"):
756-
with handler(UnsafeEvalProvider()):
757-
encodable.decode(encoded)
758-
759745
def test_encode_decode_function(self):
760746
def add(a: int, b: int) -> int:
761747
return a + b

0 commit comments

Comments
 (0)