diff --git a/ccproxy/llms/formatters/anthropic_to_openai/streams.py b/ccproxy/llms/formatters/anthropic_to_openai/streams.py index 9710dc14..a249c4c5 100644 --- a/ccproxy/llms/formatters/anthropic_to_openai/streams.py +++ b/ccproxy/llms/formatters/anthropic_to_openai/streams.py @@ -15,6 +15,7 @@ ObfuscationTokenFactory, ToolCallState, ensure_identifier, + ensure_responses_function_call_identifiers, ) from ccproxy.llms.formatters.constants import ANTHROPIC_TO_OPENAI_FINISH_REASON from ccproxy.llms.formatters.context import ( @@ -558,6 +559,18 @@ def ensure_tool_state(block_index: int) -> ToolCallState: next_output_index += 1 return state + def normalize_tool_state_identifiers( + state: ToolCallState, + ) -> tuple[str, str]: + item_id, call_id = ensure_responses_function_call_identifiers( + item_id=state.item_id, + call_id=state.call_id, + fallback_index=state.index, + ) + state.item_id = item_id + state.call_id = call_id + return item_id, call_id + def emit_tool_item_added( block_index: int, state: ToolCallState ) -> list[openai_models.StreamEventType]: @@ -574,8 +587,7 @@ def emit_tool_item_added( if not state.call_id: state.call_id = tool_entry.get("id") - item_id = state.item_id or state.call_id or f"call_{state.index}" - state.item_id = item_id + item_id, call_id = normalize_tool_state_identifiers(state) name = state.name or "function" @@ -592,7 +604,7 @@ def emit_tool_item_added( status="in_progress", name=str(name), arguments="", - call_id=state.call_id, + call_id=call_id, ), ) ) @@ -606,7 +618,7 @@ def emit_tool_arguments_delta( sequence_counter += 1 event_sequence = sequence_counter state.add_arguments_part(delta_text) - item_identifier = str(state.item_id or f"call_{state.index}") + item_identifier, _ = normalize_tool_state_identifiers(state) return openai_models.ResponseFunctionCallArgumentsDeltaEvent( type="response.function_call_arguments.delta", sequence_number=event_sequence, @@ -631,8 +643,7 @@ def emit_tool_finalize( if not state.item_id: state.item_id = tool_entry.get("id") - item_id = state.item_id or state.call_id or f"call_{state.index}" - state.item_id = item_id + item_id, call_id = normalize_tool_state_identifiers(state) name = state.name or "function" args_str = "".join(state.arguments_parts) @@ -674,7 +685,7 @@ def emit_tool_finalize( status="completed", name=str(name), arguments=args_str, - call_id=state.call_id, + call_id=call_id, ), ) ) @@ -1103,6 +1114,7 @@ def make_response_object( state.call_id = tool_entry.get("id") if not state.item_id: state.item_id = state.call_id or f"call_{state.index}" + item_id, call_id = normalize_tool_state_identifiers(state) final_args = state.final_arguments if final_args is None: @@ -1123,10 +1135,10 @@ def make_response_object( state.output_index, openai_models.FunctionCallOutput( type="function_call", - id=state.item_id, + id=item_id, status="completed", name=state.name, - call_id=state.call_id, + call_id=call_id, arguments=final_args, ), ) @@ -1254,6 +1266,7 @@ def make_response_object( state.call_id = tool_entry.get("id") if not state.item_id: state.item_id = state.call_id or f"call_{state.index}" + item_id, call_id = normalize_tool_state_identifiers(state) final_args = state.final_arguments if final_args is None: combined = "".join(state.arguments_parts) @@ -1270,10 +1283,10 @@ def make_response_object( state.output_index, openai_models.FunctionCallOutput( type="function_call", - id=state.item_id, + id=item_id, status="completed", name=state.name, - call_id=state.call_id, + call_id=call_id, arguments=final_args, ), ) diff --git a/ccproxy/llms/formatters/common/__init__.py b/ccproxy/llms/formatters/common/__init__.py index 2eef09d4..d502b93f 100644 --- a/ccproxy/llms/formatters/common/__init__.py +++ b/ccproxy/llms/formatters/common/__init__.py @@ -1,6 +1,12 @@ """Shared helpers used by formatter adapters.""" -from .identifiers import ensure_identifier, normalize_suffix +from .identifiers import ( + ensure_identifier, + ensure_responses_function_call_identifiers, + normalize_responses_function_call_ids, + normalize_responses_sse_event_bytes, + normalize_suffix, +) from .streams import ( IndexedToolCallTracker, ObfuscationTokenFactory, @@ -29,6 +35,9 @@ __all__ = [ "ensure_identifier", + "ensure_responses_function_call_identifiers", + "normalize_responses_function_call_ids", + "normalize_responses_sse_event_bytes", "normalize_suffix", "THINKING_PATTERN", "THINKING_OPEN_PATTERN", diff --git a/ccproxy/llms/formatters/common/identifiers.py b/ccproxy/llms/formatters/common/identifiers.py index 10077807..2d0d5056 100644 --- a/ccproxy/llms/formatters/common/identifiers.py +++ b/ccproxy/llms/formatters/common/identifiers.py @@ -2,7 +2,18 @@ from __future__ import annotations +import json +import re import uuid +from typing import Any, TypeVar + + +_SAFE_ID_CHARS = re.compile(r"[^A-Za-z0-9_-]+") +_FUNCTION_ARGUMENT_EVENT_TYPES = { + "response.function_call_arguments.delta", + "response.function_call_arguments.done", +} +_Payload = TypeVar("_Payload") def normalize_suffix(identifier: str) -> str: @@ -45,4 +56,143 @@ def ensure_identifier(prefix: str, existing: str | None = None) -> tuple[str, st return f"{prefix}_{suffix}", suffix -__all__ = ["ensure_identifier", "normalize_suffix"] +def _safe_identifier_suffix(identifier: str | None, fallback: str) -> str: + """Return a compact suffix suitable for generated Responses identifiers.""" + + if isinstance(identifier, str) and identifier: + suffix = normalize_suffix(identifier) + suffix = _SAFE_ID_CHARS.sub("_", suffix).strip("_") + if suffix: + return suffix + return fallback + + +def ensure_responses_function_call_identifiers( + *, + item_id: str | None, + call_id: str | None, + fallback_index: int | str = 0, +) -> tuple[str, str]: + """Return OpenAI Responses-compatible function-call item and call IDs. + + Responses function-call output items use an ``fc_*`` item id. The model call + correlation id remains a distinct ``call_*`` value and is reused by + ``function_call_output`` input items on subsequent turns. + """ + + fallback = str(fallback_index) + item_suffix = _safe_identifier_suffix(item_id or call_id, fallback) + call_suffix = _safe_identifier_suffix(call_id or item_id, fallback) + + normalized_item_id = ( + item_id if isinstance(item_id, str) and item_id.startswith("fc_") else None + ) + if normalized_item_id is None: + normalized_item_id = f"fc_{item_suffix}" + + normalized_call_id = ( + call_id if isinstance(call_id, str) and call_id.startswith("call_") else None + ) + if normalized_call_id is None: + normalized_call_id = f"call_{call_suffix}" + + return normalized_item_id, normalized_call_id + + +def normalize_responses_function_call_ids(payload: _Payload) -> _Payload: + """Normalize Responses function-call ids in a JSON-like payload. + + This intentionally skips ``function_call_output`` entries; they are input + items carrying user code output and should keep their own item identifiers. + """ + + if isinstance(payload, list): + for item in payload: + normalize_responses_function_call_ids(item) + return payload + + if not isinstance(payload, dict): + return payload + + payload_type = payload.get("type") + fallback_value = payload.get("output_index", payload.get("index", 0)) + fallback_index = fallback_value if isinstance(fallback_value, int | str) else 0 + + if payload_type == "function_call": + item_id, call_id = ensure_responses_function_call_identifiers( + item_id=payload.get("id") if isinstance(payload.get("id"), str) else None, + call_id=payload.get("call_id") + if isinstance(payload.get("call_id"), str) + else None, + fallback_index=fallback_index, + ) + payload["id"] = item_id + payload["call_id"] = call_id + + elif ( + isinstance(payload_type, str) and payload_type in _FUNCTION_ARGUMENT_EVENT_TYPES + ): + item_id, call_id = ensure_responses_function_call_identifiers( + item_id=payload.get("item_id") + if isinstance(payload.get("item_id"), str) + else None, + call_id=payload.get("call_id") + if isinstance(payload.get("call_id"), str) + else None, + fallback_index=fallback_index, + ) + payload["item_id"] = item_id + if isinstance(payload.get("call_id"), str): + payload["call_id"] = call_id + + for value in payload.values(): + normalize_responses_function_call_ids(value) + + return payload + + +def normalize_responses_sse_event_bytes(event_data: bytes) -> bytes: + """Normalize a complete SSE event carrying a Responses JSON payload.""" + + try: + text = event_data.decode("utf-8") + except UnicodeDecodeError: + return event_data + + lines = text.splitlines() + passthrough_lines: list[str] = [] + data_lines: list[str] = [] + for line in lines: + if line.startswith("data:"): + data_value = line[5:] + if data_value.startswith(" "): + data_value = data_value[1:] + data_lines.append(data_value) + elif line: + passthrough_lines.append(line) + + if not data_lines: + return event_data + + data_payload = "\n".join(data_lines) + if data_payload.strip() == "[DONE]": + return event_data + + try: + parsed = json.loads(data_payload) + except json.JSONDecodeError: + return event_data + + normalized = normalize_responses_function_call_ids(parsed) + compact = json.dumps(normalized, ensure_ascii=False, separators=(",", ":")) + normalized_lines = [*passthrough_lines, f"data: {compact}", ""] + return ("\n".join(normalized_lines) + "\n").encode("utf-8") + + +__all__ = [ + "ensure_identifier", + "ensure_responses_function_call_identifiers", + "normalize_responses_function_call_ids", + "normalize_responses_sse_event_bytes", + "normalize_suffix", +] diff --git a/ccproxy/llms/formatters/openai_to_openai/responses.py b/ccproxy/llms/formatters/openai_to_openai/responses.py index d8ccd254..b1dc96b1 100644 --- a/ccproxy/llms/formatters/openai_to_openai/responses.py +++ b/ccproxy/llms/formatters/openai_to_openai/responses.py @@ -13,6 +13,7 @@ ThinkingSegment, convert_openai_completion_usage_to_responses_usage, convert_openai_responses_usage_to_completion_usage, + ensure_responses_function_call_identifiers, merge_thinking_segments, ) from ccproxy.llms.formatters.context import get_openai_thinking_xml @@ -551,13 +552,19 @@ def flush_message() -> None: arguments_value: str | dict[str, Any] | None = arguments else: arguments_value = str(arguments) if arguments is not None else None + source_call_id = getattr(tool_call, "id", None) + item_id, call_id = ensure_responses_function_call_identifiers( + item_id=source_call_id if isinstance(source_call_id, str) else None, + call_id=source_call_id if isinstance(source_call_id, str) else None, + fallback_index=idx, + ) outputs.append( openai_models.FunctionCallOutput( type="function_call", - id=getattr(tool_call, "id", f"call_{idx}"), + id=item_id, status="completed", name=name, - call_id=getattr(tool_call, "id", None), + call_id=call_id, arguments=arguments_value, ) ) diff --git a/ccproxy/llms/formatters/openai_to_openai/streams.py b/ccproxy/llms/formatters/openai_to_openai/streams.py index 8a435e85..3eb661ac 100644 --- a/ccproxy/llms/formatters/openai_to_openai/streams.py +++ b/ccproxy/llms/formatters/openai_to_openai/streams.py @@ -22,6 +22,7 @@ ToolCallState, ToolCallTracker, ensure_identifier, + ensure_responses_function_call_identifiers, ) from ccproxy.llms.formatters.context import ( get_last_instructions, @@ -1321,6 +1322,18 @@ def get_accumulator_entry(idx: int) -> dict[str, Any] | None: return entry return None + def normalize_tool_state_identifiers( + state: ToolCallState, + ) -> tuple[str, str]: + item_id, call_id = ensure_responses_function_call_identifiers( + item_id=state.item_id, + call_id=state.call_id, + fallback_index=state.index, + ) + state.item_id = item_id + state.call_id = call_id + return item_id, call_id + def emit_tool_item_added( state: ToolCallState, ) -> list[openai_models.StreamEventType]: @@ -1334,6 +1347,7 @@ def emit_tool_item_added( if not item_identifier: item_identifier = f"call_{state.index}" state.item_id = item_identifier + item_id, call_id = normalize_tool_state_identifiers(state) sequence_counter += 1 state.added_emitted = True return [ @@ -1342,12 +1356,12 @@ def emit_tool_item_added( sequence_number=sequence_counter, output_index=state.output_index, item=openai_models.OutputItem( - id=state.item_id, + id=item_id, type="function_call", status="in_progress", name=state.name, arguments="", - call_id=state.call_id, + call_id=call_id, ), ) ] @@ -1372,6 +1386,7 @@ def finalize_tool_calls() -> list[openai_models.StreamEventType]: state.item_id = ( candidate_id or state.call_id or f"call_{state.index}" ) + item_id, call_id = normalize_tool_state_identifiers(state) if not state.added_emitted: events.extend(emit_tool_item_added(state)) final_args = state.final_arguments @@ -1390,7 +1405,7 @@ def finalize_tool_calls() -> list[openai_models.StreamEventType]: openai_models.ResponseFunctionCallArgumentsDoneEvent( type="response.function_call_arguments.done", sequence_number=sequence_counter, - item_id=state.item_id, + item_id=item_id, output_index=state.output_index, arguments=final_args, ) @@ -1404,12 +1419,12 @@ def finalize_tool_calls() -> list[openai_models.StreamEventType]: sequence_number=sequence_counter, output_index=state.output_index, item=openai_models.OutputItem( - id=state.item_id, + id=item_id, type="function_call", status="completed", name=state.name, arguments=final_args, - call_id=state.call_id, + call_id=call_id, ), ) ) @@ -1648,14 +1663,16 @@ def make_response_object( ) or arguments_payload.get("obfuscated") if arguments_delta: state.add_arguments_part(arguments_delta) + item_id, _ = normalize_tool_state_identifiers( + state + ) sequence_counter += 1 event_sequence = sequence_counter yield ( delta_event_cls( type="response.function_call_arguments.delta", sequence_number=event_sequence, - item_id=state.item_id - or f"call_{state.index}", + item_id=item_id, output_index=state.output_index, delta=arguments_delta, ) @@ -1770,15 +1787,16 @@ def make_response_object( if accumulator_entry is not None: candidate_id = accumulator_entry.get("id") state.item_id = candidate_id or f"call_{state.index}" + item_id, call_id = normalize_tool_state_identifiers(state) completed_entries.append( ( state.output_index, openai_models.FunctionCallOutput( type="function_call", - id=state.item_id, + id=item_id, status="completed", name=state.name, - call_id=state.call_id, + call_id=call_id, arguments=state.final_arguments or "", ), ) diff --git a/ccproxy/llms/models/openai.py b/ccproxy/llms/models/openai.py index e510d0d4..f62b2712 100644 --- a/ccproxy/llms/models/openai.py +++ b/ccproxy/llms/models/openai.py @@ -233,9 +233,9 @@ class ChatCompletionRequest(LlmBaseModel): n: int | None = Field(default=1) parallel_tool_calls: bool | None = Field(default=None) presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0) - reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = Field( - default=None - ) + reasoning_effort: ( + Literal["minimal", "low", "medium", "high", "xhigh", "max"] | None + ) = Field(default=None) response_format: ResponseFormat | None = Field(default=None) seed: int | None = Field(default=None) stop: str | list[str] | None = Field(default=None) @@ -262,7 +262,7 @@ class ChatCompletionRequest(LlmBaseModel): class ResponseMessageReasoning(LlmBaseModel): - effort: Literal["minimal", "low", "medium", "high"] | None = None + effort: Literal["minimal", "low", "medium", "high", "xhigh", "max"] | None = None summary: Literal["auto", "detailed", "concise"] | None = None @@ -447,6 +447,7 @@ def validate_include(cls, v: list[str] | None) -> list[str] | None: previous_response_id: str | None = Field(default=None) prompt: dict[str, Any] | None = Field(default=None) prompt_cache_key: str | None = Field(default=None) + prompt_cache_retention: str | None = Field(default=None) reasoning: dict[str, Any] | None = Field(default=None) safety_identifier: str | None = Field(default=None) service_tier: str | None = Field(default=None) diff --git a/ccproxy/llms/streaming/accumulators.py b/ccproxy/llms/streaming/accumulators.py index 7d7b33f0..6b9b44f5 100644 --- a/ccproxy/llms/streaming/accumulators.py +++ b/ccproxy/llms/streaming/accumulators.py @@ -12,6 +12,10 @@ import structlog from pydantic import TypeAdapter, ValidationError +from ccproxy.llms.formatters.common import ( + ensure_responses_function_call_identifiers, + normalize_responses_function_call_ids, +) from ccproxy.llms.models import openai as openai_models @@ -598,6 +602,7 @@ def __init__(self) -> None: ] = {} self._reasoning_text: dict[tuple[str, int], list[str]] = {} self._function_arguments: dict[str, list[str]] = {} + self._function_item_id_aliases: dict[str, str] = {} self._latest_response: openai_models.ResponseObject | None = None self.completed_response: openai_models.ResponseObject | None = None self._sequence_counter = 0 @@ -766,13 +771,16 @@ def rebuild_response_object(self, response: dict[str, Any]) -> dict[str, Any]: if self.text_content: payload["text"] = self.text_content - return payload + return normalize_responses_function_call_ids(payload) def get_completed_response(self) -> dict[str, Any] | None: - """Return the final response payload captured from the stream, if any.""" + """Return the completed response merged with accumulated stream items.""" if isinstance(self.completed_response, openai_models.ResponseObject): - return self.completed_response.model_dump() + payload = self.completed_response.model_dump() + return normalize_responses_function_call_ids( + self.rebuild_response_object(payload) + ) return None def _coerce_stream_event( @@ -819,6 +827,7 @@ def _coerce_stream_event( def _record_output_item( self, output_index: int, item: openai_models.OutputItem ) -> None: + item = self._normalize_function_output_item(output_index, item) self._items[item.id] = item self._items_by_index[output_index] = item.id if item.text: @@ -827,6 +836,7 @@ def _record_output_item( def _merge_output_item( self, output_index: int, item: openai_models.OutputItem ) -> None: + item = self._normalize_function_output_item(output_index, item) existing = self._items.get(item.id) if existing is not None: merged = existing.model_copy(update=item.model_dump(exclude_unset=True)) @@ -837,6 +847,42 @@ def _merge_output_item( if merged.text: self.text_content = merged.text + def _normalize_function_output_item( + self, output_index: int, item: openai_models.OutputItem + ) -> openai_models.OutputItem: + if item.type != "function_call": + return item + + original_item_id = item.id + original_call_id = item.call_id + item_id, call_id = ensure_responses_function_call_identifiers( + item_id=original_item_id, + call_id=original_call_id, + fallback_index=output_index, + ) + + if original_item_id and original_item_id != item_id: + self._function_item_id_aliases[original_item_id] = item_id + if original_call_id and original_call_id != item_id: + self._function_item_id_aliases[original_call_id] = item_id + + return item.model_copy(update={"id": item_id, "call_id": call_id}) + + def _normalize_function_item_id( + self, item_id: str, fallback_index: int | str = 0 + ) -> str: + if item_id in self._function_item_id_aliases: + return self._function_item_id_aliases[item_id] + + normalized_item_id, _ = ensure_responses_function_call_identifiers( + item_id=item_id, + call_id=None, + fallback_index=fallback_index, + ) + if item_id != normalized_item_id: + self._function_item_id_aliases[item_id] = normalized_item_id + return normalized_item_id + def _accumulate_text_delta( self, *, item_id: str, content_index: int, delta: str ) -> None: @@ -861,17 +907,20 @@ def _update_output_item_text(self, item_id: str, text: str) -> None: self.text_content = text def _accumulate_function_arguments(self, item_id: str, delta: str) -> None: + item_id = self._normalize_function_item_id(item_id) args = self._function_arguments.setdefault(item_id, []) args.append(delta) combined = "".join(args) self._update_output_item_arguments(item_id, combined) def _finalize_function_arguments(self, item_id: str, arguments: str) -> None: + item_id = self._normalize_function_item_id(item_id) if arguments: self._function_arguments[item_id] = [arguments] self._update_output_item_arguments(item_id, arguments) def _update_output_item_arguments(self, item_id: str, arguments: str) -> None: + item_id = self._normalize_function_item_id(item_id) item = self._items.get(item_id) if item is None: return diff --git a/ccproxy/plugins/codex/adapter.py b/ccproxy/plugins/codex/adapter.py index 99cecdea..9dfc6e22 100644 --- a/ccproxy/plugins/codex/adapter.py +++ b/ccproxy/plugins/codex/adapter.py @@ -10,11 +10,17 @@ from starlette.responses import JSONResponse, Response, StreamingResponse from ccproxy.auth.exceptions import OAuthTokenRefreshError +from ccproxy.core.constants import ( + FORMAT_OPENAI_RESPONSES, + UPSTREAM_ENDPOINT_OPENAI_RESPONSES, +) from ccproxy.core.logging import get_plugin_logger from ccproxy.core.plugins.interfaces import ( DetectionServiceProtocol, ProfiledTokenManagerProtocol, ) +from ccproxy.core.request_context import RequestContext +from ccproxy.llms.streaming.accumulators import ResponsesAccumulator from ccproxy.services.adapters.chain_composer import compose_from_chain from ccproxy.services.adapters.http_adapter import BaseHTTPAdapter from ccproxy.services.adapters.mock_adapter import MockAdapter @@ -28,10 +34,19 @@ ) from ccproxy.utils.model_mapper import restore_model_aliases +from .responses_state import CodexResponsesStateStore, ResponsesStateNotFoundError + logger = get_plugin_logger() +_CODEX_MODEL_REASONING_ALIASES = { + "gpt-5.5-high": "high", + "gpt-5.5-xhigh": "xhigh", + "gpt-5.5-max": "max", +} + + class CodexAdapter(BaseHTTPAdapter): """Simplified Codex adapter.""" @@ -47,6 +62,10 @@ def __init__( ProfiledTokenManagerProtocol, self.auth_manager ) self.base_url = self.config.base_url.rstrip("/") + self.responses_state_store = CodexResponsesStateStore( + max_entries=getattr(self.config, "responses_state_max_entries", 1024), + ttl_seconds=getattr(self.config, "responses_state_ttl_seconds", 3600), + ) async def handle_request( self, request: Request @@ -63,8 +82,10 @@ async def handle_request( return await MockAdapter(self.mock_handler).handle_request(request) endpoint = ctx.metadata.get("endpoint", "") + self._ensure_responses_accumulator(ctx, endpoint) body = await request.body() body = await self._map_request_model(ctx, body) + body = self._apply_model_alias_reasoning_effort(ctx, body) headers = extract_request_headers(request) # Determine client streaming intent from body flag (fallback to False) @@ -123,9 +144,12 @@ async def handle_request( }, ) - prepared_body, prepared_headers = await self.prepare_provider_request( - body, headers, endpoint - ) + try: + prepared_body, prepared_headers = await self.prepare_provider_request( + body, headers, endpoint, request_context=ctx + ) + except ResponsesStateNotFoundError as exc: + return self._responses_state_error_response(exc) logger.trace( "codex_adapter_prepared_provider_request", header_keys=list(prepared_headers.keys()), @@ -175,6 +199,9 @@ async def handle_request( request_context=ctx, provider_name="codex", ) + buffered_response = self._finalize_buffered_responses_state( + buffered_response, ctx + ) logger.trace( "codex_adapter_buffered_response_ready", status_code=buffered_response.status_code, @@ -254,7 +281,12 @@ async def get_target_url(self, endpoint: str) -> str: return f"{self.base_url}/responses" async def prepare_provider_request( - self, body: bytes, headers: dict[str, str], endpoint: str + self, + body: bytes, + headers: dict[str, str], + endpoint: str, + *, + request_context: RequestContext | None = None, ) -> tuple[bytes, dict[str, str]]: filtered_headers = await self.prepare_provider_headers(headers) @@ -291,9 +323,213 @@ async def prepare_provider_request( body_data.pop("instructions", None) body_data = self._sanitize_provider_body(body_data) + body_data = self._apply_responses_state_continuation( + body_data, + headers=headers, + endpoint=endpoint, + request_context=request_context, + ) return json.dumps(body_data).encode(), filtered_headers + def _apply_responses_state_continuation( + self, + body_data: dict[str, Any], + *, + headers: dict[str, str], + endpoint: str, + request_context: RequestContext | None, + ) -> dict[str, Any]: + if not self._should_manage_responses_state(endpoint, request_context): + return body_data + + prepared, scope, previous_response_id = ( + self.responses_state_store.prepare_payload(body_data, headers=headers) + ) + + if request_context is not None: + metadata = request_context.metadata + metadata["_codex_responses_state_scope"] = scope + metadata["_codex_responses_state_request"] = copy.deepcopy(prepared) + metadata["_codex_responses_stream_complete_callback"] = ( + self._record_responses_state_from_stream + ) + if previous_response_id: + metadata["_codex_responses_previous_response_id"] = previous_response_id + else: + metadata.pop("_codex_responses_previous_response_id", None) + + return prepared + + def _should_manage_responses_state( + self, endpoint: str, request_context: RequestContext | None + ) -> bool: + if endpoint != UPSTREAM_ENDPOINT_OPENAI_RESPONSES: + return False + if request_context is None: + return True + format_chain = getattr(request_context, "format_chain", None) or [] + return not format_chain or format_chain[0] == FORMAT_OPENAI_RESPONSES + + def _ensure_responses_accumulator( + self, request_context: RequestContext, endpoint: str + ) -> None: + if self._should_manage_responses_state(endpoint, request_context): + cast(Any, request_context)._tool_accumulator_class = ResponsesAccumulator + + def _responses_state_error_response( + self, exc: ResponsesStateNotFoundError + ) -> JSONResponse: + return JSONResponse(status_code=400, content=exc.to_openai_error()) + + def _responses_state_error_streaming_response( + self, exc: ResponsesStateNotFoundError + ) -> StreamingResponse: + error_bytes = json.dumps(exc.to_openai_error()).encode("utf-8") + + async def error_generator() -> Any: + yield error_bytes + + return StreamingResponse( + content=error_generator(), + status_code=400, + media_type="application/json", + ) + + def _finalize_buffered_responses_state( + self, response: Response, request_context: RequestContext + ) -> Response: + if not self._has_responses_state_request(request_context): + return response + if response.status_code >= 400: + return response + + body = ( + response.body if isinstance(response.body, bytes) else bytes(response.body) + ) + try: + payload = json.loads(body.decode("utf-8")) + except Exception: + return response + if not isinstance(payload, dict): + return response + + payload = self.record_responses_state_from_payload(payload, request_context) + if payload is None: + return response + + headers = filter_response_headers(dict(response.headers)) + return Response( + content=json.dumps(payload).encode("utf-8"), + status_code=response.status_code, + headers=headers, + media_type=response.media_type or "application/json", + ) + + def _has_responses_state_request(self, request_context: RequestContext) -> bool: + metadata = getattr(request_context, "metadata", None) + return isinstance(metadata, dict) and isinstance( + metadata.get("_codex_responses_state_request"), dict + ) + + def record_responses_state_from_payload( + self, payload: dict[str, Any], request_context: RequestContext + ) -> dict[str, Any] | None: + metadata = getattr(request_context, "metadata", None) + if not isinstance(metadata, dict): + return None + + request_payload = metadata.get("_codex_responses_state_request") + scope = metadata.get("_codex_responses_state_scope") + if not isinstance(request_payload, dict) or not isinstance(scope, str): + return None + if not isinstance(payload, dict) or payload.get("error"): + return None + + response_payload = copy.deepcopy(payload) + previous_response_id = metadata.get("_codex_responses_previous_response_id") + if ( + isinstance(previous_response_id, str) + and previous_response_id + and not response_payload.get("previous_response_id") + ): + response_payload["previous_response_id"] = previous_response_id + + self.responses_state_store.store_response( + scope=scope, + request_payload=request_payload, + response_payload=response_payload, + ) + return response_payload + + async def _record_responses_state_from_stream( + self, + request_context: RequestContext, + stream_accumulator: Any, + ) -> None: + if stream_accumulator is None: + return + + payload: dict[str, Any] | None = None + get_completed = getattr(stream_accumulator, "get_completed_response", None) + if callable(get_completed): + with contextlib.suppress(Exception): + completed = get_completed() + if isinstance(completed, dict): + payload = completed + + if payload is None: + rebuild = getattr(stream_accumulator, "rebuild_response_object", None) + if callable(rebuild): + with contextlib.suppress(Exception): + rebuilt = rebuild({}) + if isinstance(rebuilt, dict): + payload = rebuilt + + if payload is not None: + self.record_responses_state_from_payload(payload, request_context) + + def _apply_model_alias_reasoning_effort(self, ctx: Any, body: bytes) -> bytes: + """Apply reasoning effort implied by client-facing Codex model aliases.""" + + metadata = getattr(ctx, "metadata", None) + client_model = None + if isinstance(metadata, dict): + client_model = metadata.get("_last_client_model") + if not isinstance(client_model, str): + return body + + effort = _CODEX_MODEL_REASONING_ALIASES.get(client_model) + if effort is None: + return body + + try: + body_data = json.loads(body.decode()) if body else {} + except Exception: + return body + if not isinstance(body_data, dict): + return body + + is_responses_request = self._is_openai_responses_request(ctx) + if isinstance(body_data.get("reasoning"), dict): + reasoning = dict(body_data["reasoning"]) + reasoning.setdefault("effort", effort) + body_data["reasoning"] = reasoning + elif is_responses_request: + body_data["reasoning"] = {"effort": effort} + elif not body_data.get("reasoning_effort"): + body_data["reasoning_effort"] = effort + + return self._encode_json_body(body_data) + + def _is_openai_responses_request(self, ctx: Any) -> bool: + format_chain = getattr(ctx, "format_chain", None) + if format_chain is None: + return True + if not isinstance(format_chain, list | tuple): + return False + return not format_chain or format_chain[0] == FORMAT_OPENAI_RESPONSES + def _sanitize_provider_body(self, body_data: dict[str, Any]) -> dict[str, Any]: """Apply Codex-specific payload sanitization shared by all request paths.""" @@ -311,20 +547,56 @@ def _sanitize_provider_body(self, body_data: dict[str, Any]) -> dict[str, Any]: "max_tokens", "temperature", "metadata", + "stream_options", + "prompt_cache_retention", + "safety_identifier", ): body_data.pop(key, None) - list_input = body_data.get("input", []) - # Remove any input types that Codex does not support - body_data["input"] = [ - input for input in list_input if input.get("type") != "item_reference" - ] + input_value = body_data.get("input", []) + # Remove any input types that Codex does not support. Public Responses API + # input may be a plain string, but the Codex backend expects message items. + if isinstance(input_value, list): + body_data["input"] = [ + input_item + for input_item in input_value + if not ( + isinstance(input_item, dict) + and input_item.get("type") == "item_reference" + ) + ] + elif isinstance(input_value, str): + body_data["input"] = [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": input_value}], + } + ] # Remove any prefixed metadata fields that shouldn't be sent to the API body_data = self._remove_metadata_fields(body_data) + self._normalize_reasoning_effort(body_data) return body_data + def _normalize_reasoning_effort(self, body_data: dict[str, Any]) -> None: + """Clamp client-facing effort aliases to values accepted by Codex backend.""" + + reasoning_effort = body_data.pop("reasoning_effort", None) + reasoning = body_data.get("reasoning") + if isinstance(reasoning_effort, str) and reasoning_effort: + if reasoning_effort == "max": + reasoning_effort = "xhigh" + if isinstance(reasoning, dict): + reasoning.setdefault("effort", reasoning_effort) + else: + reasoning = {"effort": reasoning_effort} + body_data["reasoning"] = reasoning + + if isinstance(reasoning, dict) and reasoning.get("effort") == "max": + reasoning["effort"] = "xhigh" + async def prepare_provider_headers(self, headers: dict[str, str]) -> dict[str, str]: token_value = await self._resolve_access_token() @@ -497,10 +769,14 @@ async def handle_streaming( # Get context ctx = request.state.context self._ensure_tool_accumulator(ctx) + with contextlib.suppress(Exception): + ctx.metadata.setdefault("service_type", "codex") + self._ensure_responses_accumulator(ctx, endpoint) # Extract body and headers body = await request.body() body = await self._map_request_model(ctx, body) + body = self._apply_model_alias_reasoning_effort(ctx, body) headers = extract_request_headers(request) # Ensure format adapters are available when required @@ -549,9 +825,12 @@ async def error_generator() -> ( ) # Provider-specific preparation (adds auth, sets stream=true) - prepared_body, prepared_headers = await self.prepare_provider_request( - body, headers, endpoint - ) + try: + prepared_body, prepared_headers = await self.prepare_provider_request( + body, headers, endpoint, request_context=ctx + ) + except ResponsesStateNotFoundError as exc: + return self._responses_state_error_streaming_response(exc) # Get format adapter for streaming reverse conversion streaming_format_adapter = None diff --git a/ccproxy/plugins/codex/config.py b/ccproxy/plugins/codex/config.py index 6d9d71b5..61a4aa1c 100644 --- a/ccproxy/plugins/codex/config.py +++ b/ccproxy/plugins/codex/config.py @@ -121,6 +121,16 @@ class CodexSettings(ProviderConfig): buffer_non_streaming: bool = Field( default=True, description="Whether to buffer non-streaming requests" ) + responses_state_ttl_seconds: int = Field( + default=3600, + ge=1, + description="TTL for locally emulated OpenAI Responses continuation state", + ) + responses_state_max_entries: int = Field( + default=1024, + ge=1, + description="Maximum local OpenAI Responses continuation records to retain", + ) enable_format_registry: bool = Field( default=True, description="Whether to enable format adapter registry" ) diff --git a/ccproxy/plugins/codex/plugin.py b/ccproxy/plugins/codex/plugin.py index 1433f50b..69f254d1 100644 --- a/ccproxy/plugins/codex/plugin.py +++ b/ccproxy/plugins/codex/plugin.py @@ -22,6 +22,7 @@ from .adapter import CodexAdapter from .config import CodexSettings from .detection_service import CodexDetectionService +from .routes import openai_router as codex_openai_router from .routes import router as codex_router @@ -238,6 +239,7 @@ class CodexFactory(BaseProviderPluginFactory): credentials_manager_class = CodexTokenManager routers = [ RouterSpec(router=codex_router, prefix="/codex"), + RouterSpec(router=codex_openai_router, prefix=""), ] dependencies = ["oauth_codex"] optional_requires = ["pricing"] diff --git a/ccproxy/plugins/codex/responses_state.py b/ccproxy/plugins/codex/responses_state.py new file mode 100644 index 00000000..467e9ccc --- /dev/null +++ b/ccproxy/plugins/codex/responses_state.py @@ -0,0 +1,206 @@ +"""Local OpenAI Responses continuation state for the Codex backend.""" + +from __future__ import annotations + +import copy +import hashlib +import json +import time +from collections import OrderedDict +from dataclasses import dataclass +from threading import RLock +from typing import Any + +from ccproxy.llms.formatters.common import normalize_responses_function_call_ids + + +_FAILED_STATUSES = {"failed", "incomplete", "cancelled", "canceled"} + + +class ResponsesStateNotFoundError(ValueError): + """Raised when a local ``previous_response_id`` cannot be resolved.""" + + def __init__(self, response_id: str) -> None: + self.response_id = response_id + super().__init__(f"Unknown previous_response_id: {response_id}") + + def to_openai_error(self) -> dict[str, Any]: + return { + "error": { + "type": "invalid_request_error", + "message": str(self), + "param": "previous_response_id", + "code": "previous_response_not_found", + } + } + + +@dataclass +class ResponsesStateRecord: + scope: str + response_id: str + context_items: list[Any] + expires_at: float + + +class CodexResponsesStateStore: + """Bounded per-client state used to emulate ``previous_response_id`` locally.""" + + def __init__(self, *, max_entries: int = 1024, ttl_seconds: int = 3600) -> None: + self.max_entries = _positive_int(max_entries, default=1024) + self.ttl_seconds = _positive_int(ttl_seconds, default=3600) + self._records: OrderedDict[tuple[str, str], ResponsesStateRecord] = ( + OrderedDict() + ) + self._lock = RLock() + + def prepare_payload( + self, + payload: dict[str, Any], + *, + headers: dict[str, str], + ) -> tuple[dict[str, Any], str, str | None]: + """Return a provider payload with any local continuation expanded.""" + + scope = self.scope_for_headers(headers) + previous_response_id = payload.get("previous_response_id") + prepared = copy.deepcopy(payload) + + if previous_response_id is None: + return prepared, scope, None + + if not isinstance(previous_response_id, str) or not previous_response_id: + return prepared, scope, None + + record = self.get(scope, previous_response_id) + if record is None: + raise ResponsesStateNotFoundError(previous_response_id) + + current_input = _normalize_input_items(prepared.get("input", [])) + prepared["input"] = copy.deepcopy(record.context_items) + current_input + prepared.pop("previous_response_id", None) + return prepared, scope, previous_response_id + + def store_response( + self, + *, + scope: str, + request_payload: dict[str, Any], + response_payload: dict[str, Any], + ) -> bool: + """Store a completed or tool-pending Responses payload for continuation.""" + + if not isinstance(response_payload, dict): + return False + + response_id = response_payload.get("id") + if not isinstance(response_id, str) or not response_id: + return False + + status = response_payload.get("status") + if isinstance(status, str) and status.lower() in _FAILED_STATUSES: + return False + + output_items = response_payload.get("output") + if not isinstance(output_items, list): + output_items = [] + + input_items = _normalize_input_items(request_payload.get("input", [])) + context_items = input_items + _normalize_output_items(output_items) + if not context_items: + return False + + expires_at = time.monotonic() + self.ttl_seconds + record = ResponsesStateRecord( + scope=scope, + response_id=response_id, + context_items=context_items, + expires_at=expires_at, + ) + key = (scope, response_id) + with self._lock: + self._prune_locked(now=time.monotonic()) + self._records[key] = record + self._records.move_to_end(key) + while len(self._records) > self.max_entries: + self._records.popitem(last=False) + return True + + def get(self, scope: str, response_id: str) -> ResponsesStateRecord | None: + key = (scope, response_id) + now = time.monotonic() + with self._lock: + self._prune_locked(now=now) + record = self._records.get(key) + if record is None: + return None + if record.expires_at <= now: + self._records.pop(key, None) + return None + self._records.move_to_end(key) + return copy.deepcopy(record) + + def scope_for_headers(self, headers: dict[str, str]) -> str: + values: list[tuple[str, str]] = [] + for key in ( + "authorization", + "x-api-key", + "session_id", + "session-id", + "conversation_id", + "conversation-id", + "chatgpt-account-id", + "cf-connecting-ip", + ): + value = headers.get(key) + if isinstance(value, str) and value: + values.append((key, value)) + + if not values: + values.append(("anonymous", "anonymous")) + + raw = json.dumps(values, separators=(",", ":"), sort_keys=True) + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + def _prune_locked(self, *, now: float) -> None: + expired = [ + key for key, record in self._records.items() if record.expires_at <= now + ] + for key in expired: + self._records.pop(key, None) + + +def _normalize_input_items(value: Any) -> list[Any]: + if isinstance(value, list): + return copy.deepcopy(value) + if isinstance(value, str): + return [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": value}], + } + ] + if value is None: + return [] + return [copy.deepcopy(value)] + + +def _normalize_output_items(value: list[Any]) -> list[Any]: + normalized = copy.deepcopy(value) + normalize_responses_function_call_ids(normalized) + return normalized + + +def _positive_int(value: Any, *, default: int) -> int: + try: + return max(1, int(value)) + except (TypeError, ValueError): + return default + + +__all__ = [ + "CodexResponsesStateStore", + "ResponsesStateNotFoundError", + "ResponsesStateRecord", +] diff --git a/ccproxy/plugins/codex/routes.py b/ccproxy/plugins/codex/routes.py index cf605515..5b2381cc 100644 --- a/ccproxy/plugins/codex/routes.py +++ b/ccproxy/plugins/codex/routes.py @@ -36,6 +36,7 @@ from ccproxy.utils.model_mapper import restore_model_aliases from .config import CodexSettings +from .responses_state import ResponsesStateNotFoundError if TYPE_CHECKING: @@ -51,6 +52,7 @@ Depends(get_provider_config_dependency("codex", CodexSettings)), ] router = APIRouter() +openai_router = APIRouter() # Helper to handle adapter requests @@ -218,7 +220,10 @@ async def _sanitize_websocket_payload( body_bytes = json.dumps(provider_payload).encode("utf-8") body_bytes = await adapter._map_request_model(request_context, body_bytes) prepared_body, prepared_headers = await adapter.prepare_provider_request( - body_bytes, headers, UPSTREAM_ENDPOINT_OPENAI_RESPONSES + body_bytes, + headers, + UPSTREAM_ENDPOINT_OPENAI_RESPONSES, + request_context=request_context, ) sanitized_payload = json.loads(prepared_body.decode("utf-8")) return sanitized_payload, prepared_headers @@ -291,6 +296,26 @@ async def _send_websocket_event( ) +def _record_websocket_response_state( + adapter: "CodexAdapter", + event: dict[str, Any], + request_context: RequestContext, +) -> dict[str, Any]: + if event.get("type") != "response.completed": + return event + response_payload = event.get("response") + if not isinstance(response_payload, dict): + return event + updated = adapter.record_responses_state_from_payload( + response_payload, request_context + ) + if updated is None: + return event + patched = dict(event) + patched["response"] = updated + return patched + + def _serialize_codex_models(config: CodexSettings) -> list[dict[str, Any]]: models: list[dict[str, Any]] = [] for card in config.models_endpoint: @@ -364,9 +389,20 @@ async def _stream_websocket_response( websocket, adapter, provider_payload, request_context ) return - provider_payload, provider_headers = await _sanitize_websocket_payload( - adapter, provider_payload, request_headers, request_context - ) + try: + provider_payload, provider_headers = await _sanitize_websocket_payload( + adapter, provider_payload, request_headers, request_context + ) + except ResponsesStateNotFoundError as exc: + await _send_websocket_event( + websocket, + _make_websocket_terminal_event( + provider_payload, + error=exc.to_openai_error()["error"], + ), + request_context, + ) + return target_url = await adapter.get_target_url(UPSTREAM_ENDPOINT_OPENAI_RESPONSES) parsed_url = urlparse(target_url) @@ -411,11 +447,15 @@ async def _stream_websocket_response( for event in parser.feed(chunk): if event.get("type") in {"response.completed", "response.failed"}: saw_terminal_event = True + event = _record_websocket_response_state( + adapter, event, request_context + ) await _send_websocket_event(websocket, event, request_context) for event in parser.flush(): if event.get("type") in {"response.completed", "response.failed"}: saw_terminal_event = True + event = _record_websocket_response_state(adapter, event, request_context) await _send_websocket_event(websocket, event, request_context) if not saw_terminal_event: @@ -458,11 +498,13 @@ async def _stream_websocket_mock_response( for event in parser.feed(chunk): if event.get("type") in {"response.completed", "response.failed"}: saw_terminal_event = True + event = _record_websocket_response_state(adapter, event, request_context) await _send_websocket_event(websocket, event, request_context) for event in parser.flush(): if event.get("type") in {"response.completed", "response.failed"}: saw_terminal_event = True + event = _record_websocket_response_state(adapter, event, request_context) await _send_websocket_event(websocket, event, request_context) if not saw_terminal_event: @@ -552,6 +594,40 @@ async def codex_responses_legacy_websocket(websocket: WebSocket) -> None: await codex_responses_websocket(websocket) +@openai_router.post("/responses", response_model=None) +@with_format_chain( + [FORMAT_OPENAI_RESPONSES], endpoint=UPSTREAM_ENDPOINT_OPENAI_RESPONSES +) +async def openai_responses( + request: Request, + auth: ConditionalAuthDep, + adapter: CodexAdapterDep, +) -> StreamingResponse | Response | DeferredStreaming: + return await _codex_responses_handler(request, adapter) + + +@openai_router.post("/v1/responses", response_model=None) +@with_format_chain( + [FORMAT_OPENAI_RESPONSES], endpoint=UPSTREAM_ENDPOINT_OPENAI_RESPONSES +) +async def openai_v1_responses( + request: Request, + auth: ConditionalAuthDep, + adapter: CodexAdapterDep, +) -> StreamingResponse | Response | DeferredStreaming: + return await _codex_responses_handler(request, adapter) + + +@openai_router.websocket("/responses") +async def openai_responses_websocket(websocket: WebSocket) -> None: + await codex_responses_websocket(websocket) + + +@openai_router.websocket("/v1/responses") +async def openai_v1_responses_websocket(websocket: WebSocket) -> None: + await codex_responses_websocket(websocket) + + @router.post("/v1/chat/completions", response_model=None) @with_format_chain( [FORMAT_OPENAI_CHAT, FORMAT_OPENAI_RESPONSES], @@ -577,6 +653,16 @@ async def list_models( return {"object": "list", "data": openai_models, "models": codex_models} +@openai_router.get("/models", response_model=None) +@openai_router.get("/v1/models", response_model=None) +async def openai_list_models( + request: Request, + auth: ConditionalAuthDep, + config: CodexConfigDep, +) -> dict[str, Any]: + return await list_models(request, auth, config) + + @router.post("/v1/messages", response_model=None) @with_format_chain( [FORMAT_ANTHROPIC_MESSAGES, FORMAT_OPENAI_RESPONSES], diff --git a/ccproxy/streaming/buffer.py b/ccproxy/streaming/buffer.py index 2b4119bd..ae477a60 100644 --- a/ccproxy/streaming/buffer.py +++ b/ccproxy/streaming/buffer.py @@ -16,8 +16,10 @@ from ccproxy.core.plugins.hooks import HookEvent, HookManager from ccproxy.core.plugins.hooks.base import HookContext +from ccproxy.llms.formatters.common import normalize_responses_function_call_ids from ccproxy.llms.models import openai as openai_models from ccproxy.llms.streaming.accumulators import ResponsesAccumulator, StreamAccumulator +from ccproxy.streaming.errors import normalize_openai_error_payload if TYPE_CHECKING: @@ -383,10 +385,15 @@ async def _collect_and_parse_stream( request_id=request_id, category="streaming", ) - try: - error_data = json.loads(error_body) - except json.JSONDecodeError: - error_data = {"error": error_body.decode("utf-8", errors="ignore")} + if provider_name == "codex": + error_data = normalize_openai_error_payload(error_body) + else: + try: + error_data = json.loads(error_body) + except json.JSONDecodeError: + error_data = { + "error": error_body.decode("utf-8", errors="ignore") + } return error_data, status_code, response_headers # Collect all stream chunks @@ -631,7 +638,9 @@ async def _parse_collected_stream( completed=bool(completed_payload), ) if completed_payload is not None: - response_obj = completed_payload + response_obj = normalize_responses_function_call_ids( + completed_payload + ) return response_obj try: response_obj = accumulator_for_rebuild.rebuild_response_object( @@ -680,7 +689,7 @@ async def _parse_collected_stream( "total_tokens": usage.get("total_tokens", 0), } - return response_obj + return normalize_responses_function_call_ids(response_obj) # Try using the configured SSE parser first logger.debug( @@ -723,7 +732,7 @@ async def _parse_collected_stream( exc_info=e, ) - return parsed_data + return normalize_responses_function_call_ids(parsed_data) else: logger.warning( "sse_parser_returned_none", @@ -748,7 +757,7 @@ async def _parse_collected_stream( request_id=getattr(request_context, "request_id", None), category="streaming", ) - return parsed_json + return normalize_responses_function_call_ids(parsed_json) else: # If it's not a dict, wrap it logger.info( @@ -757,7 +766,7 @@ async def _parse_collected_stream( request_id=getattr(request_context, "request_id", None), category="streaming", ) - return {"data": parsed_json} + return normalize_responses_function_call_ids({"data": parsed_json}) except json.JSONDecodeError: pass @@ -771,7 +780,7 @@ async def _parse_collected_stream( request_id=getattr(request_context, "request_id", None), category="streaming", ) - return parsed_data + return normalize_responses_function_call_ids(parsed_data) except Exception as e: logger.debug( "generic_sse_parsing_failed", diff --git a/ccproxy/streaming/deferred.py b/ccproxy/streaming/deferred.py index b721a105..b3e4c51b 100644 --- a/ccproxy/streaming/deferred.py +++ b/ccproxy/streaming/deferred.py @@ -4,6 +4,7 @@ """ import contextlib +import inspect import json from collections.abc import AsyncGenerator, AsyncIterator, Callable from datetime import datetime @@ -13,9 +14,12 @@ import structlog from starlette.responses import JSONResponse, Response, StreamingResponse +from ccproxy.core.constants import FORMAT_ANTHROPIC_MESSAGES, FORMAT_OPENAI_RESPONSES from ccproxy.core.plugins.hooks import HookEvent, HookManager from ccproxy.core.plugins.hooks.base import HookContext +from ccproxy.llms.formatters.common import normalize_responses_sse_event_bytes from ccproxy.llms.streaming.accumulators import StreamAccumulator +from ccproxy.streaming.errors import normalize_openai_error_payload from ccproxy.streaming.sse import serialize_json_to_sse_stream from ccproxy.utils.model_mapper import restore_model_aliases @@ -233,6 +237,7 @@ async def body_generator() -> AsyncGenerator[bytes, None]: async def _emit_error_sse( error_obj: dict[str, Any], ) -> AsyncGenerator[bytes, None]: + error_obj = self._format_stream_error(error_obj) adapted: dict[str, Any] | None = None try: if self.handler_config and self.handler_config.response_adapter: @@ -260,9 +265,14 @@ async def _single() -> AsyncIterator[dict[str, Any]]: try: # Check for error status if response.status_code >= 400: - # Forward provider error body as-is (no SSE wrapping) raw_error = await response.aread() - yield raw_error + if self._should_normalize_openai_error(): + yield json.dumps( + normalize_openai_error_payload(raw_error) + ).encode("utf-8") + else: + # Forward non-Codex provider error body as-is. + yield raw_error return # Stream the response with optional SSE processing @@ -329,13 +339,24 @@ async def _single() -> AsyncIterator[dict[str, Any]]: and self.request_context.metadata.get("service_type") == "codex" ) - is_sse_format = "text/event-stream" in content_type or is_codex + is_openai_responses = bool( + self.request_context + and self.request_context.format_chain + and self.request_context.format_chain[0] + == "openai.responses" + ) + is_sse_format = ( + "text/event-stream" in content_type + or is_codex + or is_openai_responses + ) logger.debug( "streaming_no_format_adapter", content_type=content_type, is_codex=is_codex, is_sse_format=is_sse_format, + is_openai_responses=is_openai_responses, request_id=request_id, category="streaming_conversion", ) @@ -394,14 +415,29 @@ async def _single() -> AsyncIterator[dict[str, Any]]: error=str(e), ) + if is_codex or is_openai_responses: + event_data = ( + normalize_responses_sse_event_bytes( + event_data + ) + ) + # Yield the complete event self._record_sse_bytes(event_data) + event_data = ( + self._patch_responses_completed_sse_bytes( + event_data + ) + ) yield event_data # Yield any remaining data in buffer if sse_buffer: upstream_raw_chunks.append(sse_buffer) self._record_sse_bytes(sse_buffer) + sse_buffer = self._patch_responses_completed_sse_bytes( + sse_buffer + ) yield sse_buffer else: # Stream the raw response without SSE parsing @@ -618,6 +654,8 @@ async def _single() -> AsyncIterator[dict[str, Any]]: request_id=getattr(self.request_context, "request_id", None), ) + await self._notify_stream_complete() + # After the streaming context closes, optionally close the client we own if self._close_client_on_finish: with contextlib.suppress(Exception): @@ -808,6 +846,7 @@ async def _parse_sse_to_json_stream( and "type" not in json_obj ): json_obj["type"] = event_type + json_obj = self._patch_responses_completed_event(json_obj) yield json_obj except json.JSONDecodeError: continue @@ -840,6 +879,34 @@ async def _serialize_json_to_sse_stream( ): yield chunk + def _format_stream_error(self, error_obj: dict[str, Any]) -> dict[str, Any]: + """Normalize streaming error payloads for client-specific SSE schemas.""" + if isinstance(error_obj, dict) and error_obj.get("type"): + return error_obj + + format_chain = ( + self.request_context.format_chain + if self.request_context and self.request_context.format_chain + else [] + ) + client_format = format_chain[0] if format_chain else None + + if client_format == FORMAT_ANTHROPIC_MESSAGES: + return {"type": "error", "error": error_obj.get("error", error_obj)} + + return error_obj + + def _should_normalize_openai_error(self) -> bool: + if not self.request_context: + return False + + metadata = getattr(self.request_context, "metadata", None) + if isinstance(metadata, dict) and metadata.get("service_type") == "codex": + return True + + format_chain = getattr(self.request_context, "format_chain", None) or [] + return bool(format_chain and format_chain[0] == FORMAT_OPENAI_RESPONSES) + def _record_tool_event(self, event_name: str, payload: Any) -> None: if not self._stream_accumulator or not isinstance(payload, dict): return @@ -854,6 +921,79 @@ def _record_tool_event(self, event_name: str, payload: Any) -> None: request_id=getattr(self.request_context, "request_id", None), ) + def _patch_responses_completed_event(self, payload: Any) -> Any: + if not self._stream_accumulator or not isinstance(payload, dict): + return payload + if payload.get("type") != "response.completed": + return payload + response_payload = payload.get("response") + if not isinstance(response_payload, dict): + return payload + + rebuild_response = getattr( + self._stream_accumulator, "rebuild_response_object", None + ) + if not callable(rebuild_response): + return payload + + try: + rebuilt_response = rebuild_response(response_payload) + except Exception as exc: # pragma: no cover - defensive logging + logger.debug( + "responses_completed_event_patch_failed", + error=str(exc), + request_id=getattr(self.request_context, "request_id", None), + ) + return payload + + if not isinstance(rebuilt_response, dict): + return payload + + patched = dict(payload) + patched["response"] = rebuilt_response + return patched + + def _patch_responses_completed_sse_bytes(self, event_data: bytes) -> bytes: + if not self._stream_accumulator: + return event_data + + try: + text = event_data.decode("utf-8") + except UnicodeDecodeError: + return event_data + + lines = text.splitlines() + passthrough_lines: list[str] = [] + data_lines: list[str] = [] + for line in lines: + if line.startswith("data:"): + data_value = line[5:] + if data_value.startswith(" "): + data_value = data_value[1:] + data_lines.append(data_value) + elif line: + passthrough_lines.append(line) + + if not data_lines: + return event_data + + data_payload = "\n".join(data_lines) + if data_payload.strip() == "[DONE]": + return event_data + + try: + parsed = json.loads(data_payload) + except json.JSONDecodeError: + return event_data + + patched = self._patch_responses_completed_event(parsed) + if patched is parsed: + return event_data + + compact = json.dumps(patched, ensure_ascii=False, separators=(",", ":")) + patched_lines = [*passthrough_lines, f"data: {compact}", ""] + return ("\n".join(patched_lines) + "\n").encode("utf-8") + def _override_model_alias(self, payload: Any, model_value: str) -> None: if isinstance(payload, dict): for key, value in payload.items(): @@ -895,3 +1035,23 @@ def _record_sse_bytes(self, event_bytes: bytes) -> None: return self._record_tool_event(event_name, payload_obj) + + async def _notify_stream_complete(self) -> None: + if not self.request_context or not hasattr(self.request_context, "metadata"): + return + metadata = self.request_context.metadata + if not isinstance(metadata, dict): + return + callback = metadata.get("_codex_responses_stream_complete_callback") + if not callable(callback): + return + try: + result = callback(self.request_context, self._stream_accumulator) + if inspect.isawaitable(result): + await result + except Exception as exc: # pragma: no cover - defensive logging + logger.debug( + "stream_complete_callback_failed", + error=str(exc), + request_id=getattr(self.request_context, "request_id", None), + ) diff --git a/ccproxy/streaming/errors.py b/ccproxy/streaming/errors.py new file mode 100644 index 00000000..2d8e65eb --- /dev/null +++ b/ccproxy/streaming/errors.py @@ -0,0 +1,71 @@ +"""Error payload helpers for streaming provider paths.""" + +from __future__ import annotations + +import json +import re +from contextlib import suppress +from typing import Any + + +_UNSUPPORTED_PARAMETER_RE = re.compile(r"Unsupported parameter:\s*([A-Za-z0-9_.-]+)") + + +def normalize_openai_error_payload(payload: Any) -> dict[str, Any]: + """Return a minimal OpenAI-compatible error envelope.""" + + data = _decode_error_payload(payload) + if isinstance(data, dict) and isinstance(data.get("error"), dict): + error = dict(data["error"]) + message = error.get("message") + if not isinstance(message, str) or not message: + error["message"] = "upstream error" + error.setdefault("type", "invalid_request_error") + error.setdefault("param", None) + error.setdefault("code", None) + return {"error": error} + + message = _extract_error_message(data) + param = _extract_unsupported_parameter(message) + return { + "error": { + "type": "invalid_request_error", + "message": message or "upstream error", + "param": param, + "code": "unsupported_parameter" if param else None, + } + } + + +def _decode_error_payload(payload: Any) -> Any: + if isinstance(payload, bytes | bytearray | memoryview): + text = bytes(payload).decode("utf-8", errors="replace") + with suppress(json.JSONDecodeError): + return json.loads(text) + return text + return payload + + +def _extract_error_message(data: Any) -> str: + if isinstance(data, dict): + for key in ("detail", "message", "error_description"): + value = data.get(key) + if isinstance(value, str) and value: + return value + try: + return json.dumps(data, ensure_ascii=False) + except (TypeError, ValueError): + return str(data) + if isinstance(data, str): + return data + return str(data) if data is not None else "" + + +def _extract_unsupported_parameter(message: str) -> str | None: + match = _UNSUPPORTED_PARAMETER_RE.search(message) + if not match: + return None + return match.group(1) + + +__all__ = ["normalize_openai_error_payload"] diff --git a/tests/integration/test_streaming_converters.py b/tests/integration/test_streaming_converters.py index affac716..877b1967 100644 --- a/tests/integration/test_streaming_converters.py +++ b/tests/integration/test_streaming_converters.py @@ -130,7 +130,16 @@ async def test_openai_chat_stream_to_anthropic_sample() -> None: assert isinstance(tool_event.content_block, anthropic_models.ToolUseBlock), ( "expected ToolUseBlock" ) - assert tool_event.content_block.input, "tool input should be populated" + assert tool_event.content_block.input == {} + + tool_delta = next( + evt + for evt in streamed + if isinstance(evt, anthropic_models.ContentBlockDeltaEvent) + and getattr(evt.delta, "type", None) == "input_json_delta" + ) + assert isinstance(tool_delta.delta, anthropic_models.InputJsonDelta) + assert tool_delta.delta.partial_json, "tool input delta should be populated" message_delta = next( evt for evt in streamed if isinstance(evt, anthropic_models.MessageDeltaEvent) diff --git a/tests/plugins/codex/integration/test_codex_basic.py b/tests/plugins/codex/integration/test_codex_basic.py index efb1e1b7..fcdc8dca 100644 --- a/tests/plugins/codex/integration/test_codex_basic.py +++ b/tests/plugins/codex/integration/test_codex_basic.py @@ -21,6 +21,159 @@ from ccproxy.plugins.codex.models import CodexCacheData +CODEX_RESPONSES_URL = "https://chatgpt.com/backend-api/codex/responses" + + +def _base_response( + response_id: str, + *, + output: list[dict[str, Any]], + status: str = "completed", + previous_response_id: str | None = None, +) -> dict[str, Any]: + return { + "id": response_id, + "object": "response", + "created_at": 1234567890, + "model": "gpt-5", + "status": status, + "parallel_tool_calls": True, + "previous_response_id": previous_response_id, + "output": output, + "usage": { + "input_tokens": 10, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens": 5, + "output_tokens_details": {"reasoning_tokens": 0}, + "total_tokens": 15, + }, + } + + +def _message_output(text: str, *, item_id: str = "msg_1") -> dict[str, Any]: + return { + "type": "message", + "id": item_id, + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + } + + +def _function_call_output( + *, + item_id: str = "call_weather", + call_id: str = "call_weather", +) -> dict[str, Any]: + return { + "type": "function_call", + "id": item_id, + "call_id": call_id, + "name": "get_weather", + "arguments": '{"city":"Paris"}', + "status": "completed", + } + + +def _sse_bytes(events: list[dict[str, Any]]) -> bytes: + chunks = [] + for event in events: + chunks.append( + f"event: {event['type']}\n" + f"data: {json.dumps(event, separators=(',', ':'))}\n\n" + ) + chunks.append("data: [DONE]\n\n") + return "".join(chunks).encode("utf-8") + + +def _completed_sse(response: dict[str, Any]) -> bytes: + return _sse_bytes( + [ + { + "type": "response.completed", + "sequence_number": 0, + "response": response, + } + ] + ) + + +def _function_call_stream_sse( + response_id: str, *, completed_output: bool = True +) -> bytes: + function_call = _function_call_output() + response = _base_response( + response_id, output=[function_call] if completed_output else [] + ) + return _sse_bytes( + [ + { + "type": "response.created", + "sequence_number": 0, + "response": _base_response( + response_id, output=[], status="in_progress" + ), + }, + { + "type": "response.output_item.added", + "sequence_number": 1, + "output_index": 0, + "item": { + **function_call, + "arguments": "", + "status": "in_progress", + }, + }, + { + "type": "response.function_call_arguments.delta", + "sequence_number": 2, + "item_id": "call_weather", + "output_index": 0, + "delta": '{"city"', + }, + { + "type": "response.function_call_arguments.delta", + "sequence_number": 3, + "item_id": "call_weather", + "output_index": 0, + "delta": ':"Paris"}', + }, + { + "type": "response.function_call_arguments.done", + "sequence_number": 4, + "item_id": "call_weather", + "output_index": 0, + "arguments": '{"city":"Paris"}', + }, + { + "type": "response.output_item.done", + "sequence_number": 5, + "output_index": 0, + "item": function_call, + }, + { + "type": "response.completed", + "sequence_number": 6, + "response": response, + }, + ] + ) + + +def _sse_events(raw_body: bytes) -> list[dict[str, Any]]: + body = raw_body.decode("utf-8") + events: list[dict[str, Any]] = [] + for block in body.split("\n\n"): + data_lines = [ + line[6:] + for line in block.splitlines() + if line.startswith("data: ") and line[6:] != "[DONE]" + ] + if data_lines: + events.append(json.loads("".join(data_lines))) + return events + + @pytest.mark.asyncio @pytest.mark.integration @pytest.mark.codex @@ -159,6 +312,389 @@ async def test_codex_bypass_responses_streaming_emits_valid_openai_response_even assert body.strip().endswith("data: [DONE]") +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.codex +async def test_root_openai_responses_paths_work( + codex_client: Any, + httpx_mock: Any, +) -> None: + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + content=_completed_sse( + _base_response("resp_root", output=[_message_output("OK")]) + ), + status_code=200, + headers={"content-type": "text/event-stream"}, + ) + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + content=_completed_sse( + _base_response("resp_v1", output=[_message_output("OK")]) + ), + status_code=200, + headers={"content-type": "text/event-stream"}, + ) + + root_resp = await codex_client.post( + "/responses", json={"model": "gpt-5", "input": "hello"} + ) + v1_resp = await codex_client.post( + "/v1/responses", json={"model": "gpt-5", "input": "hello"} + ) + + assert root_resp.status_code == 200, root_resp.text + assert v1_resp.status_code == 200, v1_resp.text + assert root_resp.json()["id"] == "resp_root" + assert v1_resp.json()["id"] == "resp_v1" + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.codex +async def test_openai_responses_paths_tolerate_stream_options( + codex_client: Any, + httpx_mock: Any, +) -> None: + paths = ["/responses", "/v1/responses", "/codex/v1/responses"] + for index, path in enumerate(paths): + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + content=_completed_sse( + _base_response( + f"resp_stream_options_{index}", output=[_message_output("OK")] + ) + ), + status_code=200, + headers={"content-type": "text/event-stream"}, + ) + + resp = await codex_client.post( + path, + json={ + "model": "gpt-5", + "stream": True, + "stream_options": {"include_usage": True}, + "prompt_cache_retention": "24h", + "safety_identifier": "user-123", + "input": "hello", + }, + ) + raw_body = await resp.aread() + + assert resp.status_code == 200, raw_body + events = _sse_events(raw_body) + assert events[-1]["type"] == "response.completed" + assert events[-1]["response"]["usage"]["total_tokens"] == 15 + + upstream_requests = httpx_mock.get_requests(url=CODEX_RESPONSES_URL) + assert len(upstream_requests) == len(paths) + for request in upstream_requests: + upstream = json.loads(request.content) + assert upstream["stream"] is True + assert "stream_options" not in upstream + assert "prompt_cache_retention" not in upstream + assert "safety_identifier" not in upstream + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.codex +async def test_openai_responses_previous_response_id_tool_loop( + codex_client: Any, + httpx_mock: Any, +) -> None: + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + content=_completed_sse( + _base_response("resp_tool_1", output=[_function_call_output()]) + ), + status_code=200, + headers={"content-type": "text/event-stream"}, + ) + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + content=_completed_sse( + _base_response("resp_final", output=[_message_output("Paris is 25C.")]) + ), + status_code=200, + headers={"content-type": "text/event-stream"}, + ) + + headers = {"Authorization": "Bearer client-a"} + first = await codex_client.post( + "/responses", + headers=headers, + json={ + "model": "gpt-5", + "input": "What is the weather in Paris?", + "tools": [ + { + "type": "function", + "name": "get_weather", + "parameters": {"type": "object", "properties": {}}, + } + ], + "tool_choice": "required", + }, + ) + assert first.status_code == 200, first.text + first_payload = first.json() + assert first_payload["output"][0]["id"] == "fc_weather" + assert first_payload["output"][0]["call_id"] == "call_weather" + + second = await codex_client.post( + "/responses", + headers=headers, + json={ + "model": "gpt-5", + "previous_response_id": "resp_tool_1", + "input": [ + { + "type": "function_call_output", + "call_id": "call_weather", + "output": '{"temperature":"25C"}', + } + ], + }, + ) + + assert second.status_code == 200, second.text + second_payload = second.json() + assert second_payload["previous_response_id"] == "resp_tool_1" + assert second_payload["output"][0]["content"][0]["text"] == "Paris is 25C." + + upstream_requests = httpx_mock.get_requests(url=CODEX_RESPONSES_URL) + assert len(upstream_requests) == 2 + second_upstream = json.loads(upstream_requests[1].content) + assert "previous_response_id" not in second_upstream + assert [item["type"] for item in second_upstream["input"]] == [ + "message", + "function_call", + "function_call_output", + ] + assert second_upstream["input"][1]["id"] == "fc_weather" + assert second_upstream["input"][1]["call_id"] == "call_weather" + assert second_upstream["input"][2]["call_id"] == "call_weather" + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.codex +async def test_openai_responses_streaming_tool_call_ids_and_continuation( + codex_client: Any, + httpx_mock: Any, +) -> None: + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + content=_function_call_stream_sse("resp_stream_tool_1", completed_output=False), + status_code=200, + headers={"content-type": "text/event-stream"}, + ) + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + content=_completed_sse( + _base_response( + "resp_stream_final", output=[_message_output("Paris is 25C.")] + ) + ), + status_code=200, + headers={"content-type": "text/event-stream"}, + ) + + headers = {"Authorization": "Bearer client-a"} + first = await codex_client.post( + "/responses", + headers=headers, + json={ + "model": "gpt-5", + "stream": True, + "stream_options": {"include_usage": True}, + "input": "What is the weather in Paris?", + "tools": [{"type": "function", "name": "get_weather"}], + "tool_choice": "required", + }, + ) + first_body = await first.aread() + assert first.status_code == 200, first_body + events = _sse_events(first_body) + arg_events = [ + event + for event in events + if event["type"].startswith("response.function_call_arguments.") + ] + assert arg_events + assert {event["item_id"] for event in arg_events} == {"fc_weather"} + completed = [event for event in events if event["type"] == "response.completed"][-1] + assert completed["response"]["output"][0]["id"] == "fc_weather" + assert completed["response"]["output"][0]["call_id"] == "call_weather" + assert completed["response"]["usage"]["total_tokens"] == 15 + + second = await codex_client.post( + "/responses", + headers=headers, + json={ + "model": "gpt-5", + "stream": True, + "stream_options": {"include_usage": True}, + "previous_response_id": "resp_stream_tool_1", + "input": [ + { + "type": "function_call_output", + "call_id": "call_weather", + "output": '{"temperature":"25C"}', + } + ], + }, + ) + second_body = await second.aread() + assert second.status_code == 200, second_body + + upstream_requests = httpx_mock.get_requests(url=CODEX_RESPONSES_URL) + first_upstream = json.loads(upstream_requests[0].content) + assert "stream_options" not in first_upstream + second_upstream = json.loads(upstream_requests[1].content) + assert "previous_response_id" not in second_upstream + assert "stream_options" not in second_upstream + assert second_upstream["input"][1]["type"] == "function_call" + assert second_upstream["input"][1]["id"] == "fc_weather" + assert second_upstream["input"][2]["type"] == "function_call_output" + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.codex +async def test_openai_responses_unknown_previous_response_id_returns_openai_error( + codex_client: Any, +) -> None: + resp = await codex_client.post( + "/responses", + json={ + "model": "gpt-5", + "previous_response_id": "resp_missing", + "input": [ + { + "type": "function_call_output", + "call_id": "call_missing", + "output": "{}", + } + ], + }, + ) + + assert resp.status_code == 400 + payload = resp.json() + assert payload["error"]["type"] == "invalid_request_error" + assert payload["error"]["param"] == "previous_response_id" + assert payload["error"]["code"] == "previous_response_not_found" + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.codex +async def test_openai_responses_provider_detail_error_returns_openai_error( + codex_client: Any, + httpx_mock: Any, +) -> None: + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + json={"detail": "Unsupported parameter: stream_options"}, + status_code=400, + headers={"content-type": "application/json"}, + ) + + resp = await codex_client.post( + "/responses", + json={ + "model": "gpt-5", + "stream": True, + "input": "hello", + "extra_rejected_field": True, + }, + ) + body = await resp.aread() + + assert resp.status_code == 400 + payload = json.loads(body) + assert payload["error"]["type"] == "invalid_request_error" + assert payload["error"]["message"] == "Unsupported parameter: stream_options" + assert payload["error"]["param"] == "stream_options" + assert payload["error"]["code"] == "unsupported_parameter" + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.codex +async def test_openai_responses_buffered_provider_detail_error_returns_openai_error( + codex_client: Any, + httpx_mock: Any, +) -> None: + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + json={"detail": "Unsupported parameter: stream_options"}, + status_code=400, + headers={"content-type": "application/json"}, + ) + + resp = await codex_client.post( + "/responses", + json={ + "model": "gpt-5", + "input": "hello", + "extra_rejected_field": True, + }, + ) + + assert resp.status_code == 400 + payload = resp.json() + assert payload["error"]["type"] == "invalid_request_error" + assert payload["error"]["message"] == "Unsupported parameter: stream_options" + assert payload["error"]["param"] == "stream_options" + assert payload["error"]["code"] == "unsupported_parameter" + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.codex +async def test_openai_responses_state_isolated_by_authorization( + codex_client: Any, + httpx_mock: Any, +) -> None: + httpx_mock.add_response( + url=CODEX_RESPONSES_URL, + content=_completed_sse( + _base_response("resp_client_a", output=[_function_call_output()]) + ), + status_code=200, + headers={"content-type": "text/event-stream"}, + ) + + first = await codex_client.post( + "/responses", + headers={"Authorization": "Bearer client-a"}, + json={"model": "gpt-5", "input": "weather?", "tools": []}, + ) + assert first.status_code == 200, first.text + + second = await codex_client.post( + "/responses", + headers={"Authorization": "Bearer client-b"}, + json={ + "model": "gpt-5", + "previous_response_id": "resp_client_a", + "input": [ + { + "type": "function_call_output", + "call_id": "call_weather", + "output": "{}", + } + ], + }, + ) + assert second.status_code == 400 + assert second.json()["error"]["code"] == "previous_response_not_found" + + # Module-scoped client to avoid per-test startup cost # Use module-level async loop for all tests here pytestmark = pytest.mark.asyncio(loop_scope="module") diff --git a/tests/plugins/codex/unit/test_adapter.py b/tests/plugins/codex/unit/test_adapter.py index dededd10..72cd9c47 100644 --- a/tests/plugins/codex/unit/test_adapter.py +++ b/tests/plugins/codex/unit/test_adapter.py @@ -616,6 +616,127 @@ def test_sanitize_provider_body_strips_metadata( assert cleaned["stream"] is True assert cleaned["store"] is False + def test_sanitize_provider_body_normalizes_string_input( + self, adapter: CodexAdapter + ) -> None: + """Responses API string input should be normalized for Codex backend.""" + body = {"model": "gpt-5.5", "input": "Reply exactly OK"} + + cleaned = adapter._sanitize_provider_body(body) + + assert cleaned["input"] == [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Reply exactly OK"}], + } + ] + assert cleaned["stream"] is True + assert cleaned["store"] is False + + def test_apply_model_alias_reasoning_effort_for_chat_alias( + self, adapter: CodexAdapter + ) -> None: + """GPT-5.5 effort aliases should set effort while using the base model.""" + ctx = Mock() + ctx.metadata = { + "_last_client_model": "gpt-5.5-xhigh", + "_last_provider_model": "gpt-5.5", + } + body = json.dumps( + { + "model": "gpt-5.5", + "messages": [{"role": "user", "content": "Hello"}], + } + ).encode() + + result = adapter._apply_model_alias_reasoning_effort(ctx, body) + result_data = json.loads(result.decode()) + + assert result_data["model"] == "gpt-5.5" + assert result_data["reasoning_effort"] == "xhigh" + + def test_apply_model_alias_reasoning_effort_for_responses_alias( + self, adapter: CodexAdapter + ) -> None: + """Responses requests should use the Responses-native reasoning field.""" + ctx = Mock() + ctx.format_chain = ["openai.responses"] + ctx.metadata = { + "_last_client_model": "gpt-5.5-high", + "_last_provider_model": "gpt-5.5", + } + body = json.dumps( + { + "model": "gpt-5.5", + "input": "Hello", + } + ).encode() + + result = adapter._apply_model_alias_reasoning_effort(ctx, body) + result_data = json.loads(result.decode()) + + assert result_data["model"] == "gpt-5.5" + assert result_data["reasoning"] == {"effort": "high"} + assert "reasoning_effort" not in result_data + + def test_apply_model_alias_reasoning_effort_preserves_explicit_effort( + self, adapter: CodexAdapter + ) -> None: + """Explicit request effort should win over model-alias defaults.""" + ctx = Mock() + ctx.metadata = { + "_last_client_model": "gpt-5.5-max", + "_last_provider_model": "gpt-5.5", + } + body = json.dumps( + { + "model": "gpt-5.5", + "messages": [{"role": "user", "content": "Hello"}], + "reasoning_effort": "high", + } + ).encode() + + result = adapter._apply_model_alias_reasoning_effort(ctx, body) + result_data = json.loads(result.decode()) + + assert result_data["reasoning_effort"] == "high" + + def test_sanitize_provider_body_clamps_max_reasoning_effort( + self, adapter: CodexAdapter + ) -> None: + """Codex backend currently accepts xhigh but rejects max.""" + body = { + "model": "gpt-5.5", + "input": [{"type": "message", "role": "user", "content": []}], + "reasoning": {"effort": "max", "summary": "auto"}, + } + + cleaned = adapter._sanitize_provider_body(body) + + assert cleaned["reasoning"] == {"effort": "xhigh", "summary": "auto"} + + def test_sanitize_provider_body_converts_reasoning_effort( + self, adapter: CodexAdapter + ) -> None: + """Codex Responses backend rejects chat-only reasoning_effort.""" + body = { + "model": "gpt-5.5", + "input": [{"type": "message", "role": "user", "content": []}], + "reasoning_effort": "high", + "stream_options": {"include_usage": True}, + "prompt_cache_retention": "24h", + "safety_identifier": "user-123", + } + + cleaned = adapter._sanitize_provider_body(body) + + assert cleaned["reasoning"] == {"effort": "high"} + assert "reasoning_effort" not in cleaned + assert "stream_options" not in cleaned + assert "prompt_cache_retention" not in cleaned + assert "safety_identifier" not in cleaned + def test_get_instructions_default(self, adapter: CodexAdapter) -> None: """Test default instructions when no detection service data.""" instructions = adapter._get_instructions() diff --git a/tests/plugins/codex/unit/test_responses_state.py b/tests/plugins/codex/unit/test_responses_state.py new file mode 100644 index 00000000..56619e7a --- /dev/null +++ b/tests/plugins/codex/unit/test_responses_state.py @@ -0,0 +1,107 @@ +import pytest + +from ccproxy.plugins.codex.responses_state import ( + CodexResponsesStateStore, + ResponsesStateNotFoundError, +) + + +def test_responses_state_expands_previous_response_id_tool_loop() -> None: + store = CodexResponsesStateStore(max_entries=8, ttl_seconds=60) + headers = {"authorization": "Bearer client-a"} + scope = store.scope_for_headers(headers) + + stored = store.store_response( + scope=scope, + request_payload={ + "input": "What is the weather in Paris?", + "model": "gpt-5", + }, + response_payload={ + "id": "resp_tool_1", + "status": "completed", + "output": [ + { + "type": "function_call", + "id": "call_weather", + "call_id": "call_weather", + "name": "get_weather", + "arguments": '{"city":"Paris"}', + "status": "completed", + } + ], + }, + ) + assert stored is True + + prepared, prepared_scope, previous_response_id = store.prepare_payload( + { + "model": "gpt-5", + "previous_response_id": "resp_tool_1", + "input": [ + { + "type": "function_call_output", + "call_id": "call_weather", + "output": '{"temperature":"25C"}', + } + ], + }, + headers=headers, + ) + + assert prepared_scope == scope + assert previous_response_id == "resp_tool_1" + assert "previous_response_id" not in prepared + assert [item["type"] for item in prepared["input"]] == [ + "message", + "function_call", + "function_call_output", + ] + assert prepared["input"][1]["id"] == "fc_weather" + assert prepared["input"][1]["call_id"] == "call_weather" + assert prepared["input"][2]["call_id"] == "call_weather" + + +def test_responses_state_is_scoped_by_client_auth() -> None: + store = CodexResponsesStateStore(max_entries=8, ttl_seconds=60) + scope = store.scope_for_headers({"authorization": "Bearer client-a"}) + assert store.store_response( + scope=scope, + request_payload={"input": "hello"}, + response_payload={ + "id": "resp_a", + "status": "completed", + "output": [ + { + "type": "message", + "id": "msg_1", + "role": "assistant", + "status": "completed", + "content": [{"type": "output_text", "text": "hi"}], + } + ], + }, + ) + + with pytest.raises(ResponsesStateNotFoundError): + store.prepare_payload( + {"previous_response_id": "resp_a", "input": "continue"}, + headers={"authorization": "Bearer client-b"}, + ) + + +def test_responses_state_does_not_store_failed_response() -> None: + store = CodexResponsesStateStore(max_entries=8, ttl_seconds=60) + scope = store.scope_for_headers({"authorization": "Bearer client-a"}) + assert ( + store.store_response( + scope=scope, + request_payload={"input": "hello"}, + response_payload={ + "id": "resp_failed", + "status": "failed", + "output": [], + }, + ) + is False + ) diff --git a/tests/unit/llms/formatters/test_anthropic_to_openai_helpers.py b/tests/unit/llms/formatters/test_anthropic_to_openai_helpers.py index 560a5ad4..b2f4702f 100644 --- a/tests/unit/llms/formatters/test_anthropic_to_openai_helpers.py +++ b/tests/unit/llms/formatters/test_anthropic_to_openai_helpers.py @@ -253,6 +253,26 @@ async def gen() -> AsyncIterator[Any]: assert "response.function_call_arguments.delta" in event_types assert "response.function_call_arguments.done" in event_types + function_events = [ + evt + for evt in events + if getattr(evt, "type", "") + in {"response.output_item.added", "response.output_item.done"} + and getattr(getattr(evt, "item", None), "type", "") == "function_call" + ] + assert function_events + for evt in function_events: + item = evt.item # type: ignore[union-attr] + assert item.id == "fc_1" + assert item.call_id == "call_1" + + argument_item_ids = [ + getattr(evt, "item_id", "") + for evt in events + if getattr(evt, "type", "").startswith("response.function_call_arguments.") + ] + assert argument_item_ids and all(item_id == "fc_1" for item_id in argument_item_ids) + reasoning_deltas = [ getattr(evt, "delta", "") for evt in events @@ -283,6 +303,8 @@ async def gen() -> AsyncIterator[Any]: for out in complete_response.output if getattr(out, "type", "") == "function_call" ) + assert getattr(function_output, "id", "") == "fc_1" + assert getattr(function_output, "call_id", "") == "call_1" assert getattr(function_output, "name", "") == "get_weather" assert ( getattr(function_output, "arguments", "") diff --git a/tests/unit/llms/formatters/test_openai_to_openai_reasoning.py b/tests/unit/llms/formatters/test_openai_to_openai_reasoning.py index e5c6d660..386a555e 100644 --- a/tests/unit/llms/formatters/test_openai_to_openai_reasoning.py +++ b/tests/unit/llms/formatters/test_openai_to_openai_reasoning.py @@ -101,6 +101,44 @@ async def test_chat_to_responses_extracts_thinking() -> None: assert response.reasoning.summary # type: ignore[union-attr] +@pytest.mark.asyncio +async def test_chat_response_tool_calls_use_responses_function_call_item_ids() -> None: + chat_response = openai_models.ChatCompletionResponse( + id="chatcmpl-tools", + object="chat.completion", + created=0, + model="gpt-test", + choices=[ + openai_models.Choice( + index=0, + finish_reason="tool_calls", + message=openai_models.ResponseMessage( + role="assistant", + content=None, + tool_calls=[ + openai_models.ToolCall( + id="call_weather", + type="function", + function=openai_models.FunctionCall( + name="get_weather", + arguments='{"city":"Paris"}', + ), + ) + ], + ), + ) + ], + ) + + response = await convert__openai_chat_to_openai_responses__response(chat_response) + + function_call = next( + item for item in response.output if getattr(item, "type", "") == "function_call" + ) + assert getattr(function_call, "id", "") == "fc_weather" + assert getattr(function_call, "call_id", "") == "call_weather" + + def _get_type(entry: object) -> str | None: return getattr(entry, "type", None) @@ -585,10 +623,17 @@ async def source() -> AsyncIterator[dict[str, Any]]: and getattr(getattr(evt, "item", None), "type", "") == "function_call" ) fn_added_item = fn_added.item # type: ignore[union-attr] - assert fn_added_item.id == "call_abc" + assert fn_added_item.id == "fc_abc" assert fn_added_item.name == "get_weather" assert fn_added_item.call_id == "call_abc" + argument_item_ids = [ + getattr(evt, "item_id", "") + for evt in events + if getattr(evt, "type", "").startswith("response.function_call_arguments.") + ] + assert argument_item_ids == ["fc_abc", "fc_abc", "fc_abc", "fc_abc"] + args_done = next( evt for evt in events @@ -602,7 +647,8 @@ async def source() -> AsyncIterator[dict[str, Any]]: assert completed_response.parallel_tool_calls is True assert len(completed_response.output) == 1 fn_output = completed_response.output[0] - assert getattr(fn_output, "id", "") == "call_abc" + assert getattr(fn_output, "id", "") == "fc_abc" + assert getattr(fn_output, "call_id", "") == "call_abc" assert getattr(fn_output, "name", "") == "get_weather" assert getattr(fn_output, "arguments", "") == '{"city": "New York"}' tool_calls = getattr(completed_response, "tool_calls", []) or [] diff --git a/tests/unit/llms/streaming/test_accumulators.py b/tests/unit/llms/streaming/test_accumulators.py index 7633b6ca..7d37f9ac 100644 --- a/tests/unit/llms/streaming/test_accumulators.py +++ b/tests/unit/llms/streaming/test_accumulators.py @@ -397,11 +397,16 @@ def test_responses_accumulator_rebuild_response() -> None: assert "output" in rebuilt output_items = rebuilt["output"] assert isinstance(output_items, list) + output_function_call = next( + item for item in output_items if item["type"] == "function_call" + ) + assert output_function_call["id"] == "fc_2" + assert output_function_call["call_id"] == "call_1" # Verify function calls metadata assert "tool_calls" in rebuilt function_call = rebuilt["tool_calls"][0] - assert function_call["id"] == "item_2" + assert function_call["id"] == "fc_2" assert function_call["type"] == "function_call" assert function_call["call_id"] == "call_1" assert function_call["function"]["name"] == "test_function" @@ -594,7 +599,10 @@ def test_responses_accumulator_uses_completed_response_payload() -> None: assert rebuilt["usage"]["total_tokens"] == 3 completed_copy = accumulator.get_completed_response() - assert completed_copy == completed_event.response.model_dump() + assert completed_copy is not None + assert completed_copy["id"] == completed_event.response.id + assert completed_copy["usage"]["total_tokens"] == 3 + assert completed_copy["output"][0]["content"][0]["text"] == "Final text" if completed_copy is not None: completed_copy["status"] = "mutated" retrieved_response = accumulator.get_completed_response() diff --git a/tests/unit/llms/test_openai_responses_request_models.py b/tests/unit/llms/test_openai_responses_request_models.py new file mode 100644 index 00000000..3ee5c922 --- /dev/null +++ b/tests/unit/llms/test_openai_responses_request_models.py @@ -0,0 +1,32 @@ +from ccproxy.llms.models import openai as openai_models + + +def test_response_request_accepts_previous_response_id_and_function_output() -> None: + request = openai_models.ResponseRequest.model_validate( + { + "model": "gpt-5.5", + "previous_response_id": "resp_123", + "input": [ + { + "type": "function_call_output", + "call_id": "call_weather", + "output": '{"temperature":"25C"}', + } + ], + "reasoning": {"effort": "xhigh"}, + "stream": True, + "stream_options": {"include_usage": True}, + "prompt_cache_retention": "24h", + "safety_identifier": "user-123", + } + ) + + assert request.previous_response_id == "resp_123" + assert isinstance(request.input, list) + assert request.input[0]["type"] == "function_call_output" + assert request.reasoning == {"effort": "xhigh"} + assert request.stream is True + assert request.stream_options is not None + assert request.stream_options.include_usage is True + assert request.prompt_cache_retention == "24h" + assert request.safety_identifier == "user-123" diff --git a/tests/unit/streaming/test_buffer_parse_responses.py b/tests/unit/streaming/test_buffer_parse_responses.py index ce83ed70..919a0af3 100644 --- a/tests/unit/streaming/test_buffer_parse_responses.py +++ b/tests/unit/streaming/test_buffer_parse_responses.py @@ -14,6 +14,7 @@ import httpx import pytest +from ccproxy.llms.formatters.common import normalize_responses_sse_event_bytes from ccproxy.llms.models import openai as openai_models from ccproxy.llms.streaming.accumulators import ResponsesAccumulator from ccproxy.streaming.buffer import StreamingBufferService @@ -29,6 +30,57 @@ def _sse(event_type: str, payload: dict[str, Any]) -> bytes: return f"event: {event_type}\ndata: {json.dumps(body)}\n\n".encode() +def test_responses_sse_event_bytes_normalizes_function_call_item_id() -> None: + raw = _sse( + "response.output_item.added", + { + "sequence_number": 1, + "output_index": 0, + "item": { + "type": "function_call", + "id": "call_weather", + "call_id": "call_weather", + "name": "get_weather", + "arguments": "", + }, + }, + ) + + normalized = normalize_responses_sse_event_bytes(raw) + payload_line = next( + line for line in normalized.decode().splitlines() if line.startswith("data: ") + ) + payload = json.loads(payload_line.removeprefix("data: ")) + + assert payload["item"]["id"] == "fc_weather" + assert payload["item"]["call_id"] == "call_weather" + + +def test_responses_sse_event_bytes_tolerates_non_string_type_fields() -> None: + raw = _sse( + "response.output_item.added", + { + "sequence_number": 1, + "output_index": 0, + "item": { + "type": { + "kind": "json_schema", + }, + "id": "item_schema", + "content": [], + }, + }, + ) + + normalized = normalize_responses_sse_event_bytes(raw) + payload_line = next( + line for line in normalized.decode().splitlines() if line.startswith("data: ") + ) + payload = json.loads(payload_line.removeprefix("data: ")) + + assert payload["item"]["type"] == {"kind": "json_schema"} + + @pytest.fixture def buffer() -> StreamingBufferService: return StreamingBufferService(http_client=httpx.AsyncClient()) @@ -211,3 +263,106 @@ async def test_parse_collected_stream_preserves_rebuilt_output( ) validated = openai_models.ResponseObject.model_validate(parsed) assert validated.output + + +@pytest.mark.asyncio +async def test_parse_collected_stream_normalizes_function_call_item_ids( + buffer: StreamingBufferService, +) -> None: + response_dict: dict[str, Any] = { + "id": "resp_tools", + "object": "response", + "created_at": 0, + "status": "in_progress", + "model": "gpt-5-codex", + "parallel_tool_calls": False, + "output": [], + } + chunks = [ + _sse( + "response.created", + {"sequence_number": 1, "response": response_dict}, + ), + _sse( + "response.output_item.added", + { + "sequence_number": 2, + "output_index": 0, + "item": { + "type": "function_call", + "id": "call_weather", + "status": "in_progress", + "name": "get_weather", + "arguments": "", + "call_id": "call_weather", + }, + }, + ), + _sse( + "response.function_call_arguments.delta", + { + "sequence_number": 3, + "item_id": "call_weather", + "output_index": 0, + "delta": '{"city":"Paris"}', + }, + ), + _sse( + "response.function_call_arguments.done", + { + "sequence_number": 4, + "item_id": "call_weather", + "output_index": 0, + "arguments": '{"city":"Paris"}', + }, + ), + _sse( + "response.output_item.done", + { + "sequence_number": 5, + "output_index": 0, + "item": { + "type": "function_call", + "id": "call_weather", + "status": "completed", + "name": "get_weather", + "arguments": '{"city":"Paris"}', + "call_id": "call_weather", + }, + }, + ), + _sse( + "response.completed", + { + "sequence_number": 6, + "response": { + **response_dict, + "status": "completed", + "output": [ + { + "type": "function_call", + "id": "call_weather", + "status": "completed", + "name": "get_weather", + "arguments": '{"city":"Paris"}', + "call_id": "call_weather", + } + ], + }, + }, + ), + ] + + parsed = await buffer._parse_collected_stream( + chunks=chunks, + handler_config=None, # type: ignore[arg-type] + request_context=_Ctx(), # type: ignore[arg-type] + ) + + assert parsed is not None + function_call = next( + item for item in parsed["output"] if item["type"] == "function_call" + ) + assert function_call["id"] == "fc_weather" + assert function_call["call_id"] == "call_weather" + openai_models.ResponseObject.model_validate(parsed)