11import collections
22import collections .abc
3+ import dataclasses
34import functools
45import inspect
56import string
67import textwrap
8+ import traceback
79import typing
810
911import litellm
2022 OpenAIMessageContentListBlock ,
2123)
2224
23- from effectful .handlers .llm import Template , Tool
2425from effectful .handlers .llm .encoding import Encodable
26+ from effectful .handlers .llm .template import Template , Tool
27+ from effectful .ops .semantics import fwd
2528from effectful .ops .syntax import ObjectInterpretation , implements
2629from effectful .ops .types import Operation
2730
3639type ToolCallID = str
3740
3841
42+ @dataclasses .dataclass
43+ class ToolCallDecodingError (Exception ):
44+ """Error raised when decoding a tool call fails."""
45+
46+ tool_name : str
47+ tool_call_id : str
48+ original_error : Exception
49+ raw_message : Message
50+
51+ def __str__ (self ) -> str :
52+ return f"Error decoding tool call '{ self .tool_name } ': { self .original_error } . Please provide a valid response and try again."
53+
54+ def to_feedback_message (self , include_traceback : bool ) -> Message :
55+ error_message = f"{ self } "
56+ if include_traceback :
57+ tb = traceback .format_exc ()
58+ error_message = f"{ error_message } \n \n Traceback:\n ```\n { tb } ```"
59+ return typing .cast (
60+ Message ,
61+ {
62+ "role" : "tool" ,
63+ "tool_call_id" : self .tool_call_id ,
64+ "content" : error_message ,
65+ },
66+ )
67+
68+
69+ @dataclasses .dataclass
70+ class ResultDecodingError (Exception ):
71+ """Error raised when decoding the LLM response result fails."""
72+
73+ original_error : Exception
74+ raw_message : Message
75+
76+ def __str__ (self ) -> str :
77+ return f"Error decoding response: { self .original_error } . Please provide a valid response and try again."
78+
79+ def to_feedback_message (self , include_traceback : bool ) -> Message :
80+ error_message = f"{ self } "
81+ if include_traceback :
82+ tb = traceback .format_exc ()
83+ error_message = f"{ error_message } \n \n Traceback:\n ```\n { tb } ```"
84+ return typing .cast (
85+ Message ,
86+ {
87+ "role" : "user" ,
88+ "content" : error_message ,
89+ },
90+ )
91+
92+
93+ @dataclasses .dataclass
94+ class ToolCallExecutionError (Exception ):
95+ """Error raised when a tool execution fails at runtime."""
96+
97+ tool_name : str
98+ tool_call_id : str
99+ original_error : BaseException
100+
101+ def __str__ (self ) -> str :
102+ return f"Tool execution failed: Error executing tool '{ self .tool_name } ': { self .original_error } "
103+
104+ def to_feedback_message (self , include_traceback : bool ) -> Message :
105+ error_message = f"{ self } "
106+ if include_traceback :
107+ tb = traceback .format_exc ()
108+ error_message = f"{ error_message } \n \n Traceback:\n ```\n { tb } ```"
109+ return typing .cast (
110+ Message ,
111+ {
112+ "role" : "tool" ,
113+ "tool_call_id" : self .tool_call_id ,
114+ "content" : error_message ,
115+ },
116+ )
117+
118+
39119class DecodedToolCall [T ](typing .NamedTuple ):
40120 tool : Tool [..., T ]
41121 bound_args : inspect .BoundArguments
@@ -77,26 +157,49 @@ def _function_model(tool: Tool) -> ChatCompletionToolParam:
77157def decode_tool_call (
78158 tool_call : ChatCompletionMessageToolCall ,
79159 tools : collections .abc .Mapping [str , Tool ],
160+ raw_message : Message ,
80161) -> DecodedToolCall :
81- """Decode a tool call from the LLM response into a DecodedToolCall."""
82- assert tool_call .function .name is not None
83- tool = tools [tool_call .function .name ]
84- json_str = tool_call .function .arguments
162+ """Decode a tool call from the LLM response into a DecodedToolCall.
163+
164+ Args:
165+ tool_call: The tool call to decode.
166+ tools: Mapping of tool names to Tool objects.
167+ raw_message: Optional raw assistant message for error context.
85168
169+ Raises:
170+ ToolCallDecodingError: If the tool call cannot be decoded.
171+ """
172+ tool_name = tool_call .function .name
173+ assert tool_name is not None
174+
175+ try :
176+ tool = tools [tool_name ]
177+ except KeyError as e :
178+ raise ToolCallDecodingError (
179+ tool_name , tool_call .id , e , raw_message = raw_message
180+ ) from e
181+
182+ json_str = tool_call .function .arguments
86183 sig = inspect .signature (tool )
87184
88- # build dict of raw encodable types U
89- raw_args = _param_model (tool ).model_validate_json (json_str )
185+ try :
186+ # build dict of raw encodable types U
187+ raw_args = _param_model (tool ).model_validate_json (json_str )
188+
189+ # use encoders to decode Us to python types T
190+ bound_sig : inspect .BoundArguments = sig .bind (
191+ ** {
192+ param_name : Encodable .define (
193+ sig .parameters [param_name ].annotation , {}
194+ ).decode (getattr (raw_args , param_name ))
195+ for param_name in raw_args .model_fields_set
196+ }
197+ )
198+ except (pydantic .ValidationError , TypeError , ValueError ) as e :
199+ raise ToolCallDecodingError (
200+ tool_name , tool_call .id , e , raw_message = raw_message
201+ ) from e
90202
91- # use encoders to decode Us to python types T
92- bound_sig : inspect .BoundArguments = sig .bind (
93- ** {
94- param_name : Encodable .define (
95- sig .parameters [param_name ].annotation , {}
96- ).decode (getattr (raw_args , param_name ))
97- for param_name in raw_args .model_fields_set
98- }
99- )
100203 return DecodedToolCall (tool , bound_sig , tool_call .id )
101204
102205
@@ -125,6 +228,11 @@ def call_assistant[T, U](
125228 This effect is emitted for model request/response rounds so handlers can
126229 observe/log requests.
127230
231+ Raises:
232+ ToolCallDecodingError: If a tool call cannot be decoded. The error
233+ includes the raw assistant message for retry handling.
234+ ResultDecodingError: If the result cannot be decoded. The error
235+ includes the raw assistant message for retry handling.
128236 """
129237 tool_specs = {k : _function_model (t ) for k , t in tools .items ()}
130238 response_model = pydantic .create_model (
@@ -144,11 +252,15 @@ def call_assistant[T, U](
144252 message : litellm .Message = choice .message
145253 assert message .role == "assistant"
146254
255+ raw_message = typing .cast (Message , message .model_dump (mode = "json" ))
256+
147257 tool_calls : list [DecodedToolCall ] = []
148258 raw_tool_calls = message .get ("tool_calls" ) or []
149- for tool_call in raw_tool_calls :
150- tool_call = ChatCompletionMessageToolCall .model_validate (tool_call )
151- decoded_tool_call = decode_tool_call (tool_call , tools )
259+ for raw_tool_call in raw_tool_calls :
260+ validated_tool_call = ChatCompletionMessageToolCall .model_validate (
261+ raw_tool_call
262+ )
263+ decoded_tool_call = decode_tool_call (validated_tool_call , tools , raw_message )
152264 tool_calls .append (decoded_tool_call )
153265
154266 result = None
@@ -158,10 +270,13 @@ def call_assistant[T, U](
158270 assert isinstance (serialized_result , str ), (
159271 "final response from the model should be a string"
160272 )
161- raw_result = response_model .model_validate_json (serialized_result )
162- result = response_format .decode (raw_result .value ) # type: ignore
273+ try :
274+ raw_result = response_model .model_validate_json (serialized_result )
275+ result = response_format .decode (raw_result .value ) # type: ignore
276+ except pydantic .ValidationError as e :
277+ raise ResultDecodingError (e , raw_message = raw_message ) from e
163278
164- return (typing . cast ( Message , message . model_dump ( mode = "json" )) , tool_calls , result )
279+ return (raw_message , tool_calls , result )
165280
166281
167282@Operation .define
@@ -239,6 +354,95 @@ def call_system(template: Template) -> collections.abc.Sequence[Message]:
239354 return ()
240355
241356
357+ class RetryLLMHandler (ObjectInterpretation ):
358+ """Retries LLM requests if tool call or result decoding fails.
359+
360+ This handler intercepts `call_assistant` and catches `ToolCallDecodingError`
361+ and `ResultDecodingError`. When these errors occur, it appends error feedback
362+ to the messages and retries the request. Malformed messages from retry attempts
363+ are pruned from the final result.
364+
365+ For runtime tool execution failures (handled via `call_tool`), errors are
366+ captured and returned as tool response messages.
367+
368+ Args:
369+ num_retries: The maximum number of retries (default: 3).
370+ include_traceback: If True, include full traceback in error feedback
371+ for better debugging context (default: False).
372+ catch_tool_errors: Exception type(s) to catch during tool execution.
373+ Can be a single exception class or a tuple of exception classes.
374+ Defaults to Exception (catches all exceptions).
375+ """
376+
377+ def __init__ (
378+ self ,
379+ num_retries : int = 3 ,
380+ include_traceback : bool = False ,
381+ catch_tool_errors : type [BaseException ]
382+ | tuple [type [BaseException ], ...] = Exception ,
383+ ):
384+ self .num_retries = num_retries
385+ self .include_traceback = include_traceback
386+ self .catch_tool_errors = catch_tool_errors
387+
388+ @implements (call_assistant )
389+ def _call_assistant [T , U ](
390+ self ,
391+ messages : collections .abc .Sequence [Message ],
392+ tools : collections .abc .Mapping [str , Tool ],
393+ response_format : Encodable [T , U ],
394+ model : str ,
395+ ** kwargs ,
396+ ) -> MessageResult [T ]:
397+ messages_list = list (messages )
398+ last_attempt = self .num_retries
399+
400+ for attempt in range (self .num_retries + 1 ):
401+ try :
402+ message , tool_calls , result = fwd (
403+ messages_list , tools , response_format , model , ** kwargs
404+ )
405+
406+ # Success! The returned message is the final successful response.
407+ # Malformed messages from retries are only in messages_list,
408+ # not in the returned result.
409+ return (message , tool_calls , result )
410+
411+ except (ToolCallDecodingError , ResultDecodingError ) as e :
412+ # On last attempt, re-raise to preserve full traceback
413+ if attempt == last_attempt :
414+ raise
415+
416+ # Add the malformed assistant message
417+ messages_list .append (e .raw_message )
418+
419+ # Add error feedback as a tool response
420+ error_feedback : Message = e .to_feedback_message (self .include_traceback )
421+ messages_list .append (error_feedback )
422+
423+ # Should never reach here - either we return on success or raise on final failure
424+ raise AssertionError ("Unreachable: retry loop exited without return or raise" )
425+
426+ @implements (completion )
427+ def _completion (self , * args , ** kwargs ) -> typing .Any :
428+ """Inject num_retries for litellm's built-in network error handling."""
429+ return fwd (* args , num_retries = self .num_retries , ** kwargs )
430+
431+ @implements (call_tool )
432+ def _call_tool (self , tool_call : DecodedToolCall ) -> Message :
433+ """Handle tool execution with runtime error capture.
434+
435+ Runtime errors from tool execution are captured and returned as
436+ error messages to the LLM. Only exceptions matching `catch_tool_errors`
437+ are caught; others propagate up.
438+ """
439+ try :
440+ return fwd (tool_call )
441+ except self .catch_tool_errors as e :
442+ error = ToolCallExecutionError (tool_call .tool .__name__ , tool_call .id , e )
443+ return error .to_feedback_message (self .include_traceback )
444+
445+
242446class LiteLLMProvider (ObjectInterpretation ):
243447 """Implements templates using the LiteLLM API."""
244448
0 commit comments