diff --git a/README.md b/README.md index 0ca676b5..fdffcdd1 100644 --- a/README.md +++ b/README.md @@ -44,10 +44,10 @@ To evaluate only correctness of final answers (system responses), you can clone 1. Prepare an input TSV file with columns `Question`, `Reference answer` and `Actual answer` 1. Execute `poetry install --with llm` 1. Execute - ```= poetry run answer-correctness -i -o ``` + ```= poetry run answer-correctness -i -o -c ``` replacing `` by the variable used by your LLM provider to specify your LLM use key. Example: - ```OPENAI_API_KEY=XXX poetry run answer-correctness -i reference.tsv -o evaluations.tsv``` + ```OPENAI_API_KEY=XXX poetry run answer-correctness -i reference.tsv -o evaluations.tsv -c conf.yaml``` We plan to improve CLI support in future releases. @@ -104,7 +104,7 @@ The configuration has two sections: `llm` and `custom_evaluation`. Example: * `base_url`: (str) base URL for the generation model, alternative to the provider's default URL * `api_key`: (str) API key for the generation model, alternative to setting the environment variable corresponding to the provider (e.g. `OPENAI_API_KEY` for OpenAI) * `embedding`: required for [`answer_relevance`](#output-keys). - * `provider`: (str) name of the organiation providing the embedding model + * `provider`: (str) name of the organization providing the embedding model * `model`: (str) name of the embedding model * `custom_evaluations`: (list of the following maps) required nonempty for [custom evaluation](#custom-evaluation-custom-metrics). Each map has keys: * `name`: (str) name of the evaluation diff --git a/graphrag_eval/answer_correctness.py b/graphrag_eval/answer_correctness.py index 52b25426..1e8a5e4f 100644 --- a/graphrag_eval/answer_correctness.py +++ b/graphrag_eval/answer_correctness.py @@ -1,89 +1,17 @@ -import asyncio -import csv from pathlib import Path -from tqdm import tqdm +from pydantic import BaseModel, Field -from graphrag_eval import llm_factory -from graphrag_eval.evaluation import Config from graphrag_eval.util import compute_f1, singleton -IN_FILE_PATH = "../data/data-1.tsv" -PROMPT_FILE_PATH = Path(__file__).parent / "prompts" / "template.md" -OUT_FILE_PATH = "results/data-1.tsv" -OUT_FIELDS = ["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"] -LLM_PROVIDER = "openai" -LLM_MODEL = "gpt-4o-mini" -TEMPERATURE = 0.0 -MAX_TOKENS = 1024 +def load_default_prompt() -> str: + with open(Path(__file__).parent / "prompts" / "template.md", "r", encoding="utf-8") as f: + return f.read() -def parse_args() -> "argparse.Namespace": - from argparse import ArgumentParser, ArgumentTypeError - - def float_between_0_0_and_2_0(value): - try: - f = float(value) - except ValueError: - raise ArgumentTypeError(f"Invalid float value: {value}") - - if f <= 0.0 or f >= 2.0: - raise ArgumentTypeError(f"Value must be between 0.0 and 2.0, got {f}") - return f - - parser = ArgumentParser() - parser.add_argument("-i", "--in-file", type=str, default=IN_FILE_PATH) - parser.add_argument("-o", "--out-file", type=str, default=OUT_FILE_PATH) - parser.add_argument("-p", "--provider", type=str, default=LLM_PROVIDER) - parser.add_argument("-l", "--llm", type=str, default=LLM_MODEL) - parser.add_argument("-m", "--max-tokens", type=int, default=MAX_TOKENS) - parser.add_argument( - "-t", - "--temperature", - type=float_between_0_0_and_2_0, - default=TEMPERATURE - ) - return parser.parse_args() - -def compute_recall_precision_f1( - n_pos: int | None, - n_pred_pos: int | None, - n_true_pos: int | None, -) -> tuple[float | None, float | None, float | None]: - recall = None - precision = None - if n_true_pos is not None and n_pos: - recall = n_true_pos / n_pos - if n_true_pos is not None and n_pred_pos: - precision = n_true_pos / n_pred_pos - return recall, precision, compute_f1(recall, precision) - - -def extract_response_values( - response: str -) -> tuple[int | None, int | None, int | None, str, str]: - vals = response.split("\t") - n = len(vals) - if n < 4: - msg = f"Expected 4 tab-separated values: {response}" - return None, None, None, "", msg - vals = vals[:4] - try: - n_ref, n_actual, n_matching = map(int, vals[:3]) - except ValueError: - msg = f"Claims counts should be ints: {vals}" - return None, None, None, vals[3], msg - if any([ - n_ref < 1, - n_actual < 1, - n_matching < 0, - n_matching > n_ref, - n_matching > n_actual - ]): - msg = f"Invalid claims counts combination: {n_ref}\t{n_actual}\t{n_matching}" - return None, None, None, vals[3], msg - return n_ref, n_actual, n_matching, vals[3], "" +class AnswerCorrectnessConfig(BaseModel): + prompt: str = Field(default_factory=load_default_prompt) @singleton @@ -91,10 +19,10 @@ class AnswerCorrectnessEvaluator: def __init__( self, llm: "InstructorBaseRagasLLM", - prompt_file_path: str | Path = PROMPT_FILE_PATH, + config: AnswerCorrectnessConfig | None = None, ): - with open(prompt_file_path, encoding="utf-8") as f: - self.prompt_template = f.read() + self.config = config or AnswerCorrectnessConfig() + self.prompt_template = self.config.prompt self.llm = llm async def _agenerate(self, prompt): @@ -106,14 +34,17 @@ async def evaluate_answer( question: str, reference_answer: str, actual_answer: str - ): + ) -> tuple[int, int, int, str]: + if any(not s.strip() for s in [question, reference_answer, actual_answer]): + raise ValueError("The question of the reference or the actual answer is a blank " + "string!") prompt = self.prompt_template.format( question=question, reference_answer=reference_answer, - candidate_answer=actual_answer, + actual_answer=actual_answer, ) response_str = await self._agenerate(prompt) - return extract_response_values(response_str) + return self.extract_response_values(response_str) async def get_correctness_dict( self, @@ -121,22 +52,20 @@ async def get_correctness_dict( actual: dict, ): result = {"reference_answer": reference["reference_answer"]} - num_ref_claims, num_actual_claims, num_matching_claims, reason, error = \ - await self.evaluate_answer( - reference["question_text"], - reference["reference_answer"], - actual["actual_answer"], - ) - if error: - result["answer_eval_error"] = error - else: + try: + num_ref_claims, num_actual_claims, num_matching_claims, reason = \ + await self.evaluate_answer( + reference["question_text"], + reference["reference_answer"], + actual["actual_answer"], + ) result.update({ "answer_reference_claims_count": num_ref_claims, "answer_actual_claims_count": num_actual_claims, "answer_matching_claims_count": num_matching_claims, "answer_correctness_reason": reason, }) - recall, precision, f1 = compute_recall_precision_f1( + recall, precision, f1 = self.compute_recall_precision_f1( num_ref_claims, num_actual_claims, num_matching_claims ) if recall is not None: @@ -145,48 +74,45 @@ async def get_correctness_dict( result["answer_precision"] = precision if f1 is not None: result["answer_f1"] = f1 + except Exception as exc: + result["answer_eval_error"] = str(exc) return result - -async def evaluate_and_write( - in_file_path: str | Path, - out_file_path: str | Path, - config: "evaluation.Config", -) -> None: - ragas_llm = llm_factory.create_llm(config) - evaluator = AnswerCorrectnessEvaluator(llm=ragas_llm) - with open(in_file_path, encoding="utf-8") as f: - reader = csv.DictReader(f, delimiter="\t") - rows = [row for row in reader] - print(f"Writing results to {out_file_path}") - Path(out_file_path).parent.mkdir(parents=True, exist_ok=True) - with open(out_file_path, "w", encoding="utf-8") as f: - writer = csv.writer(f, delimiter="\t") - writer.writerow(OUT_FIELDS) - for row in tqdm(rows): - vals = await evaluator.evaluate_answer( - row["Question"], - row["Reference answer"], - row["Actual answer"] - ) - writer.writerow(vals) - f.flush() - - -def main(): - args = parse_args() - config = Config( - llm=llm_factory.Config( - generation=llm_factory.GenerationConfig( - provider=args.provider, - model=args.llm, - temperature=args.temperature, - max_tokens=args.max_tokens, + @staticmethod + def compute_recall_precision_f1( + n_pos: int, + n_pred_pos: int, + n_true_pos: int, + ) -> tuple[float | None, float | None, float | None]: + recall = None + precision = None + if n_pos: + recall = n_true_pos / n_pos + if n_pred_pos: + precision = n_true_pos / n_pred_pos + return recall, precision, compute_f1(recall, precision) + + @staticmethod + def extract_response_values( + response: str + ) -> tuple[int, int, int, str]: + vals = response.split("\t") + n = len(vals) + if n < 4: + raise ValueError(f"Expected 4 tab-separated values: {response}") + vals = vals[:4] + try: + n_ref, n_actual, n_matching = map(int, vals[:3]) + except ValueError: + raise ValueError(f"Claims counts should be ints: {vals}") + if any([ + n_ref < 1, + n_actual < 1, + n_matching < 0, + n_matching > n_ref, + n_matching > n_actual + ]): + raise ValueError( + f"Invalid claims counts combination: {n_ref}\t{n_actual}\t{n_matching}" ) - ) - ) - asyncio.run(evaluate_and_write( - args.in_file, - args.out_file, - config, - )) + return n_ref, n_actual, n_matching, vals[3] diff --git a/graphrag_eval/cli/__init__.py b/graphrag_eval/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphrag_eval/cli/answer_correctness.py b/graphrag_eval/cli/answer_correctness.py new file mode 100644 index 00000000..87bfc729 --- /dev/null +++ b/graphrag_eval/cli/answer_correctness.py @@ -0,0 +1,61 @@ +import argparse +import asyncio +import csv +from argparse import ArgumentParser +from pathlib import Path + +from tqdm import tqdm + +from graphrag_eval import llm_factory +from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator +from graphrag_eval.evaluation import Config + + +def parse_args() -> argparse.Namespace: + parser = ArgumentParser() + parser.add_argument("-i", "--in-file", type=str, required=True) + parser.add_argument("-o", "--out-file", type=str, required=True) + parser.add_argument("-c", "--config-path", type=Path, required=True) + return parser.parse_args() + + +async def evaluate_and_write( + in_file_path: str | Path, + out_file_path: str | Path, + evaluator: AnswerCorrectnessEvaluator, +) -> None: + with open(in_file_path, encoding="utf-8") as f: + reader = csv.DictReader(f, delimiter="\t") + rows = [row for row in reader] + print(f"Writing results to {out_file_path}") + Path(out_file_path).parent.mkdir(parents=True, exist_ok=True) + with open(out_file_path, "w", encoding="utf-8") as f: + writer = csv.writer(f, delimiter="\t") + writer.writerow(["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"]) + for row in tqdm(rows): + try: + vals = await evaluator.evaluate_answer( + row["Question"], + row["Reference answer"], + row["Actual answer"] + ) + vals = vals + ("",) + writer.writerow(vals) + except Exception as exc: + writer.writerow(["", "", "", "", str(exc)]) + f.flush() + + +def main(): + args = parse_args() + config = Config.parse(args.config_path) + ragas_llm = llm_factory.create_llm(config) + if ragas_llm is None: + raise ValueError("LLM must be configured to calculate the answer correctness!") + else: + evaluator = AnswerCorrectnessEvaluator(llm=ragas_llm) + asyncio.run(evaluate_and_write( + args.in_file, + args.out_file, + evaluator, + )) diff --git a/graphrag_eval/evaluation.py b/graphrag_eval/evaluation.py index 76c6600a..d3dfe80c 100644 --- a/graphrag_eval/evaluation.py +++ b/graphrag_eval/evaluation.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, model_validator from . import custom_evaluation +from .answer_correctness import AnswerCorrectnessConfig from .llm_factory import Config as LLMConfig, create_llm, create_embedder from .steps.evaluation import evaluate_steps @@ -12,6 +13,7 @@ class Config(BaseModel): llm: LLMConfig | None = None custom_evaluations: list[custom_evaluation.Config] | None \ = Field(default=None, min_length=1) + answer_correctness: AnswerCorrectnessConfig | None = None @model_validator(mode="after") def validate_config(self) -> "Config": @@ -75,7 +77,8 @@ async def run_evaluation( if "reference_answer" in question and ragas_llm: from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator answer_correctness_evaluator = AnswerCorrectnessEvaluator( - llm=ragas_llm + llm=ragas_llm, + config=config.answer_correctness, ) eval_result.update( await answer_correctness_evaluator.get_correctness_dict( diff --git a/graphrag_eval/llm_factory.py b/graphrag_eval/llm_factory.py index 53503fbd..87b0d803 100644 --- a/graphrag_eval/llm_factory.py +++ b/graphrag_eval/llm_factory.py @@ -6,8 +6,8 @@ class GenerationConfig(BaseModel): provider: str model: str - temperature: float = Field(ge=0.0, le=2.0) - max_tokens: int = Field(ge=1) + temperature: float = Field(default=0.0, ge=0.0, le=2.0) + max_tokens: int | None = Field(default=None, ge=1) model_config = ConfigDict(extra='allow') diff --git a/graphrag_eval/prompts/template.md b/graphrag_eval/prompts/template.md index 20f3f80b..642f6cdd 100644 --- a/graphrag_eval/prompts/template.md +++ b/graphrag_eval/prompts/template.md @@ -10,7 +10,7 @@ Below are a query, a reference response and a candidate response to it. {reference_answer} # Candidate response -{candidate_answer} +{actual_answer} # Output values * v1: Count of reference response claims diff --git a/pyproject.toml b/pyproject.toml index 37900b1d..a1a298d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ pyyaml = "6.0.3" optional = true [project.scripts] -answer-correctness = "graphrag_eval.answer_correctness:main" +answer-correctness = "graphrag_eval.cli.answer_correctness:main" [build-system] requires = ["poetry-core>=2.0.0"] diff --git a/tests-with-llm/cli/__init__.py b/tests-with-llm/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests-with-llm/cli/test_answer_correctness.py b/tests-with-llm/cli/test_answer_correctness.py new file mode 100644 index 00000000..8af70d09 --- /dev/null +++ b/tests-with-llm/cli/test_answer_correctness.py @@ -0,0 +1,54 @@ +import builtins +import io +from unittest.mock import MagicMock + +import pytest + +from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator +from graphrag_eval.cli import answer_correctness + + +@pytest.mark.asyncio +async def test_evaluate_answers(monkeypatch, tmp_path): + mock_prompt_content = "Prompt with {question} {reference_answer} {actual_answer}" + mock_input_content = "Question\tReference answer\tActual answer\nQ1\tRef\tAns\n" + + prompt_file_path = "prompt_file_path" + in_file_path = "in_file_path" + out_file_path = tmp_path / "out_file_name" + + # Mock open() + real_open = builtins.open + + def mock_open(path, *args, **kwargs): + str_path = str(path) + if str_path == prompt_file_path: + return io.StringIO(mock_prompt_content) + elif str_path == in_file_path: + return io.StringIO(mock_input_content) + return real_open(path, *args, **kwargs) + + monkeypatch.setattr(builtins, "open", mock_open) + answer_correctness_evaluator = AnswerCorrectnessEvaluator(llm=MagicMock()).__class__ + + async def mock_agenerate(self, prompt): + return "2\t2\t2\treason" + + monkeypatch.setattr( + answer_correctness_evaluator, + "_agenerate", + mock_agenerate + ) + monkeypatch.setattr(answer_correctness, "tqdm", lambda x: x) + + # Run + await answer_correctness.evaluate_and_write( + in_file_path, + out_file_path, + evaluator=AnswerCorrectnessEvaluator(llm=MagicMock()) + ) + + # Verify output file content + written = out_file_path.read_text().splitlines() + assert written[0].split("\t") == ["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"] + assert written[1].split("\t") == ["2", "2", "2", "reason", ""] diff --git a/tests-with-llm/steps/__init__.py b/tests-with-llm/steps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests-with-llm/test_answer_correctness.py b/tests-with-llm/test_answer_correctness.py index 0edcf013..d13df9a5 100644 --- a/tests-with-llm/test_answer_correctness.py +++ b/tests-with-llm/test_answer_correctness.py @@ -1,123 +1,94 @@ -import builtins -import io +import re +from typing import Any from unittest.mock import MagicMock import pytest +from pytest import raises -from graphrag_eval import answer_correctness, evaluation -from graphrag_eval.answer_correctness import ( - AnswerCorrectnessEvaluator, - extract_response_values, -) -from graphrag_eval.llm_factory import Config, GenerationConfig - - -def get_llm_config(): - return evaluation.Config( - llm=Config( - generation=GenerationConfig( - provider="openai", - model="gpt-4o-mini", - temperature=0.0, - max_tokens=1024, - ) - ) - ) +from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator def test_extract_response_values_expected_case(): response = "2\t3\t1\treason" - result = extract_response_values(response) - assert result == (2, 3, 1, "reason", "") - - -def test_extract_response_values_invalid_values(): - response = "0\t1\t1\treason" - result = extract_response_values(response) - assert result[4] - - response = "1\t0\t1\treason" - result = extract_response_values(response) - assert result[4] - - response = "1\t2\t-1\treason" - result = extract_response_values(response) - assert result[4] - - response = "1\t3\t2\treason" - result = extract_response_values(response) - assert result[4] - - response = "3\t1\t2\treason" - result = extract_response_values(response) - assert result[4] - - response = "3\t1\t2\treason" - result = extract_response_values(response) - assert result[4] + result = AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) + assert result == (2, 3, 1, "reason") + + +@pytest.mark.parametrize( + "question, reference_answer, actual_answer", + [ + ("Is Sofia the capital of Bulgaria?", "Yes", ""), + ("Is Sofia the capital of Bulgaria?", "Yes", " "), + ("Is Sofia the capital of Bulgaria?", "Yes", "\n\t \r"), + ("Is Sofia the capital of Bulgaria?", "", "No"), + ("Is Sofia the capital of Bulgaria?", " ", "No"), + ("Is Sofia the capital of Bulgaria?", "\n\t \r", "No"), + ("", "Yes", "No"), + (" ", "Yes", "No"), + ("\n\t \r", "Yes", "No"), + ("", "Yes", ""), + ("", "", "No"), + ("", "", ""), + ], +) +@pytest.mark.asyncio +async def test_evaluate_answer_empty_strings( + question: str, reference_answer: str, actual_answer: str +): + with raises(ValueError, match="The question of the reference or the actual answer is a blank " + "string!"): + await AnswerCorrectnessEvaluator(llm=MagicMock()).evaluate_answer( + question, reference_answer, actual_answer + ) -def test_extract_response_values_non_int(): - response = "2\t2\tx\treason" - result = answer_correctness.extract_response_values(response) - assert result[4] +@pytest.mark.parametrize( + "n_ref, n_actual, n_matching", + [ + (0, 1, 1), + (1, 0, 1), + (15, 0, 0), + (1, 2, -1), + (1, 3, 2), + (3, 1, 2) + ], +) +def test_extract_response_values_invalid_values(n_ref: int, n_actual: int, n_matching: int): + response = f"{n_ref}\t{n_actual}\t{n_matching}\treason" + with raises(ValueError, + match=f"Invalid claims counts combination: {n_ref}\t{n_actual}\t{n_matching}"): + AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) + + +@pytest.mark.parametrize( + "n_ref, n_actual, n_matching", + [ + (1, 1, "x"), + (1, "x", 1), + (1, "x", "y"), + ("x", 1, 1), + ("x", 1, "y"), + ("x", "y", 1), + ("x", "y", "z"), + ], +) +def test_extract_response_values_non_int(n_ref: Any, n_actual: Any, n_matching: Any): + response = f"{n_ref}\t{n_actual}\t{n_matching}\treason" + with raises(ValueError, + match=re.escape( + f"Claims counts should be ints: ['{n_ref}', '{n_actual}', '{n_matching}', " + f"'reason']" + )): + AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) def test_extract_response_values_too_few_values(): response = "2\t2\treason" - result = answer_correctness.extract_response_values(response) - # fewer than 4 values → error - assert result[4] + with raises(ValueError, match=f"Expected 4 tab-separated values: {response}"): + AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) def test_extract_response_values_too_many_values(): response = "2\t2\t2\treason\textra" - result = answer_correctness.extract_response_values(response) - # only first 4 should be taken - assert result == (2, 2, 2, "reason", "") - - -@pytest.mark.asyncio -async def test_evaluate_answers(monkeypatch, tmp_path): - mock_prompt_content = "Prompt with {question} {reference_answer} {candidate_answer}" - mock_input_content = "Question\tReference answer\tActual answer\nQ1\tRef\tAns\n" - - prompt_file_path = "prompt_file_path" - in_file_path = "in_file_path" - out_file_path = tmp_path / "out_file_name" - - # Mock open() - real_open = builtins.open - - def mock_open(path, *args, **kwargs): - str_path = str(path) - if str_path == prompt_file_path: - return io.StringIO(mock_prompt_content) - elif str_path == in_file_path: - return io.StringIO(mock_input_content) - return real_open(path, *args, **kwargs) - - monkeypatch.setattr(builtins, "open", mock_open) - answer_correctness_evaluator = AnswerCorrectnessEvaluator(llm=MagicMock()).__class__ - - async def mock_agenerate(self, prompt): - return "2\t2\t2\treason" - - monkeypatch.setattr( - answer_correctness_evaluator, - "_agenerate", - mock_agenerate - ) - monkeypatch.setattr(answer_correctness, "tqdm", lambda x: x) - - # Run - await answer_correctness.evaluate_and_write( - in_file_path, - out_file_path, - config=get_llm_config() - ) - - # Verify output file content - written = out_file_path.read_text().splitlines() - assert written[0].split("\t") == answer_correctness.OUT_FIELDS - assert written[1].split("\t") == ["2", "2", "2", "reason", ""] + result = AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) + assert result == (2, 2, 2, "reason")