@@ -2196,3 +2196,149 @@ def _completion(self, model, messages=None, **kwargs):
21962196 assert messages [0 ]["role" ] == "system" , (
21972197 "System message should be the first message in history"
21982198 )
2199+
2200+
2201+ # ============================================================================
2202+ # Prompt Caching Tests
2203+ # ============================================================================
2204+
2205+
2206+ def _has_cache_control (msg : dict ) -> bool :
2207+ """Check if a message dict contains cache_control in any content block."""
2208+ content = msg .get ("content" )
2209+ if isinstance (content , list ):
2210+ return any (isinstance (b , dict ) and "cache_control" in b for b in content )
2211+ return False
2212+
2213+
2214+ class CachingAgent (Agent ):
2215+ """A test agent with persistent history."""
2216+
2217+ @Template .define
2218+ def ask (self , question : str ) -> str :
2219+ """You are a helpful assistant. Answer concisely: {question}"""
2220+ raise NotHandled
2221+
2222+
2223+ class TestPromptCaching :
2224+ """Tests that cache_control is present in messages sent to litellm."""
2225+
2226+ def test_system_message_has_cache_control (self ):
2227+ """System message should include cache_control for prompt caching."""
2228+ capture = MockCompletionHandler ([make_text_response ("42" )])
2229+ provider = LiteLLMProvider (model = "test" )
2230+
2231+ with handler (provider ), handler (capture ):
2232+ simple_prompt ("test" )
2233+
2234+ msgs = capture .received_messages [0 ]
2235+ system_msgs = [m for m in msgs if m ["role" ] == "system" ]
2236+ assert len (system_msgs ) == 1
2237+ assert _has_cache_control (system_msgs [0 ]), (
2238+ f"System message should have cache_control. Got: { system_msgs [0 ]} "
2239+ )
2240+
2241+ def test_agent_user_message_has_cache_control (self ):
2242+ """Agent calls should add cache_control to the last user message."""
2243+ capture = MockCompletionHandler ([make_text_response ("42" )])
2244+ provider = LiteLLMProvider (model = "test" )
2245+ agent = CachingAgent ()
2246+
2247+ with handler (provider ), handler (capture ):
2248+ agent .ask ("What is 2+2?" )
2249+
2250+ msgs = capture .received_messages [0 ]
2251+ user_msgs = [m for m in msgs if m ["role" ] == "user" ]
2252+ assert len (user_msgs ) == 1
2253+ content = user_msgs [0 ]["content" ]
2254+ assert isinstance (content , list )
2255+ assert "cache_control" in content [- 1 ], (
2256+ f"Agent user message should have cache_control. Got: { content [- 1 ]} "
2257+ )
2258+
2259+ def test_non_agent_user_message_no_cache_control (self ):
2260+ """Non-agent calls should NOT add cache_control to user messages."""
2261+ capture = MockCompletionHandler ([make_text_response ("42" )])
2262+ provider = LiteLLMProvider (model = "test" )
2263+
2264+ with handler (provider ), handler (capture ):
2265+ simple_prompt ("test" )
2266+
2267+ msgs = capture .received_messages [0 ]
2268+ user_msgs = [m for m in msgs if m ["role" ] == "user" ]
2269+ content = user_msgs [0 ]["content" ]
2270+ assert isinstance (content , list )
2271+ assert "cache_control" not in content [- 1 ], (
2272+ "Non-agent user messages should NOT have cache_control"
2273+ )
2274+
2275+ def test_cache_control_format_is_ephemeral (self ):
2276+ """cache_control should use the ephemeral type."""
2277+ capture = MockCompletionHandler ([make_text_response ("42" )])
2278+ provider = LiteLLMProvider (model = "test" )
2279+
2280+ with handler (provider ), handler (capture ):
2281+ simple_prompt ("test" )
2282+
2283+ for msg in capture .received_messages [0 ]:
2284+ content = msg .get ("content" )
2285+ if isinstance (content , list ):
2286+ for block in content :
2287+ if isinstance (block , dict ) and "cache_control" in block :
2288+ assert block ["cache_control" ] == {"type" : "ephemeral" }
2289+
2290+ def test_litellm_strips_cache_control_for_openai (self ):
2291+ """Verify litellm strips cache_control when transforming for OpenAI."""
2292+ from litellm .llms .openai .chat .gpt_transformation import OpenAIGPTConfig
2293+
2294+ msgs = [
2295+ {
2296+ "role" : "system" ,
2297+ "content" : [
2298+ {
2299+ "type" : "text" ,
2300+ "text" : "Hi." ,
2301+ "cache_control" : {"type" : "ephemeral" },
2302+ }
2303+ ],
2304+ },
2305+ {
2306+ "role" : "user" ,
2307+ "content" : [
2308+ {
2309+ "type" : "text" ,
2310+ "text" : "Hi" ,
2311+ "cache_control" : {"type" : "ephemeral" },
2312+ }
2313+ ],
2314+ },
2315+ ]
2316+ config = OpenAIGPTConfig ()
2317+ transformed = config .transform_request (
2318+ model = "gpt-4o" ,
2319+ messages = msgs ,
2320+ optional_params = {},
2321+ litellm_params = {},
2322+ headers = {},
2323+ )
2324+ for msg in transformed ["messages" ]:
2325+ content = msg .get ("content" )
2326+ if isinstance (content , list ):
2327+ for block in content :
2328+ assert "cache_control" not in block
2329+
2330+ @requires_openai
2331+ def test_openai_accepts_cache_control_via_litellm (self ):
2332+ """OpenAI works fine with cache_control (litellm strips it)."""
2333+ provider = LiteLLMProvider (model = "gpt-4o-mini" )
2334+ with handler (provider ):
2335+ result = simple_prompt ("math" )
2336+ assert isinstance (result , str )
2337+
2338+ @requires_anthropic
2339+ def test_anthropic_accepts_cache_control (self ):
2340+ """Anthropic should accept messages with cache_control."""
2341+ provider = LiteLLMProvider (model = "claude-opus-4-6" , max_tokens = 20 )
2342+ with handler (provider ):
2343+ result = simple_prompt ("math" )
2344+ assert isinstance (result , str )
0 commit comments