Skip to content

Commit 5d44086

Browse files
authored
Implement RetryHandler on top of new internal API (closes #495) (#522)
* implements retryhandler * made exception catches more specific * updated RetryLLMHandler implementation to handle tool calls and delegate implementation to fwd * restored documentation on retry using litellm * switched exception classes to cleaner dataclasses * dropped redundant ToolExecutionError * add parameter to include traceback in calls * moved formatting to exception classes * fixed failing tests * raised inside except to preserve backtrace * unified exception message formatting * made retryllmhandler parametric on errors it catches
1 parent 10a4ebd commit 5d44086

2 files changed

Lines changed: 772 additions & 24 deletions

File tree

effectful/handlers/llm/completions.py

Lines changed: 226 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import collections
22
import collections.abc
3+
import dataclasses
34
import functools
45
import inspect
56
import string
67
import textwrap
8+
import traceback
79
import typing
810

911
import litellm
@@ -20,8 +22,9 @@
2022
OpenAIMessageContentListBlock,
2123
)
2224

23-
from effectful.handlers.llm import Template, Tool
2425
from effectful.handlers.llm.encoding import Encodable
26+
from effectful.handlers.llm.template import Template, Tool
27+
from effectful.ops.semantics import fwd
2528
from effectful.ops.syntax import ObjectInterpretation, implements
2629
from effectful.ops.types import Operation
2730

@@ -36,6 +39,83 @@
3639
type ToolCallID = str
3740

3841

