Skip to content

Commit 0e0b3b3

Browse files
committed
fix(oci): preserve v1 stream end finish reason
1 parent ffa62a1 commit 0e0b3b3

4 files changed

Lines changed: 45 additions & 29 deletions

File tree

src/cohere/aws_client.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import typing
55

66
import httpx
7-
from httpx import URL, SyncByteStream, ByteStream
7+
from httpx import URL, ByteStream
88

99
from . import GenerateStreamedResponse, Generation, \
1010
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \
1111
ApiMetaBilledUnits
1212
from .client import Client, ClientEnvironment
1313
from .core import construct_type
1414
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
15+
from .manually_maintained.streaming import Streamer
1516
from .client_v2 import ClientV2
1617

1718
class AwsClient(Client):
@@ -112,16 +113,6 @@ def get_event_hooks(
112113
})
113114

114115

115-
class Streamer(SyncByteStream):
116-
lines: typing.Iterator[bytes]
117-
118-
def __init__(self, lines: typing.Iterator[bytes]):
119-
self.lines = lines
120-
121-
def __iter__(self) -> typing.Iterator[bytes]:
122-
return self.lines
123-
124-
125116
response_mapping: typing.Dict[str, typing.Any] = {
126117
"chat": NonStreamedChatResponse,
127118
"embed": EmbedResponse,
@@ -291,4 +282,4 @@ def get_api_version(*, version: str):
291282
"v2": 2,
292283
}
293284

294-
return int_version.get(version, 1)
285+
return int_version.get(version, 1)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import typing
2+
3+
from httpx import SyncByteStream
4+
5+
6+
class Streamer(SyncByteStream):
7+
"""Wrap an iterator of bytes for httpx streaming responses."""
8+
9+
lines: typing.Iterator[bytes]
10+
11+
def __init__(self, lines: typing.Iterator[bytes]):
12+
self.lines = lines
13+
14+
def __iter__(self) -> typing.Iterator[bytes]:
15+
return self.lines

src/cohere/oci_client.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from .client import Client, ClientEnvironment
1313
from .client_v2 import ClientV2
1414
from .manually_maintained.lazy_oci_deps import lazy_oci
15-
from httpx import URL, ByteStream, SyncByteStream
15+
from .manually_maintained.streaming import Streamer
16+
from httpx import URL, ByteStream
1617

1718

1819
class OciClient(Client):
@@ -232,20 +233,6 @@ def __init__(
232233
EventHook = typing.Callable[..., typing.Any]
233234

234235

235-
236-
237-
class Streamer(SyncByteStream):
238-
"""Wraps an iterator of bytes for streaming responses."""
239-
240-
lines: typing.Iterator[bytes]
241-
242-
def __init__(self, lines: typing.Iterator[bytes]):
243-
self.lines = lines
244-
245-
def __iter__(self) -> typing.Iterator[bytes]:
246-
return self.lines
247-
248-
249236
def _load_oci_config(
250237
auth_type: str,
251238
config_path: typing.Optional[str],
@@ -1006,6 +993,7 @@ def transform_oci_stream_wrapper(
1006993
final_finish_reason = "COMPLETE"
1007994
final_usage: typing.Optional[typing.Dict[str, typing.Any]] = None
1008995
full_v1_text = ""
996+
final_v1_finish_reason = "COMPLETE"
1009997
buffer = b""
1010998

1011999
def _emit_v2_event(event: typing.Dict[str, typing.Any]) -> bytes:
@@ -1053,10 +1041,12 @@ def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
10531041
yield _emit_v2_event(cohere_event)
10541042

10551043
def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> bytes:
1056-
nonlocal full_v1_text
1044+
nonlocal full_v1_text, final_v1_finish_reason
10571045
event = typing.cast(typing.Dict[str, typing.Any], transform_stream_event(endpoint, oci_event, is_v2=False))
10581046
if event.get("event_type") == "text-generation" and event.get("text"):
10591047
full_v1_text += typing.cast(str, event["text"])
1048+
if "finishReason" in oci_event:
1049+
final_v1_finish_reason = oci_event.get("finishReason", final_v1_finish_reason)
10601050
return _emit_v1_event(event)
10611051

10621052
def _process_line(line: str) -> typing.Iterator[bytes]:
@@ -1083,7 +1073,7 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
10831073
"response": {
10841074
"text": full_v1_text,
10851075
"generation_id": generation_id,
1086-
"finish_reason": final_finish_reason,
1076+
"finish_reason": final_v1_finish_reason,
10871077
},
10881078
}
10891079
)

tests/test_oci_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,26 @@ def test_stream_wrapper_skips_malformed_json_with_warning(self):
10391039
# Should get message-start + content-start + content-delta + content-end + message-end.
10401040
self.assertEqual(len(events), 5)
10411041

1042+
def test_v1_stream_wrapper_preserves_finish_reason_in_stream_end(self):
1043+
"""Test that V1 stream-end uses the OCI finish reason from the final event."""
1044+
import json
1045+
from cohere.oci_client import transform_oci_stream_wrapper
1046+
1047+
chunks = [
1048+
b'data: {"text": "Hello", "isFinished": false}\n',
1049+
b'data: {"text": " world", "isFinished": true, "finishReason": "MAX_TOKENS"}\n',
1050+
b"data: [DONE]\n",
1051+
]
1052+
1053+
events = [
1054+
json.loads(raw.decode("utf-8"))
1055+
for raw in transform_oci_stream_wrapper(iter(chunks), "chat_stream", is_v2=False)
1056+
]
1057+
1058+
self.assertEqual(events[2]["event_type"], "stream-end")
1059+
self.assertEqual(events[2]["response"]["text"], "Hello world")
1060+
self.assertEqual(events[2]["response"]["finish_reason"], "MAX_TOKENS")
1061+
10421062
def test_stream_wrapper_raises_on_transform_error(self):
10431063
"""Test that transform errors in stream produce OCI-specific error, not opaque httpx error."""
10441064
from cohere.oci_client import transform_oci_stream_wrapper

0 commit comments

Comments
 (0)