Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import logging
import time
import uuid
from collections.abc import Sequence
from collections.abc import MutableMapping, Sequence
from typing import Any, ClassVar, TypedDict

from agent_framework import Message
from agent_framework import Message, Content
from agent_framework._sessions import HistoryProvider
from agent_framework._settings import SecretString, load_settings
from agent_framework._telemetry import get_user_agent
Expand All @@ -24,6 +24,39 @@
logger = logging.getLogger(__name__)


def _merge_code_interpreter_chunks(chunks: list[Content], call_id: str | None) -> Content:
"""Merge code_interpreter_tool_call chunks into a single Content item.

Concatenates text inputs in order and preserves non-text inputs.
Drops per-chunk sequence_number from the merged additional_properties
since the aggregated item no longer represents a single chunk.
"""
all_text_parts: list[str] = []
non_text_inputs: list[Content] = []
merged_additional_properties: MutableMapping[str, Any] = {}
for chunk in chunks:
for inp in (chunk.inputs or []):
if inp.type == "text" and inp.text:
all_text_parts.append(inp.text)
Comment on lines +37 to +40
else:
non_text_inputs.append(inp)
if chunk.additional_properties:
for k, v in chunk.additional_properties.items():
if k == "sequence_number":
continue
merged_additional_properties[k] = v
merged_inputs: list[Content] = []
merged_text = "".join(all_text_parts)
if merged_text:
merged_inputs.append(Content.from_text(merged_text))
merged_inputs.extend(non_text_inputs)
return Content.from_code_interpreter_tool_call(
call_id=call_id,
inputs=merged_inputs or None,
additional_properties=merged_additional_properties or None,
)


class AzureCosmosHistorySettings(TypedDict, total=False):
"""Settings for CosmosHistoryProvider resolved from args and environment."""

Expand Down Expand Up @@ -167,6 +200,38 @@ async def get_messages(

return messages

@staticmethod
def _aggregate_code_interpreter_calls(contents: list[Content]) -> list[Content]:
"""Merge consecutive code_interpreter_tool_call chunks with the same call_id.

During streaming, code interpreter output arrives as multiple content items
each with the same call_id but different sequence_number values. This method
aggregates them into a single item with the complete text.
"""
if not contents:
return contents
aggregated: list[Content] = []
pending: list[Content] = []
pending_call_id: str | None = None
for item in contents:
if item.type != "code_interpreter_tool_call":
if pending:
aggregated.append(_merge_code_interpreter_chunks(pending, pending_call_id))
pending = []
pending_call_id = None
aggregated.append(item)
continue
if pending_call_id is not None and item.call_id == pending_call_id:
pending.append(item)
else:
if pending:
aggregated.append(_merge_code_interpreter_chunks(pending, pending_call_id))
pending = [item]
pending_call_id = item.call_id
if pending:
aggregated.append(_merge_code_interpreter_chunks(pending, pending_call_id))
return aggregated

async def save_messages(
self,
session_id: str | None,
Expand All @@ -185,12 +250,16 @@ async def save_messages(
base_sort_key = time.time_ns()
operations: list[tuple[str, tuple[dict[str, Any]]]] = []
for index, message in enumerate(messages):
message_dict = message.to_dict()
if message.contents:
aggregated = self._aggregate_code_interpreter_calls(message.contents)
message_dict["contents"] = [c.to_dict() for c in aggregated]
document = {
"id": str(uuid.uuid4()),
"session_id": session_key,
"sort_key": base_sort_key + index,
"source_id": self.source_id,
"message": message.to_dict(),
"message": message_dict,
}
operations.append(("upsert", (document,)))

Expand Down
136 changes: 134 additions & 2 deletions python/packages/azure-cosmos/tests/test_cosmos_history_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from agent_framework import AgentResponse, Message
from agent_framework import AgentResponse, Content, Message
from agent_framework._sessions import AgentSession, SessionContext
from agent_framework.exceptions import SettingNotFoundError
from azure.cosmos.aio import CosmosClient
from azure.cosmos.exceptions import CosmosResourceNotFoundError

import agent_framework_azure_cosmos._history_provider as history_provider_module
from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider
from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider, _merge_code_interpreter_chunks

skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif(
any(
Expand Down Expand Up @@ -409,3 +409,135 @@ async def test_cosmos_history_provider_roundtrip_with_emulator() -> None:
finally:
with suppress(CosmosResourceNotFoundError):
await cosmos_client.delete_database(database_name)


class TestCodeInterpreterAggregation:
"""Tests for code_interpreter_tool_call chunk aggregation."""

def test_merge_code_interpreter_chunks(self) -> None:
"""_merge_code_interpreter_chunks merges multiple text chunks."""
chunk1 = Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text("import pandas")],
additional_properties={"sequence_number": 1},
)
chunk2 = Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text(" as pd")],
additional_properties={"sequence_number": 2},
)
merged = _merge_code_interpreter_chunks([chunk1, chunk2], "ci_abc")
assert merged.type == "code_interpreter_tool_call"
assert merged.call_id == "ci_abc"
assert merged.inputs is not None
assert len(merged.inputs) == 1
assert merged.inputs[0].type == "text"
assert merged.inputs[0].text == "import pandas as pd"

def test_merge_skips_sequence_number(self) -> None:
"""_merge_code_interpreter_chunks drops sequence_number from merged properties."""
chunk = Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text("hello")],
additional_properties={"sequence_number": 5, "item_id": "ci_abc"},
)
merged = _merge_code_interpreter_chunks([chunk], "ci_abc")
assert merged.additional_properties is not None
assert "sequence_number" not in merged.additional_properties
assert merged.additional_properties.get("item_id") == "ci_abc"

