@@ -166,6 +166,100 @@ def test_get_model_key_from_env_with_api_key(monkeypatch):
166166 assert api_key == "test-api-key"
167167
168168
169+ class TestExtractTokenUsage :
170+ """Tests for extract_token_usage multi-provider support."""
171+
172+ def test_gemini_format (self ):
173+ raw = {
174+ "usageMetadata" : {
175+ "promptTokenCount" : 150 ,
176+ "candidatesTokenCount" : 200 ,
177+ "totalTokenCount" : 350 ,
178+ }
179+ }
180+ result = entrypoint .extract_token_usage (raw )
181+ assert result ["input_tokens" ] == 150
182+ assert result ["output_tokens" ] == 200
183+ assert result ["total_tokens" ] == 350
184+
185+ def test_claude_format (self ):
186+ raw = {
187+ "usage" : {
188+ "input_tokens" : 100 ,
189+ "output_tokens" : 250 ,
190+ }
191+ }
192+ result = entrypoint .extract_token_usage (raw )
193+ assert result ["input_tokens" ] == 100
194+ assert result ["output_tokens" ] == 250
195+ assert result ["total_tokens" ] == 350 # computed
196+
197+ def test_openai_format (self ):
198+ raw = {
199+ "usage" : {
200+ "prompt_tokens" : 80 ,
201+ "completion_tokens" : 120 ,
202+ "total_tokens" : 200 ,
203+ }
204+ }
205+ result = entrypoint .extract_token_usage (raw )
206+ assert result ["input_tokens" ] == 80
207+ assert result ["output_tokens" ] == 120
208+ assert result ["total_tokens" ] == 200
209+
210+ def test_none_response (self ):
211+ result = entrypoint .extract_token_usage (None )
212+ assert result ["input_tokens" ] is None
213+ assert result ["output_tokens" ] is None
214+
215+ def test_empty_dict (self ):
216+ result = entrypoint .extract_token_usage ({})
217+ assert result ["input_tokens" ] is None
218+
219+ def test_missing_usage (self ):
220+ raw = {"id" : "123" , "choices" : []}
221+ result = entrypoint .extract_token_usage (raw )
222+ assert result ["input_tokens" ] is None
223+
224+
225+ class TestWriteTokenUsage :
226+ """Tests for write_token_usage file output."""
227+
228+ def test_writes_json (self , tmp_path ):
229+ class MockClient :
230+ last_raw_response = {
231+ "usageMetadata" : {
232+ "promptTokenCount" : 50 ,
233+ "candidatesTokenCount" : 100 ,
234+ "totalTokenCount" : 150 ,
235+ }
236+ }
237+ entrypoint .write_token_usage (MockClient (), "gemini-2.5-flash" , tmp_path )
238+ usage_file = tmp_path / "token_usage.json"
239+ assert usage_file .exists ()
240+ import json
241+ data = json .loads (usage_file .read_text ())
242+ assert data ["model" ] == "gemini-2.5-flash"
243+ assert data ["input_tokens" ] == 50
244+ assert data ["output_tokens" ] == 100
245+
246+ def test_creates_directory (self , tmp_path ):
247+ nested = tmp_path / "sub" / "dir"
248+ class MockClient :
249+ last_raw_response = {}
250+ entrypoint .write_token_usage (MockClient (), "claude" , nested )
251+ assert (nested / "token_usage.json" ).exists ()
252+
253+ def test_handles_none_response (self , tmp_path ):
254+ class MockClient :
255+ last_raw_response = None
256+ entrypoint .write_token_usage (MockClient (), "grok" , tmp_path )
257+ import json
258+ data = json .loads ((tmp_path / "token_usage.json" ).read_text ())
259+ assert data ["model" ] == "grok"
260+ assert data ["input_tokens" ] is None
261+
262+
169263if __name__ == '__main__' :
170264 pytest .main ([__file__ ])
171265
0 commit comments