From 607da69d1ed6315448fa41d73efc45463f548f2a Mon Sep 17 00:00:00 2001 From: smhanan Date: Wed, 7 Jan 2026 22:13:40 -0800 Subject: [PATCH] fix: tool result sanitization and context truncation --- .../openai_to_anthropic/requests.py | 284 +++++++-- ccproxy/llms/utils/__init__.py | 18 + ccproxy/llms/utils/context_truncation.py | 210 +++++++ ccproxy/llms/utils/token_estimation.py | 195 ++++++ ccproxy/plugins/claude_api/adapter.py | 31 + ccproxy/plugins/claude_sdk/adapter.py | 30 + ccproxy/streaming/deferred.py | 19 + .../test_native_anthropic_sanitization.py | 591 ++++++++++++++++++ tests/test_tool_result_sanitization.py | 268 ++++++-- .../llms/test_context_window_management.py | 290 +++++++++ .../llms/test_truncation_creates_orphans.py | 222 +++++++ .../streaming/test_deferred_stream_errors.py | 66 ++ 12 files changed, 2121 insertions(+), 103 deletions(-) create mode 100644 ccproxy/llms/utils/__init__.py create mode 100644 ccproxy/llms/utils/context_truncation.py create mode 100644 ccproxy/llms/utils/token_estimation.py create mode 100644 tests/plugins/claude_api/unit/test_native_anthropic_sanitization.py create mode 100644 tests/unit/llms/test_context_window_management.py create mode 100644 tests/unit/llms/test_truncation_creates_orphans.py create mode 100644 tests/unit/streaming/test_deferred_stream_errors.py diff --git a/ccproxy/llms/formatters/openai_to_anthropic/requests.py b/ccproxy/llms/formatters/openai_to_anthropic/requests.py index 5540289c..0129e126 100644 --- a/ccproxy/llms/formatters/openai_to_anthropic/requests.py +++ b/ccproxy/llms/formatters/openai_to_anthropic/requests.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from collections import Counter from typing import Any from ccproxy.core.constants import DEFAULT_MAX_TOKENS @@ -15,83 +16,242 @@ def _sanitize_tool_results(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Remove orphaned tool_result blocks that don't have matching tool_use blocks. + """Remove orphaned tool blocks that don't have matching counterparts. - The Anthropic API requires that each tool_result block must have a corresponding - tool_use block in the immediately preceding assistant message. This function removes - tool_result blocks that don't meet this requirement, converting them to text to + The Anthropic API requires: + 1. Each tool_result block must have a corresponding tool_use in the preceding assistant message + 2. Each tool_use block must have a corresponding tool_result in the next user message + + This function removes orphaned blocks of both types, converting them to text to preserve information. Args: messages: List of Anthropic format messages Returns: - Sanitized messages with orphaned tool_results removed or converted to text + Sanitized messages with orphaned tool blocks removed or converted to text """ if not messages: return messages - sanitized = [] - for i, msg in enumerate(messages): - if msg.get("role") == "user" and isinstance(msg.get("content"), list): - # Find tool_use_ids from the immediately preceding assistant message - valid_tool_use_ids: set[str] = set() - if i > 0 and messages[i - 1].get("role") == "assistant": - prev_content = messages[i - 1].get("content", []) - if isinstance(prev_content, list): - for block in prev_content: - if isinstance(block, dict) and block.get("type") == "tool_use": - tool_id = block.get("id") - if tool_id: - valid_tool_use_ids.add(tool_id) - - # Filter content blocks - new_content = [] - orphaned_results = [] - for block in msg["content"]: - if isinstance(block, dict) and block.get("type") == "tool_result": - tool_use_id = block.get("tool_use_id") - if tool_use_id in valid_tool_use_ids: + def _iter_content_blocks(content: Any) -> list[Any]: + if isinstance(content, list): + return content + if isinstance(content, dict): + return [content] + return [] + + def _collect_tool_use_counts(content: Any) -> Counter[str]: + counts: Counter[str] = Counter() + for block in _iter_content_blocks(content): + if isinstance(block, dict) and block.get("type") == "tool_use": + tool_id = block.get("id") + if tool_id: + counts[str(tool_id)] += 1 + return counts + + def _collect_tool_result_counts(content: Any) -> Counter[str]: + counts: Counter[str] = Counter() + for block in _iter_content_blocks(content): + if isinstance(block, dict) and block.get("type") == "tool_result": + tool_use_id = block.get("tool_use_id") + if tool_use_id: + counts[str(tool_use_id)] += 1 + return counts + + def _sanitize_once( + current_messages: list[dict[str, Any]], + ) -> tuple[list[dict[str, Any]], bool]: + assistant_tool_use_counts: list[Counter[str]] = [ + Counter() for _ in current_messages + ] + user_tool_result_counts: list[Counter[str]] = [ + Counter() for _ in current_messages + ] + + for i, msg in enumerate(current_messages): + role = msg.get("role") + content = msg.get("content") + if role == "assistant": + assistant_tool_use_counts[i] = _collect_tool_use_counts(content) + elif role == "user": + user_tool_result_counts[i] = _collect_tool_result_counts(content) + + paired_counts_for_assistant: list[Counter[str]] = [ + Counter() for _ in current_messages + ] + paired_counts_for_user: list[Counter[str]] = [ + Counter() for _ in current_messages + ] + + for i, msg in enumerate(current_messages): + if msg.get("role") != "assistant": + continue + if ( + i + 1 < len(current_messages) + and current_messages[i + 1].get("role") == "user" + ): + paired: Counter[str] = Counter() + assistant_counts = assistant_tool_use_counts[i] + user_counts = user_tool_result_counts[i + 1] + for tool_id, tool_use_count in assistant_counts.items(): + if tool_id in user_counts: + paired[tool_id] = min(tool_use_count, user_counts[tool_id]) + paired_counts_for_assistant[i] = paired + paired_counts_for_user[i + 1] = paired + + sanitized: list[dict[str, Any]] = [] + changed = False + for i, msg in enumerate(current_messages): + role = msg.get("role") + content = msg.get("content") + content_blocks = _iter_content_blocks(content) + content_was_dict = isinstance(content, dict) + + # Handle assistant messages with tool_use blocks + if role == "assistant" and content_blocks: + valid_tool_use_counts = paired_counts_for_assistant[i] + kept_tool_use_counts: Counter[str] = Counter() + + new_content = [] + orphaned_tool_uses = [] + + for block in content_blocks: + if isinstance(block, dict) and block.get("type") == "tool_use": + raw_tool_id = block.get("id") + tool_use_id = str(raw_tool_id) if raw_tool_id else None + # Only keep tool_use when the *next* user message provides the result + if tool_use_id and kept_tool_use_counts[ + tool_use_id + ] < valid_tool_use_counts.get(tool_use_id, 0): + kept_tool_use_counts[tool_use_id] += 1 + new_content.append(block) + else: + orphaned_tool_uses.append(block) + changed = True + logger.warning( + "orphaned_tool_use_removed", + tool_use_id=tool_use_id, + tool_name=block.get("name"), + message_index=i, + category="message_sanitization", + ) + else: new_content.append(block) + + # Convert orphaned tool_use blocks to text + if orphaned_tool_uses: + orphan_text = ( + "[Tool calls from compacted history - results not available]\n" + ) + for orphan in orphaned_tool_uses: + tool_name = orphan.get("name", "unknown") + tool_input = orphan.get("input", {}) + input_str = str(tool_input) + if len(input_str) > 200: + input_str = input_str[:200] + "..." + orphan_text += f"- Called {tool_name}: {input_str}\n" + + # Add text block at the beginning (or prepend to existing text) + if ( + new_content + and isinstance(new_content[0], dict) + and new_content[0].get("type") == "text" + ): + new_content[0] = { + **new_content[0], + "text": orphan_text + "\n" + new_content[0]["text"], + } + else: + new_content.insert(0, {"type": "text", "text": orphan_text}) + + if new_content: + if content_was_dict: + changed = True + sanitized.append({**msg, "content": new_content}) + else: + # If no content left, add minimal text to avoid empty assistant message + changed = True + sanitized.append( + {**msg, "content": "[Previous response content compacted]"} + ) + continue + + # Handle user messages with tool_result blocks + elif role == "user" and content_blocks: + # Find tool_use_ids from the immediately preceding assistant message + valid_tool_use_counts = paired_counts_for_user[i] + kept_tool_result_counts: Counter[str] = Counter() + + # Filter content blocks + new_content = [] + orphaned_results = [] + for block in content_blocks: + if isinstance(block, dict) and block.get("type") == "tool_result": + tool_use_id = block.get("tool_use_id") + tool_use_id = str(tool_use_id) if tool_use_id else None + if tool_use_id and kept_tool_result_counts[ + tool_use_id + ] < valid_tool_use_counts.get(tool_use_id, 0): + kept_tool_result_counts[tool_use_id] += 1 + new_content.append(block) + else: + # Track orphaned tool_result for conversion to text + orphaned_results.append(block) + changed = True + logger.warning( + "orphaned_tool_result_removed", + tool_use_id=tool_use_id, + valid_ids=list(valid_tool_use_counts.keys()), + message_index=i, + category="message_sanitization", + ) else: - # Track orphaned tool_result for conversion to text - orphaned_results.append(block) - logger.warning( - "orphaned_tool_result_removed", - tool_use_id=tool_use_id, - valid_ids=list(valid_tool_use_ids), - message_index=i, - category="message_sanitization", + new_content.append(block) + + # Convert orphaned results to text block to preserve information. + # Avoid injecting text into a valid tool_result reply message. + if orphaned_results and sum(kept_tool_result_counts.values()) == 0: + orphan_text = "[Previous tool results from compacted history]\n" + for orphan in orphaned_results: + result_content = orphan.get("content", "") + if isinstance(result_content, list): + text_parts = [] + for c in result_content: + if isinstance(c, dict) and c.get("type") == "text": + text_parts.append(c.get("text", "")) + result_content = "\n".join(text_parts) + # Truncate long content + content_str = str(result_content) + if len(content_str) > 500: + content_str = content_str[:500] + "..." + orphan_text += ( + f"- Tool {orphan.get('tool_use_id', 'unknown')}: " + f"{content_str}\n" ) + + # Add as text block at the beginning + new_content.insert(0, {"type": "text", "text": orphan_text}) + + # Update message content (only if we have content left) + if new_content: + if content_was_dict: + changed = True + sanitized.append({**msg, "content": new_content}) else: - new_content.append(block) - - # Convert orphaned results to text block to preserve information - if orphaned_results: - orphan_text = "[Previous tool results from compacted history]\n" - for orphan in orphaned_results: - content = orphan.get("content", "") - if isinstance(content, list): - text_parts = [] - for c in content: - if isinstance(c, dict) and c.get("type") == "text": - text_parts.append(c.get("text", "")) - content = "\n".join(text_parts) - # Truncate long content - content_str = str(content) - if len(content_str) > 500: - content_str = content_str[:500] + "..." - orphan_text += f"- Tool {orphan.get('tool_use_id', 'unknown')}: {content_str}\n" - - # Add as text block at the beginning - new_content.insert(0, {"type": "text", "text": orphan_text}) - - # Update message content (only if we have content left) - if new_content: - sanitized.append({**msg, "content": new_content}) - # If no content left, skip this message entirely - else: - sanitized.append(msg) + # If no content left, skip this message entirely + changed = True + continue + else: + sanitized.append(msg) + + return sanitized, changed + + sanitized = messages + for _ in range(2): + sanitized, changed = _sanitize_once(sanitized) + if not changed: + break return sanitized diff --git a/ccproxy/llms/utils/__init__.py b/ccproxy/llms/utils/__init__.py new file mode 100644 index 00000000..4c2debbb --- /dev/null +++ b/ccproxy/llms/utils/__init__.py @@ -0,0 +1,18 @@ +"""LLM utility modules for token estimation and context management.""" + +from .context_truncation import truncate_to_fit +from .token_estimation import ( + estimate_messages_tokens, + estimate_request_tokens, + estimate_tokens, + get_max_input_tokens, +) + + +__all__ = [ + "estimate_tokens", + "estimate_messages_tokens", + "estimate_request_tokens", + "get_max_input_tokens", + "truncate_to_fit", +] diff --git a/ccproxy/llms/utils/context_truncation.py b/ccproxy/llms/utils/context_truncation.py new file mode 100644 index 00000000..43e0afec --- /dev/null +++ b/ccproxy/llms/utils/context_truncation.py @@ -0,0 +1,210 @@ +"""Context window truncation utilities.""" + +import copy +from typing import Any + +from ccproxy.core.logging import get_logger + +from .token_estimation import estimate_request_tokens + + +logger = get_logger(__name__) + + +# Maximum characters to keep for truncated content blocks +MAX_TRUNCATED_CONTENT_CHARS = 10000 + + +def _truncate_large_content_blocks( + messages: list[dict[str, Any]], + max_chars: int = MAX_TRUNCATED_CONTENT_CHARS, +) -> tuple[list[dict[str, Any]], int]: + """Truncate large content blocks within messages. + + This is a fallback when message-level truncation isn't enough. + Targets large tool_result blocks and text content. + + Args: + messages: List of messages to process + max_chars: Maximum characters to keep per content block + + Returns: + Tuple of (modified_messages, blocks_truncated_count) + """ + truncated_count = 0 + modified_messages = [] + + for msg in messages: + msg_copy = copy.deepcopy(msg) + content = msg_copy.get("content") + + if isinstance(content, str) and len(content) > max_chars: + # Truncate large string content + msg_copy["content"] = ( + content[:max_chars] + + f"\n\n[Content truncated - {len(content) - max_chars} characters removed]" + ) + truncated_count += 1 + + elif isinstance(content, list): + # Process content blocks + new_content = [] + for block in content: + if isinstance(block, dict): + block_copy = copy.deepcopy(block) + + # Handle tool_result blocks with large content + if block_copy.get("type") == "tool_result": + tool_content = block_copy.get("content", "") + if ( + isinstance(tool_content, str) + and len(tool_content) > max_chars + ): + block_copy["content"] = ( + tool_content[:max_chars] + + f"\n\n[Tool result truncated - {len(tool_content) - max_chars} characters removed]" + ) + truncated_count += 1 + + # Handle text blocks with large content + elif block_copy.get("type") == "text": + text = block_copy.get("text", "") + if len(text) > max_chars: + block_copy["text"] = ( + text[:max_chars] + + f"\n\n[Text truncated - {len(text) - max_chars} characters removed]" + ) + truncated_count += 1 + + new_content.append(block_copy) + else: + new_content.append(block) + + msg_copy["content"] = new_content + + modified_messages.append(msg_copy) + + return modified_messages, truncated_count + + +def truncate_to_fit( + request_data: dict[str, Any], + max_input_tokens: int, + preserve_recent: int = 10, + safety_margin: float = 0.9, +) -> tuple[dict[str, Any], bool]: + """Truncate request to fit within token limit. + + Strategy: + 1. Always preserve system prompt and tools + 2. Try to preserve the last N messages (preserve_recent) + 3. Remove oldest messages first + 4. If too few messages to truncate, reduce preserve_recent dynamically + 5. As a last resort, truncate large content blocks within messages + 6. Add a truncation notice when content is removed + + Args: + request_data: The request payload + max_input_tokens: Model's max input token limit + preserve_recent: Number of recent messages to always keep + safety_margin: Target this fraction of max to allow for estimation error + + Returns: + Tuple of (modified_request_data, was_truncated) + """ + target_tokens = int(max_input_tokens * safety_margin) + + current_tokens = estimate_request_tokens(request_data) + if current_tokens <= target_tokens: + return request_data, False + + # Work on a copy + modified = copy.deepcopy(request_data) + messages = modified.get("messages", []) + + # If we have fewer messages than preserve_recent, reduce preserve_recent + # We need at least 1 message to be truncatable for this strategy to work + effective_preserve = min(preserve_recent, len(messages) - 1) + + # If we have 0 or 1 messages, we can't do message-level truncation + # Skip to content-level truncation + if effective_preserve < 0: + effective_preserve = 0 + + # Split into truncatable and preserved messages + if effective_preserve > 0: + truncatable = messages[:-effective_preserve] + preserved = messages[-effective_preserve:] + else: + truncatable = list(messages) + preserved = [] + + # Remove oldest messages until we're under the limit + removed_count = 0 + while truncatable and estimate_request_tokens(modified) > target_tokens: + truncatable.pop(0) + removed_count += 1 + modified["messages"] = truncatable + preserved + + # Check if we're still over the limit after removing all truncatable messages + if estimate_request_tokens(modified) > target_tokens: + logger.info( + "context_truncation_message_level_insufficient", + reason="still_over_limit_after_message_truncation", + message_count=len(modified.get("messages", [])), + current_tokens=estimate_request_tokens(modified), + target_tokens=target_tokens, + category="context_management", + ) + + # Fallback: truncate large content blocks within remaining messages + truncated_messages, blocks_truncated = _truncate_large_content_blocks( + modified.get("messages", []) + ) + modified["messages"] = truncated_messages + + if blocks_truncated > 0: + logger.info( + "context_truncation_content_level", + blocks_truncated=blocks_truncated, + current_tokens=estimate_request_tokens(modified), + target_tokens=target_tokens, + category="context_management", + ) + + # If still over limit after content truncation, log error + final_tokens = estimate_request_tokens(modified) + if final_tokens > target_tokens: + logger.error( + "context_truncation_failed", + reason="still_over_limit_after_all_truncation", + final_tokens=final_tokens, + target_tokens=target_tokens, + messages_removed=removed_count, + blocks_truncated=blocks_truncated, + category="context_management", + ) + # Still return the truncated version - it's better than nothing + # The API will return an error, but at least we tried + + # Add truncation notice as first user message if we removed content + if removed_count > 0: + notice = { + "role": "user", + "content": f"[Context truncated - {removed_count} earlier messages removed to fit context window]", + } + modified["messages"] = [notice] + modified["messages"] + + final_tokens = estimate_request_tokens(modified) + + logger.info( + "context_truncated", + original_tokens=current_tokens, + final_tokens=final_tokens, + messages_removed=removed_count, + target_tokens=target_tokens, + effective_preserve_recent=effective_preserve, + category="context_management", + ) + + return modified, True diff --git a/ccproxy/llms/utils/token_estimation.py b/ccproxy/llms/utils/token_estimation.py new file mode 100644 index 00000000..1c3abfdb --- /dev/null +++ b/ccproxy/llms/utils/token_estimation.py @@ -0,0 +1,195 @@ +"""Token estimation utilities for context window management.""" + +import json +from pathlib import Path +from typing import Any + + +# Cache for loaded token limits +_token_limits_cache: dict[str, int] | None = None + + +def estimate_tokens(content: Any) -> int: + """Estimate token count for content. + + Uses ~3 characters per token heuristic for English text. + This is a conservative estimate - actual may be lower. + + Args: + content: Message content (string, list of blocks, or dict) + + Returns: + Estimated token count + """ + if content is None: + return 0 + + if isinstance(content, str): + # ~3 chars per token for English, be conservative + return max(1, len(content) // 3) + + if isinstance(content, list): + total = 0 + for block in content: + if isinstance(block, dict): + block_type = block.get("type", "") + if block_type == "text": + total += estimate_tokens(block.get("text", "")) + elif block_type == "tool_use": + # Tool name + input + total += estimate_tokens(block.get("name", "")) + total += estimate_tokens(json.dumps(block.get("input", {}))) + elif block_type == "tool_result": + total += estimate_tokens(block.get("content", "")) + elif block_type == "image": + # Images are ~1600 tokens for typical size + total += 1600 + else: + # Generic block - serialize and estimate + total += estimate_tokens(json.dumps(block)) + else: + total += estimate_tokens(block) + return total + + if isinstance(content, dict): + return estimate_tokens(json.dumps(content)) + + return estimate_tokens(str(content)) + + +def estimate_messages_tokens(messages: list[dict[str, Any]]) -> int: + """Estimate total tokens for a list of messages. + + Args: + messages: List of message dicts with role and content + + Returns: + Estimated total token count + """ + total = 0 + for msg in messages: + # Role contributes ~2 tokens + total += 2 + total += estimate_tokens(msg.get("content")) + return total + + +def estimate_request_tokens(request_data: dict[str, Any]) -> int: + """Estimate total input tokens for a request. + + Includes messages, system prompt, and tool definitions. + + Args: + request_data: The request payload dictionary + + Returns: + Estimated total input token count + """ + total = 0 + + # Messages + messages = request_data.get("messages", []) + total += estimate_messages_tokens(messages) + + # System prompt + system = request_data.get("system") + if system: + total += estimate_tokens(system) + + # Tools + tools = request_data.get("tools", []) + if tools: + total += estimate_tokens(json.dumps(tools)) + + return total + + +def _load_token_limits() -> dict[str, int]: + """Load token limits from available sources. + + Loads from: + 1. Local token_limits.json in max_tokens plugin + 2. Pricing cache at ~/.cache/ccproxy/model_pricing.json + + Returns: + Dict mapping model names to max_input_tokens + """ + global _token_limits_cache + if _token_limits_cache is not None: + return _token_limits_cache + + _token_limits_cache = {} + + # Try local token_limits.json first + local_file = ( + Path(__file__).parent.parent.parent + / "plugins" + / "max_tokens" + / "token_limits.json" + ) + if local_file.exists(): + try: + with local_file.open("r", encoding="utf-8") as f: + data = json.load(f) + for model_name, model_data in data.items(): + if model_name.startswith("_"): + continue + if isinstance(model_data, dict): + max_input = model_data.get("max_input_tokens") + if isinstance(max_input, int): + _token_limits_cache[model_name] = max_input + except Exception: + pass # Fall through to pricing cache + + # Also try pricing cache for additional models + pricing_cache = Path.home() / ".cache" / "ccproxy" / "model_pricing.json" + if pricing_cache.exists(): + try: + with pricing_cache.open("r", encoding="utf-8") as f: + data = json.load(f) + for model_name, model_data in data.items(): + if model_name in _token_limits_cache: + continue # Local file takes precedence + if isinstance(model_data, dict): + max_input = model_data.get("max_input_tokens") + if isinstance(max_input, int): + _token_limits_cache[model_name] = max_input + except Exception: + pass + + return _token_limits_cache + + +def get_max_input_tokens(model: str) -> int | None: + """Get max input tokens for a model. + + Supports pattern matching for model variants: + - Exact match: "claude-opus-4-5-20251101" + - Prefix match: "claude-opus-4-5-*" matches "claude-opus-4-5-20251101" + + Args: + model: Model name or identifier + + Returns: + Max input tokens if known, None otherwise + """ + limits = _load_token_limits() + + # Try exact match first + if model in limits: + return limits[model] + + # Try prefix matching (for patterns like claude-opus-4-5-*) + for pattern, max_tokens in limits.items(): + if pattern.endswith("*"): + prefix = pattern[:-1] + if model.startswith(prefix): + return max_tokens + + # Try matching known model families + model_lower = model.lower() + for known_model, max_tokens in limits.items(): + if known_model.lower() in model_lower or model_lower in known_model.lower(): + return max_tokens + + return None diff --git a/ccproxy/plugins/claude_api/adapter.py b/ccproxy/plugins/claude_api/adapter.py index c056b198..509bd3b0 100644 --- a/ccproxy/plugins/claude_api/adapter.py +++ b/ccproxy/plugins/claude_api/adapter.py @@ -12,7 +12,9 @@ DetectionServiceProtocol, TokenManagerProtocol, ) +from ccproxy.llms.formatters.openai_to_anthropic.requests import _sanitize_tool_results from ccproxy.llms.formatters.utils import strict_parse_tool_arguments +from ccproxy.llms.utils import get_max_input_tokens, truncate_to_fit from ccproxy.services.adapters.http_adapter import BaseHTTPAdapter from ccproxy.utils.headers import ( extract_response_headers, @@ -58,6 +60,35 @@ async def prepare_provider_request( if body_data.get("temperature") is None: body_data.pop("temperature", None) + # Get model from request for context truncation + model = body_data.get("model", "") + + # Auto-truncate context if request exceeds model limits + # IMPORTANT: Truncation must happen BEFORE sanitization because truncation + # can create orphaned tool blocks by removing messages with tool_use while + # keeping messages with tool_result + max_input = get_max_input_tokens(model) + if max_input: + body_data, was_truncated = truncate_to_fit( + body_data, + max_input_tokens=max_input, + preserve_recent=getattr(self.config, "preserve_recent_messages", 10), + safety_margin=getattr(self.config, "context_safety_margin", 0.9), + ) + if was_truncated: + logger.info( + "request_truncated_for_context_limit", + model=model, + max_input_tokens=max_input, + category="context_management", + ) + + # Sanitize tool_result blocks to remove orphaned references + # This fixes "unexpected tool_use_id" errors from conversation compaction + # Must run AFTER truncation to catch orphans created by truncation + if "messages" in body_data: + body_data["messages"] = _sanitize_tool_results(body_data["messages"]) + # Anthropic API constraint: cannot accept both temperature and top_p # Prioritize temperature over top_p when both are present if "temperature" in body_data and "top_p" in body_data: diff --git a/ccproxy/plugins/claude_sdk/adapter.py b/ccproxy/plugins/claude_sdk/adapter.py index 75f7988f..74bed887 100644 --- a/ccproxy/plugins/claude_sdk/adapter.py +++ b/ccproxy/plugins/claude_sdk/adapter.py @@ -14,7 +14,9 @@ from ccproxy.config.utils import OPENAI_CHAT_COMPLETIONS_PATH from ccproxy.core.logging import get_plugin_logger from ccproxy.core.request_context import RequestContext +from ccproxy.llms.formatters.openai_to_anthropic.requests import _sanitize_tool_results from ccproxy.llms.streaming import OpenAIStreamProcessor +from ccproxy.llms.utils import get_max_input_tokens, truncate_to_fit from ccproxy.services.adapters.chain_composer import compose_from_chain from ccproxy.services.adapters.format_adapter import FormatAdapterProtocol from ccproxy.services.adapters.http_adapter import BaseHTTPAdapter @@ -234,6 +236,34 @@ async def handle_request( # Extract parameters for SDK handler messages = request_data.get("messages", []) model = request_data.get("model", "claude-3-opus-20240229") + + # Auto-truncate context if request exceeds model limits + # IMPORTANT: Truncation must happen BEFORE sanitization because truncation + # can create orphaned tool blocks by removing messages with tool_use while + # keeping messages with tool_result + max_input = get_max_input_tokens(model) + if max_input: + request_data, was_truncated = truncate_to_fit( + request_data, + max_input_tokens=max_input, + preserve_recent=getattr(self.config, "preserve_recent_messages", 10), + safety_margin=getattr(self.config, "context_safety_margin", 0.9), + ) + if was_truncated: + logger.info( + "request_truncated_for_context_limit", + model=model, + max_input_tokens=max_input, + category="context_management", + ) + # Update messages reference after truncation + messages = request_data.get("messages", []) + + # Sanitize tool_result blocks to remove orphaned references + # This fixes "unexpected tool_use_id" errors from conversation compaction + # Must run AFTER truncation to catch orphans created by truncation + messages = _sanitize_tool_results(messages) + request_data["messages"] = messages temperature = request_data.get("temperature") max_tokens = request_data.get("max_tokens") stream = request_data.get("stream", False) diff --git a/ccproxy/streaming/deferred.py b/ccproxy/streaming/deferred.py index b721a105..a9bbb498 100644 --- a/ccproxy/streaming/deferred.py +++ b/ccproxy/streaming/deferred.py @@ -13,6 +13,7 @@ import structlog from starlette.responses import JSONResponse, Response, StreamingResponse +from ccproxy.core.constants import FORMAT_ANTHROPIC_MESSAGES from ccproxy.core.plugins.hooks import HookEvent, HookManager from ccproxy.core.plugins.hooks.base import HookContext from ccproxy.llms.streaming.accumulators import StreamAccumulator @@ -233,6 +234,7 @@ async def body_generator() -> AsyncGenerator[bytes, None]: async def _emit_error_sse( error_obj: dict[str, Any], ) -> AsyncGenerator[bytes, None]: + error_obj = self._format_stream_error(error_obj) adapted: dict[str, Any] | None = None try: if self.handler_config and self.handler_config.response_adapter: @@ -840,6 +842,23 @@ async def _serialize_json_to_sse_stream( ): yield chunk + def _format_stream_error(self, error_obj: dict[str, Any]) -> dict[str, Any]: + """Normalize streaming error payloads for client-specific SSE schemas.""" + if isinstance(error_obj, dict) and error_obj.get("type"): + return error_obj + + format_chain = ( + self.request_context.format_chain + if self.request_context and self.request_context.format_chain + else [] + ) + client_format = format_chain[0] if format_chain else None + + if client_format == FORMAT_ANTHROPIC_MESSAGES: + return {"type": "error", "error": error_obj.get("error", error_obj)} + + return error_obj + def _record_tool_event(self, event_name: str, payload: Any) -> None: if not self._stream_accumulator or not isinstance(payload, dict): return diff --git a/tests/plugins/claude_api/unit/test_native_anthropic_sanitization.py b/tests/plugins/claude_api/unit/test_native_anthropic_sanitization.py new file mode 100644 index 00000000..6db9ca9f --- /dev/null +++ b/tests/plugins/claude_api/unit/test_native_anthropic_sanitization.py @@ -0,0 +1,591 @@ +"""Test native Anthropic request sanitization for orphaned tool blocks. + +This test module verifies that native Anthropic format requests (sent to /v1/messages) +properly sanitize orphaned tool blocks that don't have matching counterparts. + +Two types of orphaned blocks are handled: + +1. Orphaned tool_result blocks (tool_result without matching tool_use): + - Occurs when conversation is compacted, removing old tool_use blocks + - tool_result blocks remain without their corresponding tool_use blocks + - API rejects with: "unexpected tool_use_id found in tool_result blocks" + +2. Orphaned tool_use blocks (tool_use without matching tool_result): + - Occurs when conversation is compacted, removing tool_result blocks + - tool_use blocks remain without their corresponding tool_result blocks + - API rejects with: "tool_use ids were found without tool_result blocks immediately after" + +The fix applies _sanitize_tool_results() in the claude_api and claude_sdk adapters' +prepare_provider_request() method before forwarding to the Anthropic API. +""" + +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from ccproxy.llms.formatters.openai_to_anthropic.requests import _sanitize_tool_results + + +Message = dict[str, Any] + + +class TestNativeAnthropicSanitization: + """Test sanitization of native Anthropic requests with orphaned tool_result blocks.""" + + def test_orphaned_tool_result_removed_from_native_request(self): + """Native Anthropic request with orphaned tool_result should be sanitized. + + This reproduces the exact error reported: + "unexpected tool_use_id found in tool_result blocks: toolu_019M2sPZmfSNC57WBuV9NaRb" + """ + # Simulate a compacted conversation where the tool_use was summarized + # but the tool_result remains with its original ID + messages: list[Message] = [ + {"role": "user", "content": "Search for files matching *.py"}, + { + "role": "assistant", + "content": "I'll search for Python files.", # tool_use was compacted out + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_019M2sPZmfSNC57WBuV9NaRb", # orphaned! + "content": "Found 15 Python files", + }, + {"type": "text", "text": "Thanks, now please analyze them"}, + ], + }, + ] + + result = _sanitize_tool_results(messages) + + # The orphaned tool_result should be converted to text + assert len(result) == 3 + user_msg = result[2] + assert user_msg["role"] == "user" + + # Should have text blocks but no tool_result + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + assert len(tool_results) == 0 + + # Original text should be preserved + text_blocks = [b for b in user_msg["content"] if b.get("type") == "text"] + assert len(text_blocks) == 2 # original text + converted orphan + + # Check orphan was converted to informative text + orphan_text_block = text_blocks[0] # inserted at beginning + assert ( + "Previous tool results from compacted history" in orphan_text_block["text"] + ) + assert "toolu_019M2sPZmfSNC57WBuV9NaRb" in orphan_text_block["text"] + + def test_valid_tool_result_preserved_in_native_request(self): + """Native Anthropic request with valid tool_result should pass through unchanged.""" + messages: list[Message] = [ + {"role": "user", "content": "Search for files"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll search for files."}, + { + "type": "tool_use", + "id": "toolu_valid123", + "name": "glob", + "input": {"pattern": "*.py"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid123", # matches tool_use above + "content": "Found 15 files", + } + ], + }, + ] + + result = _sanitize_tool_results(messages) + + # Messages should be unchanged + assert len(result) == 3 + user_msg = result[2] + + # tool_result should be preserved + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + assert len(tool_results) == 1 + assert tool_results[0]["tool_use_id"] == "toolu_valid123" + + def test_mixed_valid_and_orphaned_tool_results(self): + """Request with both valid and orphaned tool_results should keep valid ones.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_valid", + "name": "read", + "input": {"path": "file.py"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", # valid + "content": "file contents", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_orphan_from_compaction", # orphaned + "content": "old result", + }, + ], + }, + ] + + result = _sanitize_tool_results(messages) + + user_msg = result[1] + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + text_blocks = [b for b in user_msg["content"] if b.get("type") == "text"] + + # Only valid tool_result should remain + assert len(tool_results) == 1 + assert tool_results[0]["tool_use_id"] == "toolu_valid" + assert len(text_blocks) == 0 + + def test_superdesign_conversation_compaction_scenario(self): + """Reproduce the SuperDesign VS Code extension compaction scenario. + + SuperDesign uses @ai-sdk/anthropic which sends native Anthropic format. + When the conversation gets long, it compacts history, removing old messages + but sometimes leaving orphaned tool_result blocks. + """ + # Simulated compacted conversation from SuperDesign + messages: list[Message] = [ + # Earlier context was compacted into a summary + {"role": "user", "content": "Help me build a React component"}, + { + "role": "assistant", + "content": "[Summary: Previously searched for files and read component code]", + }, + # User message still has tool_results from before compaction + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_old_glob_call", + "content": "src/components/Button.tsx\nsrc/components/Modal.tsx", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_old_read_call", + "content": "export const Button = () => ", + }, + {"type": "text", "text": "Now create a new Header component"}, + ], + }, + ] + + result = _sanitize_tool_results(messages) + + # Orphaned tool_results should be converted to text + user_msg = result[2] + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + assert len(tool_results) == 0 + + # Content should be preserved as text + text_blocks = [b for b in user_msg["content"] if b.get("type") == "text"] + assert len(text_blocks) == 2 # original text + orphan summary + + # User's actual request should be there + assert any("Header component" in b["text"] for b in text_blocks) + + def test_empty_messages_handled(self): + """Empty messages list should return empty list.""" + assert _sanitize_tool_results([]) == [] + + def test_messages_without_tool_content_unchanged(self): + """Messages without any tool-related content pass through unchanged.""" + messages: list[Message] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + result = _sanitize_tool_results(messages) + assert result == messages + + def test_long_orphan_content_truncated(self): + """Long orphaned tool_result content should be truncated to 500 chars.""" + long_content = "x" * 1000 + messages: list[Message] = [ + {"role": "assistant", "content": "No tool_use here"}, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_orphan", + "content": long_content, + } + ], + }, + ] + + result = _sanitize_tool_results(messages) + + user_msg = result[1] + text_blocks = [b for b in user_msg["content"] if b.get("type") == "text"] + orphan_text = text_blocks[0]["text"] + + # Should be truncated with "..." + assert "..." in orphan_text + # Should not contain full 1000 chars + assert len(orphan_text) < 1000 + + +class TestClaudeAPIAdapterSanitization: + """Test that the claude_api adapter properly applies sanitization.""" + + @pytest.mark.asyncio + async def test_adapter_sanitizes_native_anthropic_request(self): + """The adapter's prepare_provider_request should sanitize messages. + + This test directly verifies the sanitization is applied by checking + the code path rather than instantiating the full adapter. + """ + # Test the sanitization function directly with the message format + # that would come from a native Anthropic request + messages: list[Message] = [ + {"role": "assistant", "content": "Summary of previous work"}, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_orphan", + "content": "old result", + }, + {"type": "text", "text": "Continue please"}, + ], + }, + ] + + # Simulate what the adapter does + sanitized = _sanitize_tool_results(messages) + + # Check that orphaned tool_result was sanitized + user_msg = sanitized[1] + tool_results = [ + b + for b in user_msg["content"] + if isinstance(b, dict) and b.get("type") == "tool_result" + ] + assert len(tool_results) == 0 + + # Verify the fix is properly imported in the adapter module + from ccproxy.plugins.claude_api import adapter as api_adapter + + assert hasattr(api_adapter, "_sanitize_tool_results") + + +class TestClaudeSDKAdapterSanitization: + """Test that the claude_sdk adapter properly applies sanitization.""" + + def test_sdk_adapter_has_sanitization_import(self): + """Verify the SDK adapter imports the sanitization function.""" + from ccproxy.plugins.claude_sdk import adapter + + assert hasattr(adapter, "_sanitize_tool_results") + + +class TestOrphanedToolUseSanitization: + """Test sanitization of orphaned tool_use blocks (tool_use without matching tool_result). + + This addresses the error: + "tool_use ids were found without tool_result blocks immediately after: . + Each tool_use block must have a corresponding tool_result block in the next message." + """ + + def test_orphaned_tool_use_removed_from_assistant_message(self): + """Assistant message with tool_use but no matching tool_result should be sanitized.""" + messages: list[Message] = [ + {"role": "user", "content": "Search for files"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll search for files."}, + { + "type": "tool_use", + "id": "toolu_orphan123", + "name": "glob", + "input": {"pattern": "*.py"}, + }, + ], + }, + { + "role": "user", + "content": "please continue", # No tool_result for the tool_use above + }, + ] + + result = _sanitize_tool_results(messages) + + # The tool_use should be converted to text + assistant_msg = result[1] + assert isinstance(assistant_msg["content"], list) + tool_uses = [b for b in assistant_msg["content"] if b.get("type") == "tool_use"] + assert len(tool_uses) == 0 + + # Should have text describing the orphaned tool call + text_blocks = [b for b in assistant_msg["content"] if b.get("type") == "text"] + assert len(text_blocks) >= 1 + combined_text = " ".join(b["text"] for b in text_blocks) + assert ( + "glob" in combined_text + or "Tool calls from compacted history" in combined_text + ) + + def test_valid_tool_use_with_result_preserved(self): + """tool_use with matching tool_result should be preserved.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll search."}, + { + "type": "tool_use", + "id": "toolu_valid", + "name": "glob", + "input": {"pattern": "*.py"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", + "content": "Found 10 files", + } + ], + }, + ] + + result = _sanitize_tool_results(messages) + + # Both should be preserved + assistant_msg = result[0] + tool_uses = [b for b in assistant_msg["content"] if b.get("type") == "tool_use"] + assert len(tool_uses) == 1 + assert tool_uses[0]["id"] == "toolu_valid" + + user_msg = result[1] + tool_results = [ + b for b in user_msg["content"] if b.get("type") == "tool_result" + ] + assert len(tool_results) == 1 + + def test_mixed_valid_and_orphaned_tool_uses(self): + """Message with both valid and orphaned tool_uses should keep only valid ones.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_valid", + "name": "read", + "input": {"path": "file.py"}, + }, + { + "type": "tool_use", + "id": "toolu_orphan", + "name": "write", + "input": {"path": "out.py"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", + "content": "file contents", + } + # No tool_result for toolu_orphan + ], + }, + ] + + result = _sanitize_tool_results(messages) + + assistant_msg = result[0] + tool_uses = [b for b in assistant_msg["content"] if b.get("type") == "tool_use"] + assert len(tool_uses) == 1 + assert tool_uses[0]["id"] == "toolu_valid" + + def test_superdesign_compaction_scenario_with_orphaned_tool_use(self): + """Reproduce SuperDesign compaction where tool_use remains but tool_result is lost. + + This is the exact error reported: + "tool_use ids were found without tool_result blocks immediately after: toolu_01YJquBpATfUqskN381pdJdP" + """ + messages: list[Message] = [ + {"role": "user", "content": "Help me build a component"}, + { + "role": "assistant", + "content": "[Summary: Previously searched for files]", # Compacted + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me search for more files."}, + { + "type": "tool_use", + "id": "toolu_01YJquBpATfUqskN381pdJdP", # Exact ID from error + "name": "glob", + "input": {"pattern": "src/**/*.tsx"}, + }, + ], + }, + { + "role": "user", + "content": "please continue", # User message without tool_result + }, + ] + + result = _sanitize_tool_results(messages) + + # The orphaned tool_use should be removed/converted + for msg in result: + if msg.get("role") == "assistant" and isinstance(msg.get("content"), list): + tool_uses = [ + b + for b in msg["content"] + if isinstance(b, dict) and b.get("type") == "tool_use" + ] + assert len(tool_uses) == 0, f"Found orphaned tool_use: {tool_uses}" + + def test_tool_use_only_message_converted_to_text(self): + """Assistant message with only orphaned tool_use should become text.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_only_orphan", + "name": "write", + "input": {"path": "test.py", "content": "print('hello')"}, + } + ], + }, + {"role": "user", "content": "continue"}, + ] + + result = _sanitize_tool_results(messages) + + assistant_msg = result[0] + # Content should be text (either string or list with text block) + content = assistant_msg["content"] + if isinstance(content, list): + tool_uses = [b for b in content if b.get("type") == "tool_use"] + assert len(tool_uses) == 0 + text_blocks = [b for b in content if b.get("type") == "text"] + assert len(text_blocks) >= 1 + assert "write" in text_blocks[0]["text"] + else: + # String content is also acceptable + assert isinstance(content, str) + + def test_long_tool_input_truncated(self): + """Long orphaned tool_use input should be truncated to 200 chars.""" + long_input = {"content": "x" * 500} + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_long", + "name": "write", + "input": long_input, + } + ], + }, + {"role": "user", "content": "continue"}, + ] + + result = _sanitize_tool_results(messages) + + assistant_msg = result[0] + text_blocks = [b for b in assistant_msg["content"] if b.get("type") == "text"] + orphan_text = text_blocks[0]["text"] + + # Should be truncated with "..." + assert "..." in orphan_text + + def test_multiple_consecutive_orphaned_tool_uses(self): + """Multiple orphaned tool_use blocks should all be converted.""" + messages: list[Message] = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll do multiple operations."}, + { + "type": "tool_use", + "id": "toolu_1", + "name": "read", + "input": {"path": "a.py"}, + }, + { + "type": "tool_use", + "id": "toolu_2", + "name": "read", + "input": {"path": "b.py"}, + }, + { + "type": "tool_use", + "id": "toolu_3", + "name": "write", + "input": {"path": "c.py"}, + }, + ], + }, + {"role": "user", "content": "okay, what's next?"}, + ] + + result = _sanitize_tool_results(messages) + + assistant_msg = result[0] + tool_uses = [b for b in assistant_msg["content"] if b.get("type") == "tool_use"] + assert len(tool_uses) == 0 + + # All three should be mentioned in the text + text_blocks = [b for b in assistant_msg["content"] if b.get("type") == "text"] + combined_text = " ".join(b["text"] for b in text_blocks) + assert "read" in combined_text + assert "write" in combined_text diff --git a/tests/test_tool_result_sanitization.py b/tests/test_tool_result_sanitization.py index d902a4ea..6b9a713f 100644 --- a/tests/test_tool_result_sanitization.py +++ b/tests/test_tool_result_sanitization.py @@ -1,4 +1,4 @@ -"""Test _sanitize_tool_results method for removing orphaned tool_result blocks. +"""Test _sanitize_tool_results method for orphaned tool blocks. This module tests the bug fix for orphaned tool_result blocks that occur when conversation history is compacted. When tool_use blocks are removed during @@ -9,8 +9,9 @@ The _sanitize_tool_results method fixes this by: 1. Removing orphaned tool_result blocks that don't have matching tool_use blocks in the immediately preceding assistant message -2. Converting orphaned results to text blocks to preserve information -3. Keeping valid tool_result blocks that have matching tool_use blocks +2. Removing tool_use blocks that don't have a tool_result in the *next* user message +3. Converting orphaned blocks to text to preserve information +4. Keeping valid tool_use/tool_result pairs Real-world scenario: - A long conversation with multiple tool calls gets compacted to stay within token limits @@ -109,7 +110,7 @@ def test_valid_tool_result_preserved(self, mock_logger: Mock) -> None: - User message with tool_result(tool_use_id="tool_123") Result: tool_result should be kept unchanged """ - messages = [ + messages: list[dict[str, Any]] = [ create_assistant_with_tool_use( "I'll help you with that.", [{"id": "tool_123", "name": "calculator", "input": {"x": 5}}], @@ -137,7 +138,7 @@ def test_orphaned_tool_result_removed(self, mock_logger: Mock) -> None: - NO preceding assistant with matching tool_use Result: tool_result should be removed and converted to text """ - messages = [ + messages: list[dict[str, Any]] = [ create_user_text_message("Hello"), create_user_with_tool_result( [{"tool_use_id": "orphan_456", "content": "orphaned result"}] @@ -163,9 +164,9 @@ def test_mixed_valid_and_orphaned(self, mock_logger: Mock) -> None: Scenario: Partial compaction - Assistant with tool_use(id="valid_1") - User with tool_result(tool_use_id="valid_1") AND tool_result(tool_use_id="orphan_2") - Result: valid_1 kept, orphan_2 converted to text + Result: valid_1 kept, orphan_2 dropped (no text injected) """ - messages = [ + messages: list[dict[str, Any]] = [ create_assistant_with_tool_use( "Let me check that.", [{"id": "valid_1", "name": "search", "input": {"query": "test"}}], @@ -183,17 +184,10 @@ def test_mixed_valid_and_orphaned(self, mock_logger: Mock) -> None: assert len(result) == 2 user_content = result[1]["content"] - # Should have text block (from orphaned) + valid tool_result - assert len(user_content) == 2 - - # First should be text block with orphaned info - assert user_content[0]["type"] == "text" - assert "Previous tool results" in user_content[0]["text"] - assert "orphan_2" in user_content[0]["text"] - - # Second should be the valid tool_result - assert user_content[1]["type"] == "tool_result" - assert user_content[1]["tool_use_id"] == "valid_1" + # Only the valid tool_result should remain + assert len(user_content) == 1 + assert user_content[0]["type"] == "tool_result" + assert user_content[0]["tool_use_id"] == "valid_1" # Should log warning about orphaned result mock_logger.warning.assert_called_once() @@ -206,7 +200,7 @@ def test_multiple_tool_uses_preserved(self, mock_logger: Mock) -> None: - User with tool_result for all three Result: all three should be preserved """ - messages = [ + messages: list[dict[str, Any]] = [ create_assistant_with_tool_use( "I'll use three tools.", [ @@ -238,6 +232,150 @@ def test_multiple_tool_uses_preserved(self, mock_logger: Mock) -> None: # No warnings should be logged mock_logger.warning.assert_not_called() + def test_orphaned_tool_use_removed_when_no_next_message( + self, mock_logger: Mock + ) -> None: + """Remove tool_use when no following user message exists.""" + messages: list[dict[str, Any]] = [ + create_assistant_with_tool_use( + "I'll need to run a tool.", + [{"id": "tool_1", "name": "helper", "input": {"q": "test"}}], + ) + ] + + result = _sanitize_tool_results(messages) + + assert len(result) == 1 + assert result[0]["role"] == "assistant" + content = result[0]["content"] + assert isinstance(content, list) + assert all(block.get("type") != "tool_use" for block in content) + assert "Tool calls from compacted history" in content[0]["text"] + mock_logger.warning.assert_called_once() + + def test_orphaned_tool_use_removed_when_next_user_missing_result( + self, mock_logger: Mock + ) -> None: + """Remove tool_use when the next user message lacks tool_result.""" + messages: list[dict[str, Any]] = [ + create_assistant_with_tool_use( + "Let me check.", + [{"id": "tool_1", "name": "search", "input": {"q": "test"}}], + ), + create_user_text_message("Thanks!"), + ] + + result = _sanitize_tool_results(messages) + + assert len(result) == 2 + assistant_content = result[0]["content"] + assert isinstance(assistant_content, list) + assert all(block.get("type") != "tool_use" for block in assistant_content) + assert "Tool calls from compacted history" in assistant_content[0]["text"] + assert result[1] == messages[1] + mock_logger.warning.assert_called_once() + + def test_tool_use_removed_when_next_user_only_has_orphaned_results( + self, mock_logger: Mock + ) -> None: + """Remove tool_use when next user only includes unrelated tool_results.""" + messages: list[dict[str, Any]] = [ + create_assistant_with_tool_use( + "Let me check.", + [{"id": "tool_keep", "name": "search", "input": {"q": "test"}}], + ), + create_user_with_tool_result( + [{"tool_use_id": "orphan_only", "content": "old result"}] + ), + ] + + result = _sanitize_tool_results(messages) + + assistant_content = result[0]["content"] + assert isinstance(assistant_content, list) + assert all(block.get("type") != "tool_use" for block in assistant_content) + assert "Tool calls from compacted history" in assistant_content[0]["text"] + + user_content = result[1]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Previous tool results" in user_content[0]["text"] + + assert mock_logger.warning.call_count == 2 + + def test_duplicate_tool_use_pruned_to_match_result_count( + self, mock_logger: Mock + ) -> None: + """Remove extra tool_use blocks when tool_result count is lower. + + Scenario: Assistant has duplicate tool_use IDs; user has a single result. + Result: Keep only one tool_use and one tool_result. + """ + messages: list[dict[str, Any]] = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Calling tool twice."}, + { + "type": "tool_use", + "id": "dup_tool", + "name": "search", + "input": {}, + }, + { + "type": "tool_use", + "id": "dup_tool", + "name": "search", + "input": {}, + }, + ], + }, + create_user_with_tool_result( + [{"tool_use_id": "dup_tool", "content": "result once"}] + ), + ] + + result = _sanitize_tool_results(messages) + + assistant_content = result[0]["content"] + tool_use_blocks = [b for b in assistant_content if b.get("type") == "tool_use"] + assert len(tool_use_blocks) == 1 + + user_content = result[1]["content"] + tool_result_blocks = [b for b in user_content if b.get("type") == "tool_result"] + assert len(tool_result_blocks) == 1 + + mock_logger.warning.assert_called_once() + + def test_tool_use_removed_when_result_not_immediately_after( + self, mock_logger: Mock + ) -> None: + """Remove tool_use when tool_result is not in the next message.""" + messages = [ + create_assistant_with_tool_use( + "Calling a tool.", + [{"id": "tool_1", "name": "lookup", "input": {"q": "test"}}], + ), + create_assistant_text_message("Continuing without result."), + create_user_with_tool_result( + [{"tool_use_id": "tool_1", "content": "late result"}] + ), + ] + + result = _sanitize_tool_results(messages) + + assert len(result) == 3 + assistant_content = result[0]["content"] + assert isinstance(assistant_content, list) + assert all(block.get("type") != "tool_use" for block in assistant_content) + assert "Tool calls from compacted history" in assistant_content[0]["text"] + + user_content = result[2]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Previous tool results" in user_content[0]["text"] + mock_logger.warning.assert_called() + def test_conversation_compaction_scenario(self, mock_logger: Mock) -> None: """Test the real bug scenario: conversation compaction leaves orphaned results. @@ -275,16 +413,10 @@ def test_conversation_compaction_scenario(self, mock_logger: Mock) -> None: assert len(result) == 3 user_content = result[2]["content"] - # Should have text block (orphaned) + valid tool_result - assert len(user_content) == 2 - - # First is text with orphaned info - assert user_content[0]["type"] == "text" - assert "original_tool" in user_content[0]["text"] - - # Second is valid tool_result - assert user_content[1]["type"] == "tool_result" - assert user_content[1]["tool_use_id"] == "new_tool" + # Only the valid tool_result should remain + assert len(user_content) == 1 + assert user_content[0]["type"] == "tool_result" + assert user_content[0]["tool_use_id"] == "new_tool" # Should log warning mock_logger.warning.assert_called_once() @@ -521,7 +653,7 @@ def test_partial_match_orphaned(self, mock_logger: Mock) -> None: Scenario: Multiple results, only some have matching tool_use - Assistant with tool_use(id="valid_1") and tool_use(id="valid_2") - User with results for "valid_1", "orphan_3", "valid_2" - Result: valid_1 and valid_2 kept, orphan_3 converted to text + Result: valid_1 and valid_2 kept, orphan_3 dropped """ messages = [ create_assistant_with_tool_use( @@ -543,18 +675,12 @@ def test_partial_match_orphaned(self, mock_logger: Mock) -> None: result = _sanitize_tool_results(messages) user_content = result[1]["content"] - # Should have text block + 2 valid results - assert len(user_content) == 3 - - # First is text with orphaned info - assert user_content[0]["type"] == "text" - assert "orphan_3" in user_content[0]["text"] - - # Other two are valid results + # Only the valid tool_results should remain + assert len(user_content) == 2 + assert user_content[0]["type"] == "tool_result" + assert user_content[0]["tool_use_id"] == "valid_1" assert user_content[1]["type"] == "tool_result" - assert user_content[1]["tool_use_id"] == "valid_1" - assert user_content[2]["type"] == "tool_result" - assert user_content[2]["tool_use_id"] == "valid_2" + assert user_content[1]["tool_use_id"] == "valid_2" def test_assistant_with_string_content_no_tool_use(self, mock_logger: Mock) -> None: """Test assistant message with string content (no tool_use blocks). @@ -611,3 +737,63 @@ def test_multiple_orphaned_conversions(self, mock_logger: Mock) -> None: assert "result one" in text_block assert "result two" in text_block assert "result three" in text_block + + def test_assistant_dict_content_tool_use_removed(self, mock_logger: Mock) -> None: + """Handle assistant content supplied as a dict (single tool_use block). + + Scenario: Non-list content with tool_use and no following tool_result + Result: tool_use should be converted to text and content normalized to list + """ + messages: list[dict[str, Any]] = [ + { + "role": "assistant", + "content": { + "type": "tool_use", + "id": "tool_dict", + "name": "glob", + "input": {"pattern": "*.py"}, + }, + }, + {"role": "user", "content": "continue"}, + ] + + result = _sanitize_tool_results(messages) + + assistant_content = result[0]["content"] + assert isinstance(assistant_content, list) + assert all(block.get("type") != "tool_use" for block in assistant_content) + assert any( + block.get("type") == "text" + for block in assistant_content + if isinstance(block, dict) + ) + + mock_logger.warning.assert_called_once() + + def test_user_dict_content_tool_result_converted(self, mock_logger: Mock) -> None: + """Handle user content supplied as a dict (single tool_result block). + + Scenario: Orphaned tool_result provided as a dict + Result: tool_result should be converted to text and content normalized to list + """ + messages: list[dict[str, Any]] = [ + {"role": "assistant", "content": "No tools here"}, + { + "role": "user", + "content": { + "type": "tool_result", + "tool_use_id": "orphan_dict", + "content": "old result", + }, + }, + ] + + result = _sanitize_tool_results(messages) + + user_content = result[1]["content"] + assert isinstance(user_content, list) + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "orphan_dict" in user_content[0]["text"] + + mock_logger.warning.assert_called_once() diff --git a/tests/unit/llms/test_context_window_management.py b/tests/unit/llms/test_context_window_management.py new file mode 100644 index 00000000..0b366ac0 --- /dev/null +++ b/tests/unit/llms/test_context_window_management.py @@ -0,0 +1,290 @@ +"""Tests for context window management utilities. + +This module tests token estimation and context truncation logic +for managing requests that exceed model context limits. +""" + +import pytest + +from ccproxy.llms.utils import ( + estimate_messages_tokens, + estimate_request_tokens, + estimate_tokens, + get_max_input_tokens, + truncate_to_fit, +) + + +class TestTokenEstimation: + """Tests for token estimation functions.""" + + def test_estimate_tokens_string(self) -> None: + """Test token estimation for plain strings.""" + # ~3 chars per token + text = "Hello, world!" # 13 chars -> ~4 tokens + tokens = estimate_tokens(text) + assert tokens >= 1 + assert tokens <= 10 # Reasonable upper bound + + def test_estimate_tokens_empty_string(self) -> None: + """Test token estimation for empty string.""" + tokens = estimate_tokens("") + assert tokens == 1 # min(1, ...) + + def test_estimate_tokens_none(self) -> None: + """Test token estimation for None.""" + tokens = estimate_tokens(None) + assert tokens == 0 + + def test_estimate_tokens_text_block(self) -> None: + """Test token estimation for text content block.""" + content = [ + {"type": "text", "text": "This is a test message with some content."} + ] + tokens = estimate_tokens(content) + assert tokens > 0 + + def test_estimate_tokens_tool_use_block(self) -> None: + """Test token estimation for tool_use content block.""" + content = [ + { + "type": "tool_use", + "id": "tool_123", + "name": "read_file", + "input": {"path": "/some/file/path.txt"}, + } + ] + tokens = estimate_tokens(content) + assert tokens > 0 + + def test_estimate_tokens_tool_result_block(self) -> None: + """Test token estimation for tool_result content block.""" + content = [ + { + "type": "tool_result", + "tool_use_id": "tool_123", + "content": "File contents here with some data.", + } + ] + tokens = estimate_tokens(content) + assert tokens > 0 + + def test_estimate_tokens_image_block(self) -> None: + """Test token estimation for image content block.""" + content = [{"type": "image", "source": {"type": "base64", "data": "..."}}] + tokens = estimate_tokens(content) + assert tokens == 1600 # Fixed estimate for images + + def test_estimate_tokens_mixed_content(self) -> None: + """Test token estimation for mixed content blocks.""" + content = [ + {"type": "text", "text": "Check this image:"}, + {"type": "image", "source": {"type": "url", "url": "https://..."}}, + {"type": "text", "text": "What do you see?"}, + ] + tokens = estimate_tokens(content) + assert tokens > 1600 # At least the image tokens + + def test_estimate_messages_tokens_single(self) -> None: + """Test token estimation for a single message.""" + messages = [{"role": "user", "content": "Hello, how are you?"}] + tokens = estimate_messages_tokens(messages) + assert tokens > 2 # At least role tokens + + def test_estimate_messages_tokens_conversation(self) -> None: + """Test token estimation for a conversation.""" + messages = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + {"role": "user", "content": "And 3+3?"}, + {"role": "assistant", "content": "3+3 equals 6."}, + ] + tokens = estimate_messages_tokens(messages) + assert tokens > 8 # At least role tokens (2 per message) + + def test_estimate_request_tokens_with_system(self) -> None: + """Test token estimation for request with system prompt.""" + request = { + "model": "claude-3-opus-20240229", + "system": "You are a helpful assistant.", + "messages": [{"role": "user", "content": "Hello"}], + } + tokens = estimate_request_tokens(request) + assert tokens > 0 + + def test_estimate_request_tokens_with_tools(self) -> None: + """Test token estimation for request with tools.""" + request = { + "model": "claude-3-opus-20240229", + "messages": [{"role": "user", "content": "Read a file"}], + "tools": [ + { + "name": "read_file", + "description": "Read contents of a file", + "input_schema": { + "type": "object", + "properties": {"path": {"type": "string"}}, + }, + } + ], + } + tokens = estimate_request_tokens(request) + assert tokens > 0 + + +class TestGetMaxInputTokens: + """Tests for max input tokens lookup.""" + + def test_get_max_input_tokens_known_model(self) -> None: + """Test getting max input tokens for a known model.""" + # This test may need adjustment based on what models are in the limits file + max_tokens = get_max_input_tokens("claude-3-opus-20240229") + # Should return a value if model is in limits + if max_tokens is not None: + assert max_tokens > 0 + + def test_get_max_input_tokens_unknown_model(self) -> None: + """Test getting max input tokens for an unknown model.""" + max_tokens = get_max_input_tokens("totally-unknown-model-xyz") + assert max_tokens is None + + +class TestTruncateToFit: + """Tests for context truncation.""" + + def test_truncate_no_truncation_needed(self) -> None: + """Test that small requests are not truncated.""" + request = { + "model": "claude-3-opus-20240229", + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + } + result, was_truncated = truncate_to_fit( + request, max_input_tokens=200000, preserve_recent=10 + ) + assert was_truncated is False + assert result == request + + def test_truncate_removes_old_messages(self) -> None: + """Test that truncation removes oldest messages first.""" + # Create a request with many messages + messages = [] + for i in range(20): + messages.append({"role": "user", "content": f"Message {i} " * 100}) + messages.append({"role": "assistant", "content": f"Response {i} " * 100}) + + request = {"model": "claude-3-opus-20240229", "messages": messages} + + # Force truncation with a very low limit + result, was_truncated = truncate_to_fit( + request, max_input_tokens=1000, preserve_recent=4 + ) + + assert was_truncated is True + # Should have fewer messages + assert len(result["messages"]) < len(messages) + # Should preserve recent messages + assert len(result["messages"]) >= 4 + + def test_truncate_preserves_recent_messages(self) -> None: + """Test that recent messages are preserved during truncation.""" + messages = [ + {"role": "user", "content": "Old message " * 500}, + {"role": "assistant", "content": "Old response " * 500}, + {"role": "user", "content": "Recent message 1"}, + {"role": "assistant", "content": "Recent response 1"}, + {"role": "user", "content": "Recent message 2"}, + {"role": "assistant", "content": "Recent response 2"}, + ] + request = {"model": "claude-3-opus-20240229", "messages": messages} + + result, was_truncated = truncate_to_fit( + request, max_input_tokens=500, preserve_recent=4 + ) + + if was_truncated: + # Check that the last 4 messages are preserved + result_messages = result["messages"] + # Account for truncation notice being added + recent_messages = ( + result_messages[-4:] if len(result_messages) >= 4 else result_messages + ) + + # Verify recent content is in the preserved messages + recent_content = [m.get("content", "") for m in recent_messages] + assert any("Recent" in str(c) for c in recent_content) + + def test_truncate_adds_notice(self) -> None: + """Test that truncation adds a notice message.""" + messages = [] + for i in range(10): + messages.append({"role": "user", "content": f"Message {i} " * 200}) + messages.append({"role": "assistant", "content": f"Response {i} " * 200}) + + request = {"model": "claude-3-opus-20240229", "messages": messages} + + result, was_truncated = truncate_to_fit( + request, max_input_tokens=500, preserve_recent=2 + ) + + if was_truncated: + # First message should be the truncation notice + first_msg = result["messages"][0] + assert first_msg["role"] == "user" + assert "truncated" in first_msg["content"].lower() + + def test_truncate_not_enough_messages(self) -> None: + """Test truncation behavior when only one message exceeds the limit.""" + messages = [ + {"role": "user", "content": "Single message " * 1000}, + ] + request = {"model": "claude-3-opus-20240229", "messages": messages} + + # Try to truncate with preserve_recent > message count + result, was_truncated = truncate_to_fit( + request, max_input_tokens=100, preserve_recent=10 + ) + + # Should truncate and insert a notice if content can't fit + assert was_truncated is True + assert result["messages"][0]["role"] == "user" + assert "truncated" in result["messages"][0]["content"].lower() + + def test_truncate_preserves_system_and_tools(self) -> None: + """Test that system prompt and tools are preserved.""" + request = { + "model": "claude-3-opus-20240229", + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Old message " * 500}, + {"role": "assistant", "content": "Old response " * 500}, + {"role": "user", "content": "Recent message"}, + ], + "tools": [{"name": "test_tool", "description": "A test tool"}], + } + + result, was_truncated = truncate_to_fit( + request, max_input_tokens=500, preserve_recent=1 + ) + + # System and tools should be preserved regardless of truncation + assert result.get("system") == "You are a helpful assistant." + assert result.get("tools") == request["tools"] + + def test_truncate_safety_margin(self) -> None: + """Test that safety margin is applied correctly.""" + messages = [] + for i in range(5): + messages.append({"role": "user", "content": f"Message {i}"}) + + request = {"model": "claude-3-opus-20240229", "messages": messages} + + # With safety_margin=0.5, effective limit is 50000 + result, was_truncated = truncate_to_fit( + request, max_input_tokens=100000, preserve_recent=2, safety_margin=0.5 + ) + + # Should not truncate since content is small + assert was_truncated is False diff --git a/tests/unit/llms/test_truncation_creates_orphans.py b/tests/unit/llms/test_truncation_creates_orphans.py new file mode 100644 index 00000000..ccf0024b --- /dev/null +++ b/tests/unit/llms/test_truncation_creates_orphans.py @@ -0,0 +1,222 @@ +"""Regression tests for truncation creating invalid Anthropic tool blocks.""" + +from collections import Counter +from typing import Any + +from ccproxy.llms.formatters.openai_to_anthropic.requests import _sanitize_tool_results +from ccproxy.llms.utils import estimate_request_tokens, truncate_to_fit + + +Message = dict[str, Any] + + +def _content_blocks(content: Any) -> list[Any]: + if isinstance(content, list): + return content + if isinstance(content, dict): + return [content] + return [] + + +def _count_orphaned_tool_results(messages: list[Message]) -> int: + """Count tool_result blocks without a matching preceding assistant tool_use.""" + orphan_count = 0 + + for index, message in enumerate(messages): + if message.get("role") != "user": + continue + + valid_tool_uses: Counter[str] = Counter() + if index > 0 and messages[index - 1].get("role") == "assistant": + for block in _content_blocks(messages[index - 1].get("content")): + if isinstance(block, dict) and block.get("type") == "tool_use": + tool_id = block.get("id") + if tool_id: + valid_tool_uses[str(tool_id)] += 1 + + seen_results: Counter[str] = Counter() + for block in _content_blocks(message.get("content")): + if not isinstance(block, dict) or block.get("type") != "tool_result": + continue + + tool_use_id = block.get("tool_use_id") + key = str(tool_use_id) if tool_use_id else "" + if key and seen_results[key] < valid_tool_uses.get(key, 0): + seen_results[key] += 1 + else: + orphan_count += 1 + + return orphan_count + + +def _count_orphaned_tool_uses(messages: list[Message]) -> int: + """Count assistant tool_use blocks without a matching next user tool_result.""" + orphan_count = 0 + + for index, message in enumerate(messages): + if message.get("role") != "assistant": + continue + + next_results: Counter[str] = Counter() + if index + 1 < len(messages) and messages[index + 1].get("role") == "user": + for block in _content_blocks(messages[index + 1].get("content")): + if isinstance(block, dict) and block.get("type") == "tool_result": + tool_use_id = block.get("tool_use_id") + if tool_use_id: + next_results[str(tool_use_id)] += 1 + + seen_uses: Counter[str] = Counter() + for block in _content_blocks(message.get("content")): + if not isinstance(block, dict) or block.get("type") != "tool_use": + continue + + tool_id = block.get("id") + key = str(tool_id) if tool_id else "" + if key and seen_uses[key] < next_results.get(key, 0): + seen_uses[key] += 1 + else: + orphan_count += 1 + + return orphan_count + + +def _tool_pair(tool_id: str, tool_payload_size: int = 2000) -> list[Message]: + return [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": tool_id, + "name": "analyze_data", + "input": {"dataset": "x" * tool_payload_size}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_id, + "content": "Analysis complete: found 42 items", + } + ], + }, + ] + + +def test_sanitize_before_truncate_can_leave_orphaned_tool_result() -> None: + """Document why adapters must truncate before sanitizing.""" + messages = [ + *_tool_pair("tool_split"), + {"role": "user", "content": "Summarize the analysis."}, + ] + + initially_sanitized = _sanitize_tool_results(messages) + truncated, was_truncated = truncate_to_fit( + {"model": "claude-3-opus-20240229", "messages": initially_sanitized}, + max_input_tokens=200, + preserve_recent=2, + safety_margin=0.9, + ) + + assert was_truncated is True + assert _count_orphaned_tool_results(truncated["messages"]) == 1 + + +def test_truncate_then_sanitize_removes_orphaned_tool_result() -> None: + messages = [ + *_tool_pair("tool_split"), + {"role": "user", "content": "Summarize the analysis."}, + ] + + truncated, was_truncated = truncate_to_fit( + {"model": "claude-3-opus-20240229", "messages": messages}, + max_input_tokens=200, + preserve_recent=2, + safety_margin=0.9, + ) + + assert was_truncated is True + assert _count_orphaned_tool_results(truncated["messages"]) == 1 + + sanitized = _sanitize_tool_results(truncated["messages"]) + + assert _count_orphaned_tool_results(sanitized) == 0 + assert _count_orphaned_tool_uses(sanitized) == 0 + assert any( + isinstance(block, dict) + and block.get("type") == "text" + and "Previous tool results" in block.get("text", "") + for message in sanitized + for block in _content_blocks(message.get("content")) + ) + + +def test_sanitize_removes_tool_use_without_next_result() -> None: + messages = [ + {"role": "user", "content": "Inspect the project."}, + *_tool_pair("tool_removed_result")[:1], + {"role": "user", "content": "Continue after compaction."}, + ] + + assert _count_orphaned_tool_uses(messages) == 1 + + sanitized = _sanitize_tool_results(messages) + + assert _count_orphaned_tool_results(sanitized) == 0 + assert _count_orphaned_tool_uses(sanitized) == 0 + assert any( + isinstance(block, dict) + and block.get("type") == "text" + and "Tool calls from compacted history" in block.get("text", "") + for message in sanitized + for block in _content_blocks(message.get("content")) + ) + + +def test_few_messages_with_massive_tool_results_are_reduced() -> None: + messages: list[Message] = [ + {"role": "user", "content": "Read the README file"}, + *_tool_pair("tool_readme", tool_payload_size=100), + {"role": "user", "content": "Now read the package.json"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_package", + "name": "Read", + "input": {"file_path": "/project/package.json"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tool_package", + "content": "package.json content line\n" * 15000, + } + ], + }, + {"role": "user", "content": "Summarize both files for me."}, + ] + request = {"model": "claude-3-opus-20240229", "messages": messages} + original_tokens = estimate_request_tokens(request) + + truncated, was_truncated = truncate_to_fit( + request, + max_input_tokens=50000, + preserve_recent=10, + safety_margin=0.9, + ) + + assert was_truncated is True + assert estimate_request_tokens(truncated) < original_tokens + + sanitized = _sanitize_tool_results(truncated["messages"]) + assert _count_orphaned_tool_results(sanitized) == 0 + assert _count_orphaned_tool_uses(sanitized) == 0 diff --git a/tests/unit/streaming/test_deferred_stream_errors.py b/tests/unit/streaming/test_deferred_stream_errors.py new file mode 100644 index 00000000..3025c02f --- /dev/null +++ b/tests/unit/streaming/test_deferred_stream_errors.py @@ -0,0 +1,66 @@ +"""Tests for streaming error payload normalization.""" + +from __future__ import annotations + +import httpx +import pytest + +from ccproxy.core.constants import FORMAT_ANTHROPIC_MESSAGES, FORMAT_OPENAI_CHAT +from ccproxy.core.logging import get_logger +from ccproxy.core.request_context import RequestContext +from ccproxy.streaming.deferred import DeferredStreaming + + +@pytest.mark.anyio +async def test_anthropic_error_wrapped_with_type() -> None: + ctx = RequestContext( + request_id="req_1", + start_time=0.0, + logger=get_logger(__name__), + ) + ctx.format_chain = [FORMAT_ANTHROPIC_MESSAGES] + + client = httpx.AsyncClient() + stream = DeferredStreaming( + method="POST", + url="http://example.test/v1/messages", + headers={}, + body=b"{}", + client=client, + request_context=ctx, + ) + + error_obj = {"error": {"type": "timeout_error", "message": "Request timeout"}} + formatted = stream._format_stream_error(error_obj) + + assert formatted["type"] == "error" + assert formatted["error"]["type"] == "timeout_error" + + await client.aclose() + + +@pytest.mark.anyio +async def test_non_anthropic_error_left_unchanged() -> None: + ctx = RequestContext( + request_id="req_2", + start_time=0.0, + logger=get_logger(__name__), + ) + ctx.format_chain = [FORMAT_OPENAI_CHAT] + + client = httpx.AsyncClient() + stream = DeferredStreaming( + method="POST", + url="http://example.test/v1/chat/completions", + headers={}, + body=b"{}", + client=client, + request_context=ctx, + ) + + error_obj = {"error": {"type": "timeout_error", "message": "Request timeout"}} + formatted = stream._format_stream_error(error_obj) + + assert formatted == error_obj + + await client.aclose()