Skip to content

Commit 3299edd

Browse files
author
Mateusz
committed
fix(codex): pass ConnectorRequestContext through HTTP transport adapter
Forward context to OpenAIConnector._handle_streaming_response so wire capture and log correlation match the WebSocket path. Update unit tests and _handle_streaming_response mocks to accept context kwargs. Made-with: Cursor
1 parent e7bcfe4 commit 3299edd

5 files changed

Lines changed: 92 additions & 74 deletions

File tree

src/connectors/openai_codex/executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,12 @@ async def initiate_streaming_request(
141141

142142
# Default to HTTP/SSE
143143
return await self._connector._handle_streaming_response( # type: ignore[misc]
144-
url, payload, headers, session_id, "responses"
144+
url,
145+
payload,
146+
headers,
147+
session_id,
148+
"responses",
149+
context=context,
145150
)
146151

147152
async def _initiate_websocket_streaming(

tests/unit/connectors/openai_codex/test_executor_non_streaming.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ async def second_iterator():
9090

9191
captured_payloads: list[dict[str, object]] = []
9292

93-
async def streaming_side_effect(url, payload_dict, headers, session_id, *args):
93+
async def streaming_side_effect(
94+
url, payload_dict, headers, session_id, *args, **kwargs
95+
):
9496
captured_payloads.append(dict(payload_dict))
9597
if len(captured_payloads) == 1:
9698
return first_handle

tests/unit/connectors/openai_codex/test_executor_streaming.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,9 @@ async def second_iterator():
815815

816816
captured_payloads: list[dict[str, object]] = []
817817

818-
async def streaming_side_effect(url, payload_dict, headers, session_id, *args):
818+
async def streaming_side_effect(
819+
url, payload_dict, headers, session_id, *args, **kwargs
820+
):
819821
captured_payloads.append(dict(payload_dict))
820822
if len(captured_payloads) == 1:
821823
return first_handle

tests/unit/connectors/openai_codex/test_executor_websocket.py

Lines changed: 76 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Unit tests for ResponseExecutor WebSocket support."""
22

3-
from __future__ import annotations
4-
5-
from typing import Any, cast
6-
from unittest.mock import AsyncMock, MagicMock, patch
7-
8-
import pytest
9-
from src.connectors.contracts import ConnectorRequestContext
10-
from src.connectors.openai_codex.executor import _CodexTransportAdapter
11-
from src.core.common.exceptions import AuthenticationError
12-
from src.core.domain.responses import ProcessedResponse, StreamingResponseHandle
3+
from __future__ import annotations
4+
5+
from typing import Any, cast
6+
from unittest.mock import AsyncMock, MagicMock, patch
7+
8+
import pytest
9+
from src.connectors.contracts import ConnectorRequestContext
10+
from src.connectors.openai_codex.executor import _CodexTransportAdapter
11+
from src.core.common.exceptions import AuthenticationError
12+
from src.core.domain.responses import ProcessedResponse, StreamingResponseHandle
1313

1414

1515
@pytest.mark.asyncio
@@ -62,13 +62,13 @@ async def mock_send_response_create(*args, **kwargs):
6262
async for chunk in handle.iterator:
6363
chunks.append(chunk)
6464

65-
# Verify chunks
66-
assert len(chunks) == 2
67-
# Websocket transport adapter yields ProcessedResponse objects directly
68-
first_content = cast(dict[str, Any], chunks[0].content)
69-
second_content = cast(dict[str, Any], chunks[1].content)
70-
assert cast(dict[str, Any], first_content["message"])["content"] == "Hello"
71-
assert cast(dict[str, Any], second_content["message"])["content"] == "World"
65+
# Verify chunks
66+
assert len(chunks) == 2
67+
# Websocket transport adapter yields ProcessedResponse objects directly
68+
first_content = cast(dict[str, Any], chunks[0].content)
69+
second_content = cast(dict[str, Any], chunks[1].content)
70+
assert cast(dict[str, Any], first_content["message"])["content"] == "Hello"
71+
assert cast(dict[str, Any], second_content["message"])["content"] == "World"
7272

7373
async def test_recreates_websocket_client_when_auth_token_changes(self) -> None:
7474
"""Auth refresh retries must not reuse stale WebSocket credentials."""
@@ -117,17 +117,17 @@ async def test_initiate_websocket_streaming_no_auth(self) -> None:
117117

118118
url = "https://chatgpt.com/backend-api/codex/responses"
119119
payload = {"model": "gpt-4", "input": []}
120-
headers: dict[str, str] = {} # No authorization header
120+
headers: dict[str, str] = {} # No authorization header
121121
session_id = "test_session"
122122

123123
with pytest.raises(AuthenticationError, match="No API key"):
124124
await adapter.initiate_streaming_request(url, payload, headers, session_id)
125125

126-
async def test_http_fallback_when_websocket_disabled(self) -> None:
127-
"""Test fallback to HTTP/SSE when WebSocket is disabled."""
128-
mock_connector = MagicMock()
129-
mock_connector._handle_streaming_response = AsyncMock(
130-
return_value=StreamingResponseHandle(
126+
async def test_http_fallback_when_websocket_disabled(self) -> None:
127+
"""Test fallback to HTTP/SSE when WebSocket is disabled."""
128+
mock_connector = MagicMock()
129+
mock_connector._handle_streaming_response = AsyncMock(
130+
return_value=StreamingResponseHandle(
131131
iterator=AsyncMock(), headers={}, cancel_callback=AsyncMock()
132132
)
133133
)
@@ -143,54 +143,59 @@ async def test_http_fallback_when_websocket_disabled(self) -> None:
143143
url, payload, headers, session_id
144144
)
145145

146-
# Verify HTTP/SSE method was called
147-
mock_connector._handle_streaming_response.assert_called_once_with(
148-
url, payload, headers, session_id, "responses"
149-
)
150-
assert isinstance(handle, StreamingResponseHandle)
151-
152-
async def test_http_fallback_accepts_transport_metadata_kwargs(self) -> None:
153-
"""Transport adapter should accept the executor's keyword metadata contract."""
154-
mock_connector = MagicMock()
155-
mock_connector._handle_streaming_response = AsyncMock(
156-
return_value=StreamingResponseHandle(
157-
iterator=AsyncMock(), headers={}, cancel_callback=AsyncMock()
158-
)
159-
)
160-
161-
adapter = _CodexTransportAdapter(mock_connector, use_websocket=False)
162-
163-
url = "https://chatgpt.com/backend-api/codex/responses"
164-
payload = {"model": "gpt-4", "input": []}
165-
headers = {"Authorization": "Bearer test_key"}
166-
session_id = "test_session"
167-
request_context = ConnectorRequestContext(
168-
request_id="req-1",
169-
session_id="sess-1",
170-
client_host="127.0.0.1",
171-
extensions={},
172-
)
173-
174-
handle = await adapter.initiate_streaming_request(
175-
url,
176-
payload,
177-
headers,
178-
session_id,
179-
context=request_context,
180-
backend="openai-codex",
181-
model="gpt-4",
182-
key_name="openai-codex",
183-
)
184-
185-
mock_connector._handle_streaming_response.assert_called_once_with(
186-
url, payload, headers, session_id, "responses"
187-
)
188-
assert isinstance(handle, StreamingResponseHandle)
189-
190-
async def test_cleanup_closes_websocket_client(self) -> None:
191-
"""Test cleanup properly disconnects WebSocket client."""
192-
mock_connector = MagicMock()
193-
adapter = _CodexTransportAdapter(mock_connector, use_websocket=True)
146+
# Verify HTTP/SSE method was called with wire-capture context slot
147+
mock_connector._handle_streaming_response.assert_called_once_with(
148+
url, payload, headers, session_id, "responses", context=None
149+
)
150+
assert isinstance(handle, StreamingResponseHandle)
151+
152+
async def test_http_fallback_accepts_transport_metadata_kwargs(self) -> None:
153+
"""Transport adapter should accept the executor's keyword metadata contract."""
154+
mock_connector = MagicMock()
155+
mock_connector._handle_streaming_response = AsyncMock(
156+
return_value=StreamingResponseHandle(
157+
iterator=AsyncMock(), headers={}, cancel_callback=AsyncMock()
158+
)
159+
)
160+
161+
adapter = _CodexTransportAdapter(mock_connector, use_websocket=False)
162+
163+
url = "https://chatgpt.com/backend-api/codex/responses"
164+
payload = {"model": "gpt-4", "input": []}
165+
headers = {"Authorization": "Bearer test_key"}
166+
session_id = "test_session"
167+
request_context = ConnectorRequestContext(
168+
request_id="req-1",
169+
session_id="sess-1",
170+
client_host="127.0.0.1",
171+
extensions={},
172+
)
173+
174+
handle = await adapter.initiate_streaming_request(
175+
url,
176+
payload,
177+
headers,
178+
session_id,
179+
context=request_context,
180+
backend="openai-codex",
181+
model="gpt-4",
182+
key_name="openai-codex",
183+
)
184+
185+
mock_connector._handle_streaming_response.assert_called_once_with(
186+
url,
187+
payload,
188+
headers,
189+
session_id,
190+
"responses",
191+
context=request_context,
192+
)
193+
assert isinstance(handle, StreamingResponseHandle)
194+
195+
async def test_cleanup_closes_websocket_client(self) -> None:
196+
"""Test cleanup properly disconnects WebSocket client."""
197+
mock_connector = MagicMock()
198+
adapter = _CodexTransportAdapter(mock_connector, use_websocket=True)
194199

195200
# Create mock WebSocket client
196201
mock_ws_client = AsyncMock()

tests/unit/connectors/test_openai_codex_codex_cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ async def streaming_side_effect(
622622
request_headers: dict[str, str],
623623
request_session_id: str,
624624
stream_format: str,
625+
**kwargs: Any,
625626
) -> StreamingResponseHandle:
626627
nonlocal call_count
627628
headers_seen.append(request_headers.get("Authorization"))
@@ -752,6 +753,7 @@ def handle_side_effect(
752753
request_headers: dict[str, str],
753754
request_session_id: str,
754755
stream_format: str,
756+
**kwargs: Any,
755757
) -> StreamingResponseHandle:
756758
headers_seen.append(request_headers.get("Authorization"))
757759
return stream_handles.pop(0)
@@ -849,6 +851,7 @@ async def streaming_side_effect(
849851
request_headers: dict[str, str],
850852
request_session_id: str,
851853
stream_format: str,
854+
**kwargs: Any,
852855
) -> StreamingResponseHandle:
853856
headers_seen.append(request_headers.get("Authorization"))
854857
raise HTTPException(status_code=401, detail="expired")
@@ -936,6 +939,7 @@ def handle_side_effect(
936939
request_headers: dict[str, str],
937940
request_session_id: str,
938941
stream_format: str,
942+
**kwargs: Any,
939943
) -> StreamingResponseHandle:
940944
return stream_handle
941945

0 commit comments

Comments
 (0)