Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 88 additions & 25 deletions effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import string
import textwrap
import threading
import traceback
import typing
import uuid
Expand All @@ -25,7 +26,7 @@
)

from effectful.handlers.llm.encoding import DecodedToolCall, Encodable
from effectful.handlers.llm.template import Template, Tool
from effectful.handlers.llm.template import Template, Tool, get_bound_agent
from effectful.internals.unification import nested_type
from effectful.ops.semantics import fwd, handler
from effectful.ops.syntax import ObjectInterpretation, implements
Expand Down Expand Up @@ -71,6 +72,30 @@ def append_message(message: Message):
pass


@Operation.define
def get_agent_history(agent_id: str) -> collections.OrderedDict[str, Message]:
"""Get the message history for an agent. Returns empty OrderedDict by default."""
return collections.OrderedDict()


class AgentHistoryHandler(ObjectInterpretation):
"""Handler that stores per-agent message histories in memory.

Install this handler to give :class:`Agent` instances persistent
in-memory histories across template calls::

with handler(AgentHistoryHandler()), handler(LiteLLMProvider()):
bot.ask("question") # history accumulates across calls
"""

def __init__(self) -> None:
self._histories: dict[str, collections.OrderedDict[str, Message]] = {}

@implements(get_agent_history)
def _get(self, agent_id: str) -> collections.OrderedDict[str, Message]:
return self._histories.setdefault(agent_id, collections.OrderedDict())


def _make_message(content: dict) -> Message:
m_id = content.get("id") or str(uuid.uuid1())
message = typing.cast(Message, {**content, "id": m_id})
Expand Down Expand Up @@ -442,7 +467,11 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message:


class LiteLLMProvider(ObjectInterpretation):
"""Implements templates using the LiteLLM API."""
"""Implements templates using the LiteLLM API.

Also provides per-agent message history storage via
:func:`get_agent_history`.
"""

config: collections.abc.Mapping[str, typing.Any]

Expand All @@ -451,6 +480,19 @@ def __init__(self, model="gpt-4o", **config):
"model": model,
**inspect.signature(litellm.completion).bind_partial(**config).kwargs,
}
self._histories: dict[str, collections.OrderedDict[str, Message]] = {}
self._tls = threading.local()

def _get_depths(self) -> dict[str, int]:
if not hasattr(self._tls, "depths"):
self._tls.depths = {}
return self._tls.depths

@implements(get_agent_history)
def _get_agent_history(
self, agent_id: str
) -> collections.OrderedDict[str, Message]:
return self._histories.setdefault(agent_id, collections.OrderedDict())

@implements(Template.__apply__)
def _call[**P, T](
Expand All @@ -464,29 +506,50 @@ def _call[**P, T](
# Create response_model with env so tools passed as arguments are available
response_model = Encodable.define(template.__signature__.return_annotation, env)

history: collections.OrderedDict[str, Message] = getattr(
template, "__history__", collections.OrderedDict()
) # type: ignore
history_copy = history.copy()
# Get history: from agent history handler if bound to an agent, else fresh
agent = get_bound_agent(template)
if agent is not None:
agent_id = agent.__agent_id__
history = get_agent_history(agent_id)
else:
agent_id = None
history = collections.OrderedDict()

# Track nesting depth per agent so only the outermost call writes back.
# Inner calls work on their own copy but discard it on return.
# See: TestNestedTemplateCalling.test_only_outermost_writes_to_history
depths = self._get_depths()
if agent_id is not None:
depth = depths.get(agent_id, 0)
depths[agent_id] = depth + 1
is_outermost = depth == 0
else:
depth = 0
is_outermost = False

with handler({_get_history: lambda: history_copy}):
call_system(template)

message: Message = call_user(template.__prompt_template__, env)

# loop based on: https://cookbook.openai.com/examples/reasoning_function_calls
tool_calls: list[DecodedToolCall] = []
result: T | None = None
while message["role"] != "assistant" or tool_calls:
message, tool_calls, result = call_assistant(
template.tools, response_model, **self.config
)
for tool_call in tool_calls:
message = call_tool(tool_call)
history_copy = history.copy()

try:
_get_history()
except NotImplementedError:
history.clear()
history.update(history_copy)
return typing.cast(T, result)
with handler({_get_history: lambda: history_copy}):
call_system(template)

message: Message = call_user(template.__prompt_template__, env)

# loop based on: https://cookbook.openai.com/examples/reasoning_function_calls
tool_calls: list[DecodedToolCall] = []
result: T | None = None
while message["role"] != "assistant" or tool_calls:
message, tool_calls, result = call_assistant(
template.tools, response_model, **self.config
)
for tool_call in tool_calls:
message = call_tool(tool_call)

# Only outermost call writes back to canonical history
if is_outermost:
history.clear()
history.update(history_copy)
return typing.cast(T, result)
finally:
if agent_id is not None:
depths[agent_id] = depth
Loading
Loading