Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ sagemaker = [
"openai>=1.68.0,<3.0.0", # SageMaker uses OpenAI-compatible interface
]
otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"]
# Rename this extra once compression/context management features land #555
token-estimation = ["tiktoken>=0.7.0,<1.0.0"]
docs = [
"sphinx>=5.0.0,<10.0.0",
"sphinx-rtd-theme>=1.0.0,<4.0.0",
Expand Down
135 changes: 134 additions & 1 deletion src/strands/models/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Abstract base class for Agent model providers."""

import abc
import json
import logging
from collections.abc import AsyncGenerator, AsyncIterable
from dataclasses import dataclass
Expand All @@ -10,7 +11,7 @@

from ..hooks.events import AfterInvocationEvent
from ..plugins.plugin import Plugin
from ..types.content import Messages, SystemContentBlock
from ..types.content import ContentBlock, Messages, SystemContentBlock
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolSpec

Expand All @@ -21,6 +22,110 @@

T = TypeVar("T", bound=BaseModel)

_DEFAULT_ENCODING = "cl100k_base"
_cached_encoding: Any = None


def _get_encoding() -> Any:
"""Get the default tiktoken encoding, caching to avoid repeated lookups."""
global _cached_encoding
if _cached_encoding is None:
try:
import tiktoken
except ImportError as err:
raise ImportError(
"tiktoken is required for token estimation. "
"Install it with: pip install strands-agents[token-estimation]"
) from err
_cached_encoding = tiktoken.get_encoding(_DEFAULT_ENCODING)
return _cached_encoding


def _count_content_block_tokens(block: ContentBlock, encoding: Any) -> int:
"""Count tokens for a single content block."""
total = 0

if "text" in block:
total += len(encoding.encode(block["text"]))

if "toolUse" in block:
tool_use = block["toolUse"]
total += len(encoding.encode(tool_use.get("name", "")))
try:
total += len(encoding.encode(json.dumps(tool_use.get("input", {}))))
except (TypeError, ValueError):
logger.debug(
"tool_name=<%s> | skipping non-serializable toolUse input for token estimation",
tool_use.get("name", "unknown"),
)

if "toolResult" in block:
tool_result = block["toolResult"]
for item in tool_result.get("content", []):
if "text" in item:
total += len(encoding.encode(item["text"]))

if "reasoningContent" in block:
reasoning = block["reasoningContent"]
if "reasoningText" in reasoning:
reasoning_text = reasoning["reasoningText"]
if "text" in reasoning_text:
total += len(encoding.encode(reasoning_text["text"]))

if "guardContent" in block:
guard = block["guardContent"]
if "text" in guard:
total += len(encoding.encode(guard["text"]["text"]))

if "citationsContent" in block:
citations = block["citationsContent"]
if "content" in citations:
for citation_item in citations["content"]:
if "text" in citation_item:
total += len(encoding.encode(citation_item["text"]))

return total


def _estimate_tokens_with_tiktoken(
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
) -> int:
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.

This is a best-effort fallback for providers that don't expose native counting.
Accuracy varies by model but is sufficient for threshold-based decisions.
"""
encoding = _get_encoding()
total = 0

# Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting,
# since providers wrap system_prompt into system_prompt_content when both are provided.
if system_prompt_content:
for block in system_prompt_content:
if "text" in block:
total += len(encoding.encode(block["text"]))
elif system_prompt:
total += len(encoding.encode(system_prompt))

for message in messages:
for block in message["content"]:
total += _count_content_block_tokens(block, encoding)

if tool_specs:
for spec in tool_specs:
try:
total += len(encoding.encode(json.dumps(spec)))
except (TypeError, ValueError):
logger.debug(
"tool_name=<%s> | skipping non-serializable tool spec for token estimation",
spec.get("name", "unknown"),
)

return total


@dataclass
class CacheConfig:
Expand Down Expand Up @@ -130,6 +235,34 @@ def stream(
"""
pass

def _estimate_tokens(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
) -> int:
"""Estimate token count for the given input before sending to the model.

Used for proactive context management (e.g., triggering compression at a
threshold). This is a naive approximation using tiktoken's cl100k_base encoding.
Accuracy varies by model provider but is typically within 5-10% for most providers.
Not intended for billing or precise quota calculations.

Subclasses may override this method to provide model-specific token counting
using native APIs for improved accuracy.

Args:
messages: List of message objects to estimate tokens for.
tool_specs: List of tool specifications to include in the estimate.
system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided.
system_prompt_content: Structured system prompt content blocks. Takes priority over system_prompt.

Returns:
Estimated total input tokens.
"""
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content)


class _ModelPlugin(Plugin):
"""Plugin that manages model-related lifecycle hooks."""
Expand Down
Loading
Loading