Skip to content

Commit c976e35

Browse files
author
Yesid Cano Castro
committed
refactor
1 parent d782179 commit c976e35

9 files changed

Lines changed: 356 additions & 746 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ logs/
166166
db.png
167167

168168
# Notebooks
169-
*.ipynb
169+
#*.ipynb
170170

171171
# Mac
172172
.DS_Store

eval/BERTscore_eval.ipynb

Lines changed: 0 additions & 248 deletions
This file was deleted.

eval/BERTscore_eval.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# This code was adapted from https://github.com/MarvinIRW/Assessing-Answer-Accuracy-Hallucination-and-Document-Relevance-in-virtUOS-Chatbot/tree/main/code/eval
2+
3+
import os
4+
5+
import pandas as pd
6+
from bert_score import score
7+
8+
9+
def compute_bertscore(
10+
df: pd.DataFrame,
11+
reference_col: str,
12+
hypothesis_col: str,
13+
question_id_col: str,
14+
language: str,
15+
output_csv_path: str,
16+
mean_csv_path=None,
17+
) -> pd.DataFrame:
18+
"""
19+
Computes BERTScore for each row in `df`.
20+
"""
21+
references = df[reference_col].astype(str).tolist()
22+
hypotheses = df[hypothesis_col].astype(str).tolist()
23+
assert len(references) == len(
24+
hypotheses
25+
), "Mismatch in # of references vs. hypotheses"
26+
27+
(P, R, F1), bert_hash = score(
28+
cands=hypotheses, refs=references, lang=language, verbose=True, return_hash=True
29+
)
30+
31+
bert_df = pd.DataFrame(
32+
{
33+
question_id_col: df[question_id_col].values,
34+
"BERTScore_P": P.tolist(),
35+
"BERTScore_R": R.tolist(),
36+
"BERTScore_F1": F1.tolist(),
37+
}
38+
)
39+
40+
system_f1_mean = bert_df["BERTScore_F1"].mean()
41+
print(f"[{language.upper()}] System-level BERTScore F1: {system_f1_mean:.3f}")
42+
print(f"[{language.upper()}] BERTScore hash code: {bert_hash}\n")
43+
44+
bert_df.to_csv(output_csv_path, index=False, quoting=1)
45+
print(f"BERTScore results saved to: {output_csv_path}")
46+
47+
if (
48+
mean_csv_path is not None
49+
and os.path.exists(mean_csv_path)
50+
and language is not None
51+
):
52+
mean_eval = pd.read_csv(mean_csv_path)
53+
metric_name = f"BERTScore_F1_{language}"
54+
if metric_name not in mean_eval["metric"].values:
55+
mean_eval = pd.concat(
56+
[
57+
mean_eval,
58+
pd.DataFrame([{"metric": metric_name, "value": system_f1_mean}]),
59+
],
60+
ignore_index=True,
61+
)
62+
mean_eval.to_csv(mean_csv_path, index=False)
63+
64+
return bert_df
65+
66+
67+
def run_bertscore_eval(config):
68+
csv_path_de = config.get("csv_path_de")
69+
mean_csv_path = config.get("csv_path_mean_bert", None)
70+
output_csv_de = config.get("output_csv_bert", None)
71+
if not csv_path_de or not output_csv_de:
72+
raise ValueError("csv_path_de and output_csv_bert must be set in config.")
73+
74+
df_de = pd.read_csv(csv_path_de)
75+
bert_df_de = compute_bertscore(
76+
df=df_de,
77+
reference_col="human_answer",
78+
hypothesis_col="chatbot_answer",
79+
question_id_col="question_id_q",
80+
language="de",
81+
output_csv_path=output_csv_de,
82+
mean_csv_path=mean_csv_path,
83+
)
84+
return bert_df_de

0 commit comments

Comments
 (0)