Skip to content

Commit fa2e420

Browse files
jfesereb8680
andauthored
Make __signature__ a lazily computed property (#451)
* make __signature__ a lazily computed property * extend test * lint * Make __signature__ lazy * fix ci? --------- Co-authored-by: Eli <eli@basis.ai>
1 parent 19eafd5 commit fa2e420

5 files changed

Lines changed: 217 additions & 15 deletions

File tree

effectful/handlers/llm/template.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,14 @@ def vacation() -> str:
9090
9191
"""
9292

93-
def __init__(
94-
self, signature: inspect.Signature, name: str, default: Callable[P, T]
95-
):
93+
def __init__(self, default: Callable[P, T], name: str | None = None):
9694
if not default.__doc__:
9795
raise ValueError("Tools must have docstrings.")
98-
signature = IsRecursive.infer_annotations(signature)
99-
super().__init__(signature, name, default)
96+
super().__init__(default, name=name)
97+
98+
@property
99+
def __signature__(self):
100+
return IsRecursive.infer_annotations(super().__signature__)
100101

101102
@classmethod
102103
def define(cls, *args, **kwargs) -> "Tool[P, T]":

effectful/ops/types.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,30 @@ class Operation[**Q, V]:
7474
7575
"""
7676

77-
__signature__: inspect.Signature
7877
__name__: str
7978
__default__: Callable[Q, V]
8079
__apply__: typing.ClassVar["Operation"]
8180

82-
def __init__(
83-
self, signature: inspect.Signature, name: str, default: Callable[Q, V]
84-
):
81+
def __init__(self, default: Callable[Q, V], name: str | None = None):
8582
functools.update_wrapper(self, default)
86-
87-
self.__signature__ = signature
88-
self.__name__ = name
8983
self.__default__ = default
84+
self.__name__ = name or default.__name__
85+
86+
@property
87+
def __signature__(self):
88+
# Resolve forward references (e.g. -> "MyClass") using the
89+
# default function's __globals__. This handles module-level
90+
# forward refs; local forward refs will raise NameError.
91+
# Python 3.14's annotationlib.get_annotations(format=FORWARDREF)
92+
# could resolve local refs too via PEP 649 __annotate__ functions.
93+
annots = typing.get_type_hints(self.__default__, include_extras=True)
94+
sig = inspect.signature(self.__default__)
95+
updated_params = [
96+
p.replace(annotation=annots[p.name]) if p.name in annots else p
97+
for p in sig.parameters.values()
98+
]
99+
updated_ret = annots.get("return", sig.return_annotation)
100+
return sig.replace(parameters=updated_params, return_annotation=updated_ret)
90101

91102
def __eq__(self, other):
92103
if not isinstance(other, Operation):
@@ -267,8 +278,7 @@ def func(*args, **kwargs):
267278

268279
op = cls.define(func, name=name)
269280
else:
270-
name = name or t.__name__
271-
op = cls(inspect.signature(t), name, t) # type: ignore[arg-type]
281+
op = cls(t, name=name) # type: ignore[arg-type]
272282

273283
return op # type: ignore[return-value]
274284

tests/test_handlers_llm_template.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,3 +1518,18 @@ def test_validate_format_spec_on_undefined_var():
15181518
def bad(x: int) -> str:
15191519
"""Value: {x} and {missing:.2f}."""
15201520
raise NotHandled
1521+
1522+
1523+
# Forward ref through Tool subclass of Operation.
1524+
# Use types Pydantic can serialize (not arbitrary classes) to avoid
1525+
# PydanticSchemaGenerationError when other tests build tool schemas.
1526+
@Tool.define
1527+
def _tool_forward_ref(x: "int") -> "str":
1528+
"""A tool with forward-referenced parameter and return types."""
1529+
raise NotHandled
1530+
1531+
1532+
def test_tool_forward_ref():
1533+
sig = inspect.signature(_tool_forward_ref)
1534+
assert sig.parameters["x"].annotation is int
1535+
assert sig.return_annotation is str

tests/test_ops_syntax.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,3 +1126,140 @@ def id[T](base: T) -> T:
11261126
raise NotHandled
11271127

11281128
assert isinstance(id(A(0)).x, Term)
1129+
1130+
1131+
# Forward references in types only work on module-level definitions.
1132+
@defop
1133+
def forward_ref_op() -> "A":
1134+
raise NotHandled
1135+
1136+
1137+
class A: ...
1138+
1139+
1140+
def test_defop_forward_ref():
1141+
term = forward_ref_op()
1142+
assert term.op == forward_ref_op
1143+
assert typeof(term) is A
1144+
1145+
@defop
1146+
def local_forward_ref_op() -> "B":
1147+
raise NotHandled
1148+
1149+
class B: ...
1150+
1151+
with pytest.raises(NameError):
1152+
local_forward_ref_op()
1153+
1154+
1155+
# Forward ref in a parameter annotation.
1156+
@defop
1157+
def _forward_ref_param_op(x: "_ForwardRefParam") -> int:
1158+
raise NotHandled
1159+
1160+
1161+
class _ForwardRefParam:
1162+
pass
1163+
1164+
1165+
def test_defop_forward_ref_param():
1166+
sig = inspect.signature(_forward_ref_param_op)
1167+
assert sig.parameters["x"].annotation is _ForwardRefParam
1168+
assert sig.return_annotation is int
1169+
1170+
1171+
# Forward ref through Operation.define on a type.
1172+
class _ForwardRefType:
1173+
pass
1174+
1175+
1176+
_forward_ref_type_op = Operation.define(_ForwardRefType)
1177+
1178+
1179+
def test_define_type_forward_ref():
1180+
term = _forward_ref_type_op()
1181+
assert term.op == _forward_ref_type_op
1182+
assert typeof(term) is _ForwardRefType
1183+
1184+
1185+
# Forward ref on an instance method.
1186+
class _ForwardRefMethodHost:
1187+
@defop
1188+
def my_method(self, x: int) -> "_ForwardRefMethodResult":
1189+
raise NotHandled
1190+
1191+
1192+
class _ForwardRefMethodResult:
1193+
pass
1194+
1195+
1196+
def test_defop_forward_ref_method():
1197+
instance = _ForwardRefMethodHost()
1198+
term = instance.my_method(5)
1199+
assert isinstance(term, Term)
1200+
sig = inspect.signature(_ForwardRefMethodHost.my_method)
1201+
assert sig.return_annotation is _ForwardRefMethodResult
1202+
1203+
1204+
# Forward ref on a staticmethod.
1205+
class _ForwardRefStaticHost:
1206+
@defop
1207+
@staticmethod
1208+
def my_static(x: int) -> "_ForwardRefStaticResult":
1209+
raise NotHandled
1210+
1211+
1212+
class _ForwardRefStaticResult:
1213+
pass
1214+
1215+
1216+
def test_defop_forward_ref_staticmethod():
1217+
term = _ForwardRefStaticHost.my_static(5)
1218+
assert isinstance(term, Term)
1219+
sig = inspect.signature(_ForwardRefStaticHost.my_static)
1220+
assert sig.return_annotation is _ForwardRefStaticResult
1221+
1222+
1223+
# Forward ref on a classmethod.
1224+
class _ForwardRefClassmethodHost:
1225+
@defop
1226+
@classmethod
1227+
def my_classmethod(cls, x: int) -> "_ForwardRefClassmethodResult":
1228+
raise NotHandled
1229+
1230+
1231+
class _ForwardRefClassmethodResult:
1232+
pass
1233+
1234+
1235+
def test_defop_forward_ref_classmethod():
1236+
term = _ForwardRefClassmethodHost.my_classmethod(5)
1237+
assert isinstance(term, Term)
1238+
sig = inspect.signature(_ForwardRefClassmethodHost.my_classmethod)
1239+
assert sig.return_annotation is _ForwardRefClassmethodResult
1240+
1241+
1242+
# Mutual recursion: two classes with forward refs to each other.
1243+
class _Coordinate:
1244+
@defop
1245+
def log(self) -> "_CoordinateTangent":
1246+
raise NotHandled
1247+
1248+
1249+
class _CoordinateTangent:
1250+
@defop
1251+
def exp(self) -> "_Coordinate":
1252+
raise NotHandled
1253+
1254+
1255+
def test_defop_forward_ref_mutual_recursion():
1256+
coord = _Coordinate()
1257+
tangent = _CoordinateTangent()
1258+
1259+
log_term = coord.log()
1260+
assert isinstance(log_term, Term)
1261+
assert typeof(log_term) is _CoordinateTangent
1262+
1263+
exp_term = tangent.exp()
1264+
assert isinstance(exp_term, Term)
1265+
assert typeof(exp_term) is _Coordinate

tests/test_ops_types.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import inspect
2+
13
from effectful.ops.syntax import defop
2-
from effectful.ops.types import Interpretation
4+
from effectful.ops.types import Interpretation, NotHandled
35

46

57
def test_interpretation_isinstance():
@@ -10,3 +12,40 @@ def test_interpretation_isinstance():
1012
assert not isinstance({a: 0, b: "hello"}, Interpretation)
1113
assert not isinstance([a, b], Interpretation)
1214
assert not isinstance({"a": lambda: 0, "b": lambda: "hello"}, Interpretation)
15+
16+
17+
def test_instance_method_signature_excludes_self():
18+
"""Instance-bound operations should not have 'self' in their signature.
19+
20+
When an Operation is used as a method and accessed on an instance,
21+
__get__ creates a new Operation from a bound method. The signature
22+
should reflect the bound method (without 'self'), not the original
23+
unbound function.
24+
25+
This failed with cached_property because functools.update_wrapper
26+
copied a stale __signature__ (with 'self') into __dict__, shadowing
27+
the descriptor.
28+
"""
29+
30+
class MyClass:
31+
@defop
32+
def my_method(self, x: int) -> str:
33+
raise NotHandled
34+
35+
# Access the class-level signature first, which with cached_property
36+
# stores (self, x: int) -> str in MyClass.my_method.__dict__['__signature__'].
37+
# This is the key trigger: __get__ later copies __dict__ via functools.wraps
38+
# to the instance operation, shadowing a cached_property but not a property.
39+
cls_sig = MyClass.my_method.__signature__
40+
assert "self" in cls_sig.parameters # class-level should have self
41+
42+
instance = MyClass()
43+
instance_op = instance.my_method
44+
45+
# The instance operation should have a signature without 'self'
46+
sig = inspect.signature(instance_op)
47+
assert "self" not in sig.parameters
48+
assert "x" in sig.parameters
49+
50+
# Binding should work with just the real args (no 'self')
51+
sig.bind(42)

0 commit comments

Comments
 (0)