Skip to content

Commit 19eafd5

Browse files
authored
Deduplicate agent system message and add tests (#600)
* Deduplicate agent system message and add tests * Pin sphinx version to avoid breaking
1 parent 20c18ff commit 19eafd5

3 files changed

Lines changed: 160 additions & 1 deletion

File tree

effectful/handlers/llm/completions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,5 +487,6 @@ def _call[**P, T](
487487
try:
488488
_get_history()
489489
except NotImplementedError:
490+
history.clear()
490491
history.update(history_copy)
491492
return typing.cast(T, result)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ docs = [
5858
"sphinx_rtd_theme",
5959
"myst-parser",
6060
"nbsphinx",
61-
"sphinx_autodoc_typehints",
61+
"sphinx_autodoc_typehints>=3.6,<3.9",
6262
"pypandoc_binary<1.16",
6363
]
6464
test = [

tests/test_handlers_llm_provider.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,3 +2038,161 @@ def _completion(self, model, messages=None, **kwargs):
20382038
# Only messages from the successful call should be in history
20392039
assert len(agent.__history__) >= 2
20402040
assert len(agent.__history__) > history_after_error
2041+
2042+
2043+
class TestAgentSystemMessageDeduplication:
2044+
"""Regression tests for system message duplication bug.
2045+
2046+
When LiteLLMProvider._call copies the history, call_system replaces the
2047+
system message in the copy. Previously, history.update(history_copy) was
2048+
used to merge back, which is additive — it didn't remove the stale system
2049+
message key deleted from the copy. This caused multiple system messages to
2050+
accumulate, triggering an assertion on the 3rd+ call.
2051+
2052+
The fix is history.clear() before history.update(history_copy).
2053+
"""
2054+
2055+
def test_three_consecutive_calls_no_system_message_duplication(self):
2056+
"""Three consecutive agent calls should not fail with duplicate system messages."""
2057+
import dataclasses
2058+
2059+
@dataclasses.dataclass
2060+
class ThreeCallAgent(Agent):
2061+
"""You are a test agent for system message deduplication."""
2062+
2063+
@Template.define
2064+
def ask(self, question: str) -> str:
2065+
"""Answer: {question}"""
2066+
raise NotHandled
2067+
2068+
call_count = 0
2069+
2070+
class CountingHandler(ObjectInterpretation):
2071+
@implements(completion)
2072+
def _completion(self, model, messages=None, **kwargs):
2073+
nonlocal call_count
2074+
call_count += 1
2075+
return make_text_response(f"answer {call_count}")
2076+
2077+
agent = ThreeCallAgent()
2078+
2079+
with handler(LiteLLMProvider(model="test")), handler(CountingHandler()):
2080+
r1 = agent.ask("q1")
2081+
r2 = agent.ask("q2")
2082+
r3 = agent.ask("q3")
2083+
2084+
assert r1 == "answer 1"
2085+
assert r2 == "answer 2"
2086+
assert r3 == "answer 3"
2087+
2088+
def test_history_has_exactly_one_system_message_after_multiple_calls(self):
2089+
"""After multiple calls, the agent history should contain exactly one system message."""
2090+
import dataclasses
2091+
2092+
@dataclasses.dataclass
2093+
class SystemMsgAgent(Agent):
2094+
"""You are a system message count test agent."""
2095+
2096+
@Template.define
2097+
def do(self, task: str) -> str:
2098+
"""Do: {task}"""
2099+
raise NotHandled
2100+
2101+
call_count = 0
2102+
2103+
class MultiHandler(ObjectInterpretation):
2104+
@implements(completion)
2105+
def _completion(self, model, messages=None, **kwargs):
2106+
nonlocal call_count
2107+
call_count += 1
2108+
return make_text_response(f"done {call_count}")
2109+
2110+
agent = SystemMsgAgent()
2111+
2112+
with handler(LiteLLMProvider(model="test")), handler(MultiHandler()):
2113+
agent.do("a")
2114+
agent.do("b")
2115+
agent.do("c")
2116+
agent.do("d")
2117+
2118+
system_msgs = [m for m in agent.__history__.values() if m["role"] == "system"]
2119+
assert len(system_msgs) == 1, (
2120+
f"Expected exactly 1 system message, got {len(system_msgs)}"
2121+
)
2122+
2123+
def test_conversation_history_preserved_across_calls(self):
2124+
"""Earlier user/assistant messages should persist across multiple calls."""
2125+
import dataclasses
2126+
2127+
@dataclasses.dataclass
2128+
class MemoryAgent(Agent):
2129+
"""You are a memory test agent."""
2130+
2131+
@Template.define
2132+
def chat(self, msg: str) -> str:
2133+
"""User says: {msg}"""
2134+
raise NotHandled
2135+
2136+
call_count = 0
2137+
2138+
class MemoryHandler(ObjectInterpretation):
2139+
@implements(completion)
2140+
def _completion(self, model, messages=None, **kwargs):
2141+
nonlocal call_count
2142+
call_count += 1
2143+
# Verify that previous messages are visible to later calls
2144+
if call_count == 3:
2145+
# Third call should see messages from calls 1 and 2
2146+
user_msgs = [m for m in messages if m["role"] == "user"]
2147+
assert len(user_msgs) == 3, (
2148+
f"Third call should see 3 user messages, got {len(user_msgs)}"
2149+
)
2150+
return make_text_response(f"reply {call_count}")
2151+
2152+
agent = MemoryAgent()
2153+
2154+
with handler(LiteLLMProvider(model="test")), handler(MemoryHandler()):
2155+
agent.chat("first")
2156+
agent.chat("second")
2157+
agent.chat("third")
2158+
2159+
# History should have: 1 system + 3 user + 3 assistant = 7
2160+
assert len(agent.__history__) == 7
2161+
roles = [m["role"] for m in agent.__history__.values()]
2162+
assert roles.count("system") == 1
2163+
assert roles.count("user") == 3
2164+
assert roles.count("assistant") == 3
2165+
2166+
def test_system_message_is_always_first(self):
2167+
"""The system message should remain the first message after multiple calls."""
2168+
import dataclasses
2169+
2170+
@dataclasses.dataclass
2171+
class OrderAgent(Agent):
2172+
"""You are a message order test agent."""
2173+
2174+
@Template.define
2175+
def step(self, n: int) -> str:
2176+
"""Step {n}"""
2177+
raise NotHandled
2178+
2179+
call_count = 0
2180+
2181+
class OrderHandler(ObjectInterpretation):
2182+
@implements(completion)
2183+
def _completion(self, model, messages=None, **kwargs):
2184+
nonlocal call_count
2185+
call_count += 1
2186+
return make_text_response(f"step {call_count}")
2187+
2188+
agent = OrderAgent()
2189+
2190+
with handler(LiteLLMProvider(model="test")), handler(OrderHandler()):
2191+
agent.step(1)
2192+
agent.step(2)
2193+
agent.step(3)
2194+
2195+
messages = list(agent.__history__.values())
2196+
assert messages[0]["role"] == "system", (
2197+
"System message should be the first message in history"
2198+
)

0 commit comments

Comments
 (0)