Skip to content

Commit 857067a

Browse files
author
Mateusz
committed
fix(codex): preserve usage through legacy response conversion
1 parent 781192a commit 857067a

5 files changed

Lines changed: 160 additions & 33 deletions

File tree

dev/scripts/demo_codex_usage_reporting_fix.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from src.core.interfaces.response_processor_interface import ProcessedResponse
3333
from src.core.services.translation_service import TranslationService
34+
from src.core.transport.fastapi.response_adapters import domain_response_to_fastapi
3435

3536

3637
class _FakeTransportWithProviderUsage:
@@ -161,6 +162,26 @@ async def _run_demo() -> None:
161162
"Non-streaming total_tokens is zero; expected > 0"
162163
)
163164

165+
legacy_fastapi_response = domain_response_to_fastapi(
166+
non_stream_result
167+
)
168+
legacy_body = (
169+
legacy_fastapi_response.body.tobytes()
170+
if isinstance(legacy_fastapi_response.body, memoryview)
171+
else legacy_fastapi_response.body
172+
)
173+
legacy_payload = json.loads(legacy_body.decode("utf-8"))
174+
legacy_usage = legacy_payload.get("usage")
175+
print("[legacy-non-stream] usage:", legacy_usage)
176+
if not isinstance(legacy_usage, dict):
177+
raise RuntimeError(
178+
"Legacy OpenAI-compatible payload is missing usage"
179+
)
180+
if int(legacy_usage.get("total_tokens", 0)) <= 0:
181+
raise RuntimeError(
182+
"Legacy OpenAI-compatible total_tokens is zero; expected > 0"
183+
)
184+
164185
streaming_request = ChatRequest(
165186
model="openai-codex:gpt-5-codex",
166187
messages=[
@@ -215,7 +236,9 @@ async def _run_demo() -> None:
215236
"Streaming total_tokens is zero; expected > 0"
216237
)
217238

218-
print("SUCCESS: Codex usage reporting is non-zero for both flows.")
239+
print(
240+
"SUCCESS: Codex usage reporting is non-zero for connector and legacy OpenAI frontend flows."
241+
)
219242
finally:
220243
await backend.shutdown()
221244

src/core/app/controllers/chat_controller.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,7 @@ def _inject_reasoning_aliases(payload: object) -> object:
707707

708708
raw_usage = metadata.get(
709709
"usage",
710-
{
711-
"prompt_tokens": 0,
712-
"completion_tokens": 0,
713-
"total_tokens": 0,
714-
},
710+
None,
715711
)
716712
usage_summary = None
717713
if isinstance(raw_usage, UsageSummary):
@@ -720,15 +716,15 @@ def _inject_reasoning_aliases(payload: object) -> object:
720716
usage_summary = UsageSummary.from_dict(raw_usage)
721717

722718
# Create the response using Pydantic model
723-
response = ChatResponse(
719+
chat_response = ChatResponse(
724720
id=response_id,
725721
created=created_val,
726722
model=model_name,
727723
choices=[choice],
728724
usage=usage_summary,
729725
)
730726

731-
return _inject_reasoning_aliases(response.model_dump())
727+
return _inject_reasoning_aliases(chat_response.model_dump())
732728

733729
if metadata:
734730
meta_role = metadata.get("role") # type: ignore[arg-type]
@@ -781,7 +777,11 @@ def _inject_reasoning_aliases(payload: object) -> object:
781777
choice = ChatCompletionChoice(
782778
index=0,
783779
message=message,
784-
finish_reason=finish_reason, # type: ignore[arg-type]
780+
finish_reason=(
781+
finish_reason
782+
if isinstance(finish_reason, str)
783+
else None
784+
),
785785
)
786786

787787
from src.core.domain.usage_summary import UsageSummary
@@ -793,15 +793,15 @@ def _inject_reasoning_aliases(payload: object) -> object:
793793
elif isinstance(raw_usage, dict):
794794
usage_summary = UsageSummary.from_dict(raw_usage)
795795

796-
response = ChatResponse(
796+
chat_response = ChatResponse(
797797
id=response_id,
798798
created=created_val,
799799
model=model_name,
800800
choices=[choice],
801801
usage=usage_summary,
802802
)
803803

804-
return response.model_dump()
804+
return chat_response.model_dump()
805805

806806
# Check if content is a JSON string of tool calls (common backend response format)
807807
if isinstance(content, str):
@@ -855,26 +855,22 @@ def _inject_reasoning_aliases(payload: object) -> object:
855855
if metadata:
856856
raw_usage = metadata.get("usage")
857857
else:
858-
raw_usage = {
859-
"prompt_tokens": 0,
860-
"completion_tokens": 0,
861-
"total_tokens": 0,
862-
}
858+
raw_usage = None
863859
usage_summary = None
864860
if isinstance(raw_usage, UsageSummary):
865861
usage_summary = raw_usage
866862
elif isinstance(raw_usage, dict):
867863
usage_summary = UsageSummary.from_dict(raw_usage)
868864

869-
response = ChatResponse(
865+
chat_response = ChatResponse(
870866
id=response_id,
871867
created=created_val,
872868
model=model_name,
873869
choices=[choice],
874870
usage=usage_summary,
875871
)
876872

877-
return response.model_dump()
873+
return chat_response.model_dump()
878874
except (ValueError, TypeError) as e:
879875
if logger.isEnabledFor(TRACE_LEVEL):
880876
logger.log(
@@ -964,7 +960,7 @@ def _inject_reasoning_aliases(payload: object) -> object:
964960
)
965961

966962
# Create the response using Pydantic model
967-
response = ChatResponse(
963+
chat_response = ChatResponse(
968964
id=content.get("id", f"chatcmpl-{_uuid.uuid4().hex[:16]}"),
969965
created=int(_time.time()),
970966
model=content.get(
@@ -974,7 +970,7 @@ def _inject_reasoning_aliases(payload: object) -> object:
974970
usage=UsageSummary.from_dict(openai_usage),
975971
)
976972

977-
return response
973+
return chat_response
978974

979975
import json as _json
980976
import time
@@ -1021,7 +1017,7 @@ def _inject_reasoning_aliases(payload: object) -> object:
10211017
from src.core.domain.usage_summary import UsageSummary
10221018

10231019
# Create the response using Pydantic model
1024-
response = ChatResponse(
1020+
chat_response = ChatResponse(
10251021
id=f"chatcmpl-{uuid.uuid4().hex[:16]}",
10261022
created=int(time.time()),
10271023
model=getattr(domain_request, "model", "gpt-4"),
@@ -1035,7 +1031,7 @@ def _inject_reasoning_aliases(payload: object) -> object:
10351031
),
10361032
)
10371033

1038-
return response
1034+
return chat_response
10391035
except Exception as e:
10401036
if logger.isEnabledFor(logging.WARNING):
10411037
logger.warning(

src/core/domain/usage_summary.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,22 @@ def from_dict(cls, data: dict[str, Any]) -> UsageSummary:
4848
Returns:
4949
UsageSummary instance
5050
"""
51-
prompt_tokens = data.get("prompt_tokens")
52-
completion_tokens = data.get("completion_tokens")
53-
total_tokens = data.get("total_tokens")
51+
prompt_tokens = data.get("prompt_tokens")
52+
if not isinstance(prompt_tokens, int):
53+
prompt_tokens = data.get("input_tokens")
54+
55+
completion_tokens = data.get("completion_tokens")
56+
if not isinstance(completion_tokens, int):
57+
completion_tokens = data.get("output_tokens")
58+
59+
total_tokens = data.get("total_tokens")
60+
if not isinstance(total_tokens, int):
61+
resolved_prompt = prompt_tokens if isinstance(prompt_tokens, int) else 0
62+
resolved_completion = (
63+
completion_tokens if isinstance(completion_tokens, int) else 0
64+
)
65+
computed_total = resolved_prompt + resolved_completion
66+
total_tokens = computed_total if computed_total > 0 else None
5467

5568
# Extract extensions
5669
# If "extensions" key exists, use it directly; otherwise extract all non-standard fields

tests/unit/core/app/controllers/test_chat_controller_content.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44

55
import json
66
from typing import Any
7+
from unittest.mock import AsyncMock, Mock
78

9+
import pytest
810
from src.core.app.controllers.chat_controller import ChatController
11+
from src.core.domain.responses import ResponseEnvelope
12+
from src.core.domain.usage_summary import UsageSummary
913

1014

1115
class TestCoerceMessageContentToText:
@@ -145,3 +149,67 @@ def test_coerce_message_content_to_text_prevents_stack_overflow(self) -> None:
145149
assert len(result) > 0
146150
# The result should contain some indication of the circular reference
147151
assert "Circular reference detected" in result
152+
153+
154+
class TestEnsureOpenAIChatSchemaUsage:
155+
@pytest.mark.asyncio
156+
async def test_tool_calls_schema_preserves_metadata_usage(self) -> None:
157+
processor = AsyncMock()
158+
processor.process_request = AsyncMock(
159+
return_value=ResponseEnvelope(
160+
content='[{"type": "function", "id": "call_1", "function": {"name": "do_work", "arguments": "{}"}}]',
161+
metadata={
162+
"tool_calls": [
163+
{
164+
"id": "call_1",
165+
"type": "function",
166+
"function": {"name": "do_work", "arguments": "{}"},
167+
}
168+
],
169+
"usage": {
170+
"input_tokens": 19,
171+
"output_tokens": 7,
172+
"total_tokens": 26,
173+
},
174+
},
175+
)
176+
)
177+
178+
controller = ChatController(
179+
request_processor=processor,
180+
translation_service=None,
181+
wire_capture=None,
182+
metrics_initializer=None,
183+
)
184+
185+
request = Mock()
186+
request.body = AsyncMock(return_value=b"{}")
187+
request.headers = {}
188+
request.cookies = {}
189+
request.url = Mock()
190+
request.url.path = "/v1/chat/completions"
191+
request.state = Mock()
192+
request.app = Mock()
193+
request.app.state = Mock()
194+
request.app.state.service_provider = None
195+
196+
from src.core.domain.chat import ChatMessage, ChatRequest
197+
198+
request_data = ChatRequest(
199+
model="openai-codex:gpt-5-codex",
200+
messages=[ChatMessage(role="user", content="hello")],
201+
stream=False,
202+
)
203+
204+
response = await controller.handle_chat_completion(request, request_data)
205+
body = (
206+
response.body.tobytes()
207+
if isinstance(response.body, memoryview)
208+
else response.body
209+
)
210+
payload = json.loads(body.decode("utf-8"))
211+
212+
usage = UsageSummary.from_dict(payload["usage"])
213+
assert usage.prompt_tokens == 19
214+
assert usage.completion_tokens == 7
215+
assert usage.total_tokens == 26

tests/unit/core/domain/test_usage_summary.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,41 @@ def test_usage_summary_from_dict(self) -> None:
103103
assert summary.total_tokens == 150
104104
assert summary.extensions == {"cost": 0.002}
105105

106-
def test_usage_summary_from_dict_with_none(self) -> None:
107-
"""Test creating UsageSummary from dictionary with None values."""
108-
data = {
109-
"prompt_tokens": None,
110-
"completion_tokens": None,
111-
"total_tokens": None,
112-
"extensions": {},
113-
}
106+
def test_usage_summary_from_dict_supports_responses_api_fields(self) -> None:
107+
"""Responses API usage fields should populate canonical token counts."""
108+
data: dict[str, int] = {
109+
"input_tokens": 17,
110+
"output_tokens": 9,
111+
"total_tokens": 26,
112+
}
113+
114+
summary = UsageSummary.from_dict(data)
115+
116+
assert summary.prompt_tokens == 17
117+
assert summary.completion_tokens == 9
118+
assert summary.total_tokens == 26
119+
assert summary.extensions == {"input_tokens": 17, "output_tokens": 9}
120+
121+
def test_usage_summary_from_dict_computes_total_for_responses_api_fields(
122+
self,
123+
) -> None:
124+
"""Responses API usage should compute total_tokens when omitted."""
125+
data: dict[str, int] = {"input_tokens": 11, "output_tokens": 5}
126+
127+
summary = UsageSummary.from_dict(data)
128+
129+
assert summary.prompt_tokens == 11
130+
assert summary.completion_tokens == 5
131+
assert summary.total_tokens == 16
132+
133+
def test_usage_summary_from_dict_with_none(self) -> None:
134+
"""Test creating UsageSummary from dictionary with None values."""
135+
data: dict[str, object] = {
136+
"prompt_tokens": None,
137+
"completion_tokens": None,
138+
"total_tokens": None,
139+
"extensions": {},
140+
}
114141
summary = UsageSummary.from_dict(data)
115142
assert summary.prompt_tokens is None
116143
assert summary.completion_tokens is None

0 commit comments

Comments
 (0)