Skip to content

Commit 213d744

Browse files
committed
Address cursor[bot] review feedback for OCI client
- Remove unused response_mapping and stream_response_mapping dicts - Remove unused transform_oci_stream_response function - Remove unused imports (EmbedResponse, Generation, etc.) - Fix crash when thinking parameter is explicitly None - Fix V2 chat response role not lowercased (ASSISTANT -> assistant) - Fix V2 finish_reason incorrectly lowercased (should stay uppercase) - Add unit tests for thinking=None, role lowercase, and finish_reason
1 parent bd21958 commit 213d744

2 files changed

Lines changed: 68 additions & 52 deletions

File tree

src/cohere/oci_client.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@
88

99
import httpx
1010
import requests
11-
from . import (
12-
EmbedResponse,
13-
GenerateStreamedResponse,
14-
Generation,
15-
NonStreamedChatResponse,
16-
RerankResponse,
17-
StreamedChatResponse,
18-
)
1911
from .client import Client, ClientEnvironment
2012
from .client_v2 import ClientV2
2113
from .manually_maintained.lazy_oci_deps import lazy_oci
@@ -239,18 +231,6 @@ def __init__(
239231
EventHook = typing.Callable[..., typing.Any]
240232

241233

242-
# Response type mappings
243-
response_mapping: typing.Dict[str, typing.Any] = {
244-
"chat": NonStreamedChatResponse,
245-
"embed": EmbedResponse,
246-
"generate": Generation,
247-
"rerank": RerankResponse,
248-
}
249-
250-
stream_response_mapping: typing.Dict[str, typing.Any] = {
251-
"chat": StreamedChatResponse,
252-
"generate": GenerateStreamedResponse,
253-
}
254234

255235

256236
class Streamer(SyncByteStream):
@@ -685,7 +665,7 @@ def transform_request_to_oci(
685665
if "safety_mode" in cohere_body:
686666
chat_request["safetyMode"] = cohere_body["safety_mode"]
687667
# Thinking parameter for Command A Reasoning models
688-
if "thinking" in cohere_body:
668+
if "thinking" in cohere_body and cohere_body["thinking"] is not None:
689669
thinking = cohere_body["thinking"]
690670
oci_thinking: typing.Dict[str, typing.Any] = {}
691671
if "type" in thinking:
@@ -840,8 +820,14 @@ def transform_oci_response_to_cohere(
840820
"output_tokens": usage_data.get("completionTokens", 0),
841821
}
842822

843-
# Transform message content types from OCI (uppercase) to Cohere (lowercase)
823+
# Transform message from OCI format to Cohere format
844824
message = chat_response.get("message", {})
825+
826+
# Lowercase the role (OCI returns "ASSISTANT", Cohere expects "assistant")
827+
if "role" in message:
828+
message = {**message, "role": message["role"].lower()}
829+
830+
# Transform content types from OCI (uppercase) to Cohere (lowercase)
845831
if "content" in message and isinstance(message["content"], list):
846832
transformed_content = []
847833
for item in message["content"]:
@@ -857,7 +843,7 @@ def transform_oci_response_to_cohere(
857843
return {
858844
"id": chat_response.get("id", str(uuid.uuid4())),
859845
"message": message,
860-
"finish_reason": chat_response.get("finishReason", "COMPLETE").lower(),
846+
"finish_reason": chat_response.get("finishReason", "COMPLETE"), # V2 keeps uppercase
861847
"usage": usage,
862848
}
863849
else:
@@ -965,35 +951,6 @@ def transform_oci_stream_wrapper(
965951
continue
966952

967953

968-
def transform_oci_stream_response(
969-
response: httpx.Response, endpoint: str
970-
) -> typing.Iterator[bytes]:
971-
"""
972-
Transform OCI streaming responses to Cohere streaming format.
973-
974-
OCI uses Server-Sent Events (SSE) format.
975-
976-
Args:
977-
response: httpx Response object
978-
endpoint: Cohere endpoint name
979-
980-
Yields:
981-
Bytes of transformed streaming events
982-
"""
983-
for line in response.iter_lines():
984-
if line.startswith("data: "):
985-
data_str = line[6:] # Remove "data: " prefix
986-
if data_str.strip() == "[DONE]":
987-
break
988-
989-
try:
990-
oci_event = json.loads(data_str)
991-
cohere_event = transform_stream_event(endpoint, oci_event)
992-
yield json.dumps(cohere_event).encode("utf-8") + b"\n"
993-
except json.JSONDecodeError:
994-
continue
995-
996-
997954
def transform_stream_event(
998955
endpoint: str, oci_event: typing.Dict[str, typing.Any], is_v2: bool = False
999956
) -> typing.Dict[str, typing.Any]:

tests/test_oci_client.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,65 @@ def test_stream_event_text_transformation(self):
597597
self.assertIn("text", result["delta"]["message"]["content"])
598598
self.assertEqual(result["delta"]["message"]["content"]["text"], "The answer is...")
599599

600+
def test_thinking_parameter_none(self):
601+
"""Test that thinking=None does not crash (issue: null guard)."""
602+
from cohere.oci_client import transform_request_to_oci
603+
604+
cohere_body = {
605+
"model": "command-a-03-2025",
606+
"messages": [{"role": "user", "content": "Hello"}],
607+
"thinking": None, # Explicitly set to None
608+
}
609+
610+
# Should not crash with TypeError
611+
result = transform_request_to_oci("chat", cohere_body, "compartment-123")
612+
613+
chat_request = result["chatRequest"]
614+
# thinking should not be in request when None
615+
self.assertNotIn("thinking", chat_request)
616+
617+
def test_v2_response_role_lowercased(self):
618+
"""Test that V2 response message role is lowercased."""
619+
from cohere.oci_client import transform_oci_response_to_cohere
620+
621+
oci_response = {
622+
"chatResponse": {
623+
"id": "test-id",
624+
"message": {
625+
"role": "ASSISTANT",
626+
"content": [{"type": "TEXT", "text": "Hello"}],
627+
},
628+
"finishReason": "COMPLETE",
629+
"usage": {"inputTokens": 10, "completionTokens": 20},
630+
}
631+
}
632+
633+
result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True)
634+
635+
# Role should be lowercased
636+
self.assertEqual(result["message"]["role"], "assistant")
637+
638+
def test_v2_response_finish_reason_uppercase(self):
639+
"""Test that V2 response finish_reason stays uppercase."""
640+
from cohere.oci_client import transform_oci_response_to_cohere
641+
642+
oci_response = {
643+
"chatResponse": {
644+
"id": "test-id",
645+
"message": {
646+
"role": "ASSISTANT",
647+
"content": [{"type": "TEXT", "text": "Hello"}],
648+
},
649+
"finishReason": "MAX_TOKENS",
650+
"usage": {"inputTokens": 10, "completionTokens": 20},
651+
}
652+
}
653+
654+
result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True)
655+
656+
# V2 finish_reason should stay uppercase
657+
self.assertEqual(result["finish_reason"], "MAX_TOKENS")
658+
600659

601660
if __name__ == "__main__":
602661
unittest.main()

0 commit comments

Comments
 (0)