diff --git a/pyproject.toml b/pyproject.toml index e1ab0d7d4..ae77f4165 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,8 @@ sagemaker = [ "openai>=1.68.0,<3.0.0", # SageMaker uses OpenAI-compatible interface ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] +# Rename this extra once compression/context management features land #555 +token-estimation = ["tiktoken>=0.7.0,<1.0.0"] docs = [ "sphinx>=5.0.0,<10.0.0", "sphinx-rtd-theme>=1.0.0,<4.0.0", diff --git a/src/strands/models/model.py b/src/strands/models/model.py index f084d24d5..f910ae74e 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -1,6 +1,7 @@ """Abstract base class for Agent model providers.""" import abc +import json import logging from collections.abc import AsyncGenerator, AsyncIterable from dataclasses import dataclass @@ -10,7 +11,7 @@ from ..hooks.events import AfterInvocationEvent from ..plugins.plugin import Plugin -from ..types.content import Messages, SystemContentBlock +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -21,6 +22,110 @@ T = TypeVar("T", bound=BaseModel) +_DEFAULT_ENCODING = "cl100k_base" +_cached_encoding: Any = None + + +def _get_encoding() -> Any: + """Get the default tiktoken encoding, caching to avoid repeated lookups.""" + global _cached_encoding + if _cached_encoding is None: + try: + import tiktoken + except ImportError as err: + raise ImportError( + "tiktoken is required for token estimation. " + "Install it with: pip install strands-agents[token-estimation]" + ) from err + _cached_encoding = tiktoken.get_encoding(_DEFAULT_ENCODING) + return _cached_encoding + + +def _count_content_block_tokens(block: ContentBlock, encoding: Any) -> int: + """Count tokens for a single content block.""" + total = 0 + + if "text" in block: + total += len(encoding.encode(block["text"])) + + if "toolUse" in block: + tool_use = block["toolUse"] + total += len(encoding.encode(tool_use.get("name", ""))) + try: + total += len(encoding.encode(json.dumps(tool_use.get("input", {})))) + except (TypeError, ValueError): + logger.debug( + "tool_name=<%s> | skipping non-serializable toolUse input for token estimation", + tool_use.get("name", "unknown"), + ) + + if "toolResult" in block: + tool_result = block["toolResult"] + for item in tool_result.get("content", []): + if "text" in item: + total += len(encoding.encode(item["text"])) + + if "reasoningContent" in block: + reasoning = block["reasoningContent"] + if "reasoningText" in reasoning: + reasoning_text = reasoning["reasoningText"] + if "text" in reasoning_text: + total += len(encoding.encode(reasoning_text["text"])) + + if "guardContent" in block: + guard = block["guardContent"] + if "text" in guard: + total += len(encoding.encode(guard["text"]["text"])) + + if "citationsContent" in block: + citations = block["citationsContent"] + if "content" in citations: + for citation_item in citations["content"]: + if "text" in citation_item: + total += len(encoding.encode(citation_item["text"])) + + return total + + +def _estimate_tokens_with_tiktoken( + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, +) -> int: + """Estimate tokens by serializing messages/tools to text and counting with tiktoken. + + This is a best-effort fallback for providers that don't expose native counting. + Accuracy varies by model but is sufficient for threshold-based decisions. + """ + encoding = _get_encoding() + total = 0 + + # Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting, + # since providers wrap system_prompt into system_prompt_content when both are provided. + if system_prompt_content: + for block in system_prompt_content: + if "text" in block: + total += len(encoding.encode(block["text"])) + elif system_prompt: + total += len(encoding.encode(system_prompt)) + + for message in messages: + for block in message["content"]: + total += _count_content_block_tokens(block, encoding) + + if tool_specs: + for spec in tool_specs: + try: + total += len(encoding.encode(json.dumps(spec))) + except (TypeError, ValueError): + logger.debug( + "tool_name=<%s> | skipping non-serializable tool spec for token estimation", + spec.get("name", "unknown"), + ) + + return total + @dataclass class CacheConfig: @@ -130,6 +235,34 @@ def stream( """ pass + def _estimate_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Estimate token count for the given input before sending to the model. + + Used for proactive context management (e.g., triggering compression at a + threshold). This is a naive approximation using tiktoken's cl100k_base encoding. + Accuracy varies by model provider but is typically within 5-10% for most providers. + Not intended for billing or precise quota calculations. + + Subclasses may override this method to provide model-specific token counting + using native APIs for improved accuracy. + + Args: + messages: List of message objects to estimate tokens for. + tool_specs: List of tool specifications to include in the estimate. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. Takes priority over system_prompt. + + Returns: + Estimated total input tokens. + """ + return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content) + class _ModelPlugin(Plugin): """Plugin that manages model-related lifecycle hooks.""" diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 458e98645..46044e327 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -213,3 +213,274 @@ def test_model_plugin_preserves_messages_when_not_stateful(model_plugin): model_plugin._on_after_invocation(event) assert len(agent.messages) == 1 + + +def test_estimate_tokens_empty_messages(model): + assert model._estimate_tokens(messages=[]) == 0 + + +def test_estimate_tokens_system_prompt_only(model): + result = model._estimate_tokens(messages=[], system_prompt="You are a helpful assistant.") + assert result == 6 + + +def test_estimate_tokens_text_messages(model, messages): + result = model._estimate_tokens(messages=messages) + assert result == 1 # "hello" + + +def test_estimate_tokens_with_tool_specs(model, messages, tool_specs): + without_tools = model._estimate_tokens(messages=messages) + with_tools = model._estimate_tokens(messages=messages, tool_specs=tool_specs) + assert without_tools == 1 # "hello" + assert with_tools == 49 # "hello" (1) + tool_spec (48) + + +def test_estimate_tokens_with_system_prompt(model, messages, system_prompt): + without_prompt = model._estimate_tokens(messages=messages) + with_prompt = model._estimate_tokens(messages=messages, system_prompt=system_prompt) + assert without_prompt == 1 # "hello" + assert with_prompt == 3 # "hello" (1) + "s1" (2) + + +def test_estimate_tokens_combined(model, messages, tool_specs, system_prompt): + result = model._estimate_tokens(messages=messages, tool_specs=tool_specs, system_prompt=system_prompt) + assert result == 51 # "hello" (1) + tool_spec (48) + "s1" (2) + + +def test_estimate_tokens_tool_use_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "my_tool", + "input": {"query": "test"}, + } + } + ], + } + ] + result = model._estimate_tokens(messages=messages) + # name "my_tool" (2) + json.dumps(input) (6) = 8 + assert result == 8 + + +def test_estimate_tokens_tool_result_block(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "tool output here"}], + "status": "success", + } + } + ], + } + ] + result = model._estimate_tokens(messages=messages) + assert result == 3 # "tool output here" + + +def test_estimate_tokens_reasoning_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "Let me think about this step by step.", + } + } + } + ], + } + ] + result = model._estimate_tokens(messages=messages) + assert result == 9 # "Let me think about this step by step." + + +def test_estimate_tokens_skips_binary_content(model): + messages = [ + { + "role": "user", + "content": [{"image": {"format": "png", "source": {"bytes": b"fake image data"}}}], + } + ] + assert model._estimate_tokens(messages=messages) == 0 + + +def test_estimate_tokens_tool_result_with_bytes_only(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"image": {"format": "png", "source": {"bytes": b"image data"}}}], + "status": "success", + } + } + ], + } + ] + result = model._estimate_tokens(messages=messages) + assert result == 0 + + +def test_estimate_tokens_tool_result_with_text_and_bytes(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [ + {"text": "Here is the screenshot"}, + {"image": {"format": "png", "source": {"bytes": b"image data"}}}, + ], + "status": "success", + } + } + ], + } + ] + result = model._estimate_tokens(messages=messages) + assert result > 0 + + +def test_estimate_tokens_guard_content_block(model): + messages = [ + { + "role": "assistant", + "content": [{"guardContent": {"text": {"text": "This content was filtered by guardrails."}}}], + } + ] + result = model._estimate_tokens(messages=messages) + assert result == 8 # "This content was filtered by guardrails." + + +def test_estimate_tokens_tool_use_with_bytes(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "my_tool", + "input": {"data": b"binary data"}, + } + } + ], + } + ] + result = model._estimate_tokens(messages=messages) + # Should still count the tool name even though input has non-serializable bytes + assert result == 2 # "my_tool" name only + + +def test_estimate_tokens_non_serializable_tool_spec(model, messages): + tool_specs = [ + { + "name": "test", + "description": "a tool", + "inputSchema": {"json": {"default": b"bytes"}}, + } + ] + result = model._estimate_tokens(messages=messages, tool_specs=tool_specs) + # Should still count the message tokens even though tool spec fails + assert result == 1 # "hello" only, tool spec skipped + + +def test_estimate_tokens_citations_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "content": [{"text": "According to the document, the answer is 42."}], + "citations": [], + } + } + ], + } + ] + result = model._estimate_tokens(messages=messages) + assert result == 11 # "According to the document, the answer is 42." + + +def test_estimate_tokens_system_prompt_content(model): + result = model._estimate_tokens( + messages=[], + system_prompt_content=[{"text": "You are a helpful assistant."}], + ) + assert result == 6 # "You are a helpful assistant." + + +def test_estimate_tokens_system_prompt_content_with_cache_point(model): + result = model._estimate_tokens( + messages=[], + system_prompt_content=[ + {"text": "You are a helpful assistant."}, + {"cachePoint": {"type": "default"}}, + ], + ) + assert result == 6 # "You are a helpful assistant.", cachePoint adds 0 + + +def test_estimate_tokens_system_prompt_content_takes_priority(model): + content_only = model._estimate_tokens( + messages=[], + system_prompt_content=[{"text": "Short."}], + ) + # When both are provided, system_prompt_content wins — system_prompt is ignored + both = model._estimate_tokens( + messages=[], + system_prompt="This is a much longer system prompt that should have more tokens.", + system_prompt_content=[{"text": "Short."}], + ) + assert content_only == 2 # "Short." + assert content_only == both + + +def test_estimate_tokens_all_inputs(model): + messages = [ + {"role": "user", "content": [{"text": "hello world"}]}, + {"role": "assistant", "content": [{"text": "hi there"}]}, + ] + result = model._estimate_tokens( + messages=messages, + tool_specs=[{"name": "test", "description": "a test tool", "inputSchema": {"json": {}}}], + system_prompt="Be helpful.", + system_prompt_content=[{"text": "Additional system context."}], + ) + # system_prompt_content (4) + "hello world" (2) + "hi there" (2) + tool_spec (23) = 31 + assert result == 31 + + +def test_get_encoding_raises_without_tiktoken(monkeypatch): + """Test that _get_encoding raises ImportError with install instructions when tiktoken is missing.""" + import strands.models.model as model_module + + monkeypatch.setattr(model_module, "_cached_encoding", None) + original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def _block_tiktoken(name, *args, **kwargs): + if name == "tiktoken": + raise ImportError("No module named 'tiktoken'") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", _block_tiktoken) + + with pytest.raises(ImportError, match="pip install strands-agents\\[token-estimation\\]"): + model_module._get_encoding()