Skip to content

Commit fcad5d8

Browse files
committed
reverted flags for customizing rglobals, using same rglobals for local and globals for exec, fixing bug
1 parent bc1fcc0 commit fcad5d8

1 file changed

Lines changed: 190 additions & 105 deletions

File tree

effectful/handlers/llm/completions.py

Lines changed: 190 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,68 @@
2222

2323
from effectful.handlers.llm.encoding import Encodable
2424
from effectful.handlers.llm.template import Template, Tool
25+
from effectful.ops.semantics import fwd
2526
from effectful.ops.syntax import ObjectInterpretation, implements
2627
from effectful.ops.types import Operation
2728

29+
30+
class ToolCallDecodingError(Exception):
31+
"""Error raised when decoding a tool call fails.
32+
33+
Attributes:
34+
tool_name: Name of the tool that failed to decode.
35+
tool_call_id: ID of the tool call that failed.
36+
original_error: The underlying exception that caused the failure.
37+
raw_message: The raw assistant message containing the failed tool call.
38+
"""
39+
40+
def __init__(
41+
self,
42+
tool_name: str,
43+
tool_call_id: str,
44+
original_error: Exception,
45+
raw_message: typing.Any = None,
46+
):
47+
self.tool_name = tool_name
48+
self.tool_call_id = tool_call_id
49+
self.original_error = original_error
50+
self.raw_message = raw_message
51+
super().__init__(f"Error decoding tool call '{tool_name}': {original_error}")
52+
53+
54+
class ResultDecodingError(Exception):
55+
"""Error raised when decoding the LLM response result fails.
56+
57+
Attributes:
58+
original_error: The underlying exception that caused the failure.
59+
raw_message: The raw assistant message containing the failed result.
60+
"""
61+
62+
def __init__(
63+
self,
64+
original_error: Exception,
65+
raw_message: typing.Any = None,
66+
):
67+
self.original_error = original_error
68+
self.raw_message = raw_message
69+
super().__init__(f"Error decoding response: {original_error}")
70+
71+
72+
class ToolExecutionError(Exception):
73+
"""Error raised when a tool execution fails at runtime."""
74+
75+
def __init__(
76+
self,
77+
tool_name: str,
78+
tool_call_id: str,
79+
original_error: Exception,
80+
):
81+
self.tool_name = tool_name
82+
self.tool_call_id = tool_call_id
83+
self.original_error = original_error
84+
super().__init__(f"Error executing tool '{tool_name}': {original_error}")
85+
86+
2887
Message = (
2988
OpenAIChatCompletionAssistantMessage
3089
| ChatCompletionToolMessage
@@ -78,25 +137,39 @@ def decode_tool_call(
78137
tool_call: ChatCompletionMessageToolCall,
79138
tools: collections.abc.Mapping[str, Tool],
80139
) -> DecodedToolCall:
81-
"""Decode a tool call from the LLM response into a DecodedToolCall."""
82-
assert tool_call.function.name is not None
83-
tool = tools[tool_call.function.name]
84-
json_str = tool_call.function.arguments
140+
"""Decode a tool call from the LLM response into a DecodedToolCall.
141+
142+
Raises:
143+
ToolCallDecodingError: If the tool call cannot be decoded due to
144+
an unknown tool name or invalid arguments.
145+
"""
146+
tool_name = tool_call.function.name
147+
assert tool_name is not None
148+
149+
try:
150+
tool = tools[tool_name]
151+
except KeyError as e:
152+
raise ToolCallDecodingError(tool_name, tool_call.id, e) from e
85153

154+
json_str = tool_call.function.arguments
86155
sig = inspect.signature(tool)
87156

88-
# build dict of raw encodable types U
89-
raw_args = _param_model(tool).model_validate_json(json_str)
157+
try:
158+
# build dict of raw encodable types U
159+
raw_args = _param_model(tool).model_validate_json(json_str)
160+
161+
# use encoders to decode Us to python types T
162+
bound_sig: inspect.BoundArguments = sig.bind(
163+
**{
164+
param_name: Encodable.define(
165+
sig.parameters[param_name].annotation, {}
166+
).decode(getattr(raw_args, param_name))
167+
for param_name in raw_args.model_fields_set
168+
}
169+
)
170+
except (pydantic.ValidationError, TypeError, ValueError) as e:
171+
raise ToolCallDecodingError(tool_name, tool_call.id, e) from e
90172

91-
# use encoders to decode Us to python types T
92-
bound_sig: inspect.BoundArguments = sig.bind(
93-
**{
94-
param_name: Encodable.define(
95-
sig.parameters[param_name].annotation, {}
96-
).decode(getattr(raw_args, param_name))
97-
for param_name in raw_args.model_fields_set
98-
}
99-
)
100173
return DecodedToolCall(tool, bound_sig, tool_call.id)
101174

102175

@@ -125,6 +198,11 @@ def call_assistant[T, U](
125198
This effect is emitted for model request/response rounds so handlers can
126199
observe/log requests.
127200
201+
Raises:
202+
ToolCallDecodingError: If a tool call cannot be decoded. The error
203+
includes the raw assistant message for retry handling.
204+
ResultDecodingError: If the result cannot be decoded. The error
205+
includes the raw assistant message for retry handling.
128206
"""
129207
tool_specs = {k: _function_model(t) for k, t in tools.items()}
130208
response_model = pydantic.create_model(
@@ -144,12 +222,26 @@ def call_assistant[T, U](
144222
message: litellm.Message = choice.message
145223
assert message.role == "assistant"
146224

225+
# Store the raw message for error reporting
226+
raw_message = typing.cast(Message, message.model_dump(mode="json"))
227+
147228
tool_calls: list[DecodedToolCall] = []
148229
raw_tool_calls = message.get("tool_calls") or []
149-
for tool_call in raw_tool_calls:
150-
tool_call = ChatCompletionMessageToolCall.model_validate(tool_call)
151-
decoded_tool_call = decode_tool_call(tool_call, tools)
152-
tool_calls.append(decoded_tool_call)
230+
for raw_tool_call in raw_tool_calls:
231+
validated_tool_call = ChatCompletionMessageToolCall.model_validate(
232+
raw_tool_call
233+
)
234+
try:
235+
decoded_tool_call = decode_tool_call(validated_tool_call, tools)
236+
tool_calls.append(decoded_tool_call)
237+
except ToolCallDecodingError as e:
238+
# Re-raise with the raw message attached
239+
raise ToolCallDecodingError(
240+
e.tool_name,
241+
e.tool_call_id,
242+
e.original_error,
243+
raw_message=raw_message,
244+
) from e.original_error
153245

154246
result = None
155247
if not tool_calls:
@@ -158,10 +250,13 @@ def call_assistant[T, U](
158250
assert isinstance(serialized_result, str), (
159251
"final response from the model should be a string"
160252
)
161-
raw_result = response_model.model_validate_json(serialized_result)
162-
result = response_format.decode(raw_result.value) # type: ignore
253+
try:
254+
raw_result = response_model.model_validate_json(serialized_result)
255+
result = response_format.decode(raw_result.value) # type: ignore
256+
except pydantic.ValidationError as e:
257+
raise ResultDecodingError(e, raw_message=raw_message) from e
163258

164-
return (typing.cast(Message, message.model_dump(mode="json")), tool_calls, result)
259+
return (raw_message, tool_calls, result)
165260

166261

167262
@Operation.define
@@ -242,6 +337,18 @@ def call_system(template: Template) -> collections.abc.Sequence[Message]:
242337
class RetryHandler(ObjectInterpretation):
243338
"""Retries LLM requests if tool call or result decoding fails.
244339
340+
This handler intercepts `call_assistant` and catches `ToolCallDecodingError`
341+
and `ResultDecodingError`. When these errors occur, it appends error feedback
342+
to the messages and retries the request. Malformed messages from retry attempts
343+
are pruned from the final result.
344+
345+
For runtime tool execution failures (handled via `call_tool`), errors are
346+
captured and returned as tool response messages. These are NOT pruned since
347+
they represent legitimate runtime failures the LLM should be aware of.
348+
349+
Note: Server-side LLM API/network failures should be handled within litellm
350+
using its built-in retry mechanisms.
351+
245352
Args:
246353
num_retries: The maximum number of retries (default: 3).
247354
"""
@@ -261,100 +368,52 @@ def _call_assistant[T, U](
261368
messages_list = list(messages)
262369
last_error: Exception | None = None
263370

264-
tool_specs = {k: _function_model(t) for k, t in tools.items()}
265-
response_model = pydantic.create_model(
266-
"Response", value=response_format.enc, __config__={"extra": "forbid"}
267-
)
268-
269371
for _attempt in range(self.num_retries + 1):
270-
response: litellm.types.utils.ModelResponse = completion(
271-
model,
272-
messages=messages_list,
273-
response_format=response_model,
274-
tools=list(tool_specs.values()),
275-
**kwargs,
276-
)
277-
choice = response.choices[0]
278-
assert isinstance(choice, litellm.types.utils.Choices)
279-
280-
message: litellm.Message = choice.message
281-
assert message.role == "assistant"
282-
283-
raw_tool_calls = message.get("tool_calls") or []
372+
try:
373+
message, tool_calls, result = fwd(
374+
messages_list, tools, response_format, model, **kwargs
375+
)
284376

285-
# Try to decode tool calls, catching any decoding errors
286-
tool_calls: list[DecodedToolCall] = []
287-
decoding_errors: list[tuple[ChatCompletionMessageToolCall, Exception]] = []
377+
# Success! The returned message is the final successful response.
378+
# Malformed messages from retries are only in messages_list,
379+
# not in the returned result.
380+
return (message, tool_calls, result)
288381

289-
for raw_tool_call in raw_tool_calls:
290-
validated_tool_call = ChatCompletionMessageToolCall.model_validate(
291-
raw_tool_call
382+
except ToolCallDecodingError as e:
383+
last_error = e
384+
# The error includes the raw message from the failed attempt
385+
assert e.raw_message is not None, (
386+
"ToolCallDecodingError should include raw_message"
292387
)
293-
try:
294-
decoded_tool_call = decode_tool_call(validated_tool_call, tools)
295-
tool_calls.append(decoded_tool_call)
296-
except (KeyError, pydantic.ValidationError) as e:
297-
decoding_errors.append((validated_tool_call, e))
298-
299-
# If there were tool call decoding errors, add error feedback and retry
300-
if decoding_errors:
388+
301389
# Add the malformed assistant message
302-
messages_list.append(
303-
typing.cast(Message, message.model_dump(mode="json"))
304-
)
390+
messages_list.append(e.raw_message)
305391

306-
# Add error feedback for each failed tool call
307-
for failed_tool_call, error in decoding_errors:
308-
last_error = error
309-
error_msg = (
310-
f"Error decoding tool call '{failed_tool_call.function.name}': {error}. "
311-
f"Please fix the tool call arguments and try again."
312-
)
313-
error_feedback: Message = typing.cast(
314-
Message,
315-
{
316-
"role": "tool",
317-
"tool_call_id": failed_tool_call.id,
318-
"content": error_msg,
319-
},
320-
)
321-
messages_list.append(error_feedback)
392+
# Add error feedback as a tool response
393+
error_msg = f"{e}. Please fix the tool call arguments and try again."
394+
error_feedback: Message = typing.cast(
395+
Message,
396+
{
397+
"role": "tool",
398+
"tool_call_id": e.tool_call_id,
399+
"content": error_msg,
400+
},
401+
)
402+
messages_list.append(error_feedback)
322403
continue
323404

324-
# If there are tool calls, return them without decoding result
325-
if tool_calls:
326-
return (
327-
typing.cast(Message, message.model_dump(mode="json")),
328-
tool_calls,
329-
None,
405+
except ResultDecodingError as e:
406+
last_error = e
407+
# The error includes the raw message from the failed attempt
408+
assert e.raw_message is not None, (
409+
"ResultDecodingError should include raw_message"
330410
)
331411

332-
# No tool calls - try to decode the result
333-
serialized_result = message.get("content") or message.get(
334-
"reasoning_content"
335-
)
336-
assert isinstance(serialized_result, str), (
337-
"final response from the model should be a string"
338-
)
412+
# Add the malformed assistant message
413+
messages_list.append(e.raw_message)
339414

340-
try:
341-
raw_result = response_model.model_validate_json(serialized_result)
342-
result = response_format.decode(raw_result.value) # type: ignore
343-
return (
344-
typing.cast(Message, message.model_dump(mode="json")),
345-
tool_calls,
346-
result,
347-
)
348-
except pydantic.ValidationError as e:
349-
last_error = e
350-
# Add the assistant message and error feedback for result decoding failure
351-
messages_list.append(
352-
typing.cast(Message, message.model_dump(mode="json"))
353-
)
354-
error_msg = (
355-
f"Error decoding response: {e}. "
356-
f"Please provide a valid response and try again."
357-
)
415+
# Add error feedback as a user message
416+
error_msg = f"{e}. Please provide a valid response and try again."
358417
result_error_feedback: Message = typing.cast(
359418
Message,
360419
{
@@ -369,6 +428,32 @@ def _call_assistant[T, U](
369428
assert last_error is not None
370429
raise last_error
371430

431+
@implements(call_tool)
432+
def _call_tool(self, tool_call: DecodedToolCall) -> Message:
433+
"""Handle tool execution with runtime error capture.
434+
435+
Runtime errors from tool execution are captured and returned as
436+
error messages to the LLM. These messages are NOT pruned since
437+
they represent legitimate runtime failures.
438+
"""
439+
try:
440+
return fwd(tool_call)
441+
except Exception as e:
442+
# Wrap runtime errors and return as a tool message
443+
error = ToolExecutionError(
444+
tool_call.tool.__name__,
445+
tool_call.id,
446+
e,
447+
)
448+
return typing.cast(
449+
Message,
450+
{
451+
"role": "tool",
452+
"tool_call_id": tool_call.id,
453+
"content": f"Tool execution failed: {error}",
454+
},
455+
)
456+
372457

373458
class LiteLLMProvider(ObjectInterpretation):
374459
"""Implements templates using the LiteLLM API."""

0 commit comments

Comments
 (0)