Skip to content

Commit 5da4634

Browse files
beachdwellerclaude
andcommitted
feat: concise success feedback and token usage tracking
Shorten all-tests-passed feedback to 3-5 sentences (#2) and write token_usage.json to OUTPUT-DIR for post-run analysis (#5). Multi-provider extraction: Gemini, Claude, OpenAI-compatible formats. 11 new tests covering both features. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 68a18d9 commit 5da4634

5 files changed

Lines changed: 223 additions & 2 deletions

File tree

entrypoint.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#!/usr/bin/env python3
22
# begin entrypoint.py
33

4+
import json
45
import logging
56
import os
67
import pathlib
78
import sys
89

9-
from typing import Any, Dict, Tuple
10+
from typing import Any, Dict, Optional, Tuple
1011

1112

1213
sys.path.insert(
@@ -79,6 +80,69 @@ def main(b_ask:bool=True) -> None:
7980
elif b_fail_expected:
8081
assert n_failed > 0, 'No failed tests detected when failure was expected'
8182

83+
# Write token usage to artifact directory if available
84+
output_dir = os.getenv('INPUT_OUTPUT-DIR', '')
85+
if output_dir and b_ask:
86+
write_token_usage(client, model, pathlib.Path(output_dir))
87+
88+
89+
def extract_token_usage(raw_response: Optional[dict]) -> Dict[str, Any]:
90+
"""Extract token usage from LLM API response (best-effort, multi-provider).
91+
92+
Different providers return usage in different structures:
93+
Gemini: usageMetadata.promptTokenCount / candidatesTokenCount
94+
Claude: usage.input_tokens / output_tokens
95+
OpenAI-like: usage.prompt_tokens / completion_tokens (Grok, NVIDIA, Perplexity)
96+
97+
Returns dict with input_tokens, output_tokens, total_tokens (None if unavailable).
98+
"""
99+
if not raw_response or not isinstance(raw_response, dict):
100+
return {"input_tokens": None, "output_tokens": None, "total_tokens": None}
101+
102+
# Gemini format
103+
usage = raw_response.get("usageMetadata", {})
104+
if usage:
105+
return {
106+
"input_tokens": usage.get("promptTokenCount"),
107+
"output_tokens": usage.get("candidatesTokenCount"),
108+
"total_tokens": usage.get("totalTokenCount"),
109+
}
110+
111+
# Claude / OpenAI-compatible format
112+
usage = raw_response.get("usage", {})
113+
if usage:
114+
input_t = usage.get("input_tokens") or usage.get("prompt_tokens")
115+
output_t = usage.get("output_tokens") or usage.get("completion_tokens")
116+
total_t = usage.get("total_tokens")
117+
if total_t is None and input_t is not None and output_t is not None:
118+
total_t = input_t + output_t
119+
return {
120+
"input_tokens": input_t,
121+
"output_tokens": output_t,
122+
"total_tokens": total_t,
123+
}
124+
125+
return {"input_tokens": None, "output_tokens": None, "total_tokens": None}
126+
127+
128+
def write_token_usage(
129+
client: 'LLMAPIClient',
130+
model: str,
131+
output_dir: pathlib.Path,
132+
) -> None:
133+
"""Write token_usage.json to output directory."""
134+
usage = extract_token_usage(client.last_raw_response)
135+
usage["model"] = model
136+
137+
output_dir.mkdir(parents=True, exist_ok=True)
138+
usage_path = output_dir / "token_usage.json"
139+
try:
140+
with open(usage_path, "w", encoding="utf-8") as f:
141+
json.dump(usage, f, indent=2)
142+
logging.info(f"Token usage written to {usage_path}: {usage}")
143+
except OSError as e:
144+
logging.warning(f"Could not write token usage: {e}")
145+
82146

83147
def get_startwith(key:str, dictionary:dict) -> Any:
84148
result = None

llm_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, config: LLMConfig, retry_delay_sec: float = 5.0,
5050
self.max_retry_attempt = max_retry_attempt
5151
self.timeout_sec = timeout_sec
5252
self.logger = logging.getLogger(__name__) # Logger for this module
53+
self.last_raw_response = None # Store last API response for token usage extraction
5354

5455
def call_api(self, question: str) -> Optional[str]:
5556
"""Send a question to the LLM API with retry and timeout handling.
@@ -99,6 +100,7 @@ def call_api(self, question: str) -> Optional[str]:
99100
try:
100101
# Parse JSON and extract response using config-specific method
101102
result = response.json()
103+
self.last_raw_response = result
102104
return self.config.parse_response(result)
103105
except (ValueError, KeyError) as e:
104106
# Log parsing errors (invalid JSON or unexpected structure)

prompt.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,11 @@ def get_initial_instruction(questions: List[str], language: str) -> str:
9393
)
9494
return (
9595
f"{guardrail}\n"
96-
f"In {language}, please comment on the student code given the assignment instruction."
96+
f"All tests passed. In {language}, in 3-5 sentences:\n"
97+
"1. Briefly note what the student did well.\n"
98+
"2. Suggest one specific improvement if applicable "
99+
"(e.g., efficiency, readability, edge cases).\n"
100+
"Do not repeat test results. Do not assign or fabricate scores."
97101
)
98102

99103
prompt_list = (

tests/test_entrypoint.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,100 @@ def test_get_model_key_from_env_with_api_key(monkeypatch):
166166
assert api_key == "test-api-key"
167167

168168

169+
class TestExtractTokenUsage:
170+
"""Tests for extract_token_usage multi-provider support."""
171+
172+
def test_gemini_format(self):
173+
raw = {
174+
"usageMetadata": {
175+
"promptTokenCount": 150,
176+
"candidatesTokenCount": 200,
177+
"totalTokenCount": 350,
178+
}
179+
}
180+
result = entrypoint.extract_token_usage(raw)
181+
assert result["input_tokens"] == 150
182+
assert result["output_tokens"] == 200
183+
assert result["total_tokens"] == 350
184+
185+
def test_claude_format(self):
186+
raw = {
187+
"usage": {
188+
"input_tokens": 100,
189+
"output_tokens": 250,
190+
}
191+
}
192+
result = entrypoint.extract_token_usage(raw)
193+
assert result["input_tokens"] == 100
194+
assert result["output_tokens"] == 250
195+
assert result["total_tokens"] == 350 # computed
196+
197+
def test_openai_format(self):
198+
raw = {
199+
"usage": {
200+
"prompt_tokens": 80,
201+
"completion_tokens": 120,
202+
"total_tokens": 200,
203+
}
204+
}
205+
result = entrypoint.extract_token_usage(raw)
206+
assert result["input_tokens"] == 80
207+
assert result["output_tokens"] == 120
208+
assert result["total_tokens"] == 200
209+
210+
def test_none_response(self):
211+
result = entrypoint.extract_token_usage(None)
212+
assert result["input_tokens"] is None
213+
assert result["output_tokens"] is None
214+
215+
def test_empty_dict(self):
216+
result = entrypoint.extract_token_usage({})
217+
assert result["input_tokens"] is None
218+
219+
def test_missing_usage(self):
220+
raw = {"id": "123", "choices": []}
221+
result = entrypoint.extract_token_usage(raw)
222+
assert result["input_tokens"] is None
223+
224+
225+
class TestWriteTokenUsage:
226+
"""Tests for write_token_usage file output."""
227+
228+
def test_writes_json(self, tmp_path):
229+
class MockClient:
230+
last_raw_response = {
231+
"usageMetadata": {
232+
"promptTokenCount": 50,
233+
"candidatesTokenCount": 100,
234+
"totalTokenCount": 150,
235+
}
236+
}
237+
entrypoint.write_token_usage(MockClient(), "gemini-2.5-flash", tmp_path)
238+
usage_file = tmp_path / "token_usage.json"
239+
assert usage_file.exists()
240+
import json
241+
data = json.loads(usage_file.read_text())
242+
assert data["model"] == "gemini-2.5-flash"
243+
assert data["input_tokens"] == 50
244+
assert data["output_tokens"] == 100
245+
246+
def test_creates_directory(self, tmp_path):
247+
nested = tmp_path / "sub" / "dir"
248+
class MockClient:
249+
last_raw_response = {}
250+
entrypoint.write_token_usage(MockClient(), "claude", nested)
251+
assert (nested / "token_usage.json").exists()
252+
253+
def test_handles_none_response(self, tmp_path):
254+
class MockClient:
255+
last_raw_response = None
256+
entrypoint.write_token_usage(MockClient(), "grok", tmp_path)
257+
import json
258+
data = json.loads((tmp_path / "token_usage.json").read_text())
259+
assert data["model"] == "grok"
260+
assert data["input_tokens"] is None
261+
262+
169263
if __name__ == '__main__':
170264
pytest.main([__file__])
171265

tests/test_prompt.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,63 @@ def test_collect_longrepr__compare_contents(collect_longrepr_result: List[str]):
524524
assert not missing_markers, f"Missing markers: {missing_markers}"
525525

526526

527+
@pytest.fixture
528+
def all_passing_report(tmp_path) -> pathlib.Path:
529+
"""Create a pytest JSON report where all tests pass."""
530+
report = {
531+
"tests": [
532+
{"nodeid": "test_syntax::test_valid", "outcome": "passed",
533+
"call": {"longrepr": None}},
534+
{"nodeid": "test_results::test_calc_area", "outcome": "passed",
535+
"call": {"longrepr": None}},
536+
]
537+
}
538+
path = tmp_path / "all_pass_report.json"
539+
path.write_text(json.dumps(report))
540+
return path
541+
542+
543+
def test_get_prompt__all_passing__concise_instruction(
544+
all_passing_report: pathlib.Path,
545+
sample_student_code_path: pathlib.Path,
546+
sample_readme_path: pathlib.Path,
547+
):
548+
"""When all tests pass, the prompt should instruct concise feedback."""
549+
n_failed, prompt_text = prompt.get_prompt(
550+
report_paths=(all_passing_report,),
551+
student_files=(sample_student_code_path,),
552+
readme_file=sample_readme_path,
553+
explanation_in="Korean",
554+
)
555+
556+
assert n_failed == 0
557+
# Should contain concise instructions
558+
assert "3-5 sentences" in prompt_text
559+
assert "Do not assign or fabricate scores" in prompt_text
560+
# Should NOT contain the verbose "comment on the student code" instruction
561+
assert "please comment on the student code" not in prompt_text
562+
563+
564+
def test_get_prompt__with_failures__has_directive(
565+
sample_report_path: pathlib.Path,
566+
sample_student_code_path: pathlib.Path,
567+
sample_readme_path: pathlib.Path,
568+
):
569+
"""When tests fail, the prompt should contain the error directive."""
570+
n_failed, prompt_text = prompt.get_prompt(
571+
report_paths=(sample_report_path,),
572+
student_files=(sample_student_code_path,),
573+
readme_file=sample_readme_path,
574+
explanation_in="Korean",
575+
)
576+
577+
assert n_failed > 0
578+
# Should contain failure-specific instruction
579+
assert "mutually exclusively and collectively exhaustively" in prompt_text
580+
# Should NOT contain the concise success instruction
581+
assert "3-5 sentences" not in prompt_text
582+
583+
527584
if __name__ == '__main__':
528585
pytest.main([__file__])
529586

0 commit comments

Comments
 (0)