Skip to content

Commit c47abd6

Browse files
committed
updated call assistant to handle decoding tool calls
1 parent a06296d commit c47abd6

23 files changed

Lines changed: 949 additions & 232 deletions

File tree

effectful/handlers/llm/completions.py

Lines changed: 88 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
ChatCompletionMessageToolCall,
1414
ChatCompletionTextObject,
1515
ChatCompletionToolMessage,
16-
ChatCompletionToolParam,
1716
OpenAIChatCompletionAssistantMessage,
1817
OpenAIChatCompletionSystemMessage,
1918
OpenAIChatCompletionUserMessage,
@@ -22,7 +21,6 @@
2221

2322
from effectful.handlers.llm import Template, Tool
2423
from effectful.handlers.llm.encoding import Encodable
25-
from effectful.ops.semantics import fwd
2624
from effectful.ops.syntax import ObjectInterpretation, implements
2725
from effectful.ops.types import Operation
2826

@@ -34,102 +32,117 @@
3432
| OpenAIChatCompletionUserMessage
3533
)
3634

35+
type ToolCallID = str
3736

38-
def _parameter_model(sig: inspect.Signature) -> type[pydantic.BaseModel]:
39-
return pydantic.create_model(
40-
"Params",
41-
__config__={"extra": "forbid"},
37+
38+
class DecodedToolCall[T](typing.NamedTuple):
39+
tool: Tool[..., T]
40+
bound_args: inspect.BoundArguments
41+
id: ToolCallID
42+
43+
44+
type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None]
45+
46+
47+
def decode_tool_call(
48+
tool_call: ChatCompletionMessageToolCall,
49+
tools: collections.abc.Mapping[str, Tool],
50+
) -> DecodedToolCall:
51+
"""Decode a tool call from the LLM response into a DecodedToolCall."""
52+
assert tool_call.function.name is not None
53+
tool = tools[tool_call.function.name]
54+
json_str = tool_call.function.arguments
55+
56+
sig = inspect.signature(tool)
57+
58+
# build dict of raw encodable types U
59+
raw_args = tool.param_model.model_validate_json(json_str)
60+
61+
# use encoders to decode Us to python types T
62+
bound_sig: inspect.BoundArguments = sig.bind(
4263
**{
43-
name: Encodable.define(param.annotation).enc
44-
for name, param in sig.parameters.items()
45-
}, # type: ignore
64+
param_name: Encodable.define(
65+
sig.parameters[param_name].annotation, {}
66+
).decode(getattr(raw_args, param_name))
67+
for param_name in raw_args.model_fields_set
68+
}
4669
)
70+
return DecodedToolCall(tool, bound_sig, tool_call.id)
4771

4872

49-
def _response_model(sig: inspect.Signature) -> type[pydantic.BaseModel]:
50-
return pydantic.create_model(
51-
"Response",
52-
value=Encodable.define(sig.return_annotation).enc,
53-
__config__={"extra": "forbid"},
54-
)
73+
@Operation.define
74+
@functools.wraps(litellm.completion)
75+
def completion(*args, **kwargs) -> typing.Any:
76+
"""Low-level LLM request. Handlers may log/modify requests and delegate via fwd().
5577
78+
This effect is emitted for model request/response rounds so handlers can
79+
observe/log requests.
5680
57-
def _tool_model(tool: Tool) -> ChatCompletionToolParam:
58-
param_model = _parameter_model(inspect.signature(tool))
59-
response_format = litellm.utils.type_to_response_format_param(param_model)
60-
assert response_format is not None
61-
assert tool.__default__.__doc__ is not None
62-
return {
63-
"type": "function",
64-
"function": {
65-
"name": tool.__name__,
66-
"description": textwrap.dedent(tool.__default__.__doc__),
67-
"parameters": response_format["json_schema"]["schema"],
68-
"strict": True,
69-
},
70-
}
81+
"""
82+
return litellm.completion(*args, **kwargs)
7183

7284

