Skip to content

Commit 9f1e924

Browse files
committed
fix(oci): resolve remaining stream and usage edge cases
1 parent 8bc5f5d commit 9f1e924

2 files changed

Lines changed: 56 additions & 42 deletions

File tree

src/cohere/oci_client.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,30 @@ def _remove_inherited_session_auth(
305305
if profile_name == "DEFAULT" or "security_token_file" not in oci_config:
306306
return
307307

308+
config_file = os.path.expanduser(config_path or "~/.oci/config")
308309
parser = configparser.ConfigParser(interpolation=None)
309-
if not parser.read(os.path.expanduser(config_path or "~/.oci/config")):
310+
if not parser.read(config_file):
310311
return
311312

312-
explicit_profile = parser._sections.get(profile_name, {})
313-
if "security_token_file" not in explicit_profile:
313+
if not parser.has_section(profile_name):
314+
oci_config.pop("security_token_file", None)
315+
return
316+
317+
explicit_security_token = False
318+
current_section: typing.Optional[str] = None
319+
with open(config_file, encoding="utf-8") as handle:
320+
for raw_line in handle:
321+
line = raw_line.strip()
322+
if not line or line.startswith(("#", ";")):
323+
continue
324+
if line.startswith("[") and line.endswith("]"):
325+
current_section = line[1:-1].strip()
326+
continue
327+
if current_section == profile_name and line.split("=", 1)[0].strip() == "security_token_file":
328+
explicit_security_token = True
329+
break
330+
331+
if not explicit_security_token:
314332
oci_config.pop("security_token_file", None)
315333

316334

@@ -469,7 +487,6 @@ def _event_hook(request: httpx.Request) -> None:
469487
request.stream = ByteStream(oci_body_bytes)
470488
request._content = oci_body_bytes
471489
request.extensions["endpoint"] = endpoint
472-
request.extensions["cohere_body"] = body
473490
request.extensions["is_stream"] = "stream" in endpoint or body.get("stream", False)
474491
# Store V2 detection for streaming event transformation
475492
# For chat, detect V2 by presence of "messages" field (V2) vs "message" field (V1)
@@ -853,20 +870,11 @@ def transform_oci_response_to_cohere(
853870
}
854871

855872
# Add usage info if available
856-
if "usage" in oci_response and oci_response["usage"]:
857-
usage = oci_response["usage"]
858-
# OCI usage has inputTokens, outputTokens, totalTokens
859-
input_tokens = usage.get("inputTokens", 0)
860-
output_tokens = usage.get("outputTokens", 0)
861-
862-
meta["billed_units"] = {
863-
"input_tokens": input_tokens,
864-
"output_tokens": output_tokens,
865-
}
866-
meta["tokens"] = {
867-
"input_tokens": input_tokens,
868-
"output_tokens": output_tokens,
869-
}
873+
usage = _usage_from_oci(oci_response.get("usage"))
874+
if "tokens" in usage:
875+
meta["tokens"] = usage["tokens"]
876+
if "billed_units" in usage:
877+
meta["billed_units"] = usage["billed_units"]
870878

871879
return {
872880
"id": oci_response.get("id", str(uuid.uuid4())),
@@ -916,19 +924,11 @@ def transform_oci_response_to_cohere(
916924
"api_version": {"version": "1"},
917925
}
918926

919-
if "usage" in chat_response and chat_response["usage"]:
920-
usage = chat_response["usage"]
921-
input_tokens = usage.get("inputTokens", 0)
922-
output_tokens = usage.get("outputTokens", 0)
923-
924-
meta["billed_units"] = {
925-
"input_tokens": input_tokens,
926-
"output_tokens": output_tokens,
927-
}
928-
meta["tokens"] = {
929-
"input_tokens": input_tokens,
930-
"output_tokens": output_tokens,
931-
}
927+
usage = _usage_from_oci(chat_response.get("usage"))
928+
if "tokens" in usage:
929+
meta["tokens"] = usage["tokens"]
930+
if "billed_units" in usage:
931+
meta["billed_units"] = usage["billed_units"]
932932

933933
return {
934934
"text": chat_response.get("text", ""),
@@ -1056,16 +1056,17 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
10561056
data_str = line[6:]
10571057
if data_str.strip() == "[DONE]":
10581058
if is_v2:
1059-
if emitted_start and not emitted_content_end:
1060-
yield _emit_v2_event({"type": "content-end", "index": 0})
1061-
message_end_event: typing.Dict[str, typing.Any] = {
1062-
"type": "message-end",
1063-
"id": generation_id,
1064-
"delta": {"finish_reason": final_finish_reason},
1065-
}
1066-
if final_usage:
1067-
message_end_event["delta"]["usage"] = final_usage
1068-
yield _emit_v2_event(message_end_event)
1059+
if emitted_start:
1060+
if not emitted_content_end:
1061+
yield _emit_v2_event({"type": "content-end", "index": 0})
1062+
message_end_event: typing.Dict[str, typing.Any] = {
1063+
"type": "message-end",
1064+
"id": generation_id,
1065+
"delta": {"finish_reason": final_finish_reason},
1066+
}
1067+
if final_usage:
1068+
message_end_event["delta"]["usage"] = final_usage
1069+
yield _emit_v2_event(message_end_event)
10691070
else:
10701071
yield _emit_v1_event(
10711072
{

tests/test_oci_client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,13 +966,18 @@ def test_embed_response_lowercases_embedding_keys(self):
966966

967967
result = transform_oci_response_to_cohere(
968968
"embed",
969-
{"id": "embed-id", "embeddings": {"FLOAT": [[0.1, 0.2]], "INT8": [[1, 2]]}},
969+
{
970+
"id": "embed-id",
971+
"embeddings": {"FLOAT": [[0.1, 0.2]], "INT8": [[1, 2]]},
972+
"usage": {"inputTokens": 3, "completionTokens": 7},
973+
},
970974
is_v2=True,
971975
)
972976

973977
self.assertIn("float", result["embeddings"])
974978
self.assertIn("int8", result["embeddings"])
975979
self.assertNotIn("FLOAT", result["embeddings"])
980+
self.assertEqual(result["meta"]["tokens"]["output_tokens"], 7)
976981

977982
def test_normalize_model_for_oci_rejects_empty_model(self):
978983
"""Test model normalization fails clearly for empty model names."""
@@ -1050,6 +1055,14 @@ def test_stream_wrapper_skips_malformed_json_with_warning(self):
10501055
# Should get message-start + content-start + content-delta + content-end + message-end.
10511056
self.assertEqual(len(events), 5)
10521057

1058+
def test_stream_wrapper_skips_message_end_for_empty_v2_stream(self):
1059+
"""Test empty V2 streams do not emit message-end without a preceding message-start."""
1060+
from cohere.oci_client import transform_oci_stream_wrapper
1061+
1062+
events = list(transform_oci_stream_wrapper(iter([b"data: [DONE]\n"]), "chat", is_v2=True))
1063+
1064+
self.assertEqual(events, [])
1065+
10531066
def test_v1_stream_wrapper_preserves_finish_reason_in_stream_end(self):
10541067
"""Test that V1 stream-end uses the OCI finish reason from the final event."""
10551068
import json

0 commit comments

Comments
 (0)