Skip to content

Commit fa3fb6b

Browse files
committed
Add Agent class
1 parent d65dc4e commit fa3fb6b

6 files changed

Lines changed: 384 additions & 71 deletions

File tree

docs/source/agent.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

docs/source/agent_example.rst

Lines changed: 0 additions & 18 deletions
This file was deleted.

docs/source/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ Table of Contents
1818
lambda_example
1919
semi_ring_example
2020
beam_search_example
21-
agent_example
2221

2322
.. toctree::
2423
:maxdepth: 2

effectful/handlers/llm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .template import Template, Tool
1+
from .template import Agent, Template, Tool
22

3-
__all__ = ["Template", "Tool"]
3+
__all__ = ["Agent", "Template", "Tool"]

effectful/handlers/llm/template.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import abc
2+
import collections
3+
import functools
14
import inspect
25
import types
36
import typing
4-
from collections import ChainMap
57
from collections.abc import Callable, Mapping, MutableMapping
68
from dataclasses import dataclass
79
from typing import Annotated, Any
810

11+
from effectful.ops.semantics import handler
912
from effectful.ops.types import INSTANCE_OP_PREFIX, Annotation, Operation
1013

1114

@@ -183,7 +186,7 @@ class Template[**P, T](Tool[P, T]):
183186
184187
"""
185188

186-
__context__: ChainMap[str, Any]
189+
__context__: collections.ChainMap[str, Any]
187190

188191
@property
189192
def __prompt_template__(self) -> str:
@@ -283,11 +286,72 @@ def define[**Q, V](
283286
frame = frame.f_back
284287

285288
contexts.append(globals_proxy)
286-
context: ChainMap[str, Any] = ChainMap(
289+
context: collections.ChainMap[str, Any] = collections.ChainMap(
287290
*typing.cast(list[MutableMapping[str, Any]], contexts)
288291
)
289292

290293
op = super().define(default, *args, **kwargs)
291294
op.__context__ = context # type: ignore[attr-defined]
292295

293296
return typing.cast(Template[Q, V], op)
297+
298+
299+
class Agent(abc.ABC):
300+
"""Mixin that gives each instance a persistent LLM message history.
301+
302+
Subclass and decorate methods with :func:`Template.define`.
303+
Each instance accumulates messages across calls so the LLM sees
304+
prior conversation context.
305+
306+
Agents compose freely with :func:`dataclasses.dataclass` and other
307+
base classes. Instance attributes are available in template
308+
docstrings via ``{self.attr}``.
309+
310+
Example::
311+
312+
import dataclasses
313+
from effectful.handlers.llm import Agent, Template
314+
from effectful.handlers.llm.completions import LiteLLMProvider
315+
from effectful.ops.semantics import handler
316+
from effectful.ops.types import NotHandled
317+
318+
@dataclasses.dataclass
319+
class ChatBot(Agent):
320+
bot_name: str = dataclasses.field(default="ChatBot")
321+
322+
@Template.define
323+
def send(self, user_input: str) -> str:
324+
\"""Friendly bot named {self.bot_name}. User writes: {user_input}\"""
325+
raise NotHandled
326+
327+
provider = LiteLLMProvider()
328+
chatbot = ChatBot()
329+
330+
with handler(provider):
331+
chatbot.send("Hi! How are you? I am in France.")
332+
chatbot.send("Remind me again, where am I?") # sees prior context
333+
334+
"""
335+
336+
__history__: collections.OrderedDict[str, Any]
337+
338+
def __init_subclass__(cls, **kwargs):
339+
super().__init_subclass__(**kwargs)
340+
prop = functools.cached_property(lambda _: collections.OrderedDict())
341+
prop.__set_name__(cls, "__history__")
342+
cls.__history__ = prop
343+
344+
for name in list(cls.__dict__):
345+
attr = cls.__dict__[name]
346+
if not isinstance(attr, Template):
347+
continue
348+
_template = attr
349+
350+
@functools.wraps(_template)
351+
def wrapper(self, *args, _t=_template, **kwargs):
352+
from effectful.handlers.llm.completions import get_message_sequence
353+
354+
with handler({get_message_sequence: lambda: self.__history__}):
355+
return _t(self, *args, **kwargs)
356+
357+
setattr(cls, name, wrapper)

0 commit comments

Comments
 (0)