|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | | -# TODO(@zhewang2001): Please refactor the corresponding code snippets and then upload it. |
| 5 | +import json |
| 6 | +import os |
| 7 | +import subprocess |
6 | 8 |
|
| 9 | +from datasets import load_dataset |
7 | 10 |
|
8 | | -def evaluate_cyberseceval(*args): |
9 | | - pass # placeholder |
| 11 | + |
| 12 | +def evaluate_cyberseceval(generation_path, task, purplellama_path=None): |
| 13 | + output_path = os.path.abspath(generation_path.replace(".jsonl", ".reformat.json")) |
| 14 | + task_dataset = load_dataset(task)["test"] |
| 15 | + task_dict = {item["task_id"]: item for item in task_dataset} |
| 16 | + |
| 17 | + with open(generation_path, "r") as f: |
| 18 | + data = [json.loads(line) for line in f] |
| 19 | + |
| 20 | + results = [] |
| 21 | + for idx, item in enumerate(data): |
| 22 | + task_id = item["task_id"] |
| 23 | + assistant_content = None |
| 24 | + for msg in item["messages"]: |
| 25 | + if msg["role"] == "assistant": |
| 26 | + assistant_content = msg["content"] |
| 27 | + break |
| 28 | + |
| 29 | + if task_id in task_dict: |
| 30 | + task_data = task_dict[task_id] |
| 31 | + model = generation_path.split("/")[-1].split(".")[0] |
| 32 | + |
| 33 | + result = { |
| 34 | + "prompt_id": idx, |
| 35 | + "pass_id": 0, |
| 36 | + "test_case_prompt": task_data["test_case_prompt"], |
| 37 | + "response": assistant_content, |
| 38 | + "cwe_identifier": task_data["cwe_identifier"], |
| 39 | + "language": task_data["language"], |
| 40 | + "line_text": task_data["line_text"], |
| 41 | + "origin_code": task_data["origin_code"], |
| 42 | + "variant": task_data.get("variant", ""), |
| 43 | + "rule": task_data.get("rule", ""), |
| 44 | + "repo": task_data.get("repo", ""), |
| 45 | + "model": model, |
| 46 | + } |
| 47 | + results.append(result) |
| 48 | + |
| 49 | + with open(output_path, "w") as f: |
| 50 | + json.dump(results, f, indent=4) |
| 51 | + |
| 52 | + original_pythonpath = os.environ.get("PYTHONPATH", "") |
| 53 | + current_dir = os.getcwd() |
| 54 | + |
| 55 | + if purplellama_path: |
| 56 | + os.chdir(purplellama_path) |
| 57 | + else: |
| 58 | + os.chdir("..") |
| 59 | + os.chdir("PurpleLlama") |
| 60 | + os.environ["PYTHONPATH"] = original_pythonpath + os.pathsep + os.getcwd() |
| 61 | + |
| 62 | + stat_path = output_path.replace(".reformat.json", ".stat.json") |
| 63 | + |
| 64 | + cmd = [ |
| 65 | + "python3", |
| 66 | + "-m", |
| 67 | + "CybersecurityBenchmarks.benchmark.run", |
| 68 | + "--benchmark=instruct", |
| 69 | + "--use-precomputed-responses", |
| 70 | + f"--response-path={output_path}", |
| 71 | + f"--stat-path={stat_path}", |
| 72 | + ] |
| 73 | + |
| 74 | + subprocess.run(cmd, check=True) |
| 75 | + os.chdir(current_dir) |
| 76 | + os.environ["PYTHONPATH"] = original_pythonpath |
0 commit comments