Skip to content

Commit 9c3bb56

Browse files
committed
fix(streaming): handle top-level SSE error messages
1 parent 5ae2cc1 commit 9c3bb56

2 files changed

Lines changed: 108 additions & 66 deletions

File tree

src/openai/_streaming.py

Lines changed: 47 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,35 @@
2020
_T = TypeVar("_T")
2121

2222

23+
def _build_streaming_api_error(
24+
*,
25+
data: object,
26+
request: httpx.Request,
27+
is_error_event: bool,
28+
) -> APIError | None:
29+
if not is_mapping(data):
30+
return None
31+
32+
error = data.get("error")
33+
if is_error_event:
34+
body = error if error is not None else data
35+
elif error is not None:
36+
body = error
37+
else:
38+
return None
39+
40+
message = data.get("message")
41+
if not isinstance(message, str) and is_mapping(error):
42+
nested_message = error.get("message")
43+
if isinstance(nested_message, str):
44+
message = nested_message
45+
46+
if not isinstance(message, str) or not message:
47+
message = "An error occurred during streaming"
48+
49+
return APIError(message=message, request=request, body=body)
50+
51+
2352
class Stream(Generic[_T]):
2453
"""Provides the core interface to iterate over a synchronous stream response."""
2554

@@ -63,41 +92,19 @@ def __stream__(self) -> Iterator[_T]:
6392
if sse.data.startswith("[DONE]"):
6493
break
6594

95+
data = sse.json()
96+
api_error = _build_streaming_api_error(
97+
data=data,
98+
request=self.response.request,
99+
is_error_event=sse.event == "error",
100+
)
101+
if api_error is not None:
102+
raise api_error
103+
66104
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
67105
if sse.event and sse.event.startswith("thread."):
68-
data = sse.json()
69-
70-
if sse.event == "error" and is_mapping(data) and data.get("error"):
71-
message = None
72-
error = data.get("error")
73-
if is_mapping(error):
74-
message = error.get("message")
75-
if not message or not isinstance(message, str):
76-
message = "An error occurred during streaming"
77-
78-
raise APIError(
79-
message=message,
80-
request=self.response.request,
81-
body=data["error"],
82-
)
83-
84106
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
85107
else:
86-
data = sse.json()
87-
if is_mapping(data) and data.get("error"):
88-
message = None
89-
error = data.get("error")
90-
if is_mapping(error):
91-
message = error.get("message")
92-
if not message or not isinstance(message, str):
93-
message = "An error occurred during streaming"
94-
95-
raise APIError(
96-
message=message,
97-
request=self.response.request,
98-
body=data["error"],
99-
)
100-
101108
yield process_data(
102109
data={"data": data, "event": sse.event}
103110
if self._options is not None and self._options.synthesize_event_and_data
@@ -173,41 +180,19 @@ async def __stream__(self) -> AsyncIterator[_T]:
173180
if sse.data.startswith("[DONE]"):
174181
break
175182

183+
data = sse.json()
184+
api_error = _build_streaming_api_error(
185+
data=data,
186+
request=self.response.request,
187+
is_error_event=sse.event == "error",
188+
)
189+
if api_error is not None:
190+
raise api_error
191+
176192
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
177193
if sse.event and sse.event.startswith("thread."):
178-
data = sse.json()
179-
180-
if sse.event == "error" and is_mapping(data) and data.get("error"):
181-
message = None
182-
error = data.get("error")
183-
if is_mapping(error):
184-
message = error.get("message")
185-
if not message or not isinstance(message, str):
186-
message = "An error occurred during streaming"
187-
188-
raise APIError(
189-
message=message,
190-
request=self.response.request,
191-
body=data["error"],
192-
)
193-
194194
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
195195
else:
196-
data = sse.json()
197-
if is_mapping(data) and data.get("error"):
198-
message = None
199-
error = data.get("error")
200-
if is_mapping(error):
201-
message = error.get("message")
202-
if not message or not isinstance(message, str):
203-
message = "An error occurred during streaming"
204-
205-
raise APIError(
206-
message=message,
207-
request=self.response.request,
208-
body=data["error"],
209-
)
210-
211196
yield process_data(
212197
data={"data": data, "event": sse.event}
213198
if self._options is not None and self._options.synthesize_event_and_data

tests/test_streaming.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

3-
from typing import Iterator, AsyncIterator
3+
from typing import TypeVar, Iterator, AsyncIterator
44

55
import httpx
66
import pytest
77

8-
from openai import OpenAI, AsyncOpenAI
8+
from openai import OpenAI, APIError, AsyncOpenAI
99
from openai._streaming import Stream, AsyncStream, ServerSentEvent
1010

11+
_ItemT = TypeVar("_ItemT")
12+
1113

1214
@pytest.mark.asyncio
1315
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
@@ -216,23 +218,78 @@ def body() -> Iterator[bytes]:
216218
assert sse.json() == {"content": "известни"}
217219

218220

221+
@pytest.mark.asyncio
222+
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
223+
async def test_error_event_uses_top_level_message(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
224+
def body() -> Iterator[bytes]:
225+
yield b"event: error\n"
226+
yield b'data: {"type":"error","code":"server_error","message":"Something went wrong"}\n'
227+
yield b"\n"
228+
229+
stream = make_stream(content=body(), sync=sync, client=client, async_client=async_client)
230+
231+
with pytest.raises(APIError, match="Something went wrong") as exc_info:
232+
await iter_next(stream)
233+
234+
assert exc_info.value.body == {"type": "error", "code": "server_error", "message": "Something went wrong"}
235+
236+
237+
@pytest.mark.asyncio
238+
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
239+
async def test_error_event_keeps_nested_error_message_fallback(
240+
sync: bool,
241+
client: OpenAI,
242+
async_client: AsyncOpenAI,
243+
) -> None:
244+
def body() -> Iterator[bytes]:
245+
yield b"event: error\n"
246+
yield b'data: {"error":{"type":"error","code":"nested_error","message":"Nested failure"}}\n'
247+
yield b"\n"
248+
249+
stream = make_stream(content=body(), sync=sync, client=client, async_client=async_client)
250+
251+
with pytest.raises(APIError, match="Nested failure") as exc_info:
252+
await iter_next(stream)
253+
254+
assert exc_info.value.body == {"type": "error", "code": "nested_error", "message": "Nested failure"}
255+
256+
219257
async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
220258
for chunk in iter:
221259
yield chunk
222260

223261

224-
async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent:
262+
async def iter_next(iter: Iterator[_ItemT] | AsyncIterator[_ItemT]) -> _ItemT:
225263
if isinstance(iter, AsyncIterator):
226264
return await iter.__anext__()
227265

228266
return next(iter)
229267

230268

231-
async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None:
269+
async def assert_empty_iter(iter: Iterator[object] | AsyncIterator[object]) -> None:
232270
with pytest.raises((StopAsyncIteration, RuntimeError)):
233271
await iter_next(iter)
234272

235273

274+
def make_stream(
275+
content: Iterator[bytes],
276+
*,
277+
sync: bool,
278+
client: OpenAI,
279+
async_client: AsyncOpenAI,
280+
) -> Stream[object] | AsyncStream[object]:
281+
request = httpx.Request("GET", "https://example.com/stream")
282+
283+
if sync:
284+
return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content, request=request))
285+
286+
return AsyncStream(
287+
cast_to=object,
288+
client=async_client,
289+
response=httpx.Response(200, content=to_aiter(content), request=request),
290+
)
291+
292+
236293
def make_event_iterator(
237294
content: Iterator[bytes],
238295
*,

0 commit comments

Comments
 (0)