7385
@Operation.define
74-
def call_assistant(
86+
def call_assistant[T, U](
7587
messages: collections.abc.Sequence[Message],
76-
response_format: type[pydantic.BaseModel] | None,
77-
tools: collections.abc.Mapping[str, ChatCompletionToolParam],
88+
tools: collections.abc.Mapping[str, Tool],
89+
response_format: Encodable[T, U],
7890
model: str,
7991
**kwargs,
80-
) -> Message:
92+
) -> MessageResult[T]:
8193
"""Low-level LLM request. Handlers may log/modify requests and delegate via fwd().
8294
8395
This effect is emitted for model request/response rounds so handlers can
8496
observe/log requests.
8597
8698
"""
87-
response: litellm.types.utils.ModelResponse = litellm.completion(
99+
tool_specs = {k: t.model for k, t in tools.items()}
100+
response_model = pydantic.create_model(
101+
"Response", value=response_format.enc, __config__={"extra": "forbid"}
102+
)
103+
104+
response: litellm.types.utils.ModelResponse = completion(
88105
model,
89106
messages=list(messages),
90-
response_format=response_format,
91-
tools=list(tools.values()),
107+
response_format=response_model,
108+
tools=list(tool_specs.values()),
92109
**kwargs,
93110
)
94111
choice = response.choices[0]
95112
assert isinstance(choice, litellm.types.utils.Choices)
113+
96114
message: litellm.Message = choice.message
97115
assert message.role == "assistant"
98-
return typing.cast(Message, message.model_dump(mode="json"))
116+
117+
tool_calls: list[DecodedToolCall] = []
118+
raw_tool_calls = message.get("tool_calls") or []
119+
for tool_call in raw_tool_calls:
120+
tool_call = ChatCompletionMessageToolCall.model_validate(tool_call)
121+
decoded_tool_call = decode_tool_call(tool_call, tools)
122+
tool_calls.append(decoded_tool_call)
123+
124+
result = None
125+
if not tool_calls:
126+
# return response
127+
serialized_result = message.get("content") or message.get("reasoning_content")
128+
assert isinstance(serialized_result, str), (
129+
"final response from the model should be a string"
130+
)
131+
raw_result = response_model.model_validate_json(serialized_result)
132+
result = response_format.decode(raw_result.value) # type: ignore
133+
134+
return (typing.cast(Message, message.model_dump(mode="json")), tool_calls, result)
99135

100136

101137
@Operation.define
102-
def call_tool(
103-
tool_call: ChatCompletionMessageToolCall,
104-
tools: collections.abc.Mapping[str, Tool],
105-
) -> Message:
138+
def call_tool(tool_call: DecodedToolCall) -> Message:
106139
"""Implements a roundtrip call to a python function. Input is a json
107140
string representing an LLM tool call request parameters. The output is
108141
the serialised response to the model.
109142
110143
"""
111-
assert tool_call.function.name is not None
112-
tool = tools[tool_call.function.name]
113-
json_str = tool_call.function.arguments
114-
115-
sig = inspect.signature(tool)
116-
param_model = _parameter_model(sig)
117-
118-
# build dict of raw encodable types U
119-
raw_args = param_model.model_validate_json(json_str)
120-
121-
# use encoders to decode Us to python types T
122-
bound_sig: inspect.BoundArguments = sig.bind(
123-
**{
124-
param_name: Encodable.define(
125-
sig.parameters[param_name].annotation, {}
126-
).decode(getattr(raw_args, param_name))
127-
for param_name in raw_args.model_fields_set
128-
}
129-
)
130-
131144
# call tool with python types
132-
result = tool(*bound_sig.args, **bound_sig.kwargs)
145+
result = tool_call.tool(*tool_call.bound_args.args, **tool_call.bound_args.kwargs)
133146

134147
# serialize back to U using encoder for return type
135148
return_type = Encodable.define(type(result))
@@ -207,57 +220,39 @@ def __init__(self, model="gpt-4o", **config):
207220
**inspect.signature(litellm.completion).bind_partial(**config).kwargs,
208221
}
209222

210-
@implements(call_assistant)
211-
@functools.wraps(call_assistant)
212-
def _completion(self, *args, **kwargs):
213-
return fwd(*args, **{**self.config, **kwargs})
214-
215223
@implements(Template.__apply__)
216224
def _call[**P, T](
217225
self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs
218226
) -> T:
219-
response_encoding_type: Encodable = Encodable.define(
220-
inspect.signature(template).return_annotation, template.__context__
221-
)
222-
response_model = _response_model(inspect.signature(template))
223-
224227
messages: list[Message] = [*call_system(template)]
225228

226229
# encode arguments
227230
bound_args = inspect.signature(template).bind(*args, **kwargs)
228231
bound_args.apply_defaults()
229232
env = template.__context__.new_child(bound_args.arguments)
230233

234+
# Create response_model with env so tools passed as arguments are available
235+
response_model = Encodable.define(template.__signature__.return_annotation, env)
236+
231237
user_messages: list[Message] = call_user(template.__prompt_template__, env)
232238
messages.extend(user_messages)
233239

234-
tools = {
235-
**template.tools,
236-
**{k: t for k, t in bound_args.arguments.items() if isinstance(t, Tool)},
237-
}
238-
tool_specs = {k: _tool_model(t) for k, t in tools.items()}
239-
240240
# loop based on: https://cookbook.openai.com/examples/reasoning_function_calls
241-
tool_calls: list[ChatCompletionMessageToolCall] = []
241+
tool_calls: list[DecodedToolCall] = []
242242

243243
message = messages[-1]
244+
result: T | None = None
244245
while message["role"] != "assistant" or tool_calls:
245-
message = call_assistant(messages, response_model, tool_specs)
246+
message, tool_calls, result = call_assistant(
247+
messages, template.tools, response_model, **self.config
248+
)
246249
messages.append(message)
247-
tool_calls = message.get("tool_calls") or []
248250
for tool_call in tool_calls:
249-
tool_call = ChatCompletionMessageToolCall.model_validate(tool_call)
250-
message = call_tool(tool_call, tools)
251+
message = call_tool(tool_call)
251252
messages.append(message)
252253

