|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | | -# TODO(@zhewang2001): Please refactor the corresponding code snippets and then upload it. |
| 5 | +import json |
6 | 6 |
|
| 7 | +from datasets import load_dataset |
| 8 | +from rich import print as rprint |
7 | 9 |
|
8 | | -def evaluate_phtest(*args): |
9 | | - pass # placeholder |
| 10 | +from eval.oracles.check_secqa import evaluate_secqa_answers, is_bad_answer |
| 11 | + |
| 12 | + |
| 13 | +def evaluate_phtest( |
| 14 | + generation_path: str, |
| 15 | + model: str, |
| 16 | +): |
| 17 | + dataset = load_dataset("purpcode/PHTest", split="test") |
| 18 | + task_id_to_harmfulness = {row["task_id"]: row["Harmfulness"] for row in dataset} |
| 19 | + |
| 20 | + evaluate_secqa_answers(generation_path, model) |
| 21 | + |
| 22 | + intermediate_path = generation_path.replace(".jsonl", ".reject-intermediate.jsonl") |
| 23 | + with open(intermediate_path, "r") as f: |
| 24 | + intermediate_results = [json.loads(line) for line in f] |
| 25 | + |
| 26 | + harmless_rejected = 0 |
| 27 | + harmless_total = 0 |
| 28 | + controversial_rejected = 0 |
| 29 | + controversial_total = 0 |
| 30 | + |
| 31 | + for row in intermediate_results: |
| 32 | + task_id = row["task_id"] |
| 33 | + is_rejected = is_bad_answer(row["messages"][-1]["content"]) |
| 34 | + harmfulness = task_id_to_harmfulness.get(task_id, "unknown") |
| 35 | + |
| 36 | + if harmfulness == "harmless": |
| 37 | + harmless_total += 1 |
| 38 | + harmless_rejected += int(is_rejected) |
| 39 | + elif harmfulness == "controversial": |
| 40 | + controversial_total += 1 |
| 41 | + controversial_rejected += int(is_rejected) |
| 42 | + |
| 43 | + total_rejected = harmless_rejected + controversial_rejected |
| 44 | + total_all = harmless_total + controversial_total |
| 45 | + |
| 46 | + rprint( |
| 47 | + f"[green]Harmless rejection rate:{harmless_rejected} / {harmless_total} = {100 * harmless_rejected / harmless_total if harmless_total else 0.0:.1f}%[/green]" |
| 48 | + ) |
| 49 | + rprint( |
| 50 | + f"[green]Controversial rejection rate:{controversial_rejected} / {controversial_total} = {100 * controversial_rejected / controversial_total if controversial_total else 0.0:.1f}%[/green]" |
| 51 | + ) |
| 52 | + rprint( |
| 53 | + f"[green]Total rejection rate:{total_rejected} / {total_all} = {100 * total_rejected / total_all if total_all else 0.0:.1f}%[/green]" |
| 54 | + ) |
0 commit comments