Skip to content

Commit a06296d

Browse files
committed
updated completions to fix basic type errors
1 parent 1400d19 commit a06296d

1 file changed

Lines changed: 10 additions & 10 deletions

File tree

effectful/handlers/llm/completions.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
from effectful.handlers.llm import Template, Tool
24-
from effectful.handlers.llm.encoding import Encodable, type_to_encodable_type
24+
from effectful.handlers.llm.encoding import Encodable
2525
from effectful.ops.semantics import fwd
2626
from effectful.ops.syntax import ObjectInterpretation, implements
2727
from effectful.ops.types import Operation
@@ -40,7 +40,7 @@ def _parameter_model(sig: inspect.Signature) -> type[pydantic.BaseModel]:
4040
"Params",
4141
__config__={"extra": "forbid"},
4242
**{
43-
name: type_to_encodable_type(param.annotation).t
43+
name: Encodable.define(param.annotation).enc
4444
for name, param in sig.parameters.items()
4545
}, # type: ignore
4646
)
@@ -49,7 +49,7 @@ def _parameter_model(sig: inspect.Signature) -> type[pydantic.BaseModel]:
4949
def _response_model(sig: inspect.Signature) -> type[pydantic.BaseModel]:
5050
return pydantic.create_model(
5151
"Response",
52-
value=type_to_encodable_type(sig.return_annotation).t,
52+
value=Encodable.define(sig.return_annotation).enc,
5353
__config__={"extra": "forbid"},
5454
)
5555

@@ -121,8 +121,8 @@ def call_tool(
121121
# use encoders to decode Us to python types T
122122
bound_sig: inspect.BoundArguments = sig.bind(
123123
**{
124-
param_name: type_to_encodable_type(
125-
sig.parameters[param_name].annotation
124+
param_name: Encodable.define(
125+
sig.parameters[param_name].annotation, {}
126126
).decode(getattr(raw_args, param_name))
127127
for param_name in raw_args.model_fields_set
128128
}
@@ -132,7 +132,7 @@ def call_tool(
132132
result = tool(*bound_sig.args, **bound_sig.kwargs)
133133

134134
# serialize back to U using encoder for return type
135-
return_type = type_to_encodable_type(type(result))
135+
return_type = Encodable.define(type(result))
136136
encoded_result = return_type.serialize(return_type.encode(result))
137137
return typing.cast(
138138
Message, dict(role="tool", content=encoded_result, tool_call_id=tool_call.id)
@@ -167,8 +167,8 @@ def flush_text() -> None:
167167
continue
168168

169169
obj, _ = formatter.get_field(field_name, (), env)
170-
encoder = type_to_encodable_type(type(obj))
171-
encoded_obj: list[OpenAIMessageContentListBlock] = encoder.serialize(
170+
encoder = Encodable.define(type(obj))
171+
encoded_obj: typing.Sequence[OpenAIMessageContentListBlock] = encoder.serialize(
172172
encoder.encode(obj)
173173
)
174174
for part in encoded_obj:
@@ -216,8 +216,8 @@ def _completion(self, *args, **kwargs):
216216
def _call[**P, T](
217217
self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs
218218
) -> T:
219-
response_encoding_type: Encodable[T] = type_to_encodable_type(
220-
inspect.signature(template).return_annotation
219+
response_encoding_type: Encodable = Encodable.define(
220+
inspect.signature(template).return_annotation, template.__context__
221221
)
222222
response_model = _response_model(inspect.signature(template))
223223

0 commit comments

Comments
 (0)