Skip to content

Commit 2586f84

Browse files
authored
feat(eval): phtest (#7)
1 parent 21394ff commit 2586f84

1 file changed

Lines changed: 48 additions & 3 deletions

File tree

eval/phtest.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,53 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
# TODO(@zhewang2001): Please refactor the corresponding code snippets and then upload it.
5+
import json
66

7+
from datasets import load_dataset
8+
from rich import print as rprint
79

8-
def evaluate_phtest(*args):
9-
pass # placeholder
10+
from eval.oracles.check_secqa import evaluate_secqa_answers, is_bad_answer
11+
12+
13+
def evaluate_phtest(
14+
generation_path: str,
15+
model: str,
16+
):
17+
dataset = load_dataset("purpcode/PHTest", split="test")
18+
task_id_to_harmfulness = {row["task_id"]: row["Harmfulness"] for row in dataset}
19+
20+
evaluate_secqa_answers(generation_path, model)
21+
22+
intermediate_path = generation_path.replace(".jsonl", ".reject-intermediate.jsonl")
23+
with open(intermediate_path, "r") as f:
24+
intermediate_results = [json.loads(line) for line in f]
25+
26+
harmless_rejected = 0
27+
harmless_total = 0
28+
controversial_rejected = 0
29+
controversial_total = 0
30+
31+
for row in intermediate_results:
32+
task_id = row["task_id"]
33+
is_rejected = is_bad_answer(row["messages"][-1]["content"])
34+
harmfulness = task_id_to_harmfulness.get(task_id, "unknown")
35+
36+
if harmfulness == "harmless":
37+
harmless_total += 1
38+
harmless_rejected += int(is_rejected)
39+
elif harmfulness == "controversial":
40+
controversial_total += 1
41+
controversial_rejected += int(is_rejected)
42+
43+
total_rejected = harmless_rejected + controversial_rejected
44+
total_all = harmless_total + controversial_total
45+
46+
rprint(
47+
f"[green]Harmless rejection rate:{harmless_rejected} / {harmless_total} = {100 * harmless_rejected / harmless_total if harmless_total else 0.0:.1f}%[/green]"
48+
)
49+
rprint(
50+
f"[green]Controversial rejection rate:{controversial_rejected} / {controversial_total} = {100 * controversial_rejected / controversial_total if controversial_total else 0.0:.1f}%[/green]"
51+
)
52+
rprint(
53+
f"[green]Total rejection rate:{total_rejected} / {total_all} = {100 * total_rejected / total_all if total_all else 0.0:.1f}%[/green]"
54+
)

0 commit comments

Comments
 (0)