|
13 | 13 | ChatCompletionMessageToolCall, |
14 | 14 | ChatCompletionTextObject, |
15 | 15 | ChatCompletionToolMessage, |
16 | | - ChatCompletionToolParam, |
17 | 16 | OpenAIChatCompletionAssistantMessage, |
18 | 17 | OpenAIChatCompletionSystemMessage, |
19 | 18 | OpenAIChatCompletionUserMessage, |
|
22 | 21 |
|
23 | 22 | from effectful.handlers.llm import Template, Tool |
24 | 23 | from effectful.handlers.llm.encoding import Encodable |
25 | | -from effectful.ops.semantics import fwd |
26 | 24 | from effectful.ops.syntax import ObjectInterpretation, implements |
27 | 25 | from effectful.ops.types import Operation |
28 | 26 |
|
|
34 | 32 | | OpenAIChatCompletionUserMessage |
35 | 33 | ) |
36 | 34 |
|
| 35 | +type ToolCallID = str |
37 | 36 |
|
38 | | -def _parameter_model(sig: inspect.Signature) -> type[pydantic.BaseModel]: |
39 | | - return pydantic.create_model( |
40 | | - "Params", |
41 | | - __config__={"extra": "forbid"}, |
| 37 | + |
| 38 | +class DecodedToolCall[T](typing.NamedTuple): |
| 39 | + tool: Tool[..., T] |
| 40 | + bound_args: inspect.BoundArguments |
| 41 | + id: ToolCallID |
| 42 | + |
| 43 | + |
| 44 | +type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None] |
| 45 | + |
| 46 | + |
| 47 | +def decode_tool_call( |
| 48 | + tool_call: ChatCompletionMessageToolCall, |
| 49 | + tools: collections.abc.Mapping[str, Tool], |
| 50 | +) -> DecodedToolCall: |
| 51 | + """Decode a tool call from the LLM response into a DecodedToolCall.""" |
| 52 | + assert tool_call.function.name is not None |
| 53 | + tool = tools[tool_call.function.name] |
| 54 | + json_str = tool_call.function.arguments |
| 55 | + |
| 56 | + sig = inspect.signature(tool) |
| 57 | + |
| 58 | + # build dict of raw encodable types U |
| 59 | + raw_args = tool.param_model.model_validate_json(json_str) |
| 60 | + |
| 61 | + # use encoders to decode Us to python types T |
| 62 | + bound_sig: inspect.BoundArguments = sig.bind( |
42 | 63 | **{ |
43 | | - name: Encodable.define(param.annotation).enc |
44 | | - for name, param in sig.parameters.items() |
45 | | - }, # type: ignore |
| 64 | + param_name: Encodable.define( |
| 65 | + sig.parameters[param_name].annotation, {} |
| 66 | + ).decode(getattr(raw_args, param_name)) |
| 67 | + for param_name in raw_args.model_fields_set |
| 68 | + } |
46 | 69 | ) |
| 70 | + return DecodedToolCall(tool, bound_sig, tool_call.id) |
47 | 71 |
|
48 | 72 |
|
49 | | -def _response_model(sig: inspect.Signature) -> type[pydantic.BaseModel]: |
50 | | - return pydantic.create_model( |
51 | | - "Response", |
52 | | - value=Encodable.define(sig.return_annotation).enc, |
53 | | - __config__={"extra": "forbid"}, |
54 | | - ) |
| 73 | +@Operation.define |
| 74 | +@functools.wraps(litellm.completion) |
| 75 | +def completion(*args, **kwargs) -> typing.Any: |
| 76 | + """Low-level LLM request. Handlers may log/modify requests and delegate via fwd(). |
55 | 77 |
|
| 78 | + This effect is emitted for model request/response rounds so handlers can |
| 79 | + observe/log requests. |
56 | 80 |
|
57 | | -def _tool_model(tool: Tool) -> ChatCompletionToolParam: |
58 | | - param_model = _parameter_model(inspect.signature(tool)) |
59 | | - response_format = litellm.utils.type_to_response_format_param(param_model) |
60 | | - assert response_format is not None |
61 | | - assert tool.__default__.__doc__ is not None |
62 | | - return { |
63 | | - "type": "function", |
64 | | - "function": { |
65 | | - "name": tool.__name__, |
66 | | - "description": textwrap.dedent(tool.__default__.__doc__), |
67 | | - "parameters": response_format["json_schema"]["schema"], |
68 | | - "strict": True, |
69 | | - }, |
70 | | - } |
| 81 | + """ |
| 82 | + return litellm.completion(*args, **kwargs) |
71 | 83 |
|
72 | 84 |
|
73 | 85 | @Operation.define |
74 | | -def call_assistant( |
| 86 | +def call_assistant[T, U]( |
75 | 87 | messages: collections.abc.Sequence[Message], |
76 | | - response_format: type[pydantic.BaseModel] | None, |
77 | | - tools: collections.abc.Mapping[str, ChatCompletionToolParam], |
| 88 | + tools: collections.abc.Mapping[str, Tool], |
| 89 | + response_format: Encodable[T, U], |
78 | 90 | model: str, |
79 | 91 | **kwargs, |
80 | | -) -> Message: |
| 92 | +) -> MessageResult[T]: |
81 | 93 | """Low-level LLM request. Handlers may log/modify requests and delegate via fwd(). |
82 | 94 |
|
83 | 95 | This effect is emitted for model request/response rounds so handlers can |
84 | 96 | observe/log requests. |
85 | 97 |
|
86 | 98 | """ |
87 | | - response: litellm.types.utils.ModelResponse = litellm.completion( |
| 99 | + tool_specs = {k: t.model for k, t in tools.items()} |
| 100 | + response_model = pydantic.create_model( |
| 101 | + "Response", value=response_format.enc, __config__={"extra": "forbid"} |
| 102 | + ) |
| 103 | + |
| 104 | + response: litellm.types.utils.ModelResponse = completion( |
88 | 105 | model, |
89 | 106 | messages=list(messages), |
90 | | - response_format=response_format, |
91 | | - tools=list(tools.values()), |
| 107 | + response_format=response_model, |
| 108 | + tools=list(tool_specs.values()), |
92 | 109 | **kwargs, |
93 | 110 | ) |
94 | 111 | choice = response.choices[0] |
95 | 112 | assert isinstance(choice, litellm.types.utils.Choices) |
| 113 | + |
96 | 114 | message: litellm.Message = choice.message |
97 | 115 | assert message.role == "assistant" |
98 | | - return typing.cast(Message, message.model_dump(mode="json")) |
| 116 | + |
| 117 | + tool_calls: list[DecodedToolCall] = [] |
| 118 | + raw_tool_calls = message.get("tool_calls") or [] |
| 119 | + for tool_call in raw_tool_calls: |
| 120 | + tool_call = ChatCompletionMessageToolCall.model_validate(tool_call) |
| 121 | + decoded_tool_call = decode_tool_call(tool_call, tools) |
| 122 | + tool_calls.append(decoded_tool_call) |
| 123 | + |
| 124 | + result = None |
| 125 | + if not tool_calls: |
| 126 | + # return response |
| 127 | + serialized_result = message.get("content") or message.get("reasoning_content") |
| 128 | + assert isinstance(serialized_result, str), ( |
| 129 | + "final response from the model should be a string" |
| 130 | + ) |
| 131 | + raw_result = response_model.model_validate_json(serialized_result) |
| 132 | + result = response_format.decode(raw_result.value) # type: ignore |
| 133 | + |
| 134 | + return (typing.cast(Message, message.model_dump(mode="json")), tool_calls, result) |
99 | 135 |
|
100 | 136 |
|
101 | 137 | @Operation.define |
102 | | -def call_tool( |
103 | | - tool_call: ChatCompletionMessageToolCall, |
104 | | - tools: collections.abc.Mapping[str, Tool], |
105 | | -) -> Message: |
| 138 | +def call_tool(tool_call: DecodedToolCall) -> Message: |
106 | 139 | """Implements a roundtrip call to a python function. Input is a json |
107 | 140 | string representing an LLM tool call request parameters. The output is |
108 | 141 | the serialised response to the model. |
109 | 142 |
|
110 | 143 | """ |
111 | | - assert tool_call.function.name is not None |
112 | | - tool = tools[tool_call.function.name] |
113 | | - json_str = tool_call.function.arguments |
114 | | - |
115 | | - sig = inspect.signature(tool) |
116 | | - param_model = _parameter_model(sig) |
117 | | - |
118 | | - # build dict of raw encodable types U |
119 | | - raw_args = param_model.model_validate_json(json_str) |
120 | | - |
121 | | - # use encoders to decode Us to python types T |
122 | | - bound_sig: inspect.BoundArguments = sig.bind( |
123 | | - **{ |
124 | | - param_name: Encodable.define( |
125 | | - sig.parameters[param_name].annotation, {} |
126 | | - ).decode(getattr(raw_args, param_name)) |
127 | | - for param_name in raw_args.model_fields_set |
128 | | - } |
129 | | - ) |
130 | | - |
131 | 144 | # call tool with python types |
132 | | - result = tool(*bound_sig.args, **bound_sig.kwargs) |
| 145 | + result = tool_call.tool(*tool_call.bound_args.args, **tool_call.bound_args.kwargs) |
133 | 146 |
|
134 | 147 | # serialize back to U using encoder for return type |
135 | 148 | return_type = Encodable.define(type(result)) |
@@ -207,57 +220,39 @@ def __init__(self, model="gpt-4o", **config): |
207 | 220 | **inspect.signature(litellm.completion).bind_partial(**config).kwargs, |
208 | 221 | } |
209 | 222 |
|
210 | | - @implements(call_assistant) |
211 | | - @functools.wraps(call_assistant) |
212 | | - def _completion(self, *args, **kwargs): |
213 | | - return fwd(*args, **{**self.config, **kwargs}) |
214 | | - |
215 | 223 | @implements(Template.__apply__) |
216 | 224 | def _call[**P, T]( |
217 | 225 | self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs |
218 | 226 | ) -> T: |
219 | | - response_encoding_type: Encodable = Encodable.define( |
220 | | - inspect.signature(template).return_annotation, template.__context__ |
221 | | - ) |
222 | | - response_model = _response_model(inspect.signature(template)) |
223 | | - |
224 | 227 | messages: list[Message] = [*call_system(template)] |
225 | 228 |
|
226 | 229 | # encode arguments |
227 | 230 | bound_args = inspect.signature(template).bind(*args, **kwargs) |
228 | 231 | bound_args.apply_defaults() |
229 | 232 | env = template.__context__.new_child(bound_args.arguments) |
230 | 233 |
|
| 234 | + # Create response_model with env so tools passed as arguments are available |
| 235 | + response_model = Encodable.define(template.__signature__.return_annotation, env) |
| 236 | + |
231 | 237 | user_messages: list[Message] = call_user(template.__prompt_template__, env) |
232 | 238 | messages.extend(user_messages) |
233 | 239 |
|
234 | | - tools = { |
235 | | - **template.tools, |
236 | | - **{k: t for k, t in bound_args.arguments.items() if isinstance(t, Tool)}, |
237 | | - } |
238 | | - tool_specs = {k: _tool_model(t) for k, t in tools.items()} |
239 | | - |
240 | 240 | # loop based on: https://cookbook.openai.com/examples/reasoning_function_calls |
241 | | - tool_calls: list[ChatCompletionMessageToolCall] = [] |
| 241 | + tool_calls: list[DecodedToolCall] = [] |
242 | 242 |
|
243 | 243 | message = messages[-1] |
| 244 | + result: T | None = None |
244 | 245 | while message["role"] != "assistant" or tool_calls: |
245 | | - message = call_assistant(messages, response_model, tool_specs) |
| 246 | + message, tool_calls, result = call_assistant( |
| 247 | + messages, template.tools, response_model, **self.config |
| 248 | + ) |
246 | 249 | messages.append(message) |
247 | | - tool_calls = message.get("tool_calls") or [] |
248 | 250 | for tool_call in tool_calls: |
249 | | - tool_call = ChatCompletionMessageToolCall.model_validate(tool_call) |
250 | | - message = call_tool(tool_call, tools) |
| 251 | + message = call_tool(tool_call) |
251 | 252 | messages.append(message) |
252 | 253 |
|
253 | | - # return response |
254 | | - serialized_result = message.get("content") or message.get("reasoning_content") |
255 | | - assert isinstance(serialized_result, str), ( |
256 | | - "final response from the model should be a string" |
| 254 | + assert result is not None, ( |
| 255 | + "call_assistant did not produce a result nor tool_calls" |
257 | 256 | ) |
258 | | - encoded_result = ( |
259 | | - serialized_result |
260 | | - if response_model is None |
261 | | - else response_model.model_validate_json(serialized_result).value # type: ignore |
262 | | - ) |
263 | | - return response_encoding_type.decode(encoded_result) |
| 257 | + # return response |
| 258 | + return result |
0 commit comments