Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 132 additions & 1 deletion effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
OpenAIMessageContentListBlock,
)

from effectful.handlers.llm import Template, Tool
from effectful.handlers.llm.encoding import Encodable
from effectful.handlers.llm.template import Template, Tool
from effectful.ops.syntax import ObjectInterpretation, implements
from effectful.ops.types import Operation

Expand Down Expand Up @@ -239,6 +239,137 @@ def call_system(template: Template) -> collections.abc.Sequence[Message]:
return ()


class RetryHandler(ObjectInterpretation):
"""Retries LLM requests if tool call or result decoding fails.

Args:
num_retries: The maximum number of retries (default: 3).
"""

def __init__(self, num_retries: int = 3):
self.num_retries = num_retries

@implements(call_assistant)
def _call_assistant[T, U](
Comment thread
kiranandcode marked this conversation as resolved.
self,
messages: collections.abc.Sequence[Message],
tools: collections.abc.Mapping[str, Tool],
response_format: Encodable[T, U],
model: str,
**kwargs,
) -> MessageResult[T]:
messages_list = list(messages)
last_error: Exception | None = None

tool_specs = {k: _function_model(t) for k, t in tools.items()}
response_model = pydantic.create_model(
"Response", value=response_format.enc, __config__={"extra": "forbid"}
)

for _attempt in range(self.num_retries + 1):
response: litellm.types.utils.ModelResponse = completion(
model,
messages=messages_list,
response_format=response_model,
tools=list(tool_specs.values()),
**kwargs,
)
choice = response.choices[0]
assert isinstance(choice, litellm.types.utils.Choices)

message: litellm.Message = choice.message
assert message.role == "assistant"

raw_tool_calls = message.get("tool_calls") or []

# Try to decode tool calls, catching any decoding errors
tool_calls: list[DecodedToolCall] = []
decoding_errors: list[tuple[ChatCompletionMessageToolCall, Exception]] = []

for raw_tool_call in raw_tool_calls:
validated_tool_call = ChatCompletionMessageToolCall.model_validate(
raw_tool_call
)
try:
decoded_tool_call = decode_tool_call(validated_tool_call, tools)
tool_calls.append(decoded_tool_call)
except (KeyError, pydantic.ValidationError) as e:
decoding_errors.append((validated_tool_call, e))

# If there were tool call decoding errors, add error feedback and retry
if decoding_errors:
# Add the malformed assistant message
messages_list.append(
typing.cast(Message, message.model_dump(mode="json"))
)

# Add error feedback for each failed tool call
for failed_tool_call, error in decoding_errors:
last_error = error
error_msg = (
f"Error decoding tool call '{failed_tool_call.function.name}': {error}. "
f"Please fix the tool call arguments and try again."
)
error_feedback: Message = typing.cast(
Message,
{
"role": "tool",
"tool_call_id": failed_tool_call.id,
"content": error_msg,
},
)
messages_list.append(error_feedback)
continue

# If there are tool calls, return them without decoding result
if tool_calls:
return (
typing.cast(Message, message.model_dump(mode="json")),
tool_calls,
None,
)

# No tool calls - try to decode the result
serialized_result = message.get("content") or message.get(
"reasoning_content"
)
assert isinstance(serialized_result, str), (
"final response from the model should be a string"
)

try:
raw_result = response_model.model_validate_json(serialized_result)
result = response_format.decode(raw_result.value) # type: ignore
return (
typing.cast(Message, message.model_dump(mode="json")),
tool_calls,
result,
)
except pydantic.ValidationError as e:
last_error = e
# Add the assistant message and error feedback for result decoding failure
messages_list.append(
typing.cast(Message, message.model_dump(mode="json"))
)
error_msg = (
f"Error decoding response: {e}. "
f"Please provide a valid response and try again."
)
result_error_feedback: Message = typing.cast(
Message,
{
"role": "user",
"content": error_msg,
},
)
messages_list.append(result_error_feedback)
continue

# If all retries failed, raise the last error
assert last_error is not None
raise last_error
Comment thread
kiranandcode marked this conversation as resolved.
Outdated


class LiteLLMProvider(ObjectInterpretation):
"""Implements templates using the LiteLLM API."""

Expand Down
Loading
Loading