42+
@dataclasses.dataclass
43+
class ToolCallDecodingError(Exception):
44+
"""Error raised when decoding a tool call fails."""
45+
46+
tool_name: str
47+
tool_call_id: str
48+
original_error: Exception
49+
raw_message: Message
50+
51+
def __str__(self) -> str:
52+
return f"Error decoding tool call '{self.tool_name}': {self.original_error}. Please provide a valid response and try again."
53+
54+
def to_feedback_message(self, include_traceback: bool) -> Message:
55+
error_message = f"{self}"
56+
if include_traceback:
57+
tb = traceback.format_exc()
58+
error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```"
59+
return typing.cast(
60+
Message,
61+
{
62+
"role": "tool",
63+
"tool_call_id": self.tool_call_id,
64+
"content": error_message,
65+
},
66+
)
67+
68+
69+
@dataclasses.dataclass
70+
class ResultDecodingError(Exception):
71+
"""Error raised when decoding the LLM response result fails."""
72+
73+
original_error: Exception
74+
raw_message: Message
75+
76+
def __str__(self) -> str:
77+
return f"Error decoding response: {self.original_error}. Please provide a valid response and try again."
78+
79+
def to_feedback_message(self, include_traceback: bool) -> Message:
80+
error_message = f"{self}"
81+
if include_traceback:
82+
tb = traceback.format_exc()
83+
error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```"
84+
return typing.cast(
85+
Message,
86+
{
87+
"role": "user",
88+
"content": error_message,
89+
},
90+
)
91+
92+
93+
@dataclasses.dataclass
94+
class ToolCallExecutionError(Exception):
95+
"""Error raised when a tool execution fails at runtime."""
96+
97+
tool_name: str
98+
tool_call_id: str
99+
original_error: BaseException
100+
101+
def __str__(self) -> str:
102+
return f"Tool execution failed: Error executing tool '{self.tool_name}': {self.original_error}"
103+
104+
def to_feedback_message(self, include_traceback: bool) -> Message:
105+
error_message = f"{self}"
106+
if include_traceback:
107+
tb = traceback.format_exc()
108+
error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```"
109+
return typing.cast(
110+
Message,
111+
{
112+
"role": "tool",
113+
"tool_call_id": self.tool_call_id,
114+
"content": error_message,
115+
},
116+
)
117+
118+
39119
class DecodedToolCall[T](typing.NamedTuple):
40120
tool: Tool[..., T]
41121
bound_args: inspect.BoundArguments
@@ -77,26 +157,49 @@ def _function_model(tool: Tool) -> ChatCompletionToolParam:
77157
def decode_tool_call(
78158
tool_call: ChatCompletionMessageToolCall,
79159
tools: collections.abc.Mapping[str, Tool],
160+
raw_message: Message,
80161
) -> DecodedToolCall:
81-
"""Decode a tool call from the LLM response into a DecodedToolCall."""
82-
assert tool_call.function.name is not None
83-
tool = tools[tool_call.function.name]
84-
json_str = tool_call.function.arguments
162+
"""Decode a tool call from the LLM response into a DecodedToolCall.
163+
164+
Args:
165+
tool_call: The tool call to decode.
166+
tools: Mapping of tool names to Tool objects.
167+
raw_message: Optional raw assistant message for error context.
85168
169+
Raises:
170+
ToolCallDecodingError: If the tool call cannot be decoded.
171+
"""
172+
tool_name = tool_call.function.name
173+
assert tool_name is not None
174+
175+
try:
176+
tool = tools[tool_name]
177+
except KeyError as e:
178+
raise ToolCallDecodingError(
179+
tool_name, tool_call.id, e, raw_message=raw_message
180+
) from e
181+
182+
json_str = tool_call.function.arguments
86183
sig = inspect.signature(tool)
87184

88-
# build dict of raw encodable types U
89-
raw_args = _param_model(tool).model_validate_json(json_str)
185+
try:
186+
# build dict of raw encodable types U
187+
raw_args = _param_model(tool).model_validate_json(json_str)
188+
189+
# use encoders to decode Us to python types T
190+
bound_sig: inspect.BoundArguments = sig.bind(
191+
**{
192+
param_name: Encodable.define(
193+
sig.parameters[param_name].annotation, {}
194+
).decode(getattr(raw_args, param_name))
195+
for param_name in raw_args.model_fields_set
196+
}
197+
)
198+
except (pydantic.ValidationError, TypeError, ValueError) as e:
199+
raise ToolCallDecodingError(
200+
tool_name, tool_call.id, e, raw_message=raw_message
201+
) from e
90202

91-
# use encoders to decode Us to python types T
92-
bound_sig: inspect.BoundArguments = sig.bind(
93-
**{
94-
param_name: Encodable.define(
95-
sig.parameters[param_name].annotation, {}
96-
).decode(getattr(raw_args, param_name))
97-
for param_name in raw_args.model_fields_set
98-
}
99-
)
100203
return DecodedToolCall(tool, bound_sig, tool_call.id)
101204

102205

@@ -125,6 +228,11 @@ def call_assistant[T, U](
125228
This effect is emitted for model request/response rounds so handlers can
126229
observe/log requests.
127230
231+
Raises:
232+
ToolCallDecodingError: If a tool call cannot be decoded. The error
233+
includes the raw assistant message for retry handling.
234+
ResultDecodingError: If the result cannot be decoded. The error
235+
includes the raw assistant message for retry handling.
128236
"""
129237
tool_specs = {k: _function_model(t) for k, t in tools.items()}
130238
response_model = pydantic.create_model(
@@ -144,11 +252,15 @@ def call_assistant[T, U](
144252
message: litellm.Message = choice.message
145253
assert message.role == "assistant"
146254

255+
raw_message = typing.cast(Message, message.model_dump(mode="json"))
256+
147257
tool_calls: list[DecodedToolCall] = []
148258
raw_tool_calls = message.get("tool_calls") or []
149-
for tool_call in raw_tool_calls:
150-
tool_call = ChatCompletionMessageToolCall.model_validate(tool_call)
151-
decoded_tool_call = decode_tool_call(tool_call, tools)
259+
for raw_tool_call in raw_tool_calls:
260+
validated_tool_call = ChatCompletionMessageToolCall.model_validate(
261+
raw_tool_call
262+
)
263+
decoded_tool_call = decode_tool_call(validated_tool_call, tools, raw_message)
152264
tool_calls.append(decoded_tool_call)
153265

154266
result = None
@@ -158,10 +270,13 @@ def call_assistant[T, U](
158270
assert isinstance(serialized_result, str), (
159271
"final response from the model should be a string"
160272
)
161-
raw_result = response_model.model_validate_json(serialized_result)
162-
result = response_format.decode(raw_result.value) # type: ignore
273+
try:
274+
raw_result = response_model.model_validate_json(serialized_result)
275+
result = response_format.decode(raw_result.value) # type: ignore
276+
except pydantic.ValidationError as e:
277+
raise ResultDecodingError(e, raw_message=raw_message) from e
163278

164-
return (typing.cast(Message, message.model_dump(mode="json")), tool_calls, result)
279+
return (raw_message, tool_calls, result)
165280

166281

167282
@Operation.define
@@ -239,6 +354,95 @@ def call_system(template: Template) -> collections.abc.Sequence[Message]:
239354
return ()
240355

241356

357+
class RetryLLMHandler(ObjectInterpretation):
358+
"""Retries LLM requests if tool call or result decoding fails.
359+
360+
This handler intercepts `call_assistant` and catches `ToolCallDecodingError`
361+
and `ResultDecodingError`. When these errors occur, it appends error feedback
362+
to the messages and retries the request. Malformed messages from retry attempts
363+
are pruned from the final result.
364+
365+
For runtime tool execution failures (handled via `call_tool`), errors are
366+
captured and returned as tool response messages.
367+
368+
Args:
369+
num_retries: The maximum number of retries (default: 3).
370+
include_traceback: If True, include full traceback in error feedback
371+
for better debugging context (default: False).
372+
catch_tool_errors: Exception type(s) to catch during tool execution.
373+
Can be a single exception class or a tuple of exception classes.
374+
Defaults to Exception (catches all exceptions).
375+
"""
376+
377+
def __init__(
378+
self,
379+
num_retries: int = 3,
380+
include_traceback: bool = False,
381+
catch_tool_errors: type[BaseException]
382+
| tuple[type[BaseException], ...] = Exception,
383+
):
384+
self.num_retries = num_retries
385+
self.include_traceback = include_traceback
386+
self.catch_tool_errors = catch_tool_errors
387+
388+
@implements(call_assistant)
389+
def _call_assistant[T, U](
390+
self,
391+
messages: collections.abc.Sequence[Message],
392+
tools: collections.abc.Mapping[str, Tool],
393+
response_format: Encodable[T, U],
394+
model: str,
395+
**kwargs,
396+
) -> MessageResult[T]:
397+
messages_list = list(messages)
398+
last_attempt = self.num_retries
399+
400+
for attempt in range(self.num_retries + 1):
401+
try:
402+
message, tool_calls, result = fwd(
403+
messages_list, tools, response_format, model, **kwargs
404+
)
405+
406+
# Success! The returned message is the final successful response.
407+
# Malformed messages from retries are only in messages_list,
408+
# not in the returned result.
409+
return (message, tool_calls, result)
410+
411+
except (ToolCallDecodingError, ResultDecodingError) as e:
412+
# On last attempt, re-raise to preserve full traceback
413+
if attempt == last_attempt:
414+
raise
415+
416+
# Add the malformed assistant message
417+
messages_list.append(e.raw_message)
418+
419+
# Add error feedback as a tool response
420+
error_feedback: Message = e.to_feedback_message(self.include_traceback)
421+
messages_list.append(error_feedback)
422+
423+
# Should never reach here - either we return on success or raise on final failure
424+
raise AssertionError("Unreachable: retry loop exited without return or raise")
425+
426+
@implements(completion)
427+
def _completion(self, *args, **kwargs) -> typing.Any:
428+
"""Inject num_retries for litellm's built-in network error handling."""
429+
return fwd(*args, num_retries=self.num_retries, **kwargs)
430+
431+
@implements(call_tool)
432+
def _call_tool(self, tool_call: DecodedToolCall) -> Message:
433+
"""Handle tool execution with runtime error capture.
434+
435+
Runtime errors from tool execution are captured and returned as
436+
error messages to the LLM. Only exceptions matching `catch_tool_errors`
437+
are caught; others propagate up.
438+
"""
439+
try:
440+
return fwd(tool_call)
441+
except self.catch_tool_errors as e:
442+
error = ToolCallExecutionError(tool_call.tool.__name__, tool_call.id, e)
443+
return error.to_feedback_message(self.include_traceback)
444+
445+
242446
class LiteLLMProvider(ObjectInterpretation):
243447
"""Implements templates using the LiteLLM API."""
244448

0 commit comments

Comments
 (0)