Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,16 +693,17 @@ def end_agent_span(
if hasattr(response, "metrics") and hasattr(response.metrics, "accumulated_usage"):
if self.is_langfuse:
attributes.update({"langfuse.observation.type": "span"})
accumulated_usage = response.metrics.accumulated_usage
latest_invocation = response.metrics.latest_agent_invocation
usage = latest_invocation.usage if latest_invocation else response.metrics.accumulated_usage
attributes.update(
{
"gen_ai.usage.prompt_tokens": accumulated_usage["inputTokens"],
"gen_ai.usage.completion_tokens": accumulated_usage["outputTokens"],
"gen_ai.usage.input_tokens": accumulated_usage["inputTokens"],
"gen_ai.usage.output_tokens": accumulated_usage["outputTokens"],
"gen_ai.usage.total_tokens": accumulated_usage["totalTokens"],
"gen_ai.usage.cache_read_input_tokens": accumulated_usage.get("cacheReadInputTokens", 0),
"gen_ai.usage.cache_write_input_tokens": accumulated_usage.get("cacheWriteInputTokens", 0),
"gen_ai.usage.prompt_tokens": usage["inputTokens"],
"gen_ai.usage.completion_tokens": usage["outputTokens"],
"gen_ai.usage.input_tokens": usage["inputTokens"],
"gen_ai.usage.output_tokens": usage["outputTokens"],
"gen_ai.usage.total_tokens": usage["totalTokens"],
"gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0),
"gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0),
}
)

Expand Down
78 changes: 73 additions & 5 deletions tests/strands/telemetry/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,8 +889,11 @@ def test_end_agent_span(mock_span):
tracer = Tracer()

# Mock AgentResult with metrics
mock_invocation = mock.MagicMock()
mock_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics.accumulated_usage = {"inputTokens": 500, "outputTokens": 1000, "totalTokens": 1500}
mock_metrics.latest_agent_invocation = mock_invocation

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
Expand Down Expand Up @@ -924,8 +927,11 @@ def test_end_agent_span_with_langfuse_observation_type(mock_span, monkeypatch):
tracer = Tracer()

# Mock AgentResult with metrics
mock_invocation = mock.MagicMock()
mock_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics.accumulated_usage = {"inputTokens": 500, "outputTokens": 1000, "totalTokens": 1500}
mock_metrics.latest_agent_invocation = mock_invocation

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
Expand Down Expand Up @@ -960,8 +966,11 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch):
tracer = Tracer()

# Mock AgentResult with metrics
mock_invocation = mock.MagicMock()
mock_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150}
mock_metrics.accumulated_usage = {"inputTokens": 500, "outputTokens": 1000, "totalTokens": 1500}
mock_metrics.latest_agent_invocation = mock_invocation

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
Expand Down Expand Up @@ -999,6 +1008,56 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch):
mock_span.end.assert_called_once()


def test_end_agent_span_uses_per_invocation_usage_not_accumulated(mock_span):
"""Test that agent span reports per-invocation usage, not session-accumulated usage."""
tracer = Tracer()

# Simulate a multi-invocation session where accumulated_usage has grown large
# but the latest invocation only used a small amount of tokens
mock_invocation = mock.MagicMock()
mock_invocation.usage = {"inputTokens": 100, "outputTokens": 50, "totalTokens": 150}

mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {"inputTokens": 1000, "outputTokens": 500, "totalTokens": 1500}
mock_metrics.latest_agent_invocation = mock_invocation

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
mock_response.stop_reason = "end_turn"
mock_response.__str__ = mock.MagicMock(return_value="Agent response")

tracer.end_agent_span(mock_span, mock_response)

call_args = mock_span.set_attributes.call_args[0][0]
# Should use per-invocation usage (100/50/150), NOT accumulated (1000/500/1500)
assert call_args["gen_ai.usage.input_tokens"] == 100
assert call_args["gen_ai.usage.output_tokens"] == 50
assert call_args["gen_ai.usage.total_tokens"] == 150
assert call_args["gen_ai.usage.prompt_tokens"] == 100
assert call_args["gen_ai.usage.completion_tokens"] == 50


def test_end_agent_span_falls_back_to_accumulated_when_no_invocations(mock_span):
"""Test fallback to accumulated_usage when no agent invocations exist."""
tracer = Tracer()

mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {"inputTokens": 200, "outputTokens": 100, "totalTokens": 300}
mock_metrics.latest_agent_invocation = None

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
mock_response.stop_reason = "end_turn"
mock_response.__str__ = mock.MagicMock(return_value="Agent response")

tracer.end_agent_span(mock_span, mock_response)

call_args = mock_span.set_attributes.call_args[0][0]
assert call_args["gen_ai.usage.input_tokens"] == 200
assert call_args["gen_ai.usage.output_tokens"] == 100
assert call_args["gen_ai.usage.total_tokens"] == 300


def test_end_model_invoke_span_with_cache_metrics(mock_span):
"""Test ending a model invoke span with cache metrics."""
tracer = Tracer()
Expand Down Expand Up @@ -1035,14 +1094,23 @@ def test_end_agent_span_with_cache_metrics(mock_span):
tracer = Tracer()

# Mock AgentResult with metrics including cache tokens
mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {
mock_invocation = mock.MagicMock()
mock_invocation.usage = {
"inputTokens": 50,
"outputTokens": 100,
"totalTokens": 150,
"cacheReadInputTokens": 25,
"cacheWriteInputTokens": 10,
}
mock_metrics = mock.MagicMock()
mock_metrics.accumulated_usage = {
"inputTokens": 500,
"outputTokens": 1000,
"totalTokens": 1500,
"cacheReadInputTokens": 250,
"cacheWriteInputTokens": 100,
}
mock_metrics.latest_agent_invocation = mock_invocation

mock_response = mock.MagicMock()
mock_response.metrics = mock_metrics
Expand Down