Skip to content

Commit d2bafd3

Browse files
committed
enable prompt caching for agent calls
1 parent 7364125 commit d2bafd3

3 files changed

Lines changed: 199 additions & 2 deletions

File tree

effectful/handlers/llm/completions.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,34 @@ def to_feedback_message(self, include_traceback: bool) -> Message:
159159

160160
type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None]
161161

162+
CACHE_CONTROL_EPHEMERAL = {"type": "ephemeral"}
163+
164+
165+
def _add_cache_control_to_history(
166+
history: collections.OrderedDict[str, "Message"],
167+
) -> None:
168+
"""Add cache_control to the last user/tool message in an agent's history.
169+
170+
This enables prompt caching on providers that support it (e.g. Anthropic).
171+
Providers that don't support it (e.g. OpenAI) have cache_control stripped
172+
by litellm's request transformation, so this is always safe to apply.
173+
174+
Mutates the history OrderedDict in place.
175+
"""
176+
if not history:
177+
return
178+
key = next(reversed(history))
179+
msg = history[key]
180+
if msg["role"] not in ("user", "tool"):
181+
return
182+
content = msg.get("content")
183+
if isinstance(content, list) and content:
184+
last_block = content[-1]
185+
if isinstance(last_block, dict) and "cache_control" not in last_block:
186+
new_content = list(content)
187+
new_content[-1] = {**last_block, "cache_control": CACHE_CONTROL_EPHEMERAL}
188+
history[key] = typing.cast(Message, {**msg, "content": new_content})
189+
162190

