Skip to content

Commit d4ff382

Browse files
committed
feat: Add V2 API support for OCI with Command A models
- Implemented automatic V1/V2 API detection based on request structure - Added V2 request transformation for messages format - Added V2 response transformation for Command A models - Removed hardcoded region-specific model OCIDs - Now uses display names (e.g., cohere.command-a-03-2025) that work across all OCI regions - V2 chat fully functional with command-a-03-2025 model - Updated tests to use command-a-03-2025 for V2 API testing Test Results: 14 PASSED, 8 SKIPPED, 0 FAILED
1 parent ad2bad1 commit d4ff382

2 files changed

Lines changed: 271 additions & 46 deletions

File tree

src/cohere/oci_client.py

Lines changed: 246 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def _event_hook(request: httpx.Request) -> None:
388388
request._content = oci_body_bytes
389389
request.extensions["endpoint"] = endpoint
390390
request.extensions["cohere_body"] = body
391+
request.extensions["is_stream"] = "stream" in endpoint or body.get("stream", False)
391392

392393
return _event_hook
393394

@@ -402,18 +403,31 @@ def map_response_from_oci() -> EventHook:
402403

403404
def _hook(response: httpx.Response) -> None:
404405
endpoint = response.request.extensions["endpoint"]
405-
is_stream = "stream" in endpoint
406+
is_stream = response.request.extensions.get("is_stream", False)
406407

407408
output: typing.Iterator[bytes]
408409

410+
# Only transform successful responses (200-299)
411+
# Let error responses pass through unchanged so SDK error handling works
412+
if not (200 <= response.status_code < 300):
413+
return
414+
415+
# For streaming responses, wrap the stream with a transformer
409416
if is_stream:
410-
# Handle streaming responses
411-
output = transform_oci_stream_response(response, endpoint)
412-
else:
413-
# Handle non-streaming responses
414-
oci_response = json.loads(response.read())
415-
cohere_response = transform_oci_response_to_cohere(endpoint, oci_response)
416-
output = iter([json.dumps(cohere_response).encode("utf-8")])
417+
original_stream = response.stream
418+
transformed_stream = transform_oci_stream_wrapper(original_stream, endpoint)
419+
response.stream = Streamer(transformed_stream)
420+
# Reset consumption flags
421+
if hasattr(response, "_content"):
422+
del response._content
423+
response.is_stream_consumed = False
424+
response.is_closed = False
425+
return
426+
427+
# Handle non-streaming responses
428+
oci_response = json.loads(response.read())
429+
cohere_response = transform_oci_response_to_cohere(endpoint, oci_response)
430+
output = iter([json.dumps(cohere_response).encode("utf-8")])
417431

418432
response.stream = Streamer(output)
419433

@@ -452,13 +466,45 @@ def get_oci_url(
452466
"chat_stream": "chat",
453467
"generate": "generateText",
454468
"generate_stream": "generateText",
455-
"rerank": "rerank",
469+
"rerank": "rerankText", # OCI uses rerankText, not rerank
456470
}
457471

458472
action = action_map.get(endpoint, endpoint)
459473
return f"{base}/{api_version}/actions/{action}"
460474

461475

