Skip to content

Commit 29a84ca

Browse files
committed
feature: Allow specifying the Gemini model in the API URL
This commit introduces the ability to specify the Gemini model to be used when calling the generative language API. Previously, the model was hardcoded to `gemini-1.5-flash-latest`. This change modifies the `url` function to accept an optional `model` parameter, allowing users to select different Gemini models (e.g., different versions or specialized models) if needed. The default remains `gemini-1.5-flash-latest` for backward compatibility.
1 parent 6417731 commit 29a84ca

5 files changed

Lines changed: 74 additions & 6 deletions

File tree

action.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ inputs:
1515
description: 'API token for AI'
1616
required: true
1717
type: string
18+
model: # New input for the model
19+
description: 'The Gemini model to use (e.g., gemini-1.5-flash-latest)'
20+
required: false
21+
default: 'gemini-1.5-flash-latest'
22+
type: string
1823
student-files:
1924
description: "Comma-separated list of student's Python file paths or a glob pattern"
2025
required: false

ai_tutor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020

2121
@functools.lru_cache
22-
def url(api_key:str) -> str:
23-
return f'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={api_key}'
22+
def url(api_key:str, model:str='gemini-1.5-flash-latest') -> str:
23+
return f'https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}'
2424

2525

2626
@functools.lru_cache
@@ -31,6 +31,7 @@ def header() -> HEADER:
3131
def ask_gemini(
3232
question: str,
3333
api_key:str,
34+
model:str='gemini-1.5-flash-latest',
3435
header:HEADER=header(),
3536
retry_delay_sec: float = 5.0,
3637
max_retry_attempt: int = 3,
@@ -60,7 +61,11 @@ def ask_gemini(
6061
logging.error(f"Timeout exceeded for question: {question}")
6162
break # Exit the loop on timeout
6263

63-
response = requests.post(url(api_key), headers=header, json=data)
64+
response = requests.post(
65+
url(api_key, model=model),
66+
headers=header,
67+
json=data
68+
)
6469

6570
if response.status_code == 200:
6671
result = response.json()
@@ -88,6 +93,7 @@ def gemini_qna(
8893
readme_file:pathlib.Path,
8994
api_key:str,
9095
explanation_in:str='Korean',
96+
model:str='gemini-1.5-flash-latest',
9197
) -> Tuple[int, str]:
9298
'''
9399
Queries the Gemini API to provide explanations for failed pytest test cases.
@@ -112,7 +118,7 @@ def gemini_qna(
112118
explanation_in
113119
)
114120

115-
answers = ask_gemini(consolidated_question, api_key)
121+
answers = ask_gemini(consolidated_question, api_key, model=model)
116122

117123
return n_failed, answers
118124

entrypoint.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@ def main() -> None:
3030
assert readme_file.exists(), 'No README file'
3131

3232
api_key = os.environ['INPUT_API-KEY'].strip()
33-
3433
assert api_key, "Please check API-KEY"
3534

35+
model = os.getenv(
36+
'INPUT_MODEL', # Get model from environment
37+
'gemini-1.5-flash-latest' # use default if not provided
38+
)
3639
explanation_in = os.environ['INPUT_EXPLANATION-IN']
3740

3841
b_fail_expected = ('true' == os.getenv('INPUT_FAIL-EXPECTED', 'false').lower())
@@ -42,7 +45,8 @@ def main() -> None:
4245
student_files,
4346
readme_file,
4447
api_key,
45-
explanation_in
48+
explanation_in,
49+
model=model,
4650
)
4751

4852
print(feedback)

tests/test_ai_tutor.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import pathlib
33
import sys
4+
import urllib.parse as up
45

56
from typing import Callable, Dict, List, Tuple, Union
67

@@ -549,5 +550,55 @@ def test__exclude_common_contents__double_specific(
549550
assert end_marker.strip() not in result
550551

551552

553+
@pytest.fixture
554+
def test_api_key() -> str:
555+
return 'test_api_key'
556+
557+
558+
@pytest.fixture
559+
def expected_default_gemini_model() -> str:
560+
return 'gemini-1.5-flash-latest'
561+
562+
563+
def test_url__default_model(
564+
test_api_key:str,
565+
expected_default_gemini_model:str
566+
):
567+
result = ai_tutor.url(test_api_key)
568+
569+
result_parsed = up.urlparse(result)
570+
571+
path_parts = result_parsed.path.split('/')
572+
573+
b_found = False
574+
575+
for part in path_parts:
576+
if expected_default_gemini_model in part.split(':'):
577+
b_found = True
578+
break
579+
580+
assert b_found, f"Could not find {expected_default_gemini_model} in {path_parts}."
581+
582+
583+
def test_url__specific_model(
584+
test_api_key:str,
585+
):
586+
model = 'gemini-2.0-flash'
587+
result = ai_tutor.url(test_api_key, model=model)
588+
589+
result_parsed = up.urlparse(result)
590+
591+
path_parts = result_parsed.path.split('/')
592+
593+
b_found = False
594+
595+
for part in path_parts:
596+
if model in part.split(':'):
597+
b_found = True
598+
break
599+
600+
assert b_found, f"Could not find {expected_default_gemini_model} in {path_parts}."
601+
602+
552603
if '__main__' == __name__:
553604
pytest.main([__file__])

tests/test_integration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_main_argument_passing__all_exists(mock_gemini_qna, caplog, tmp_path) ->
2020
# Setup
2121
os.environ['INPUT_API-KEY'] = 'test_key'
2222
os.environ['INPUT_EXPLANATION-IN'] = 'Korean'
23+
os.environ['INPUT_MODEL'] = 'gemini-2.0-flash-exp'
2324

2425
os.environ['GITHUB_OUTPUT'] = str(tmp_path / 'output.txt')
2526

@@ -54,6 +55,7 @@ def test_main_argument_passing__all_exists(mock_gemini_qna, caplog, tmp_path) ->
5455
tmp_path / 'readme.txt',
5556
'test_key',
5657
'Korean',
58+
model='gemini-2.0-flash-exp',
5759
)
5860

5961
assert 'does not exist' not in caplog.text

0 commit comments

Comments
 (0)