def test_merge_preserves_non_text_inputs(self) -> None:
"""Non-text inputs in chunks are preserved in the merged item."""
image_input = Content.from_image("https://example.com/img.png")
chunk = Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text("plot"), image_input],
)
merged = _merge_code_interpreter_chunks([chunk], "ci_abc")
assert merged.inputs is not None
assert len(merged.inputs) == 2
assert merged.inputs[0].type == "text"
assert merged.inputs[1].type == "image"

def test_aggregate_code_interpreter_calls_merges_consecutive_chunks(self) -> None:
"""_aggregate_code_interpreter_calls merges consecutive chunks with same call_id."""
contents = [
Content.from_text_reasoning(text="thinking..."),
Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text("import ")],
additional_properties={"sequence_number": 1},
),
Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text("pandas")],
additional_properties={"sequence_number": 2},
),
Content.from_text("Final result: 42"),
]
result = CosmosHistoryProvider._aggregate_code_interpreter_calls(contents)
assert len(result) == 3
assert result[0].type == "text_reasoning"
assert result[1].type == "code_interpreter_tool_call"
assert result[1].inputs is not None
assert result[1].inputs[0].text == "import pandas"
assert result[2].type == "text"

def test_aggregate_preserves_independent_tool_calls(self) -> None:
"""Tool calls with different call_ids are not merged."""
contents = [
Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text("code a")],
),
Content.from_code_interpreter_tool_call(
call_id="ci_def",
inputs=[Content.from_text("code b")],
),
]
result = CosmosHistoryProvider._aggregate_code_interpreter_calls(contents)
assert len(result) == 2
assert result[0].call_id == "ci_abc"
assert result[1].call_id == "ci_def"

@pytest.mark.asyncio
async def test_save_messages_does_not_mutate_original(self) -> None:
"""save_messages does not mutate the caller's Message objects."""
provider = CosmosHistoryProvider(
endpoint="https://mock.documents.azure.com:443/",
database_name="test-db",
container_name="test-container",
credential="mock-key",
)
provider._container_proxy = MagicMock()
provider._container_proxy.execute_item_batch = AsyncMock()

message = Message(
role="assistant",
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text("import ")],
additional_properties={"sequence_number": 1},
),
Content.from_code_interpreter_tool_call(
call_id="ci_abc",
inputs=[Content.from_text("pandas")],
additional_properties={"sequence_number": 2},
),
],
)
original_contents = list(message.contents)
with patch.object(CosmosHistoryProvider, "_ensure_container_proxy", new_callable=AsyncMock):
await provider.save_messages("session-1", [message])
# Original message contents should be unchanged
assert len(message.contents) == 2
assert message.contents[0].inputs is not None
assert message.contents[0].inputs[0].text == "import "
# Saved document should contain aggregated content
call_kwargs = provider._container_proxy.execute_item_batch.call_args
assert call_kwargs is not None
batch_ops = call_kwargs[1]["batch_operations"]
saved_msg = batch_ops[0][1][0]["message"]
assert len(saved_msg["contents"]) == 1
assert saved_msg["contents"][0]["inputs"][0]["text"] == "import pandas"