Skip to content

Commit 36775d4

Browse files
author
Mateusz
committed
Retry incompatible Codex tool calls after text output
1 parent 740fed0 commit 36775d4

2 files changed

Lines changed: 134 additions & 56 deletions

File tree

src/connectors/openai_codex/executor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -823,12 +823,12 @@ async def _streaming_iterator() -> AsyncIterator[ProcessedResponse]:
823823
context,
824824
)
825825
)
826-
if incompatible_tools and not visible_output_emitted:
827-
if (
828-
incompatible_tool_retries
829-
< self._max_incompatible_tool_retries
830-
):
831-
retry_for_incompatible_tools = True
826+
if incompatible_tools:
827+
if (
828+
incompatible_tool_retries
829+
< self._max_incompatible_tool_retries
830+
):
831+
retry_for_incompatible_tools = True
832832
restart_stream = True
833833
incompatible_tool_retries += 1
834834
current_payload_dict = self._append_incompatible_tool_retry_steering(

tests/unit/connectors/openai_codex/test_executor_streaming.py

Lines changed: 128 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,10 @@ async def test_execute_streaming_handshake_maps_instruction_invalid_error(
345345
pass
346346

347347
assert exc_info.value.status_code == 400
348-
assert isinstance(exc_info.value.detail, dict)
349-
detail = exc_info.value.detail
350-
assert detail.get("error") == "codex_instructions_invalid"
351-
assert "prompt_mode" in str(detail.get("suggestion", ""))
348+
assert isinstance(exc_info.value.detail, dict)
349+
detail = exc_info.value.detail
350+
assert detail.get("error") == "codex_instructions_invalid"
351+
assert "prompt_mode" in str(detail.get("suggestion", ""))
352352

353353
@pytest.mark.asyncio
354354
async def test_execute_streaming_handshake_uses_retry_after_from_error_detail(
@@ -536,11 +536,11 @@ async def test_execute_streaming_handshake_429_usage_limit_notifies_when_rotatio
536536

537537
assert exc_info.value.status_code == 429
538538
mock_credential_manager.notify_codex_usage_limit_unrecovered.assert_awaited_once()
539-
await_args = (
540-
mock_credential_manager.notify_codex_usage_limit_unrecovered.await_args
541-
)
542-
assert await_args is not None
543-
notify_kw = cast(dict[str, Any], await_args.kwargs)
539+
await_args = (
540+
mock_credential_manager.notify_codex_usage_limit_unrecovered.await_args
541+
)
542+
assert await_args is not None
543+
notify_kw = cast(dict[str, Any], await_args.kwargs)
544544
assert notify_kw["upstream_detail"] == detail
545545
assert notify_kw["pool_exhaustion_confirmed"] is True
546546

@@ -663,7 +663,7 @@ async def test_execute_streaming_handshake_auth_retry_exhausted(
663663

664664
# Exception is raised when consuming the stream
665665
assert result.content is not None
666-
content = result.content
666+
content = result.content
667667
with pytest.raises(HTTPException) as exc_info:
668668
async for _ in content:
669669
pass
@@ -928,7 +928,7 @@ async def auth_error_iterator():
928928

929929
# Should raise after retries exhausted
930930
assert result.content is not None
931-
content = result.content
931+
content = result.content
932932
with pytest.raises(HTTPException) as exc_info:
933933
async for _ in content:
934934
pass
@@ -962,7 +962,7 @@ async def test_execute_streaming_refresh_fails(
962962

963963
# Exception is raised when consuming the stream after refresh fails
964964
assert result.content is not None
965-
content = result.content
965+
content = result.content
966966
with pytest.raises(HTTPException) as exc_info:
967967
async for _ in content:
968968
pass
@@ -1111,6 +1111,84 @@ async def streaming_side_effect(
11111111
assert matching[-1].retry_reason == "incompatible_tools"
11121112
assert matching[-1].response_id == "resp_retry_123"
11131113

1114+
async def test_execute_streaming_retries_incompatible_tool_call_after_text_output(
1115+
self, executor, sample_context, streaming_payload
1116+
) -> None:
1117+
"""Incompatible tool retries should still fire even after brief text output."""
1118+
compatibility_layer = MagicMock(spec=ICompatibilityLayer)
1119+
compatibility_layer.detect_incompatible_tool_calls.return_value = [
1120+
"apply_patch"
1121+
]
1122+
compatibility_layer.append_incompatible_tool_steering.side_effect = (
1123+
lambda payload_dict, incompatible_tools, context: {
1124+
**payload_dict,
1125+
"instructions": "retry steering",
1126+
}
1127+
)
1128+
executor._compatibility_layer = compatibility_layer
1129+
1130+
first_handle = MagicMock()
1131+
first_handle.headers = {}
1132+
first_handle.cancel_callback = AsyncMock()
1133+
1134+
async def first_iterator():
1135+
yield ProcessedResponse(
1136+
content={"choices": [{"delta": {"content": "Working on it."}}]},
1137+
metadata={},
1138+
)
1139+
yield ProcessedResponse(
1140+
content={
1141+
"type": "response.output_item.added",
1142+
"item": {"type": "function_call", "name": "apply_patch"},
1143+
}
1144+
)
1145+
1146+
first_handle.iterator = first_iterator()
1147+
1148+
second_handle = MagicMock()
1149+
second_handle.headers = {}
1150+
second_handle.cancel_callback = AsyncMock()
1151+
1152+
async def second_iterator():
1153+
yield ProcessedResponse(
1154+
content={"choices": [{"delta": {"content": "Using native edit."}}]},
1155+
metadata={},
1156+
)
1157+
1158+
second_handle.iterator = second_iterator()
1159+
1160+
captured_payloads: list[dict[str, object]] = []
1161+
1162+
async def streaming_side_effect(
1163+
url, payload_dict, headers, session_id, *args, **kwargs
1164+
):
1165+
captured_payloads.append(dict(payload_dict))
1166+
if len(captured_payloads) == 1:
1167+
return first_handle
1168+
return second_handle
1169+
1170+
executor._base_connector._handle_streaming_response = AsyncMock(
1171+
side_effect=streaming_side_effect
1172+
)
1173+
1174+
result = await executor.execute(streaming_payload, sample_context)
1175+
chunks = [
1176+
chunk
1177+
async for chunk in cast(AsyncIterator[ProcessedResponse], result.content)
1178+
]
1179+
1180+
assert len(chunks) == 2
1181+
assert chunks[0].content == {
1182+
"choices": [{"delta": {"content": "Working on it."}}]
1183+
}
1184+
assert chunks[1].content == {
1185+
"choices": [{"delta": {"content": "Using native edit."}}]
1186+
}
1187+
assert len(captured_payloads) == 2
1188+
assert captured_payloads[1]["instructions"] == "retry steering"
1189+
first_handle.cancel_callback.assert_awaited()
1190+
compatibility_layer.append_incompatible_tool_steering.assert_called_once()
1191+
11141192
async def test_conversation_id_preserved_across_streaming_retries(
11151193
self,
11161194
executor,
@@ -2016,7 +2094,7 @@ async def handle_streaming_side_effect(
20162094

20172095
first_result = await executor.execute(first_payload, sample_context)
20182096
assert first_result.content is not None
2019-
first_stream = first_result.content
2097+
first_stream = first_result.content
20202098
first_chunk = await anext(first_stream)
20212099
assert isinstance(first_chunk, ProcessedResponse)
20222100
await cast(Any, first_stream).aclose()
@@ -2456,10 +2534,10 @@ async def test_normalize_processed_stream_chunk_marks_tool_call_emission(
24562534
assert content["choices"][0]["finish_reason"] == "tool_calls"
24572535

24582536
@pytest.mark.asyncio
2459-
async def test_normalize_processed_stream_chunk_marks_function_call_done_as_tool_output(
2460-
self,
2461-
mock_base_connector,
2462-
mock_credential_manager,
2537+
async def test_normalize_processed_stream_chunk_marks_function_call_done_as_tool_output(
2538+
self,
2539+
mock_base_connector,
2540+
mock_credential_manager,
24632541
) -> None:
24642542
mock_base_connector.translation_service = TranslationService()
24652543
executor = ResponseExecutor(
@@ -2479,39 +2557,39 @@ async def test_normalize_processed_stream_chunk_marks_function_call_done_as_tool
24792557
normalized = executor._normalize_processed_stream_chunk(chunk)
24802558

24812559
assert normalized.metadata.get("tool_call_emitted") is True
2482-
assert normalized.metadata.get("finish_reason") == "tool_calls"
2483-
content = cast(dict[str, Any], normalized.content)
2484-
assert content["choices"][0]["delta"] == {}
2485-
2486-
@pytest.mark.asyncio
2487-
async def test_normalize_processed_stream_chunk_overrides_falsey_tool_markers(
2488-
self,
2489-
mock_base_connector,
2490-
mock_credential_manager,
2491-
) -> None:
2492-
mock_base_connector.translation_service = TranslationService()
2493-
executor = ResponseExecutor(
2494-
mock_base_connector,
2495-
mock_credential_manager,
2496-
)
2497-
2498-
chunk = ProcessedResponse(
2499-
content={
2500-
"type": "response.function_call_arguments.done",
2501-
"item_id": "fc_ws_tool",
2502-
"arguments": '{"command":["bash","-lc","git status --short"]}',
2503-
},
2504-
metadata={
2505-
"event_type": "response.function_call_arguments.done",
2506-
"tool_call_emitted": False,
2507-
"finish_reason": None,
2508-
},
2509-
)
2510-
2511-
normalized = executor._normalize_processed_stream_chunk(chunk)
2512-
2513-
assert normalized.metadata.get("tool_call_emitted") is True
2514-
assert normalized.metadata.get("finish_reason") == "tool_calls"
2560+
assert normalized.metadata.get("finish_reason") == "tool_calls"
2561+
content = cast(dict[str, Any], normalized.content)
2562+
assert content["choices"][0]["delta"] == {}
2563+
2564+
@pytest.mark.asyncio
2565+
async def test_normalize_processed_stream_chunk_overrides_falsey_tool_markers(
2566+
self,
2567+
mock_base_connector,
2568+
mock_credential_manager,
2569+
) -> None:
2570+
mock_base_connector.translation_service = TranslationService()
2571+
executor = ResponseExecutor(
2572+
mock_base_connector,
2573+
mock_credential_manager,
2574+
)
2575+
2576+
chunk = ProcessedResponse(
2577+
content={
2578+
"type": "response.function_call_arguments.done",
2579+
"item_id": "fc_ws_tool",
2580+
"arguments": '{"command":["bash","-lc","git status --short"]}',
2581+
},
2582+
metadata={
2583+
"event_type": "response.function_call_arguments.done",
2584+
"tool_call_emitted": False,
2585+
"finish_reason": None,
2586+
},
2587+
)
2588+
2589+
normalized = executor._normalize_processed_stream_chunk(chunk)
2590+
2591+
assert normalized.metadata.get("tool_call_emitted") is True
2592+
assert normalized.metadata.get("finish_reason") == "tool_calls"
25152593

25162594
@pytest.mark.asyncio
25172595
async def test_normalize_processed_stream_chunk_marks_local_shell_item_done_as_tool_output(

0 commit comments

Comments
 (0)