From 442dbb76541fc8aa89df976e1d08402705e6357b Mon Sep 17 00:00:00 2001 From: FU-max-boop Date: Fri, 29 May 2026 02:46:18 +0800 Subject: [PATCH] fix streaming delta index merging --- src/openai/lib/streaming/_assistants.py | 64 +----------------- src/openai/lib/streaming/_deltas.py | 87 ++++++++++++++++++------- tests/lib/test_streaming_deltas.py | 76 +++++++++++++++++++++ 3 files changed, 141 insertions(+), 86 deletions(-) create mode 100644 tests/lib/test_streaming_deltas.py diff --git a/src/openai/lib/streaming/_assistants.py b/src/openai/lib/streaming/_assistants.py index 6efb3ca3f1..cc13e0bf2a 100644 --- a/src/openai/lib/streaming/_assistants.py +++ b/src/openai/lib/streaming/_assistants.py @@ -7,7 +7,8 @@ import httpx -from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator +from ._deltas import accumulate_delta +from ..._utils import consume_sync_iterator, consume_async_iterator from ..._compat import model_dump from ..._models import construct_type from ..._streaming import Stream, AsyncStream @@ -975,64 +976,3 @@ def accumulate_event( ) return current_message_snapshot, new_content - - -def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: - for key, delta_value in delta.items(): - if key not in acc: - acc[key] = delta_value - continue - - acc_value = acc[key] - if acc_value is None: - acc[key] = delta_value - continue - - # the `index` property is used in arrays of objects so it should - # not be accumulated like other values e.g. - # [{'foo': 'bar', 'index': 0}] - # - # the same applies to `type` properties as they're used for - # discriminated unions - if key == "index" or key == "type": - acc[key] = delta_value - continue - - if isinstance(acc_value, str) and isinstance(delta_value, str): - acc_value += delta_value - elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)): - acc_value += delta_value - elif is_dict(acc_value) and is_dict(delta_value): - acc_value = accumulate_delta(acc_value, delta_value) - elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue - - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc - - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") - - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") - - acc_value[index] = accumulate_delta(acc_entry, delta_entry) - - acc[key] = acc_value - - return acc diff --git a/src/openai/lib/streaming/_deltas.py b/src/openai/lib/streaming/_deltas.py index a5e1317612..335c77b26d 100644 --- a/src/openai/lib/streaming/_deltas.py +++ b/src/openai/lib/streaming/_deltas.py @@ -6,11 +6,19 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: + if is_list(delta_value) and _has_indexed_entries(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + continue + acc[key] = delta_value continue acc_value = acc[key] if acc_value is None: + if is_list(delta_value) and _has_indexed_entries(delta_value): + acc[key] = _accumulate_list_delta([], delta_value) + continue + acc[key] = delta_value continue @@ -31,34 +39,65 @@ def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): - # for lists of non-dictionary items we'll only ever get new entries - # in the array, existing entries will never be changed - if all(isinstance(x, (str, int, float)) for x in acc_value): - acc_value.extend(delta_value) - continue + acc_value = _accumulate_list_delta(acc_value, delta_value) + + acc[key] = acc_value - for delta_entry in delta_value: - if not is_dict(delta_entry): - raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") + return acc - try: - index = delta_entry["index"] - except KeyError as exc: - raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc - if not isinstance(index, int): - raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") +def _has_indexed_entries(value: list[object]) -> bool: + return any(is_dict(entry) and "index" in entry for entry in value) - try: - acc_entry = acc_value[index] - except IndexError: - acc_value.insert(index, delta_entry) - else: - if not is_dict(acc_entry): - raise TypeError("not handled yet") - acc_value[index] = accumulate_delta(acc_entry, delta_entry) +def _accumulate_list_delta(acc_value: list[object], delta_value: list[object]) -> list[object]: + # for lists of non-dictionary items we'll only ever get new entries + # in the array, existing entries will never be changed + if not _has_indexed_entries(delta_value) and all(isinstance(x, (str, int, float)) for x in acc_value): + acc_value.extend(delta_value) + return acc_value - acc[key] = acc_value + for delta_entry in delta_value: + if not is_dict(delta_entry): + raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") - return acc + try: + index = delta_entry["index"] + except KeyError as exc: + raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc + + if not isinstance(index, int): + raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") + + acc_index = _find_entry_index(acc_value, index) + if acc_index is None: + acc_value.insert(_find_insert_position(acc_value, index), delta_entry) + continue + + acc_entry = acc_value[acc_index] + if not is_dict(acc_entry): + raise TypeError("not handled yet") + + acc_value[acc_index] = accumulate_delta(acc_entry, delta_entry) + + return acc_value + + +def _find_entry_index(entries: list[object], index: int) -> int | None: + for entry_index, entry in enumerate(entries): + if is_dict(entry) and entry.get("index") == index: + return entry_index + + return None + + +def _find_insert_position(entries: list[object], index: int) -> int: + for entry_index, entry in enumerate(entries): + if not is_dict(entry): + continue + + entry_delta_index = entry.get("index") + if isinstance(entry_delta_index, int) and entry_delta_index > index: + return entry_index + + return len(entries) diff --git a/tests/lib/test_streaming_deltas.py b/tests/lib/test_streaming_deltas.py new file mode 100644 index 0000000000..64fe227ad8 --- /dev/null +++ b/tests/lib/test_streaming_deltas.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from openai.lib.streaming._deltas import accumulate_delta as accumulate_chat_delta +from openai.lib.streaming._assistants import accumulate_delta as accumulate_assistant_delta + + +def test_accumulate_delta_merges_duplicate_indexed_entries_on_initial_chunk() -> None: + acc: dict[object, object] = {"tool_calls": None} + + accumulate_chat_delta( + acc, + { + "tool_calls": [ + { + "index": 0, + "id": "call_abc", + "function": {"name": "get_weather"}, + "type": "function", + }, + {"index": 0, "function": {"arguments": '{"city"'}}, + ] + }, + ) + accumulate_chat_delta(acc, {"tool_calls": [{"index": 0, "function": {"arguments": ': "London"}'}}]}) + + assert acc == { + "tool_calls": [ + { + "index": 0, + "id": "call_abc", + "function": {"name": "get_weather", "arguments": '{"city": "London"}'}, + "type": "function", + } + ] + } + + +def test_assistant_accumulate_delta_uses_logical_index_for_initial_chunk() -> None: + acc: dict[object, object] = {} + + accumulate_assistant_delta( + acc, + { + "tool_calls": [ + {"index": 0, "id": "call_abc", "function": {"name": "get_weather"}, "type": "function"}, + {"index": 0, "function": {"arguments": '{"path"'}}, + {"index": 1, "id": "call_def", "function": {"name": "list_files"}, "type": "function"}, + ] + }, + ) + accumulate_assistant_delta( + acc, + { + "tool_calls": [ + {"index": 1, "function": {"arguments": '{"limit": 10}'}}, + {"index": 0, "function": {"arguments": ': "."}'}}, + ] + }, + ) + + assert acc == { + "tool_calls": [ + { + "index": 0, + "id": "call_abc", + "function": {"name": "get_weather", "arguments": '{"path": "."}'}, + "type": "function", + }, + { + "index": 1, + "id": "call_def", + "function": {"name": "list_files", "arguments": '{"limit": 10}'}, + "type": "function", + }, + ] + }