253-
# return response
254-
serialized_result = message.get("content") or message.get("reasoning_content")
255-
assert isinstance(serialized_result, str), (
256-
"final response from the model should be a string"
254+
assert result is not None, (
255+
"call_assistant did not produce a result nor tool_calls"
257256
)
258-
encoded_result = (
259-
serialized_result
260-
if response_model is None
261-
else response_model.model_validate_json(serialized_result).value # type: ignore
262-
)
263-
return response_encoding_type.decode(encoded_result)
257+
# return response
258+
return result

effectful/handlers/llm/template.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
import functools
12
import inspect
3+
import textwrap
24
import types
35
import typing
46
from collections import ChainMap
57
from collections.abc import Callable, Mapping, MutableMapping
68
from dataclasses import dataclass
79
from typing import Annotated, Any
810

11+
import litellm
12+
import pydantic
13+
from litellm import ChatCompletionToolParam
14+
15+
from effectful.handlers.llm.encoding import Encodable
916
from effectful.ops.types import INSTANCE_OP_PREFIX, Annotation, Operation
1017

1118

@@ -95,6 +102,33 @@ def __init__(
95102
signature = IsRecursive.infer_annotations(signature)
96103
super().__init__(signature, name, default)
97104

105+
@functools.cached_property
106+
def param_model(self) -> type[pydantic.BaseModel]:
107+
sig = inspect.signature(self)
108+
return pydantic.create_model(
109+
"Params",
110+
__config__={"extra": "forbid"},
111+
**{
112+
name: Encodable.define(param.annotation).enc
113+
for name, param in sig.parameters.items()
114+
}, # type: ignore
115+
)
116+
117+
@functools.cached_property
118+
def model(self) -> ChatCompletionToolParam:
119+
response_format = litellm.utils.type_to_response_format_param(self.param_model)
120+
assert response_format is not None
121+
assert self.__default__.__doc__ is not None
122+
return {
123+
"type": "function",
124+
"function": {
125+
"name": self.__name__,
126+
"description": textwrap.dedent(self.__default__.__doc__),
127+
"parameters": response_format["json_schema"]["schema"],
128+
"strict": True,
129+
},
130+
}
131+
98132
@classmethod
99133
def define(cls, *args, **kwargs) -> "Tool[P, T]":
100134
"""Define a tool.
@@ -185,6 +219,7 @@ def tools(self) -> Mapping[str, Tool]:
185219
for name, obj in self.__context__.items():
186220
if obj is self and not is_recursive:
187221
continue
222+
188223
# Collect tools in context
189224
if isinstance(obj, Tool):
190225
result[name] = obj
@@ -260,6 +295,14 @@ def define[**Q, V](
260295
*typing.cast(list[MutableMapping[str, Any]], contexts)
261296
)
262297

298+
is_recursive = _is_recursive_signature(inspect.signature(default))
299+
# todo: make this more pythonic
300+
if not is_recursive:
301+
# drop default.__name__ from context
302+
pass
303+
263304
op = super().define(default, *args, **kwargs)
264305
op.__context__ = context # type: ignore[attr-defined]
306+
# todo: drop self from contexts if not is_recursive
307+
265308
return typing.cast(Template[Q, V], op)
Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,44 @@
11
{
2-
"annotations": [],
3-
"content": "{\"value\":73}",
4-
"function_call": null,
5-
"provider_specific_fields": {
6-
"refusal": null
2+
"id": "chatcmpl-D3rNbSUc9fUpX7qU6kanq3CXVh8eJ",
3+
"created": 1769812547,
4+
"model": "gpt-5-nano-2025-08-07",
5+
"object": "chat.completion",
6+
"system_fingerprint": null,
7+
"choices": [
8+
{
9+
"finish_reason": "stop",
10+
"index": 0,
11+
"message": {
12+
"content": "{\"value\":73}",
13+
"role": "assistant",
14+
"tool_calls": null,
15+
"function_call": null,
16+
"provider_specific_fields": {
17+
"refusal": null
18+
},
19+
"annotations": []
20+
},
21+
"provider_specific_fields": {}
22+
}
23+
],
24+
"usage": {
25+
"completion_tokens": 401,
26+
"prompt_tokens": 340,
27+
"total_tokens": 741,
28+
"completion_tokens_details": {
29+
"accepted_prediction_tokens": 0,
30+
"audio_tokens": 0,
31+
"reasoning_tokens": 384,
32+
"rejected_prediction_tokens": 0,
33+
"text_tokens": null,
34+
"image_tokens": null
35+
},
36+
"prompt_tokens_details": {
37+
"audio_tokens": 0,
38+
"cached_tokens": 0,
39+
"text_tokens": null,
40+
"image_tokens": null
41+
}
742
},
8-
"role": "assistant",
9-
"tool_calls": null
43+
"service_tier": "default"
1044
}

0 commit comments

Comments
 (0)