163191
@Operation.define
164192
@functools.wraps(litellm.completion)
@@ -326,7 +354,18 @@ def flush_text() -> None:
326354
def call_system(template: Template) -> Message:
327355
"""Get system instruction message(s) to prepend to all LLM prompts."""
328356
system_prompt = template.__system_prompt__ or DEFAULT_SYSTEM_PROMPT
329-
message = _make_message(dict(role="system", content=system_prompt))
357+
message = _make_message(
358+
dict(
359+
role="system",
360+
content=[
361+
{
362+
"type": "text",
363+
"text": system_prompt,
364+
"cache_control": {"type": "ephemeral"},
365+
}
366+
],
367+
)
368+
)
330369
try:
331370
history: collections.OrderedDict[str, Message] = _get_history()
332371
if any(m["role"] == "system" for m in history.values()):
@@ -467,13 +506,20 @@ def _call[**P, T](
467506
history: collections.OrderedDict[str, Message] = getattr(
468507
template, "__history__", collections.OrderedDict()
469508
) # type: ignore
509+
is_agent = hasattr(template, "__history__")
470510
history_copy = history.copy()
471511

472512
with handler({_get_history: lambda: history_copy}):
473513
call_system(template)
474514

475515
message: Message = call_user(template.__prompt_template__, env)
476516

517+
# For agents with persistent history, add cache_control to the
518+
# last user message so the growing prefix gets cached on providers
519+
# that support it (Anthropic). litellm strips it for OpenAI.
520+
if is_agent:
521+
_add_cache_control_to_history(history_copy)
522+
477523
# loop based on: https://cookbook.openai.com/examples/reasoning_function_calls
478524
tool_calls: list[DecodedToolCall] = []
479525
result: T | None = None

tests/test_handlers_llm_provider.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,3 +2196,149 @@ def _completion(self, model, messages=None, **kwargs):
21962196
assert messages[0]["role"] == "system", (
21972197
"System message should be the first message in history"
21982198
)
2199+
2200+
2201+
# ============================================================================
2202+
# Prompt Caching Tests
2203+
# ============================================================================
2204+
2205+
2206+
def _has_cache_control(msg: dict) -> bool:
2207+
"""Check if a message dict contains cache_control in any content block."""
2208+
content = msg.get("content")
2209+
if isinstance(content, list):
2210+
return any(isinstance(b, dict) and "cache_control" in b for b in content)
2211+
return False
2212+
2213+
2214+
class CachingAgent(Agent):
2215+
"""A test agent with persistent history."""
2216+
2217+
@Template.define
2218+
def ask(self, question: str) -> str:
2219+
"""You are a helpful assistant. Answer concisely: {question}"""
2220+
raise NotHandled
2221+
2222+
2223+
class TestPromptCaching:
2224+
"""Tests that cache_control is present in messages sent to litellm."""
2225+
2226+
def test_system_message_has_cache_control(self):
2227+
"""System message should include cache_control for prompt caching."""
2228+
capture = MockCompletionHandler([make_text_response("42")])
2229+
provider = LiteLLMProvider(model="test")
2230+
2231+
with handler(provider), handler(capture):
2232+
simple_prompt("test")
2233+
2234+
msgs = capture.received_messages[0]
2235+
system_msgs = [m for m in msgs if m["role"] == "system"]
2236+
assert len(system_msgs) == 1
2237+
assert _has_cache_control(system_msgs[0]), (
2238+
f"System message should have cache_control. Got: {system_msgs[0]}"
2239+
)
2240+
2241+
def test_agent_user_message_has_cache_control(self):
2242+
"""Agent calls should add cache_control to the last user message."""
2243+
capture = MockCompletionHandler([make_text_response("42")])
2244+
provider = LiteLLMProvider(model="test")
2245+
agent = CachingAgent()
2246+
2247+
with handler(provider), handler(capture):
2248+
agent.ask("What is 2+2?")
2249+
2250+
msgs = capture.received_messages[0]
2251+
user_msgs = [m for m in msgs if m["role"] == "user"]
2252+
assert len(user_msgs) == 1
2253+
content = user_msgs[0]["content"]
2254+
assert isinstance(content, list)
2255+
assert "cache_control" in content[-1], (
2256+
f"Agent user message should have cache_control. Got: {content[-1]}"
2257+
)
2258+
2259+
def test_non_agent_user_message_no_cache_control(self):
2260+
"""Non-agent calls should NOT add cache_control to user messages."""
2261+
capture = MockCompletionHandler([make_text_response("42")])
2262+
provider = LiteLLMProvider(model="test")
2263+
2264+
with handler(provider), handler(capture):
2265+
simple_prompt("test")
2266+
2267+
msgs = capture.received_messages[0]
2268+
user_msgs = [m for m in msgs if m["role"] == "user"]
2269+
content = user_msgs[0]["content"]
2270+
assert isinstance(content, list)
2271+
assert "cache_control" not in content[-1], (
2272+
"Non-agent user messages should NOT have cache_control"
2273+
)
2274+
2275+
def test_cache_control_format_is_ephemeral(self):
2276+
"""cache_control should use the ephemeral type."""
2277+
capture = MockCompletionHandler([make_text_response("42")])
2278+
provider = LiteLLMProvider(model="test")
2279+
2280+
with handler(provider), handler(capture):
2281+
simple_prompt("test")
2282+
2283+
for msg in capture.received_messages[0]:
2284+
content = msg.get("content")
2285+
if isinstance(content, list):
2286+
for block in content:
2287+
if isinstance(block, dict) and "cache_control" in block:
2288+
assert block["cache_control"] == {"type": "ephemeral"}
2289+
2290+
def test_litellm_strips_cache_control_for_openai(self):
2291+
"""Verify litellm strips cache_control when transforming for OpenAI."""
2292+
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
2293+
2294+
msgs = [
2295+
{
2296+
"role": "system",
2297+
"content": [
2298+
{
2299+
"type": "text",
2300+
"text": "Hi.",
2301+
"cache_control": {"type": "ephemeral"},
2302+
}
2303+
],
2304+
},
2305+
{
2306+
"role": "user",
2307+
"content": [
2308+
{
2309+
"type": "text",
2310+
"text": "Hi",
2311+
"cache_control": {"type": "ephemeral"},
2312+
}
2313+
],
2314+
},
2315+
]
2316+
config = OpenAIGPTConfig()
2317+
transformed = config.transform_request(
2318+
model="gpt-4o",
2319+
messages=msgs,
2320+
optional_params={},
2321+
litellm_params={},
2322+
headers={},
2323+
)
2324+
for msg in transformed["messages"]:
2325+
content = msg.get("content")
2326+
if isinstance(content, list):
2327+
for block in content:
2328+
assert "cache_control" not in block
2329+
2330+
@requires_openai
2331+
def test_openai_accepts_cache_control_via_litellm(self):
2332+
"""OpenAI works fine with cache_control (litellm strips it)."""
2333+
provider = LiteLLMProvider(model="gpt-4o-mini")
2334+
with handler(provider):
2335+
result = simple_prompt("math")
2336+
assert isinstance(result, str)
2337+
2338+
@requires_anthropic
2339+
def test_anthropic_accepts_cache_control(self):
2340+
"""Anthropic should accept messages with cache_control."""
2341+
provider = LiteLLMProvider(model="claude-opus-4-6", max_tokens=20)
2342+
with handler(provider):
2343+
result = simple_prompt("math")
2344+
assert isinstance(result, str)

tests/test_handlers_llm_template.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,12 @@ def standalone(topic: str) -> str:
464464
standalone("fish")
465465

466466
assert_single_system_message_first(mock.received_messages[0])
467-
assert mock.received_messages[0][0]["content"] == DEFAULT_SYSTEM_PROMPT
467+
content = mock.received_messages[0][0]["content"]
468+
# System message content is now a list of blocks with cache_control
469+
if isinstance(content, list):
470+
assert content[0]["text"] == DEFAULT_SYSTEM_PROMPT
471+
else:
472+
assert content == DEFAULT_SYSTEM_PROMPT
468473

469474

470475
class TestAgentDocstringFallback:

0 commit comments

Comments
 (0)