From 387a8dda8136c09109bd1268a786601bb060fe01 Mon Sep 17 00:00:00 2001 From: venti <1308199824@qq.com> Date: Sat, 30 May 2026 15:07:51 +0800 Subject: [PATCH 1/2] fix: aggregate code_interpreter_tool_call chunks in CosmosHistoryProvider (#5793) During streaming, code interpreter output arrives as multiple content items with the same call_id but different sequence_number values. Before persisting to Cosmos DB, consecutive code_interpreter_tool_call chunks with matching call_id are now merged into a single Content item with the complete, aggregated text. --- .../_history_provider.py | 57 ++++++++- .../tests/test_cosmos_history_provider.py | 116 ++++++++++++++++++ 2 files changed, 171 insertions(+), 2 deletions(-) diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index 62a83e6a0f..c27d51eab5 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -7,10 +7,10 @@ import logging import time import uuid -from collections.abc import Sequence +from collections.abc import MutableMapping, Sequence from typing import Any, ClassVar, TypedDict -from agent_framework import Message +from agent_framework import Message, Content from agent_framework._sessions import HistoryProvider from agent_framework._settings import SecretString, load_settings from agent_framework._telemetry import get_user_agent @@ -24,6 +24,25 @@ logger = logging.getLogger(__name__) +def _merge_code_interpreter_chunks(chunks: list[Content], call_id: str | None) -> Content: + """Merge code_interpreter_tool_call chunks into a single Content item.""" + all_text_parts: list[str] = [] + merged_additional_properties: MutableMapping[str, Any] = {} + for chunk in chunks: + for inp in (chunk.inputs or []): + if inp.type == "text" and inp.text: + all_text_parts.append(inp.text) + if chunk.additional_properties: + merged_additional_properties.update(chunk.additional_properties) + merged_text = "".join(all_text_parts) + merged_inputs = [Content.from_text(merged_text)] if merged_text else None + return Content.from_code_interpreter_tool_call( + call_id=call_id, + inputs=merged_inputs, + additional_properties=merged_additional_properties or None, + ) + + class AzureCosmosHistorySettings(TypedDict, total=False): """Settings for CosmosHistoryProvider resolved from args and environment.""" @@ -167,6 +186,38 @@ async def get_messages( return messages + @staticmethod + def _aggregate_code_interpreter_calls(contents: list[Content]) -> list[Content]: + """Merge consecutive code_interpreter_tool_call chunks with the same call_id. + + During streaming, code interpreter output arrives as multiple content items + each with the same call_id but different sequence_number values. This method + aggregates them into a single item with the complete text. + """ + if not contents: + return contents + aggregated: list[Content] = [] + pending: list[Content] = [] + pending_call_id: str | None = None + for item in contents: + if item.type != "code_interpreter_tool_call": + if pending: + aggregated.append(_merge_code_interpreter_chunks(pending, pending_call_id)) + pending = [] + pending_call_id = None + aggregated.append(item) + continue + if pending_call_id is not None and item.call_id == pending_call_id: + pending.append(item) + else: + if pending: + aggregated.append(_merge_code_interpreter_chunks(pending, pending_call_id)) + pending = [item] + pending_call_id = item.call_id + if pending: + aggregated.append(_merge_code_interpreter_chunks(pending, pending_call_id)) + return aggregated + async def save_messages( self, session_id: str | None, @@ -185,6 +236,8 @@ async def save_messages( base_sort_key = time.time_ns() operations: list[tuple[str, tuple[dict[str, Any]]]] = [] for index, message in enumerate(messages): + if message.contents: + message.contents = self._aggregate_code_interpreter_calls(message.contents) document = { "id": str(uuid.uuid4()), "session_id": session_key, diff --git a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py index e3ac636aa6..6d8651efa0 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py @@ -409,3 +409,119 @@ async def test_cosmos_history_provider_roundtrip_with_emulator() -> None: finally: with suppress(CosmosResourceNotFoundError): await cosmos_client.delete_database(database_name) + + +class TestCodeInterpreterAggregation: + """Tests for code_interpreter_tool_call chunk aggregation.""" + + def test_merge_code_interpreter_chunks(self) -> None: + """_merge_code_interpreter_chunks merges multiple text chunks.""" + from agent_framework import Content + from agent_framework_azure_cosmos._history_provider import _merge_code_interpreter_chunks + + chunk1 = Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text("import pandas")], + additional_properties={"sequence_number": 1}, + ) + chunk2 = Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text(" as pd")], + additional_properties={"sequence_number": 2}, + ) + merged = _merge_code_interpreter_chunks([chunk1, chunk2], "ci_abc") + assert merged.type == "code_interpreter_tool_call" + assert merged.call_id == "ci_abc" + assert merged.inputs is not None + assert len(merged.inputs) == 1 + assert merged.inputs[0].type == "text" + assert merged.inputs[0].text == "import pandas as pd" + + def test_aggregate_code_interpreter_calls_merges_consecutive_chunks(self) -> None: + """_aggregate_code_interpreter_calls merges consecutive chunks with same call_id.""" + from agent_framework import Content + from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider + + contents = [ + Content.from_text_reasoning(text="thinking..."), + Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text("import ")], + additional_properties={"sequence_number": 1}, + ), + Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text("pandas")], + additional_properties={"sequence_number": 2}, + ), + Content.from_text("Final result: 42"), + ] + result = CosmosHistoryProvider._aggregate_code_interpreter_calls(contents) + assert len(result) == 3 + assert result[0].type == "text_reasoning" + assert result[1].type == "code_interpreter_tool_call" + assert result[1].inputs is not None + assert result[1].inputs[0].text == "import pandas" + assert result[2].type == "text" + + def test_aggregate_preserves_independent_tool_calls(self) -> None: + """Tool calls with different call_ids are not merged.""" + from agent_framework import Content + from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider + + contents = [ + Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text("code a")], + ), + Content.from_code_interpreter_tool_call( + call_id="ci_def", + inputs=[Content.from_text("code b")], + ), + ] + result = CosmosHistoryProvider._aggregate_code_interpreter_calls(contents) + assert len(result) == 2 + assert result[0].call_id == "ci_abc" + assert result[1].call_id == "ci_def" + + def test_save_messages_calls_aggregation(self) -> None: + """save_messages aggregates code interpreter chunks before saving.""" + from unittest.mock import AsyncMock, patch, MagicMock + from agent_framework import Content, Message + from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider + + provider = CosmosHistoryProvider( + endpoint="https://mock.documents.azure.com:443/", + database_name="test-db", + container_name="test-container", + credential="mock-key", + ) + provider._container_proxy = MagicMock() + provider._container_proxy.execute_item_batch = AsyncMock() + + message = Message( + role="assistant", + contents=[ + Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text("import ")], + additional_properties={"sequence_number": 1}, + ), + Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text("pandas")], + additional_properties={"sequence_number": 2}, + ), + ], + ) + with patch.object(CosmosHistoryProvider, "_ensure_container_proxy", new_callable=AsyncMock): + import asyncio + asyncio.run(provider.save_messages("session-1", [message])) + call_kwargs = provider._container_proxy.execute_item_batch.call_args + assert call_kwargs is not None + batch_ops = call_kwargs[1]["batch_operations"] + assert len(batch_ops) == 1 + saved_msg = batch_ops[0][1][0]["message"] + assert len(saved_msg["contents"]) == 1 + assert saved_msg["contents"][0]["type"] == "code_interpreter_tool_call" + assert saved_msg["contents"][0]["inputs"][0]["text"] == "import pandas" From 2ae1e661fca95dd01bff439cd6096bd1c50a75a7 Mon Sep 17 00:00:00 2001 From: venti <1308199824@qq.com> Date: Sat, 30 May 2026 15:42:18 +0800 Subject: [PATCH 2/2] address PR review: immutability, non-text inputs, sequence_number, test style - save_messages: build message dict directly instead of mutating caller's Message\n- _merge_code_interpreter_chunks: preserve non-text inputs, drop sequence_number\n- Tests: module-level imports, async test methods, add coverage for new behaviors --- .../_history_provider.py | 28 +++++++-- .../tests/test_cosmos_history_provider.py | 58 ++++++++++++------- 2 files changed, 59 insertions(+), 27 deletions(-) diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index c27d51eab5..12ac729c47 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -25,20 +25,34 @@ def _merge_code_interpreter_chunks(chunks: list[Content], call_id: str | None) -> Content: - """Merge code_interpreter_tool_call chunks into a single Content item.""" + """Merge code_interpreter_tool_call chunks into a single Content item. + + Concatenates text inputs in order and preserves non-text inputs. + Drops per-chunk sequence_number from the merged additional_properties + since the aggregated item no longer represents a single chunk. + """ all_text_parts: list[str] = [] + non_text_inputs: list[Content] = [] merged_additional_properties: MutableMapping[str, Any] = {} for chunk in chunks: for inp in (chunk.inputs or []): if inp.type == "text" and inp.text: all_text_parts.append(inp.text) + else: + non_text_inputs.append(inp) if chunk.additional_properties: - merged_additional_properties.update(chunk.additional_properties) + for k, v in chunk.additional_properties.items(): + if k == "sequence_number": + continue + merged_additional_properties[k] = v + merged_inputs: list[Content] = [] merged_text = "".join(all_text_parts) - merged_inputs = [Content.from_text(merged_text)] if merged_text else None + if merged_text: + merged_inputs.append(Content.from_text(merged_text)) + merged_inputs.extend(non_text_inputs) return Content.from_code_interpreter_tool_call( call_id=call_id, - inputs=merged_inputs, + inputs=merged_inputs or None, additional_properties=merged_additional_properties or None, ) @@ -236,14 +250,16 @@ async def save_messages( base_sort_key = time.time_ns() operations: list[tuple[str, tuple[dict[str, Any]]]] = [] for index, message in enumerate(messages): + message_dict = message.to_dict() if message.contents: - message.contents = self._aggregate_code_interpreter_calls(message.contents) + aggregated = self._aggregate_code_interpreter_calls(message.contents) + message_dict["contents"] = [c.to_dict() for c in aggregated] document = { "id": str(uuid.uuid4()), "session_id": session_key, "sort_key": base_sort_key + index, "source_id": self.source_id, - "message": message.to_dict(), + "message": message_dict, } operations.append(("upsert", (document,))) diff --git a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py index 6d8651efa0..baa4878756 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py @@ -10,14 +10,14 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import AgentResponse, Message +from agent_framework import AgentResponse, Content, Message from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import SettingNotFoundError from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosResourceNotFoundError import agent_framework_azure_cosmos._history_provider as history_provider_module -from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider +from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider, _merge_code_interpreter_chunks skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif( any( @@ -416,9 +416,6 @@ class TestCodeInterpreterAggregation: def test_merge_code_interpreter_chunks(self) -> None: """_merge_code_interpreter_chunks merges multiple text chunks.""" - from agent_framework import Content - from agent_framework_azure_cosmos._history_provider import _merge_code_interpreter_chunks - chunk1 = Content.from_code_interpreter_tool_call( call_id="ci_abc", inputs=[Content.from_text("import pandas")], @@ -437,11 +434,33 @@ def test_merge_code_interpreter_chunks(self) -> None: assert merged.inputs[0].type == "text" assert merged.inputs[0].text == "import pandas as pd" + def test_merge_skips_sequence_number(self) -> None: + """_merge_code_interpreter_chunks drops sequence_number from merged properties.""" + chunk = Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text("hello")], + additional_properties={"sequence_number": 5, "item_id": "ci_abc"}, + ) + merged = _merge_code_interpreter_chunks([chunk], "ci_abc") + assert merged.additional_properties is not None + assert "sequence_number" not in merged.additional_properties + assert merged.additional_properties.get("item_id") == "ci_abc" + + def test_merge_preserves_non_text_inputs(self) -> None: + """Non-text inputs in chunks are preserved in the merged item.""" + image_input = Content.from_image("https://example.com/img.png") + chunk = Content.from_code_interpreter_tool_call( + call_id="ci_abc", + inputs=[Content.from_text("plot"), image_input], + ) + merged = _merge_code_interpreter_chunks([chunk], "ci_abc") + assert merged.inputs is not None + assert len(merged.inputs) == 2 + assert merged.inputs[0].type == "text" + assert merged.inputs[1].type == "image" + def test_aggregate_code_interpreter_calls_merges_consecutive_chunks(self) -> None: """_aggregate_code_interpreter_calls merges consecutive chunks with same call_id.""" - from agent_framework import Content - from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider - contents = [ Content.from_text_reasoning(text="thinking..."), Content.from_code_interpreter_tool_call( @@ -466,9 +485,6 @@ def test_aggregate_code_interpreter_calls_merges_consecutive_chunks(self) -> Non def test_aggregate_preserves_independent_tool_calls(self) -> None: """Tool calls with different call_ids are not merged.""" - from agent_framework import Content - from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider - contents = [ Content.from_code_interpreter_tool_call( call_id="ci_abc", @@ -484,12 +500,9 @@ def test_aggregate_preserves_independent_tool_calls(self) -> None: assert result[0].call_id == "ci_abc" assert result[1].call_id == "ci_def" - def test_save_messages_calls_aggregation(self) -> None: - """save_messages aggregates code interpreter chunks before saving.""" - from unittest.mock import AsyncMock, patch, MagicMock - from agent_framework import Content, Message - from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider - + @pytest.mark.asyncio + async def test_save_messages_does_not_mutate_original(self) -> None: + """save_messages does not mutate the caller's Message objects.""" provider = CosmosHistoryProvider( endpoint="https://mock.documents.azure.com:443/", database_name="test-db", @@ -514,14 +527,17 @@ def test_save_messages_calls_aggregation(self) -> None: ), ], ) + original_contents = list(message.contents) with patch.object(CosmosHistoryProvider, "_ensure_container_proxy", new_callable=AsyncMock): - import asyncio - asyncio.run(provider.save_messages("session-1", [message])) + await provider.save_messages("session-1", [message]) + # Original message contents should be unchanged + assert len(message.contents) == 2 + assert message.contents[0].inputs is not None + assert message.contents[0].inputs[0].text == "import " + # Saved document should contain aggregated content call_kwargs = provider._container_proxy.execute_item_batch.call_args assert call_kwargs is not None batch_ops = call_kwargs[1]["batch_operations"] - assert len(batch_ops) == 1 saved_msg = batch_ops[0][1][0]["message"] assert len(saved_msg["contents"]) == 1 - assert saved_msg["contents"][0]["type"] == "code_interpreter_tool_call" assert saved_msg["contents"][0]["inputs"][0]["text"] == "import pandas"