476+
def normalize_model_for_oci(model: str) -> str:
477+
"""
478+
Normalize model name for OCI.
479+
480+
OCI accepts model names in the format "cohere.model-name" or full OCIDs.
481+
This function ensures proper formatting for all regions.
482+
483+
Args:
484+
model: Model name (e.g., "command-r-08-2024") or full OCID
485+
486+
Returns:
487+
Normalized model identifier (e.g., "cohere.command-r-08-2024" or OCID)
488+
489+
Examples:
490+
>>> normalize_model_for_oci("command-a-03-2025")
491+
"cohere.command-a-03-2025"
492+
>>> normalize_model_for_oci("cohere.embed-english-v3.0")
493+
"cohere.embed-english-v3.0"
494+
>>> normalize_model_for_oci("ocid1.generativeaimodel.oc1...")
495+
"ocid1.generativeaimodel.oc1..."
496+
"""
497+
# If it's already an OCID, return as-is (works across all regions)
498+
if model.startswith("ocid1."):
499+
return model
500+
501+
# Add "cohere." prefix if not present
502+
if not model.startswith("cohere."):
503+
return f"cohere.{model}"
504+
505+
return model
506+
507+
462508
def transform_request_to_oci(
463509
endpoint: str,
464510
cohere_body: typing.Dict[str, typing.Any],
@@ -475,9 +521,7 @@ def transform_request_to_oci(
475521
Returns:
476522
Transformed request body in OCI format
477523
"""
478-
model = cohere_body.get("model", "")
479-
if not model.startswith("cohere."):
480-
model = f"cohere.{model}"
524+
model = normalize_model_for_oci(cohere_body.get("model", ""))
481525

482526
if endpoint == "embed":
483527
# Transform Cohere input_type to OCI format
@@ -506,21 +550,96 @@ def transform_request_to_oci(
506550
return oci_body
507551

508552
elif endpoint in ["chat", "chat_stream"]:
553+
# Detect V1 vs V2 API based on request body structure
554+
is_v2 = "messages" in cohere_body # V2 uses messages array
555+
556+
# OCI uses a nested chatRequest structure
557+
chat_request = {
558+
"apiFormat": "COHEREV2" if is_v2 else "COHERE",
559+
}
560+
561+
if is_v2:
562+
# V2 API: uses messages array
563+
# Transform Cohere V2 messages to OCI V2 format
564+
# Cohere sends: [{"role": "user", "content": "text"}]
565+
# OCI expects: [{"role": "USER", "content": [{"type": "TEXT", "text": "..."}]}]
566+
oci_messages = []
567+
for msg in cohere_body["messages"]:
568+
oci_msg = {
569+
"role": msg["role"].upper(),
570+
}
571+
572+
# Transform content
573+
if isinstance(msg.get("content"), str):
574+
# Simple string content -> wrap in array
575+
oci_msg["content"] = [{"type": "TEXT", "text": msg["content"]}]
576+
elif isinstance(msg.get("content"), list):
577+
# Already array format (from tool calls, etc.)
578+
oci_msg["content"] = msg["content"]
579+
else:
580+
oci_msg["content"] = msg.get("content", [])
581+
582+
# Add tool_calls if present
583+
if "tool_calls" in msg:
584+
oci_msg["toolCalls"] = msg["tool_calls"]
585+
586+
oci_messages.append(oci_msg)
587+
588+
chat_request["messages"] = oci_messages
589+
590+
# V2 optional parameters (use Cohere's camelCase names for OCI)
591+
if "max_tokens" in cohere_body:
592+
chat_request["maxTokens"] = cohere_body["max_tokens"]
593+
if "temperature" in cohere_body:
594+
chat_request["temperature"] = cohere_body["temperature"]
595+
if "k" in cohere_body:
596+
chat_request["topK"] = cohere_body["k"]
597+
if "p" in cohere_body:
598+
chat_request["topP"] = cohere_body["p"]
599+
if "seed" in cohere_body:
600+
chat_request["seed"] = cohere_body["seed"]
601+
if "frequency_penalty" in cohere_body:
602+
chat_request["frequencyPenalty"] = cohere_body["frequency_penalty"]
603+
if "presence_penalty" in cohere_body:
604+
chat_request["presencePenalty"] = cohere_body["presence_penalty"]
605+
if "stop_sequences" in cohere_body:
606+
chat_request["stopSequences"] = cohere_body["stop_sequences"]
607+
if "tools" in cohere_body:
608+
chat_request["tools"] = cohere_body["tools"]
609+
if "documents" in cohere_body:
610+
chat_request["documents"] = cohere_body["documents"]
611+
if "citation_options" in cohere_body:
612+
chat_request["citationOptions"] = cohere_body["citation_options"]
613+
if "safety_mode" in cohere_body:
614+
chat_request["safetyMode"] = cohere_body["safety_mode"]
615+
else:
616+
# V1 API: uses single message string
617+
chat_request["message"] = cohere_body["message"]
618+
619+
# V1 optional parameters
620+
if "temperature" in cohere_body:
621+
chat_request["temperature"] = cohere_body["temperature"]
622+
if "max_tokens" in cohere_body:
623+
chat_request["maxTokens"] = cohere_body["max_tokens"]
624+
if "preamble" in cohere_body:
625+
chat_request["preambleOverride"] = cohere_body["preamble"]
626+
if "chat_history" in cohere_body:
627+
chat_request["chatHistory"] = cohere_body["chat_history"]
628+
629+
# Handle streaming for both versions
630+
if "stream" in endpoint or cohere_body.get("stream"):
631+
chat_request["isStream"] = True
632+
633+
# Top level OCI request structure
509634
oci_body = {
510-
"message": cohere_body["message"],
511635
"servingMode": {
512636
"servingType": "ON_DEMAND",
513637
"modelId": model,
514638
},
515639
"compartmentId": compartment_id,
516-
"isStream": endpoint == "chat_stream" or cohere_body.get("stream", False),
640+
"chatRequest": chat_request,
517641
}
518-
if "chat_history" in cohere_body:
519-
oci_body["chatHistory"] = cohere_body["chat_history"]
520-
if "temperature" in cohere_body:
521-
oci_body["temperature"] = cohere_body["temperature"]
522-
if "max_tokens" in cohere_body:
523-
oci_body["maxTokens"] = cohere_body["max_tokens"]
642+
524643
return oci_body
525644

526645
elif endpoint in ["generate", "generate_stream"]:
@@ -540,17 +659,24 @@ def transform_request_to_oci(
540659
return oci_body
541660

542661
elif endpoint == "rerank":
662+
# OCI rerank uses a flat structure (not nested like chat)
663+
# and "input" instead of "query"
543664
oci_body = {
544-
"query": cohere_body["query"],
545-
"documents": cohere_body["documents"],
546665
"servingMode": {
547666
"servingType": "ON_DEMAND",
548667
"modelId": model,
549668
},
550669
"compartmentId": compartment_id,
670+
"input": cohere_body["query"], # OCI uses "input" not "query"
671+
"documents": cohere_body["documents"],
551672
}
673+
674+
# Add optional rerank parameters
552675
if "top_n" in cohere_body:
553676
oci_body["topN"] = cohere_body["top_n"]
677+
if "max_chunks_per_doc" in cohere_body:
678+
oci_body["maxChunksPerDocument"] = cohere_body["max_chunks_per_doc"]
679+
554680
return oci_body
555681

556682
return cohere_body
@@ -603,14 +729,66 @@ def transform_oci_response_to_cohere(
603729
"meta": meta,
604730
}
605731

606-
elif endpoint == "chat":
607-
return {
608-
"text": oci_response.get("chatResponse", {}).get("text", ""),
609-
"generation_id": str(uuid.uuid4()),
610-
"chat_history": [],
611-
"finish_reason": oci_response.get("finishReason", "COMPLETE"),
612-
"meta": {"api_version": {"version": "1"}},
613-
}
732+
elif endpoint == "chat" or endpoint == "chat_stream":
733+
chat_response = oci_response.get("chatResponse", {})
734+
735+
# Detect V2 response (has apiFormat field)
736+
is_v2 = chat_response.get("apiFormat") == "COHEREV2"
737+
738+
if is_v2:
739+
# V2 response transformation
740+
# Extract usage for V2
741+
usage_data = chat_response.get("usage", {})
742+
usage = {
743+
"tokens": {
744+
"input_tokens": usage_data.get("inputTokens", 0),
745+
"output_tokens": usage_data.get("completionTokens", 0),
746+
},
747+
}
748+
if usage_data.get("inputTokens") or usage_data.get("completionTokens"):
749+
usage["billed_units"] = {
750+
"input_tokens": usage_data.get("inputTokens", 0),
751+
"output_tokens": usage_data.get("completionTokens", 0),
752+
}
753+
754+
return {
755+
"id": chat_response.get("id", str(uuid.uuid4())),
756+
"message": chat_response.get("message", {}),
757+
"finish_reason": chat_response.get("finishReason", "COMPLETE").lower(),
758+
"usage": usage,
759+
}
760+
else:
761+
# V1 response transformation
762+
# Build proper meta structure
763+
meta = {
764+
"api_version": {"version": "1"},
765+
}
766+
767+
# Add usage info if available
768+
if "usage" in chat_response and chat_response["usage"]:
769+
usage = chat_response["usage"]
770+
input_tokens = usage.get("inputTokens", 0)
771+
output_tokens = usage.get("outputTokens", 0)
772+
773+
meta["billed_units"] = {
774+
"input_tokens": input_tokens,
775+
"output_tokens": output_tokens,
776+
}
777+
meta["tokens"] = {
778+
"input_tokens": input_tokens,
779+
"output_tokens": output_tokens,
780+
}
781+
782+
return {
783+
"text": chat_response.get("text", ""),
784+
"generation_id": oci_response.get("modelId", str(uuid.uuid4())),
785+
"chat_history": chat_response.get("chatHistory", []),
786+
"finish_reason": chat_response.get("finishReason", "COMPLETE"),
787+
"citations": chat_response.get("citations", []),
788+
"documents": chat_response.get("documents", []),
789+
"search_queries": chat_response.get("searchQueries", []),
790+
"meta": meta,
791+
}
614792

615793
elif endpoint == "generate":
616794
return {
@@ -627,22 +805,57 @@ def transform_oci_response_to_cohere(
627805
}
628806

629807
elif endpoint == "rerank":
630-
results = oci_response.get("results", [])
808+
# OCI returns flat structure with document_ranks
809+
document_ranks = oci_response.get("documentRanks", [])
810+
631811
return {
632-
"id": str(uuid.uuid4()),
812+
"id": oci_response.get("id", str(uuid.uuid4())),
633813
"results": [
634814
{
635815
"index": r.get("index"),
636816
"relevance_score": r.get("relevanceScore"),
637817
}
638-
for r in results
818+
for r in document_ranks
639819
],
640820
"meta": {"api_version": {"version": "1"}},
641821
}
642822

643823
return oci_response
644824

645825

826+
def transform_oci_stream_wrapper(
827+
stream: typing.Iterator[bytes], endpoint: str
828+
) -> typing.Iterator[bytes]:
829+
"""
830+
Wrap OCI stream and transform events to Cohere format.
831+
832+
Args:
833+
stream: Original OCI stream iterator
834+
endpoint: Cohere endpoint name
835+
836+
Yields:
837+
Bytes of transformed streaming events
838+
"""
839+
buffer = b""
840+
for chunk in stream:
841+
buffer += chunk
842+
while b"\n" in buffer:
843+
line_bytes, buffer = buffer.split(b"\n", 1)
844+
line = line_bytes.decode("utf-8").strip()
845+
846+
if line.startswith("data: "):
847+
data_str = line[6:] # Remove "data: " prefix
848+
if data_str.strip() == "[DONE]":
849+
break
850+
851+
try:
852+
oci_event = json.loads(data_str)
853+
cohere_event = transform_stream_event(endpoint, oci_event)
854+
yield json.dumps(cohere_event).encode("utf-8") + b"\n"
855+
except json.JSONDecodeError:
856+
continue
857+
858+
646859
def transform_oci_stream_response(
647860
response: httpx.Response, endpoint: str
648861
) -> typing.Iterator[bytes]:

0 commit comments

Comments
 (0)