Skip to content

Commit e7bcfe4

Browse files
author
Mateusz
committed
fix: improve codex continuation logic and add regression tests
- Refactor droid detection to use regex token matching instead of substring search, preventing false positives on Android user agents - Add early return guard in history divergence detection - Add test for mid-conversation history divergence replay - Add test verifying Android user agents aren't misclassified as droid
1 parent 9717806 commit e7bcfe4

4 files changed

Lines changed: 216 additions & 30 deletions

File tree

src/connectors/openai_codex/continuation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import hashlib
66
import json
7+
import re
78
import time
89
from collections import OrderedDict
910
from dataclasses import dataclass
@@ -32,7 +33,6 @@
3233
"factory-cli",
3334
"factory_cli",
3435
"factorydroid",
35-
"droid",
3636
)
3737

3838

@@ -253,10 +253,17 @@ def _normalize_client_family(cls, candidate: str) -> str:
253253
"clinelike",
254254
}:
255255
return "cline_like"
256-
if any(pattern in lowered for pattern in _DROID_USER_AGENT_PATTERNS):
256+
if cls._is_droid_candidate(lowered):
257257
return "droid"
258258
return "generic"
259259

260+
@staticmethod
261+
def _is_droid_candidate(lowered: str) -> bool:
262+
if any(pattern in lowered for pattern in _DROID_USER_AGENT_PATTERNS):
263+
return True
264+
tokens = {token for token in re.split(r"[^a-z0-9]+", lowered) if token}
265+
return "droid" in tokens
266+
260267
@classmethod
261268
def _fingerprint_component(cls, value: Any) -> str | None:
262269
if value is None:

tests/unit/connectors/openai_codex/test_continuation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,23 @@ def test_continuation_key_uses_user_agent_header_family_signal() -> None:
134134
assert coordinator.resolve_previous_response_id(generic_context) is None
135135

136136

137+
def test_continuation_key_does_not_treat_android_user_agent_as_droid() -> None:
138+
coordinator = InMemoryCodexContinuationCoordinator(ttl_seconds=60, max_entries=4)
139+
android_context = _context("session-1")
140+
generic_context = _context("session-1")
141+
metadata = cast(dict[str, object], android_context.metadata)
142+
metadata["headers"] = {
143+
"User-Agent": (
144+
"Mozilla/5.0 (Linux; Android 14; Pixel 8) AppleWebKit/537.36 "
145+
"(KHTML, like Gecko) Chrome/124.0 Mobile Safari/537.36"
146+
)
147+
}
148+
149+
coordinator.record_response_id(android_context, "resp-generic")
150+
151+
assert coordinator.resolve_previous_response_id(generic_context) == "resp-generic"
152+
153+
137154
def test_continuation_key_separates_different_client_families_in_same_session() -> None:
138155
coordinator = InMemoryCodexContinuationCoordinator(ttl_seconds=60, max_entries=4)
139156
opencode_context = _context("session-1")

