Skip to content

Commit bd21958

Browse files
committed
feat: Add thinking parameter support for Command A Reasoning models
Support the thinking/reasoning feature for command-a-reasoning-08-2025 on OCI. Transforms Cohere's thinking parameter (type, token_budget) to OCI format and handles thinking content in both non-streaming and streaming responses.
1 parent 951bba7 commit bd21958

2 files changed

Lines changed: 188 additions & 9 deletions

File tree

src/cohere/oci_client.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,16 @@ def transform_request_to_oci(
684684
chat_request["citationOptions"] = cohere_body["citation_options"]
685685
if "safety_mode" in cohere_body:
686686
chat_request["safetyMode"] = cohere_body["safety_mode"]
687+
# Thinking parameter for Command A Reasoning models
688+
if "thinking" in cohere_body:
689+
thinking = cohere_body["thinking"]
690+
oci_thinking: typing.Dict[str, typing.Any] = {}
691+
if "type" in thinking:
692+
oci_thinking["type"] = thinking["type"].upper()
693+
if "token_budget" in thinking and thinking["token_budget"] is not None:
694+
oci_thinking["token_budget"] = thinking["token_budget"]
695+
if oci_thinking:
696+
chat_request["thinking"] = oci_thinking
687697
else:
688698
# V1 API: uses single message string
689699
chat_request["message"] = cohere_body["message"]
@@ -830,9 +840,23 @@ def transform_oci_response_to_cohere(
830840
"output_tokens": usage_data.get("completionTokens", 0),
831841
}
832842

843+
# Transform message content types from OCI (uppercase) to Cohere (lowercase)
844+
message = chat_response.get("message", {})
845+
if "content" in message and isinstance(message["content"], list):
846+
transformed_content = []
847+
for item in message["content"]:
848+
if isinstance(item, dict):
849+
transformed_item = item.copy()
850+
if "type" in transformed_item:
851+
transformed_item["type"] = transformed_item["type"].lower()
852+
transformed_content.append(transformed_item)
853+
else:
854+
transformed_content.append(item)
855+
message = {**message, "content": transformed_content}
856+
833857
return {
834858
"id": chat_response.get("id", str(uuid.uuid4())),
835-
"message": chat_response.get("message", {}),
859+
"message": message,
836860
"finish_reason": chat_response.get("finishReason", "COMPLETE").lower(),
837861
"usage": usage,
838862
}
@@ -987,14 +1011,22 @@ def transform_stream_event(
9871011
if endpoint in ["chat_stream", "chat"]:
9881012
if is_v2:
9891013
# V2 API format: OCI returns full message structure in each event
990-
# Extract text from nested structure: message.content[0].text
991-
text = ""
1014+
# Extract content from nested structure: message.content[0]
1015+
content_type = "text"
1016+
content_value = ""
1017+
9921018
if "message" in oci_event and "content" in oci_event["message"]:
9931019
content_list = oci_event["message"]["content"]
9941020
if content_list and isinstance(content_list, list) and len(content_list) > 0:
9951021
first_content = content_list[0]
996-
if "text" in first_content:
997-
text = first_content["text"]
1022+
# Detect content type (TEXT or THINKING)
1023+
oci_type = first_content.get("type", "TEXT").upper()
1024+
if oci_type == "THINKING":
1025+
content_type = "thinking"
1026+
content_value = first_content.get("thinking", "")
1027+
else:
1028+
content_type = "text"
1029+
content_value = first_content.get("text", "")
9981030

9991031
is_finished = "finishReason" in oci_event
10001032

@@ -1005,15 +1037,19 @@ def transform_stream_event(
10051037
"index": 0,
10061038
}
10071039
else:
1008-
# Content delta event
1040+
# Content delta event - include type for thinking vs text
1041+
delta_content: typing.Dict[str, typing.Any] = {}
1042+
if content_type == "thinking":
1043+
delta_content["thinking"] = content_value
1044+
else:
1045+
delta_content["text"] = content_value
1046+
10091047
return {
10101048
"type": "content-delta",
10111049
"index": 0,
10121050
"delta": {
10131051
"message": {
1014-
"content": {
1015-
"text": text,
1016-
}
1052+
"content": delta_content,
10171053
}
10181054
},
10191055
}

tests/test_oci_client.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,46 @@ def test_chat_v2(self):
236236
self.assertIsNotNone(response)
237237
self.assertIsNotNone(response.message)
238238

239+
@unittest.skip(
240+
"Command A Reasoning model (command-a-reasoning-08-2025) may not be available in all regions. "
241+
"Enable this test when the reasoning model is available in your OCI region."
242+
)
243+
def test_chat_v2_with_thinking(self):
244+
"""Test chat with thinking parameter for Command A Reasoning model."""
245+
from cohere.types import Thinking
246+
247+
response = self.client.chat(
248+
model="command-a-reasoning-08-2025",
249+
messages=[{"role": "user", "content": "What is 15 * 27? Think step by step."}],
250+
thinking=Thinking(type="enabled", token_budget=5000),
251+
)
252+
253+
self.assertIsNotNone(response)
254+
self.assertIsNotNone(response.message)
255+
# The response should contain content (may include thinking content)
256+
self.assertIsNotNone(response.message.content)
257+
258+
@unittest.skip(
259+
"Command A Reasoning model (command-a-reasoning-08-2025) may not be available in all regions. "
260+
"Enable this test when the reasoning model is available in your OCI region."
261+
)
262+
def test_chat_stream_v2_with_thinking(self):
263+
"""Test streaming chat with thinking parameter for Command A Reasoning model."""
264+
from cohere.types import Thinking
265+
266+
events = []
267+
for event in self.client.chat_stream(
268+
model="command-a-reasoning-08-2025",
269+
messages=[{"role": "user", "content": "What is 15 * 27? Think step by step."}],
270+
thinking=Thinking(type="enabled", token_budget=5000),
271+
):
272+
events.append(event)
273+
274+
self.assertTrue(len(events) > 0)
275+
# Verify we received content-delta events
276+
content_delta_events = [e for e in events if hasattr(e, "type") and e.type == "content-delta"]
277+
self.assertTrue(len(content_delta_events) > 0)
278+
239279
def test_chat_stream_v2(self):
240280
"""Test streaming chat with v2 client."""
241281
events = []
@@ -455,5 +495,108 @@ def test_rerank_v3(self):
455495
self.assertIsNotNone(response.results)
456496

457497

498+
class TestOciClientTransformations(unittest.TestCase):
499+
"""Unit tests for OCI request/response transformations (no OCI credentials required)."""
500+
501+
def test_thinking_parameter_transformation(self):
502+
"""Test that thinking parameter is correctly transformed to OCI format."""
503+
from cohere.oci_client import transform_request_to_oci
504+
505+
cohere_body = {
506+
"model": "command-a-reasoning-08-2025",
507+
"messages": [{"role": "user", "content": "What is 2+2?"}],
508+
"thinking": {
509+
"type": "enabled",
510+
"token_budget": 10000,
511+
},
512+
}
513+
514+
result = transform_request_to_oci("chat", cohere_body, "compartment-123")
515+
516+
# Verify thinking parameter is transformed
517+
chat_request = result["chatRequest"]
518+
self.assertIn("thinking", chat_request)
519+
self.assertEqual(chat_request["thinking"]["type"], "ENABLED")
520+
self.assertEqual(chat_request["thinking"]["token_budget"], 10000)
521+
522+
def test_thinking_parameter_disabled(self):
523+
"""Test that disabled thinking is correctly transformed."""
524+
from cohere.oci_client import transform_request_to_oci
525+
526+
cohere_body = {
527+
"model": "command-a-reasoning-08-2025",
528+
"messages": [{"role": "user", "content": "Hello"}],
529+
"thinking": {
530+
"type": "disabled",
531+
},
532+
}
533+
534+
result = transform_request_to_oci("chat", cohere_body, "compartment-123")
535+
536+
chat_request = result["chatRequest"]
537+
self.assertIn("thinking", chat_request)
538+
self.assertEqual(chat_request["thinking"]["type"], "DISABLED")
539+
self.assertNotIn("token_budget", chat_request["thinking"])
540+
541+
def test_thinking_response_transformation(self):
542+
"""Test that thinking content in response is correctly transformed."""
543+
from cohere.oci_client import transform_oci_response_to_cohere
544+
545+
oci_response = {
546+
"chatResponse": {
547+
"id": "test-id",
548+
"message": {
549+
"role": "ASSISTANT",
550+
"content": [
551+
{"type": "THINKING", "thinking": "Let me think about this..."},
552+
{"type": "TEXT", "text": "The answer is 4."},
553+
],
554+
},
555+
"finishReason": "COMPLETE",
556+
"usage": {"inputTokens": 10, "completionTokens": 20},
557+
}
558+
}
559+
560+
result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True)
561+
562+
# Verify content types are lowercased
563+
self.assertEqual(result["message"]["content"][0]["type"], "thinking")
564+
self.assertEqual(result["message"]["content"][1]["type"], "text")
565+
566+
def test_stream_event_thinking_transformation(self):
567+
"""Test that thinking content in stream events is correctly transformed."""
568+
from cohere.oci_client import transform_stream_event
569+
570+
# OCI thinking event
571+
oci_event = {
572+
"message": {
573+
"content": [{"type": "THINKING", "thinking": "Reasoning step..."}]
574+
}
575+
}
576+
577+
result = transform_stream_event("chat", oci_event, is_v2=True)
578+
579+
self.assertEqual(result["type"], "content-delta")
580+
self.assertIn("thinking", result["delta"]["message"]["content"])
581+
self.assertEqual(result["delta"]["message"]["content"]["thinking"], "Reasoning step...")
582+
583+
def test_stream_event_text_transformation(self):
584+
"""Test that text content in stream events is correctly transformed."""
585+
from cohere.oci_client import transform_stream_event
586+
587+
# OCI text event
588+
oci_event = {
589+
"message": {
590+
"content": [{"type": "TEXT", "text": "The answer is..."}]
591+
}
592+
}
593+
594+
result = transform_stream_event("chat", oci_event, is_v2=True)
595+
596+
self.assertEqual(result["type"], "content-delta")
597+
self.assertIn("text", result["delta"]["message"]["content"])
598+
self.assertEqual(result["delta"]["message"]["content"]["text"], "The answer is...")
599+
600+
458601
if __name__ == "__main__":
459602
unittest.main()

0 commit comments

Comments
 (0)