Skip to content

Commit 362f102

Browse files
committed
add assert given failure
1 parent d3ca0c4 commit 362f102

5 files changed

Lines changed: 23 additions & 11 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
### Added
1010
* MECE principle in comment generation
11+
* if failed, assert the error message
1112

1213

1314
### Changed

ai_tutor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def gemini_qna(
8888
readme_file:pathlib.Path,
8989
api_key:str,
9090
explanation_in:str='Korean',
91-
) -> str:
91+
) -> Tuple[int, str]:
9292
'''
9393
Queries the Gemini API to provide explanations for failed pytest test cases.
9494
@@ -105,7 +105,7 @@ def gemini_qna(
105105
logging.info(f"Student files: {student_files}")
106106
logging.info(f"Readme file: {readme_file}")
107107

108-
consolidated_question = get_prompt(
108+
n_failed, consolidated_question = get_prompt(
109109
report_paths,
110110
student_files,
111111
readme_file,
@@ -114,17 +114,20 @@ def gemini_qna(
114114

115115
answers = ask_gemini(consolidated_question, api_key)
116116

117-
return answers
117+
return n_failed, answers
118118

119119

120120
def get_prompt(
121121
report_paths:Tuple[pathlib.Path],
122122
student_files:Tuple[pathlib.Path],
123123
readme_file:pathlib.Path,
124124
explanation_in:str,
125-
) -> str:
125+
) -> Tuple[int, str]:
126126
pytest_longrepr_list = collect_longrepr_from_multiple_reports(report_paths, explanation_in)
127127

128+
n_failed_tests = len(pytest_longrepr_list)
129+
130+
128131
def get_initial_instruction(questions:List[str],language:str) -> str:
129132
# Add the main directive or instruction based on whether there are failed tests
130133
if questions:
@@ -153,7 +156,7 @@ def get_initial_instruction(questions:List[str],language:str) -> str:
153156
# Join all questions into a single string
154157
prompt_str = "\n\n".join(prompt_list)
155158

156-
return prompt_str
159+
return n_failed_tests, prompt_str
157160

158161

159162
def collect_longrepr_from_multiple_reports(pytest_json_report_paths:Tuple[pathlib.Path], explanation_in:str) -> List[str]:

entrypoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def main() -> None:
3333

3434
explanation_in = os.environ['INPUT_EXPLANATION-IN']
3535

36-
feedback = ai_tutor.gemini_qna(
36+
n_failed, feedback = ai_tutor.gemini_qna(
3737
report_files,
3838
student_files,
3939
readme_file,
@@ -48,6 +48,8 @@ def main() -> None:
4848
out_string = f'feedback<<EOF\n{feedback}\nEOF'
4949
logging.info(f"Writing to GITHUB_OUTPUT: {f.write(out_string)} characters")
5050

51+
assert n_failed == 0, f'{n_failed} failed tests'
52+
5153

5254
def get_path_tuple(report_files_str:str) -> Tuple[pathlib.Path]:
5355
"""

tests/test_ai_tutor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,25 @@ def test_get_prompt__has__homework__msg__instruction(
246246
student_files=(sample_student_code_path,),
247247
readme_file=sample_readme_path,
248248
explanation_in=explanation_in,
249-
).lower()
249+
)
250+
251+
n_failed = result[0]
252+
prompe_text = result[1].lower()
250253

251254
assert any(
252255
map(
253-
lambda x: x in result,
256+
lambda x: x in prompe_text,
254257
homework
255258
)
256259
)
257-
assert msg in result
260+
assert msg in prompe_text
258261

259262
assert any(
260263
map(
261-
lambda x: x in result,
264+
lambda x: x in prompe_text,
262265
instruction
263266
)
264-
), f"Could not find instruction: {instruction} in result: {result}."
267+
), f"Could not find instruction: {instruction} in result: {prompe_text}."
265268

266269

267270
def test_load_locale(explanation_in:str, homework:Tuple[str]):

tests/test_integration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def test_main_argument_passing__all_exists(mock_gemini_qna, caplog, tmp_path) ->
4343
path_readme.touch()
4444
os.environ['INPUT_README-PATH'] = str(path_readme)
4545

46+
# mock return value
47+
mock_gemini_qna.return_value = (0, "This is the feedback message.")
48+
4649
entrypoint.main()
4750

4851
mock_gemini_qna.assert_called_once_with(

0 commit comments

Comments
 (0)