Skip to content

Commit 96ce132

Browse files
committed
Merge branch 'feature/gemini-model-selection'
2 parents 6417731 + d33a198 commit 96ce132

6 files changed

Lines changed: 76 additions & 6 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010

11+
- Added the ability to specify the Gemini model when using the AI Tutor GitHub Action. A new `model` input has been added to the action, allowing users to select different Gemini models. The default model remains `gemini-1.5-flash-latest` for backward compatibility.
12+
1113

1214
### Changed
1315

action.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ inputs:
1515
description: 'API token for AI'
1616
required: true
1717
type: string
18+
model: # New input for the model
19+
description: 'The Gemini model to use (e.g., gemini-1.5-flash-latest)'
20+
required: false
21+
default: 'gemini-1.5-flash-latest'
22+
type: string
1823
student-files:
1924
description: "Comma-separated list of student's Python file paths or a glob pattern"
2025
required: false

ai_tutor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020

2121
@functools.lru_cache
22-
def url(api_key:str) -> str:
23-
return f'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={api_key}'
22+
def url(api_key:str, model:str='gemini-1.5-flash-latest') -> str:
23+
return f'https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}'
2424

2525

2626
@functools.lru_cache
@@ -31,6 +31,7 @@ def header() -> HEADER:
3131
def ask_gemini(
3232
question: str,
3333
api_key:str,
34+
model:str='gemini-1.5-flash-latest',
3435
header:HEADER=header(),
3536
retry_delay_sec: float = 5.0,
3637
max_retry_attempt: int = 3,
@@ -60,7 +61,11 @@ def ask_gemini(
6061
logging.error(f"Timeout exceeded for question: {question}")
6162
break # Exit the loop on timeout
6263

63-
response = requests.post(url(api_key), headers=header, json=data)
64+
response = requests.post(
65+
url(api_key, model=model),
66+
headers=header,
67+
json=data
68+
)
6469

6570
if response.status_code == 200:
6671
result = response.json()
@@ -88,6 +93,7 @@ def gemini_qna(
8893
readme_file:pathlib.Path,
8994
api_key:str,
9095
explanation_in:str='Korean',
96+
model:str='gemini-1.5-flash-latest',
9197
) -> Tuple[int, str]:
9298
'''
9399
Queries the Gemini API to provide explanations for failed pytest test cases.
@@ -112,7 +118,7 @@ def gemini_qna(
112118
explanation_in
113119
)
114120

115-
answers = ask_gemini(consolidated_question, api_key)
121+
answers = ask_gemini(consolidated_question, api_key, model=model)
116122

117123
return n_failed, answers
118124

entrypoint.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@ def main() -> None:
3030
assert readme_file.exists(), 'No README file'
3131

3232
api_key = os.environ['INPUT_API-KEY'].strip()
33-
3433
assert api_key, "Please check API-KEY"
3534

35+
model = os.getenv(
36+
'INPUT_MODEL', # Get model from environment
37+
'gemini-1.5-flash-latest' # use default if not provided
38+
)
3639
explanation_in = os.environ['INPUT_EXPLANATION-IN']
3740

3841
b_fail_expected = ('true' == os.getenv('INPUT_FAIL-EXPECTED', 'false').lower())
@@ -42,7 +45,8 @@ def main() -> None:
4245
student_files,
4346
readme_file,
4447
api_key,
45-
explanation_in
48+
explanation_in,
49+
model=model,
4650
)
4751

4852
print(feedback)

tests/test_ai_tutor.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import pathlib
33
import sys
4+
import urllib.parse as up
45

56
from typing import Callable, Dict, List, Tuple, Union
67

@@ -549,5 +550,55 @@ def test__exclude_common_contents__double_specific(
549550
assert end_marker.strip() not in result
550551

551552

553+
@pytest.fixture
554+
def test_api_key() -> str:
555+
return 'test_api_key'
556+
557+
558+
@pytest.fixture
559+
def expected_default_gemini_model() -> str:
560+
return 'gemini-1.5-flash-latest'
561+
562+
563+
def test_url__default_model(
564+
test_api_key:str,
565+
expected_default_gemini_model:str
566+
):
567+
result = ai_tutor.url(test_api_key)
568+
569+
result_parsed = up.urlparse(result)
570+
571+
path_parts = result_parsed.path.split('/')
572+
573+
b_found = False
574+
575+
for part in path_parts:
576+
if expected_default_gemini_model in part.split(':'):
577+
b_found = True
578+
break
579+
580+
assert b_found, f"Could not find {expected_default_gemini_model} in {path_parts}."
581+
582+
583+
def test_url__specific_model(
584+
test_api_key:str,
585+
):
586+
model = 'gemini-2.0-flash'
587+
result = ai_tutor.url(test_api_key, model=model)
588+
589+
result_parsed = up.urlparse(result)
590+
591+
path_parts = result_parsed.path.split('/')
592+
593+
b_found = False
594+
595+
for part in path_parts:
596+
if model in part.split(':'):
597+
b_found = True
598+
break
599+
600+
assert b_found, f"Could not find {expected_default_gemini_model} in {path_parts}."
601+
602+
552603
if '__main__' == __name__:
553604
pytest.main([__file__])

tests/test_integration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_main_argument_passing__all_exists(mock_gemini_qna, caplog, tmp_path) ->
2020
# Setup
2121
os.environ['INPUT_API-KEY'] = 'test_key'
2222
os.environ['INPUT_EXPLANATION-IN'] = 'Korean'
23+
os.environ['INPUT_MODEL'] = 'gemini-2.0-flash-exp'
2324

2425
os.environ['GITHUB_OUTPUT'] = str(tmp_path / 'output.txt')
2526

@@ -54,6 +55,7 @@ def test_main_argument_passing__all_exists(mock_gemini_qna, caplog, tmp_path) ->
5455
tmp_path / 'readme.txt',
5556
'test_key',
5657
'Korean',
58+
model='gemini-2.0-flash-exp',
5759
)
5860

5961
assert 'does not exist' not in caplog.text

0 commit comments

Comments
 (0)