@@ -388,6 +388,7 @@ def _event_hook(request: httpx.Request) -> None:
388388 request ._content = oci_body_bytes
389389 request .extensions ["endpoint" ] = endpoint
390390 request .extensions ["cohere_body" ] = body
391+ request .extensions ["is_stream" ] = "stream" in endpoint or body .get ("stream" , False )
391392
392393 return _event_hook
393394
@@ -402,18 +403,31 @@ def map_response_from_oci() -> EventHook:
402403
403404 def _hook (response : httpx .Response ) -> None :
404405 endpoint = response .request .extensions ["endpoint" ]
405- is_stream = "stream" in endpoint
406+ is_stream = response . request . extensions . get ( "is_stream" , False )
406407
407408 output : typing .Iterator [bytes ]
408409
410+ # Only transform successful responses (200-299)
411+ # Let error responses pass through unchanged so SDK error handling works
412+ if not (200 <= response .status_code < 300 ):
413+ return
414+
415+ # For streaming responses, wrap the stream with a transformer
409416 if is_stream :
410- # Handle streaming responses
411- output = transform_oci_stream_response (response , endpoint )
412- else :
413- # Handle non-streaming responses
414- oci_response = json .loads (response .read ())
415- cohere_response = transform_oci_response_to_cohere (endpoint , oci_response )
416- output = iter ([json .dumps (cohere_response ).encode ("utf-8" )])
417+ original_stream = response .stream
418+ transformed_stream = transform_oci_stream_wrapper (original_stream , endpoint )
419+ response .stream = Streamer (transformed_stream )
420+ # Reset consumption flags
421+ if hasattr (response , "_content" ):
422+ del response ._content
423+ response .is_stream_consumed = False
424+ response .is_closed = False
425+ return
426+
427+ # Handle non-streaming responses
428+ oci_response = json .loads (response .read ())
429+ cohere_response = transform_oci_response_to_cohere (endpoint , oci_response )
430+ output = iter ([json .dumps (cohere_response ).encode ("utf-8" )])
417431
418432 response .stream = Streamer (output )
419433
@@ -452,13 +466,45 @@ def get_oci_url(
452466 "chat_stream" : "chat" ,
453467 "generate" : "generateText" ,
454468 "generate_stream" : "generateText" ,
455- "rerank" : "rerank" ,
469+ "rerank" : "rerankText" , # OCI uses rerankText, not rerank
456470 }
457471
458472 action = action_map .get (endpoint , endpoint )
459473 return f"{ base } /{ api_version } /actions/{ action } "
460474
461475
476+ def normalize_model_for_oci (model : str ) -> str :
477+ """
478+ Normalize model name for OCI.
479+
480+ OCI accepts model names in the format "cohere.model-name" or full OCIDs.
481+ This function ensures proper formatting for all regions.
482+
483+ Args:
484+ model: Model name (e.g., "command-r-08-2024") or full OCID
485+
486+ Returns:
487+ Normalized model identifier (e.g., "cohere.command-r-08-2024" or OCID)
488+
489+ Examples:
490+ >>> normalize_model_for_oci("command-a-03-2025")
491+ "cohere.command-a-03-2025"
492+ >>> normalize_model_for_oci("cohere.embed-english-v3.0")
493+ "cohere.embed-english-v3.0"
494+ >>> normalize_model_for_oci("ocid1.generativeaimodel.oc1...")
495+ "ocid1.generativeaimodel.oc1..."
496+ """
497+ # If it's already an OCID, return as-is (works across all regions)
498+ if model .startswith ("ocid1." ):
499+ return model
500+
501+ # Add "cohere." prefix if not present
502+ if not model .startswith ("cohere." ):
503+ return f"cohere.{ model } "
504+
505+ return model
506+
507+
462508def transform_request_to_oci (
463509 endpoint : str ,
464510 cohere_body : typing .Dict [str , typing .Any ],
@@ -475,9 +521,7 @@ def transform_request_to_oci(
475521 Returns:
476522 Transformed request body in OCI format
477523 """
478- model = cohere_body .get ("model" , "" )
479- if not model .startswith ("cohere." ):
480- model = f"cohere.{ model } "
524+ model = normalize_model_for_oci (cohere_body .get ("model" , "" ))
481525
482526 if endpoint == "embed" :
483527 # Transform Cohere input_type to OCI format
@@ -506,21 +550,96 @@ def transform_request_to_oci(
506550 return oci_body
507551
508552 elif endpoint in ["chat" , "chat_stream" ]:
553+ # Detect V1 vs V2 API based on request body structure
554+ is_v2 = "messages" in cohere_body # V2 uses messages array
555+
556+ # OCI uses a nested chatRequest structure
557+ chat_request = {
558+ "apiFormat" : "COHEREV2" if is_v2 else "COHERE" ,
559+ }
560+
561+ if is_v2 :
562+ # V2 API: uses messages array
563+ # Transform Cohere V2 messages to OCI V2 format
564+ # Cohere sends: [{"role": "user", "content": "text"}]
565+ # OCI expects: [{"role": "USER", "content": [{"type": "TEXT", "text": "..."}]}]
566+ oci_messages = []
567+ for msg in cohere_body ["messages" ]:
568+ oci_msg = {
569+ "role" : msg ["role" ].upper (),
570+ }
571+
572+ # Transform content
573+ if isinstance (msg .get ("content" ), str ):
574+ # Simple string content -> wrap in array
575+ oci_msg ["content" ] = [{"type" : "TEXT" , "text" : msg ["content" ]}]
576+ elif isinstance (msg .get ("content" ), list ):
577+ # Already array format (from tool calls, etc.)
578+ oci_msg ["content" ] = msg ["content" ]
579+ else :
580+ oci_msg ["content" ] = msg .get ("content" , [])
581+
582+ # Add tool_calls if present
583+ if "tool_calls" in msg :
584+ oci_msg ["toolCalls" ] = msg ["tool_calls" ]
585+
586+ oci_messages .append (oci_msg )
587+
588+ chat_request ["messages" ] = oci_messages
589+
590+ # V2 optional parameters (use Cohere's camelCase names for OCI)
591+ if "max_tokens" in cohere_body :
592+ chat_request ["maxTokens" ] = cohere_body ["max_tokens" ]
593+ if "temperature" in cohere_body :
594+ chat_request ["temperature" ] = cohere_body ["temperature" ]
595+ if "k" in cohere_body :
596+ chat_request ["topK" ] = cohere_body ["k" ]
597+ if "p" in cohere_body :
598+ chat_request ["topP" ] = cohere_body ["p" ]
599+ if "seed" in cohere_body :
600+ chat_request ["seed" ] = cohere_body ["seed" ]
601+ if "frequency_penalty" in cohere_body :
602+ chat_request ["frequencyPenalty" ] = cohere_body ["frequency_penalty" ]
603+ if "presence_penalty" in cohere_body :
604+ chat_request ["presencePenalty" ] = cohere_body ["presence_penalty" ]
605+ if "stop_sequences" in cohere_body :
606+ chat_request ["stopSequences" ] = cohere_body ["stop_sequences" ]
607+ if "tools" in cohere_body :
608+ chat_request ["tools" ] = cohere_body ["tools" ]
609+ if "documents" in cohere_body :
610+ chat_request ["documents" ] = cohere_body ["documents" ]
611+ if "citation_options" in cohere_body :
612+ chat_request ["citationOptions" ] = cohere_body ["citation_options" ]
613+ if "safety_mode" in cohere_body :
614+ chat_request ["safetyMode" ] = cohere_body ["safety_mode" ]
615+ else :
616+ # V1 API: uses single message string
617+ chat_request ["message" ] = cohere_body ["message" ]
618+
619+ # V1 optional parameters
620+ if "temperature" in cohere_body :
621+ chat_request ["temperature" ] = cohere_body ["temperature" ]
622+ if "max_tokens" in cohere_body :
623+ chat_request ["maxTokens" ] = cohere_body ["max_tokens" ]
624+ if "preamble" in cohere_body :
625+ chat_request ["preambleOverride" ] = cohere_body ["preamble" ]
626+ if "chat_history" in cohere_body :
627+ chat_request ["chatHistory" ] = cohere_body ["chat_history" ]
628+
629+ # Handle streaming for both versions
630+ if "stream" in endpoint or cohere_body .get ("stream" ):
631+ chat_request ["isStream" ] = True
632+
633+ # Top level OCI request structure
509634 oci_body = {
510- "message" : cohere_body ["message" ],
511635 "servingMode" : {
512636 "servingType" : "ON_DEMAND" ,
513637 "modelId" : model ,
514638 },
515639 "compartmentId" : compartment_id ,
516- "isStream " : endpoint == "chat_stream" or cohere_body . get ( "stream" , False ) ,
640+ "chatRequest " : chat_request ,
517641 }
518- if "chat_history" in cohere_body :
519- oci_body ["chatHistory" ] = cohere_body ["chat_history" ]
520- if "temperature" in cohere_body :
521- oci_body ["temperature" ] = cohere_body ["temperature" ]
522- if "max_tokens" in cohere_body :
523- oci_body ["maxTokens" ] = cohere_body ["max_tokens" ]
642+
524643 return oci_body
525644
526645 elif endpoint in ["generate" , "generate_stream" ]:
@@ -540,17 +659,24 @@ def transform_request_to_oci(
540659 return oci_body
541660
542661 elif endpoint == "rerank" :
662+ # OCI rerank uses a flat structure (not nested like chat)
663+ # and "input" instead of "query"
543664 oci_body = {
544- "query" : cohere_body ["query" ],
545- "documents" : cohere_body ["documents" ],
546665 "servingMode" : {
547666 "servingType" : "ON_DEMAND" ,
548667 "modelId" : model ,
549668 },
550669 "compartmentId" : compartment_id ,
670+ "input" : cohere_body ["query" ], # OCI uses "input" not "query"
671+ "documents" : cohere_body ["documents" ],
551672 }
673+
674+ # Add optional rerank parameters
552675 if "top_n" in cohere_body :
553676 oci_body ["topN" ] = cohere_body ["top_n" ]
677+ if "max_chunks_per_doc" in cohere_body :
678+ oci_body ["maxChunksPerDocument" ] = cohere_body ["max_chunks_per_doc" ]
679+
554680 return oci_body
555681
556682 return cohere_body
@@ -603,14 +729,66 @@ def transform_oci_response_to_cohere(
603729 "meta" : meta ,
604730 }
605731
606- elif endpoint == "chat" :
607- return {
608- "text" : oci_response .get ("chatResponse" , {}).get ("text" , "" ),
609- "generation_id" : str (uuid .uuid4 ()),
610- "chat_history" : [],
611- "finish_reason" : oci_response .get ("finishReason" , "COMPLETE" ),
612- "meta" : {"api_version" : {"version" : "1" }},
613- }
732+ elif endpoint == "chat" or endpoint == "chat_stream" :
733+ chat_response = oci_response .get ("chatResponse" , {})
734+
735+ # Detect V2 response (has apiFormat field)
736+ is_v2 = chat_response .get ("apiFormat" ) == "COHEREV2"
737+
738+ if is_v2 :
739+ # V2 response transformation
740+ # Extract usage for V2
741+ usage_data = chat_response .get ("usage" , {})
742+ usage = {
743+ "tokens" : {
744+ "input_tokens" : usage_data .get ("inputTokens" , 0 ),
745+ "output_tokens" : usage_data .get ("completionTokens" , 0 ),
746+ },
747+ }
748+ if usage_data .get ("inputTokens" ) or usage_data .get ("completionTokens" ):
749+ usage ["billed_units" ] = {
750+ "input_tokens" : usage_data .get ("inputTokens" , 0 ),
751+ "output_tokens" : usage_data .get ("completionTokens" , 0 ),
752+ }
753+
754+ return {
755+ "id" : chat_response .get ("id" , str (uuid .uuid4 ())),
756+ "message" : chat_response .get ("message" , {}),
757+ "finish_reason" : chat_response .get ("finishReason" , "COMPLETE" ).lower (),
758+ "usage" : usage ,
759+ }
760+ else :
761+ # V1 response transformation
762+ # Build proper meta structure
763+ meta = {
764+ "api_version" : {"version" : "1" },
765+ }
766+
767+ # Add usage info if available
768+ if "usage" in chat_response and chat_response ["usage" ]:
769+ usage = chat_response ["usage" ]
770+ input_tokens = usage .get ("inputTokens" , 0 )
771+ output_tokens = usage .get ("outputTokens" , 0 )
772+
773+ meta ["billed_units" ] = {
774+ "input_tokens" : input_tokens ,
775+ "output_tokens" : output_tokens ,
776+ }
777+ meta ["tokens" ] = {
778+ "input_tokens" : input_tokens ,
779+ "output_tokens" : output_tokens ,
780+ }
781+
782+ return {
783+ "text" : chat_response .get ("text" , "" ),
784+ "generation_id" : oci_response .get ("modelId" , str (uuid .uuid4 ())),
785+ "chat_history" : chat_response .get ("chatHistory" , []),
786+ "finish_reason" : chat_response .get ("finishReason" , "COMPLETE" ),
787+ "citations" : chat_response .get ("citations" , []),
788+ "documents" : chat_response .get ("documents" , []),
789+ "search_queries" : chat_response .get ("searchQueries" , []),
790+ "meta" : meta ,
791+ }
614792
615793 elif endpoint == "generate" :
616794 return {
@@ -627,22 +805,57 @@ def transform_oci_response_to_cohere(
627805 }
628806
629807 elif endpoint == "rerank" :
630- results = oci_response .get ("results" , [])
808+ # OCI returns flat structure with document_ranks
809+ document_ranks = oci_response .get ("documentRanks" , [])
810+
631811 return {
632- "id" : str (uuid .uuid4 ()),
812+ "id" : oci_response . get ( "id" , str (uuid .uuid4 () )),
633813 "results" : [
634814 {
635815 "index" : r .get ("index" ),
636816 "relevance_score" : r .get ("relevanceScore" ),
637817 }
638- for r in results
818+ for r in document_ranks
639819 ],
640820 "meta" : {"api_version" : {"version" : "1" }},
641821 }
642822
643823 return oci_response
644824
645825
826+ def transform_oci_stream_wrapper (
827+ stream : typing .Iterator [bytes ], endpoint : str
828+ ) -> typing .Iterator [bytes ]:
829+ """
830+ Wrap OCI stream and transform events to Cohere format.
831+
832+ Args:
833+ stream: Original OCI stream iterator
834+ endpoint: Cohere endpoint name
835+
836+ Yields:
837+ Bytes of transformed streaming events
838+ """
839+ buffer = b""
840+ for chunk in stream :
841+ buffer += chunk
842+ while b"\n " in buffer :
843+ line_bytes , buffer = buffer .split (b"\n " , 1 )
844+ line = line_bytes .decode ("utf-8" ).strip ()
845+
846+ if line .startswith ("data: " ):
847+ data_str = line [6 :] # Remove "data: " prefix
848+ if data_str .strip () == "[DONE]" :
849+ break
850+
851+ try :
852+ oci_event = json .loads (data_str )
853+ cohere_event = transform_stream_event (endpoint , oci_event )
854+ yield json .dumps (cohere_event ).encode ("utf-8" ) + b"\n "
855+ except json .JSONDecodeError :
856+ continue
857+
858+
646859def transform_oci_stream_response (
647860 response : httpx .Response , endpoint : str
648861) -> typing .Iterator [bytes ]:
0 commit comments