Skip to content

Commit 452769d

Browse files
committed
Ignore checks where an error occured for evaluation
1 parent 3dedd00 commit 452769d

1 file changed

Lines changed: 21 additions & 1 deletion

File tree

ragability/ragability_eval.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def run(config: dict):
105105
n_errors_per_llm[llm] += 1
106106
continue
107107
for check in q["checks"]:
108+
error = check.get("error")
109+
if error:
110+
nc_errors += 1
111+
nc_errors_per_llm[llm] += 1
112+
continue
108113
func = check["func"]
109114
funcdef = CHECKS.get(func)
110115
kind = funcdef["kind"]
@@ -121,6 +126,8 @@ def run(config: dict):
121126
qid=q["qid"],
122127
tags=q["tags"],
123128
llm=llm,
129+
func=func,
130+
kind=kind,
124131
)
125132
# add any non=standard fields from the query to the row
126133
for k, v in q.items():
@@ -166,11 +173,24 @@ def run(config: dict):
166173
for key, df in checkdfs.items():
167174
kind, metric = key.split(":")
168175
for llm, llmgroup in df.groupby("llm"):
176+
try:
177+
score_value = sk.metrics.accuracy_score(llmgroup["target"].values, llmgroup["result"].values)
178+
except Exception as e:
179+
logger.error(f"Error: {e} in calculating metric {metric} for {llm} in {key}")
180+
# print the rows from the df where the result value is None or NaN and make sure all
181+
# columns are printed properly! For this, we need to convert each row to a dictionary
182+
# of column name / value pairs and print each dictionary in a separate line.
183+
for idx, row in llmgroup.iterrows():
184+
# only print the rows where the result is None or NaN
185+
if row["result"] is None or row["result"] != row["result"]:
186+
logger.error(f"Row {idx}: {dict(row)}")
187+
# set it to NaN if there is an error
188+
score_value = float("nan")
169189
dfrows.append(dict(
170190
group="all",
171191
llm=llm,
172192
metric=f"{metric}:accuracy",
173-
value=sk.metrics.accuracy_score(llmgroup["target"].values, llmgroup["result"].values)
193+
value=score_value
174194
))
175195
dfrows.append(dict(
176196
group="all",

0 commit comments

Comments
 (0)