Skip to content

Commit 5a6df59

Browse files
feat: add count_token method to model with naive estimation using tiktoken (#2031)
Co-authored-by: opieter-aws <opieter@amazon.com>
1 parent 3e08d5e commit 5a6df59

2 files changed

Lines changed: 564 additions & 2 deletions

File tree

src/strands/models/model.py

Lines changed: 194 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""Abstract base class for Agent model providers."""
22

33
import abc
4+
import functools
5+
import json
46
import logging
5-
from collections.abc import AsyncGenerator, AsyncIterable
7+
import math
8+
from collections.abc import AsyncGenerator, AsyncIterable, Callable
69
from dataclasses import dataclass
710
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar
811

912
from pydantic import BaseModel
1013

1114
from ..hooks.events import AfterInvocationEvent
1215
from ..plugins.plugin import Plugin
13-
from ..types.content import Messages, SystemContentBlock
16+
from ..types.content import ContentBlock, Messages, SystemContentBlock
1417
from ..types.streaming import StreamEvent
1518
from ..types.tools import ToolChoice, ToolSpec
1619

@@ -21,6 +24,164 @@
2124

2225
T = TypeVar("T", bound=BaseModel)
2326

27+
_DEFAULT_ENCODING = "cl100k_base"
28+
29+
30+
def _heuristic_estimate_text(text: str) -> int:
31+
"""Estimate token count from text using characters / 4 heuristic."""
32+
return math.ceil(len(text) / 4)
33+
34+
35+
def _heuristic_estimate_json(obj: Any) -> int:
36+
"""Estimate token count from a JSON-serializable object using characters / 2 heuristic."""
37+
try:
38+
return math.ceil(len(json.dumps(obj)) / 2)
39+
except (TypeError, ValueError):
40+
return 0
41+
42+
43+
@functools.lru_cache(maxsize=1)
44+
def _get_encoding() -> Any:
45+
"""Get the default tiktoken encoding, caching to avoid repeated lookups.
46+
47+
Returns:
48+
The tiktoken encoding, or None if tiktoken is not installed.
49+
"""
50+
try:
51+
import tiktoken
52+
53+
return tiktoken.get_encoding(_DEFAULT_ENCODING)
54+
except ImportError:
55+
logger.debug("tiktoken not available, falling back to heuristic token estimation")
56+
return None
57+
58+
59+
def _count_content_block_tokens(
60+
block: ContentBlock, count_text: Callable[[str], int], count_json: Callable[[Any], int]
61+
) -> int:
62+
"""Count tokens for a single content block.
63+
64+
Args:
65+
block: The content block to count tokens for.
66+
count_text: Function that returns token count for a text string.
67+
count_json: Function that returns token count for a JSON-serializable object.
68+
"""
69+
total = 0
70+
71+
if "text" in block:
72+
total += count_text(block["text"])
73+
74+
if "toolUse" in block:
75+
tool_use = block["toolUse"]
76+
total += count_text(tool_use.get("name", ""))
77+
total += count_json(tool_use.get("input", {}))
78+
79+
if "toolResult" in block:
80+
tool_result = block["toolResult"]
81+
for item in tool_result.get("content", []):
82+
if "text" in item:
83+
total += count_text(item["text"])
84+
85+
if "reasoningContent" in block:
86+
reasoning = block["reasoningContent"]
87+
if "reasoningText" in reasoning:
88+
reasoning_text = reasoning["reasoningText"]
89+
if "text" in reasoning_text:
90+
total += count_text(reasoning_text["text"])
91+
92+
if "guardContent" in block:
93+
guard = block["guardContent"]
94+
if "text" in guard and "text" in guard["text"]:
95+
total += count_text(guard["text"]["text"])
96+
97+
if "citationsContent" in block:
98+
citations = block["citationsContent"]
99+
if "content" in citations:
100+
for citation_item in citations["content"]:
101+
if "text" in citation_item:
102+
total += count_text(citation_item["text"])
103+
104+
return total
105+
106+
107+
def _estimate_tokens_with_tiktoken(
108+
messages: Messages,
109+
tool_specs: list[ToolSpec] | None = None,
110+
system_prompt: str | None = None,
111+
system_prompt_content: list[SystemContentBlock] | None = None,
112+
) -> int:
113+
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.
114+
115+
This is a best-effort fallback for providers that don't expose native counting.
116+
Accuracy varies by model but is sufficient for threshold-based decisions.
117+
118+
Raises:
119+
ImportError: If tiktoken is not installed.
120+
"""
121+
encoding = _get_encoding()
122+
if encoding is None:
123+
raise ImportError("tiktoken is not available")
124+
125+
def count_text(text: str) -> int:
126+
return len(encoding.encode(text))
127+
128+
def count_json(obj: Any) -> int:
129+
try:
130+
return len(encoding.encode(json.dumps(obj)))
131+
except (TypeError, ValueError):
132+
return 0
133+
134+
total = 0
135+
136+
# Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting,
137+
# since providers wrap system_prompt into system_prompt_content when both are provided.
138+
if system_prompt_content:
139+
for block in system_prompt_content:
140+
if "text" in block:
141+
total += count_text(block["text"])
142+
elif system_prompt:
143+
total += count_text(system_prompt)
144+
145+
for message in messages:
146+
for block in message["content"]:
147+
total += _count_content_block_tokens(block, count_text, count_json)
148+
149+
if tool_specs:
150+
for spec in tool_specs:
151+
total += count_json(spec)
152+
153+
return total
154+
155+
156+
def _estimate_tokens_with_heuristic(
157+
messages: Messages,
158+
tool_specs: list[ToolSpec] | None = None,
159+
system_prompt: str | None = None,
160+
system_prompt_content: list[SystemContentBlock] | None = None,
161+
) -> int:
162+
"""Estimate tokens using character-based heuristics (text: chars/4, JSON: chars/2).
163+
164+
Dependency-free fallback when tiktoken is not installed.
165+
"""
166+
total = 0
167+
168+
if system_prompt_content:
169+
for block in system_prompt_content:
170+
if "text" in block:
171+
total += _heuristic_estimate_text(block["text"])
172+
elif system_prompt:
173+
total += _heuristic_estimate_text(system_prompt)
174+
175+
for message in messages:
176+
for block in message["content"]:
177+
total += _count_content_block_tokens(block, _heuristic_estimate_text, _heuristic_estimate_json)
178+
179+
if tool_specs:
180+
for spec in tool_specs:
181+
total += _heuristic_estimate_json(spec)
182+
183+
return total
184+
24185

25186
class BaseModelConfig(TypedDict, total=False):
26187
"""Base configuration shared by all model providers.
@@ -151,6 +312,37 @@ def stream(
151312
"""
152313
pass
153314

315+
async def count_tokens(
316+
self,
317+
messages: Messages,
318+
tool_specs: list[ToolSpec] | None = None,
319+
system_prompt: str | None = None,
320+
system_prompt_content: list[SystemContentBlock] | None = None,
321+
) -> int:
322+
"""Estimate token count for the given input before sending to the model.
323+
324+
Used for proactive context management (e.g., triggering compression at a threshold).
325+
Uses tiktoken's cl100k_base encoding when available, otherwise falls back to a
326+
heuristic (characters / 4 for text, characters / 2 for JSON). Accuracy varies by
327+
model provider. Not intended for billing or precise quota calculations.
328+
329+
Subclasses may override this method to provide model-specific token counting
330+
using native APIs for improved accuracy.
331+
332+
Args:
333+
messages: List of message objects to estimate tokens for.
334+
tool_specs: List of tool specifications to include in the estimate.
335+
system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided.
336+
system_prompt_content: Structured system prompt content blocks. Takes priority over system_prompt.
337+
338+
Returns:
339+
Estimated total input tokens.
340+
"""
341+
try:
342+
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content)
343+
except ImportError:
344+
return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content)
345+
154346

155347
class _ModelPlugin(Plugin):
156348
"""Plugin that manages model-related lifecycle hooks."""

0 commit comments

Comments
 (0)