Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ To evaluate only correctness of final answers (system responses), you can clone
1. Prepare an input TSV file with columns `Question`, `Reference answer` and `Actual answer`
1. Execute `poetry install --with llm`
1. Execute
```<LLM_ACCESS_VARIABLE>=<your_api_key> poetry run answer-correctness -i <input_file.tsv> -o <output_file.tsv>```
```<LLM_ACCESS_VARIABLE>=<your_api_key> poetry run answer-correctness -i <input_file.tsv> -o <output_file.tsv> -c <config.yaml>```
replacing `<LLM_ACCESS_VARIABLE>` by the variable used by your LLM provider to specify your LLM use key.
Example:
```OPENAI_API_KEY=XXX poetry run answer-correctness -i reference.tsv -o evaluations.tsv```
```OPENAI_API_KEY=XXX poetry run answer-correctness -i reference.tsv -o evaluations.tsv -c conf.yaml```

We plan to improve CLI support in future releases.

Expand Down Expand Up @@ -104,7 +104,7 @@ The configuration has two sections: `llm` and `custom_evaluation`. Example:
* `base_url`: (str) base URL for the generation model, alternative to the provider's default URL
* `api_key`: (str) API key for the generation model, alternative to setting the environment variable corresponding to the provider (e.g. `OPENAI_API_KEY` for OpenAI)
* `embedding`: required for [`answer_relevance`](#output-keys).
* `provider`: (str) name of the organiation providing the embedding model
* `provider`: (str) name of the organization providing the embedding model
* `model`: (str) name of the embedding model
* `custom_evaluations`: (list of the following maps) required nonempty for [custom evaluation](#custom-evaluation-custom-metrics). Each map has keys:
* `name`: (str) name of the evaluation
Expand Down
198 changes: 62 additions & 136 deletions graphrag_eval/answer_correctness.py
Original file line number Diff line number Diff line change
@@ -1,100 +1,28 @@
import asyncio
import csv
from pathlib import Path

from tqdm import tqdm
from pydantic import BaseModel, Field

from graphrag_eval import llm_factory
from graphrag_eval.evaluation import Config
from graphrag_eval.util import compute_f1, singleton

IN_FILE_PATH = "../data/data-1.tsv"
PROMPT_FILE_PATH = Path(__file__).parent / "prompts" / "template.md"
OUT_FILE_PATH = "results/data-1.tsv"
OUT_FIELDS = ["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"]
LLM_PROVIDER = "openai"
LLM_MODEL = "gpt-4o-mini"
TEMPERATURE = 0.0
MAX_TOKENS = 1024

def load_default_prompt() -> str:
with open(Path(__file__).parent / "prompts" / "template.md", "r", encoding="utf-8") as f:
return f.read()

def parse_args() -> "argparse.Namespace":
from argparse import ArgumentParser, ArgumentTypeError

def float_between_0_0_and_2_0(value):
try:
f = float(value)
except ValueError:
raise ArgumentTypeError(f"Invalid float value: {value}")

if f <= 0.0 or f >= 2.0:
raise ArgumentTypeError(f"Value must be between 0.0 and 2.0, got {f}")
return f

parser = ArgumentParser()
parser.add_argument("-i", "--in-file", type=str, default=IN_FILE_PATH)
parser.add_argument("-o", "--out-file", type=str, default=OUT_FILE_PATH)
parser.add_argument("-p", "--provider", type=str, default=LLM_PROVIDER)
parser.add_argument("-l", "--llm", type=str, default=LLM_MODEL)
parser.add_argument("-m", "--max-tokens", type=int, default=MAX_TOKENS)
parser.add_argument(
"-t",
"--temperature",
type=float_between_0_0_and_2_0,
default=TEMPERATURE
)
return parser.parse_args()


def compute_recall_precision_f1(
n_pos: int | None,
n_pred_pos: int | None,
n_true_pos: int | None,
) -> tuple[float | None, float | None, float | None]:
recall = None
precision = None
if n_true_pos is not None and n_pos:
recall = n_true_pos / n_pos
if n_true_pos is not None and n_pred_pos:
precision = n_true_pos / n_pred_pos
return recall, precision, compute_f1(recall, precision)


def extract_response_values(
response: str
) -> tuple[int | None, int | None, int | None, str, str]:
vals = response.split("\t")
n = len(vals)
if n < 4:
msg = f"Expected 4 tab-separated values: {response}"
return None, None, None, "", msg
vals = vals[:4]
try:
n_ref, n_actual, n_matching = map(int, vals[:3])
except ValueError:
msg = f"Claims counts should be ints: {vals}"
return None, None, None, vals[3], msg
if any([
n_ref < 1,
n_actual < 1,
n_matching < 0,
n_matching > n_ref,
n_matching > n_actual
]):
msg = f"Invalid claims counts combination: {n_ref}\t{n_actual}\t{n_matching}"
return None, None, None, vals[3], msg
return n_ref, n_actual, n_matching, vals[3], ""
class AnswerCorrectnessConfig(BaseModel):
prompt: str = Field(default_factory=load_default_prompt)


@singleton
class AnswerCorrectnessEvaluator:
def __init__(
self,
llm: "InstructorBaseRagasLLM",
prompt_file_path: str | Path = PROMPT_FILE_PATH,
config: AnswerCorrectnessConfig | None = None,
):
with open(prompt_file_path, encoding="utf-8") as f:
self.prompt_template = f.read()
self.config = config or AnswerCorrectnessConfig()
self.prompt_template = self.config.prompt
self.llm = llm

async def _agenerate(self, prompt):
Expand All @@ -106,37 +34,38 @@ async def evaluate_answer(
question: str,
reference_answer: str,
actual_answer: str
):
) -> tuple[int, int, int, str]:
if any(not s.strip() for s in [question, reference_answer, actual_answer]):
raise ValueError("The question of the reference or the actual answer is a blank "
"string!")
prompt = self.prompt_template.format(
question=question,
reference_answer=reference_answer,
candidate_answer=actual_answer,
actual_answer=actual_answer,
)
response_str = await self._agenerate(prompt)
return extract_response_values(response_str)
return self.extract_response_values(response_str)

async def get_correctness_dict(
self,
reference: dict,
actual: dict,
):
result = {"reference_answer": reference["reference_answer"]}
num_ref_claims, num_actual_claims, num_matching_claims, reason, error = \
await self.evaluate_answer(
reference["question_text"],
reference["reference_answer"],
actual["actual_answer"],
)
if error:
result["answer_eval_error"] = error
else:
try:
num_ref_claims, num_actual_claims, num_matching_claims, reason = \
await self.evaluate_answer(
reference["question_text"],
reference["reference_answer"],
actual["actual_answer"],
)
result.update({
"answer_reference_claims_count": num_ref_claims,
"answer_actual_claims_count": num_actual_claims,
"answer_matching_claims_count": num_matching_claims,
"answer_correctness_reason": reason,
})
recall, precision, f1 = compute_recall_precision_f1(
recall, precision, f1 = self.compute_recall_precision_f1(
num_ref_claims, num_actual_claims, num_matching_claims
)
if recall is not None:
Expand All @@ -145,48 +74,45 @@ async def get_correctness_dict(
result["answer_precision"] = precision
if f1 is not None:
result["answer_f1"] = f1
except Exception as exc:
result["answer_eval_error"] = str(exc)
return result


async def evaluate_and_write(
in_file_path: str | Path,
out_file_path: str | Path,
config: "evaluation.Config",
) -> None:
ragas_llm = llm_factory.create_llm(config)
evaluator = AnswerCorrectnessEvaluator(llm=ragas_llm)
with open(in_file_path, encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t")
rows = [row for row in reader]
print(f"Writing results to {out_file_path}")
Path(out_file_path).parent.mkdir(parents=True, exist_ok=True)
with open(out_file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f, delimiter="\t")
writer.writerow(OUT_FIELDS)
for row in tqdm(rows):
vals = await evaluator.evaluate_answer(
row["Question"],
row["Reference answer"],
row["Actual answer"]
)
writer.writerow(vals)
f.flush()


def main():
args = parse_args()
config = Config(
llm=llm_factory.Config(
generation=llm_factory.GenerationConfig(
provider=args.provider,
model=args.llm,
temperature=args.temperature,
max_tokens=args.max_tokens,
@staticmethod
def compute_recall_precision_f1(
n_pos: int,
n_pred_pos: int,
n_true_pos: int,
) -> tuple[float | None, float | None, float | None]:
recall = None
precision = None
if n_pos:
recall = n_true_pos / n_pos
if n_pred_pos:
precision = n_true_pos / n_pred_pos
return recall, precision, compute_f1(recall, precision)

@staticmethod
def extract_response_values(
response: str
) -> tuple[int, int, int, str]:
vals = response.split("\t")
n = len(vals)
if n < 4:
raise ValueError(f"Expected 4 tab-separated values: {response}")
vals = vals[:4]
try:
n_ref, n_actual, n_matching = map(int, vals[:3])
except ValueError:
raise ValueError(f"Claims counts should be ints: {vals}")
if any([
n_ref < 1,
n_actual < 1,
n_matching < 0,
n_matching > n_ref,
n_matching > n_actual
]):
raise ValueError(
f"Invalid claims counts combination: {n_ref}\t{n_actual}\t{n_matching}"
)
)
)
asyncio.run(evaluate_and_write(
args.in_file,
args.out_file,
config,
))
return n_ref, n_actual, n_matching, vals[3]
Empty file added graphrag_eval/cli/__init__.py
Empty file.
61 changes: 61 additions & 0 deletions graphrag_eval/cli/answer_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import argparse
import asyncio
import csv
from argparse import ArgumentParser
from pathlib import Path

from tqdm import tqdm

from graphrag_eval import llm_factory
from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator
from graphrag_eval.evaluation import Config


def parse_args() -> argparse.Namespace:
parser = ArgumentParser()
parser.add_argument("-i", "--in-file", type=str, required=True)
parser.add_argument("-o", "--out-file", type=str, required=True)
parser.add_argument("-c", "--config-path", type=Path, required=True)
return parser.parse_args()


async def evaluate_and_write(
in_file_path: str | Path,
out_file_path: str | Path,
evaluator: AnswerCorrectnessEvaluator,
) -> None:
with open(in_file_path, encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t")
rows = [row for row in reader]
print(f"Writing results to {out_file_path}")
Path(out_file_path).parent.mkdir(parents=True, exist_ok=True)
with open(out_file_path, "w", encoding="utf-8") as f:
writer = csv.writer(f, delimiter="\t")
writer.writerow(["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"])
for row in tqdm(rows):
try:
vals = await evaluator.evaluate_answer(
row["Question"],
row["Reference answer"],
row["Actual answer"]
)
vals = vals + ("",)
writer.writerow(vals)
except Exception as exc:
writer.writerow(["", "", "", "", str(exc)])
f.flush()


def main():
args = parse_args()
config = Config.parse(args.config_path)
ragas_llm = llm_factory.create_llm(config)
if ragas_llm is None:
raise ValueError("LLM must be configured to calculate the answer correctness!")
else:
evaluator = AnswerCorrectnessEvaluator(llm=ragas_llm)
asyncio.run(evaluate_and_write(
args.in_file,
args.out_file,
evaluator,
))
5 changes: 4 additions & 1 deletion graphrag_eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel, Field, model_validator

from . import custom_evaluation
from .answer_correctness import AnswerCorrectnessConfig
from .llm_factory import Config as LLMConfig, create_llm, create_embedder
from .steps.evaluation import evaluate_steps

Expand All @@ -12,6 +13,7 @@ class Config(BaseModel):
llm: LLMConfig | None = None
custom_evaluations: list[custom_evaluation.Config] | None \
= Field(default=None, min_length=1)
answer_correctness: AnswerCorrectnessConfig | None = None

@model_validator(mode="after")
def validate_config(self) -> "Config":
Expand Down Expand Up @@ -75,7 +77,8 @@ async def run_evaluation(
if "reference_answer" in question and ragas_llm:
from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator
answer_correctness_evaluator = AnswerCorrectnessEvaluator(
llm=ragas_llm
llm=ragas_llm,
config=config.answer_correctness,
)
eval_result.update(
await answer_correctness_evaluator.get_correctness_dict(
Expand Down
4 changes: 2 additions & 2 deletions graphrag_eval/llm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
class GenerationConfig(BaseModel):
provider: str
model: str
temperature: float = Field(ge=0.0, le=2.0)
max_tokens: int = Field(ge=1)
temperature: float = Field(default=0.0, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, ge=1)
model_config = ConfigDict(extra='allow')


Expand Down
2 changes: 1 addition & 1 deletion graphrag_eval/prompts/template.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Below are a query, a reference response and a candidate response to it.
{reference_answer}

# Candidate response
{candidate_answer}
{actual_answer}

# Output values
* v1: Count of reference response claims
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pyyaml = "6.0.3"
optional = true

[project.scripts]
answer-correctness = "graphrag_eval.answer_correctness:main"
answer-correctness = "graphrag_eval.cli.answer_correctness:main"

[build-system]
requires = ["poetry-core>=2.0.0"]
Expand Down
Empty file added tests-with-llm/cli/__init__.py
Empty file.
Loading
Loading