Skip to content

Commit b78c63e

Browse files
committed
fix: Address Bugbot feedback for OCI client
- Add validation for direct credentials (user_id requires fingerprint and tenancy_id) - Emit message-end event for V2 streaming before [DONE]
1 parent 716d743 commit b78c63e

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

src/cohere/oci_client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,13 @@ def _load_oci_config(
274274
return {"signer": signer, "auth_type": "resource_principal"}
275275

276276
elif kwargs.get("user_id"):
277-
# Direct credentials provided
277+
# Direct credentials provided - validate required fields
278+
required_fields = ["fingerprint", "tenancy_id"]
279+
missing = [f for f in required_fields if not kwargs.get(f)]
280+
if missing:
281+
raise ValueError(
282+
f"When providing oci_user_id, you must also provide: {', '.join('oci_' + f for f in missing)}"
283+
)
278284
config = {
279285
"user": kwargs["user_id"],
280286
"fingerprint": kwargs["fingerprint"],
@@ -942,7 +948,11 @@ def transform_oci_stream_wrapper(
942948
if line.startswith("data: "):
943949
data_str = line[6:] # Remove "data: " prefix
944950
if data_str.strip() == "[DONE]":
945-
# Return (not break) to stop the generator completely, preventing further chunk processing
951+
# Emit message-end event for V2 before stopping
952+
if is_v2:
953+
message_end_event = {"type": "message-end"}
954+
yield b"data: " + json.dumps(message_end_event).encode("utf-8") + b"\n\n"
955+
# Return to stop the generator completely
946956
return
947957

948958
try:

0 commit comments

Comments
 (0)