diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index c03d9d962..997523988 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -693,16 +693,17 @@ def end_agent_span( if hasattr(response, "metrics") and hasattr(response.metrics, "accumulated_usage"): if self.is_langfuse: attributes.update({"langfuse.observation.type": "span"}) - accumulated_usage = response.metrics.accumulated_usage + latest_invocation = response.metrics.latest_agent_invocation + usage = latest_invocation.usage if latest_invocation else response.metrics.accumulated_usage attributes.update( { - "gen_ai.usage.prompt_tokens": accumulated_usage["inputTokens"], - "gen_ai.usage.completion_tokens": accumulated_usage["outputTokens"], - "gen_ai.usage.input_tokens": accumulated_usage["inputTokens"], - "gen_ai.usage.output_tokens": accumulated_usage["outputTokens"], - "gen_ai.usage.total_tokens": accumulated_usage["totalTokens"], - "gen_ai.usage.cache_read_input_tokens": accumulated_usage.get("cacheReadInputTokens", 0), - "gen_ai.usage.cache_write_input_tokens": accumulated_usage.get("cacheWriteInputTokens", 0), + "gen_ai.usage.prompt_tokens": usage["inputTokens"], + "gen_ai.usage.completion_tokens": usage["outputTokens"], + "gen_ai.usage.input_tokens": usage["inputTokens"], + "gen_ai.usage.output_tokens": usage["outputTokens"], + "gen_ai.usage.total_tokens": usage["totalTokens"], + "gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0), + "gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0), } ) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 9176ce4ae..3d95d4b28 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -889,8 +889,11 @@ def test_end_agent_span(mock_span): tracer = Tracer() # Mock AgentResult with metrics + mock_invocation = mock.MagicMock() + mock_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} mock_metrics = mock.MagicMock() - mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} + mock_metrics.accumulated_usage = {"inputTokens": 500, "outputTokens": 1000, "totalTokens": 1500} + mock_metrics.latest_agent_invocation = mock_invocation mock_response = mock.MagicMock() mock_response.metrics = mock_metrics @@ -924,8 +927,11 @@ def test_end_agent_span_with_langfuse_observation_type(mock_span, monkeypatch): tracer = Tracer() # Mock AgentResult with metrics + mock_invocation = mock.MagicMock() + mock_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} mock_metrics = mock.MagicMock() - mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} + mock_metrics.accumulated_usage = {"inputTokens": 500, "outputTokens": 1000, "totalTokens": 1500} + mock_metrics.latest_agent_invocation = mock_invocation mock_response = mock.MagicMock() mock_response.metrics = mock_metrics @@ -960,8 +966,11 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch): tracer = Tracer() # Mock AgentResult with metrics + mock_invocation = mock.MagicMock() + mock_invocation.usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} mock_metrics = mock.MagicMock() - mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} + mock_metrics.accumulated_usage = {"inputTokens": 500, "outputTokens": 1000, "totalTokens": 1500} + mock_metrics.latest_agent_invocation = mock_invocation mock_response = mock.MagicMock() mock_response.metrics = mock_metrics @@ -999,6 +1008,56 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch): mock_span.end.assert_called_once() +def test_end_agent_span_uses_per_invocation_usage_not_accumulated(mock_span): + """Test that agent span reports per-invocation usage, not session-accumulated usage.""" + tracer = Tracer() + + # Simulate a multi-invocation session where accumulated_usage has grown large + # but the latest invocation only used a small amount of tokens + mock_invocation = mock.MagicMock() + mock_invocation.usage = {"inputTokens": 100, "outputTokens": 50, "totalTokens": 150} + + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 1000, "outputTokens": 500, "totalTokens": 1500} + mock_metrics.latest_agent_invocation = mock_invocation + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + tracer.end_agent_span(mock_span, mock_response) + + call_args = mock_span.set_attributes.call_args[0][0] + # Should use per-invocation usage (100/50/150), NOT accumulated (1000/500/1500) + assert call_args["gen_ai.usage.input_tokens"] == 100 + assert call_args["gen_ai.usage.output_tokens"] == 50 + assert call_args["gen_ai.usage.total_tokens"] == 150 + assert call_args["gen_ai.usage.prompt_tokens"] == 100 + assert call_args["gen_ai.usage.completion_tokens"] == 50 + + +def test_end_agent_span_falls_back_to_accumulated_when_no_invocations(mock_span): + """Test fallback to accumulated_usage when no agent invocations exist.""" + tracer = Tracer() + + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 200, "outputTokens": 100, "totalTokens": 300} + mock_metrics.latest_agent_invocation = None + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + tracer.end_agent_span(mock_span, mock_response) + + call_args = mock_span.set_attributes.call_args[0][0] + assert call_args["gen_ai.usage.input_tokens"] == 200 + assert call_args["gen_ai.usage.output_tokens"] == 100 + assert call_args["gen_ai.usage.total_tokens"] == 300 + + def test_end_model_invoke_span_with_cache_metrics(mock_span): """Test ending a model invoke span with cache metrics.""" tracer = Tracer() @@ -1035,14 +1094,23 @@ def test_end_agent_span_with_cache_metrics(mock_span): tracer = Tracer() # Mock AgentResult with metrics including cache tokens - mock_metrics = mock.MagicMock() - mock_metrics.accumulated_usage = { + mock_invocation = mock.MagicMock() + mock_invocation.usage = { "inputTokens": 50, "outputTokens": 100, "totalTokens": 150, "cacheReadInputTokens": 25, "cacheWriteInputTokens": 10, } + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = { + "inputTokens": 500, + "outputTokens": 1000, + "totalTokens": 1500, + "cacheReadInputTokens": 250, + "cacheWriteInputTokens": 100, + } + mock_metrics.latest_agent_invocation = mock_invocation mock_response = mock.MagicMock() mock_response.metrics = mock_metrics