tests/unit/connectors/openai_codex/test_executor_streaming.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,126 @@ async def handle_streaming_side_effect(
12961296
assert changed_tools[0]["name"] == "write_file"
12971297
assert captured_payloads[1]["instructions"] == "Full Codex bootstrap"
12981298

1299+
@pytest.mark.asyncio
1300+
async def test_execute_streaming_replays_when_history_diverges_mid_conversation(
1301+
self,
1302+
mock_base_connector,
1303+
mock_credential_manager,
1304+
sample_context,
1305+
) -> None:
1306+
continuation = InMemoryCodexContinuationCoordinator()
1307+
executor = ResponseExecutor(
1308+
mock_base_connector,
1309+
mock_credential_manager,
1310+
continuation_coordinator=continuation,
1311+
)
1312+
from src.connectors.openai_codex.contracts import CodexInputItem, CodexPayload
1313+
1314+
first_payload = CodexPayload(
1315+
model="gpt-5.1-codex",
1316+
input=[
1317+
CodexInputItem.model_validate(
1318+
{
1319+
"type": "message",
1320+
"role": "user",
1321+
"content": [{"type": "input_text", "text": "A"}],
1322+
}
1323+
),
1324+
CodexInputItem.model_validate(
1325+
{
1326+
"type": "message",
1327+
"role": "assistant",
1328+
"content": [{"type": "output_text", "text": "B"}],
1329+
}
1330+
),
1331+
CodexInputItem.model_validate(
1332+
{
1333+
"type": "message",
1334+
"role": "user",
1335+
"content": [{"type": "input_text", "text": "C"}],
1336+
}
1337+
),
1338+
],
1339+
tools=[],
1340+
tool_choice="auto",
1341+
parallel_tool_calls=False,
1342+
store=False,
1343+
stream=True,
1344+
include=[],
1345+
prompt_cache_key="test-key",
1346+
instructions="Full Codex bootstrap",
1347+
)
1348+
diverged_payload = first_payload.model_copy(
1349+
update={
1350+
"input": [
1351+
first_payload.input[0],
1352+
first_payload.input[1],
1353+
CodexInputItem.model_validate(
1354+
{
1355+
"type": "message",
1356+
"role": "user",
1357+
"content": [{"type": "input_text", "text": "X"}],
1358+
}
1359+
),
1360+
CodexInputItem.model_validate(
1361+
{
1362+
"type": "message",
1363+
"role": "assistant",
1364+
"content": [{"type": "output_text", "text": "D"}],
1365+
}
1366+
),
1367+
CodexInputItem.model_validate(
1368+
{
1369+
"type": "message",
1370+
"role": "user",
1371+
"content": [{"type": "input_text", "text": "E"}],
1372+
}
1373+
),
1374+
]
1375+
}
1376+
)
1377+
1378+
async def done_iterator(response_id: str):
1379+
yield ProcessedResponse(
1380+
content={"id": response_id, "output": []},
1381+
metadata={"event_type": "response.done", "done": True},
1382+
)
1383+
1384+
captured_payloads: list[dict[str, object]] = []
1385+
1386+
async def handle_streaming_side_effect(
1387+
url, payload_dict, headers, session_id, *args, **kwargs
1388+
):
1389+
captured_payloads.append(dict(payload_dict))
1390+
stream_handle = MagicMock()
1391+
stream_handle.headers = {}
1392+
stream_handle.cancel_callback = AsyncMock()
1393+
stream_handle.iterator = done_iterator(
1394+
"resp_first" if len(captured_payloads) == 1 else "resp_second"
1395+
)
1396+
return stream_handle
1397+
1398+
mock_base_connector._handle_streaming_response = AsyncMock(
1399+
side_effect=handle_streaming_side_effect
1400+
)
1401+
1402+
first_result = await executor.execute(first_payload, sample_context)
1403+
assert first_result.content is not None
1404+
async for _ in first_result.content:
1405+
pass
1406+
1407+
second_result = await executor.execute(diverged_payload, sample_context)
1408+
assert second_result.content is not None
1409+
async for _ in second_result.content:
1410+
pass
1411+
1412+
assert len(captured_payloads) == 2
1413+
assert "previous_response_id" not in captured_payloads[1]
1414+
assert captured_payloads[1]["instructions"] == "Full Codex bootstrap"
1415+
assert captured_payloads[1]["input"] == [
1416+
item.model_dump(exclude_none=True) for item in diverged_payload.input
1417+
]
1418+
12991419
@pytest.mark.asyncio
13001420
async def test_execute_streaming_records_terminal_response_id_in_continuation(
13011421
self,

tests/unit/connectors/openai_codex/test_executor_websocket.py

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Unit tests for ResponseExecutor WebSocket support."""
22

3-
from __future__ import annotations
4-
5-
from unittest.mock import AsyncMock, MagicMock, patch
6-
7-
import pytest
8-
from src.connectors.openai_codex.executor import _CodexTransportAdapter
9-
from src.core.common.exceptions import AuthenticationError
10-
from src.core.domain.responses import ProcessedResponse, StreamingResponseHandle
3+
from __future__ import annotations
4+
5+
from typing import Any, cast
6+
from unittest.mock import AsyncMock, MagicMock, patch
7+
8+
import pytest
9+
from src.connectors.contracts import ConnectorRequestContext
10+
from src.connectors.openai_codex.executor import _CodexTransportAdapter
11+
from src.core.common.exceptions import AuthenticationError
12+
from src.core.domain.responses import ProcessedResponse, StreamingResponseHandle
1113

1214

1315
@pytest.mark.asyncio
@@ -60,11 +62,13 @@ async def mock_send_response_create(*args, **kwargs):
6062
async for chunk in handle.iterator:
6163
chunks.append(chunk)
6264

63-
# Verify chunks
64-
assert len(chunks) == 2
65-
# Websocket transport adapter yields ProcessedResponse objects directly
66-
assert chunks[0].content["message"]["content"] == "Hello"
67-
assert chunks[1].content["message"]["content"] == "World"
65+
# Verify chunks
66+
assert len(chunks) == 2
67+
# Websocket transport adapter yields ProcessedResponse objects directly
68+
first_content = cast(dict[str, Any], chunks[0].content)
69+
second_content = cast(dict[str, Any], chunks[1].content)
70+
assert cast(dict[str, Any], first_content["message"])["content"] == "Hello"
71+
assert cast(dict[str, Any], second_content["message"])["content"] == "World"
6872

6973
async def test_recreates_websocket_client_when_auth_token_changes(self) -> None:
7074
"""Auth refresh retries must not reuse stale WebSocket credentials."""
@@ -113,17 +117,17 @@ async def test_initiate_websocket_streaming_no_auth(self) -> None:
113117

114118
url = "https://chatgpt.com/backend-api/codex/responses"
115119
payload = {"model": "gpt-4", "input": []}
116-
headers = {} # No authorization header
120+
headers: dict[str, str] = {} # No authorization header
117121
session_id = "test_session"
118122

119123
with pytest.raises(AuthenticationError, match="No API key"):
120124
await adapter.initiate_streaming_request(url, payload, headers, session_id)
121125

122-
async def test_http_fallback_when_websocket_disabled(self) -> None:
123-
"""Test fallback to HTTP/SSE when WebSocket is disabled."""
124-
mock_connector = MagicMock()
125-
mock_connector._handle_streaming_response = AsyncMock(
126-
return_value=StreamingResponseHandle(
126+
async def test_http_fallback_when_websocket_disabled(self) -> None:
127+
"""Test fallback to HTTP/SSE when WebSocket is disabled."""
128+
mock_connector = MagicMock()
129+
mock_connector._handle_streaming_response = AsyncMock(
130+
return_value=StreamingResponseHandle(
127131
iterator=AsyncMock(), headers={}, cancel_callback=AsyncMock()
128132
)
129133
)
@@ -140,15 +144,53 @@ async def test_http_fallback_when_websocket_disabled(self) -> None:
140144
)
141145

142146
# Verify HTTP/SSE method was called
143-
mock_connector._handle_streaming_response.assert_called_once_with(
144-
url, payload, headers, session_id, "responses"
145-
)
146-
assert isinstance(handle, StreamingResponseHandle)
147-
148-
async def test_cleanup_closes_websocket_client(self) -> None:
149-
"""Test cleanup properly disconnects WebSocket client."""
150-
mock_connector = MagicMock()
151-
adapter = _CodexTransportAdapter(mock_connector, use_websocket=True)
147+
mock_connector._handle_streaming_response.assert_called_once_with(
148+
url, payload, headers, session_id, "responses"
149+
)
150+
assert isinstance(handle, StreamingResponseHandle)
151+
152+
async def test_http_fallback_accepts_transport_metadata_kwargs(self) -> None:
153+
"""Transport adapter should accept the executor's keyword metadata contract."""
154+
mock_connector = MagicMock()
155+
mock_connector._handle_streaming_response = AsyncMock(
156+
return_value=StreamingResponseHandle(
157+
iterator=AsyncMock(), headers={}, cancel_callback=AsyncMock()
158+
)
159+
)
160+
161+
adapter = _CodexTransportAdapter(mock_connector, use_websocket=False)
162+
163+
url = "https://chatgpt.com/backend-api/codex/responses"
164+
payload = {"model": "gpt-4", "input": []}
165+
headers = {"Authorization": "Bearer test_key"}
166+
session_id = "test_session"
167+
request_context = ConnectorRequestContext(
168+
request_id="req-1",
169+
session_id="sess-1",
170+
client_host="127.0.0.1",
171+
extensions={},
172+
)
173+
174+
handle = await adapter.initiate_streaming_request(
175+
url,
176+
payload,
177+
headers,
178+
session_id,
179+
context=request_context,
180+
backend="openai-codex",
181+
model="gpt-4",
182+
key_name="openai-codex",
183+
)
184+
185+
mock_connector._handle_streaming_response.assert_called_once_with(
186+
url, payload, headers, session_id, "responses"
187+
)
188+
assert isinstance(handle, StreamingResponseHandle)
189+
190+
async def test_cleanup_closes_websocket_client(self) -> None:
191+
"""Test cleanup properly disconnects WebSocket client."""
192+
mock_connector = MagicMock()
193+
adapter = _CodexTransportAdapter(mock_connector, use_websocket=True)
152194

153195
# Create mock WebSocket client
154196
mock_ws_client = AsyncMock()

0 commit comments

Comments
 (0)