-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathFoMA_Eval.py
More file actions
178 lines (154 loc) · 8.14 KB
/
FoMA_Eval.py
File metadata and controls
178 lines (154 loc) · 8.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
import logging
import os
import sys
from generate_answers import process_data
from verify_answers import verify_answers
from evaluate_results import monte_carlo_evaluate
def parse_args():
parser = argparse.ArgumentParser(description="Pipeline for Lean theorem proof (generation, verification, and evaluation)")
# File paths
parser.add_argument("--input_file", default=None,
help="Path to the initial input file")
parser.add_argument("--generated_file", default=None,
help="Path to the output file containing generated answers")
parser.add_argument("--verification_file", default=None,
help="Path to the output file containing verification results")
parser.add_argument("--evaluation_file", default=None,
help="Path to the output file containing evaluation results")
# Task control
parser.add_argument("--auto_dl", action="store_true", default=True,
help="Automatically download dataset")
parser.add_argument("--generate", action="store_true", default=False,
help="Enable generation of answers")
parser.add_argument("--verify", action="store_true", default=False,
help="Enable verification of generated answers")
parser.add_argument("--evaluate", action="store_true", default=False,
help="Enable evaluation of verification results")
parser.add_argument("--datasets", default="FomaMATH-All",
help="Choose dataset version: FomaMATH-All or FomaMATH-Lite")
# Generation parameters - Add all parameters from the first script
parser.add_argument("--model", default=None,
help="Path to the model used for generating answers.")
parser.add_argument("--n", type=int, default=200,
help="Number of answers to generate per process via vllm.")
parser.add_argument("--nums_answer", type=int, default=3200,
help="Number of answers to generate per question.")
# Verification parameters
parser.add_argument("--repl_path", default="./repl",
help="Path to the Lean REPL used for verification")
parser.add_argument("--lean_env_path", default="./repl/test/Mathlib",
help="Path to the Lean environment used for verification")
parser.add_argument("--num_batches", default=32, type=int,
help="Number of parallel batches for verification")
parser.add_argument("--session_timeout", default=600, type=int,
help="Timeout for interactive sessions in seconds")
parser.add_argument("--expect_timeout", default=120, type=int,
help="Timeout for the expect command in seconds")
# Evaluation parameters
parser.add_argument("--n_simulations", default=50, type=int,
help="Number of Monte Carlo simulations")
parser.add_argument("--n_processes", default=50, type=int,
help="Number of parallel processes for Monte Carlo simulation")
parser.add_argument("--custom_sample_sizes", default=None, type=str,
help="Custom sampling sizes as a comma-separated list (e.g., '1,5,10,50,100')")
return parser.parse_args()
def set_up_logging(level=logging.INFO):
"""Set up logging with the specified level."""
logging.basicConfig(level=level, format='%(asctime)s - %(levelname)s - %(message)s')
def main():
args = parse_args()
# set_up_logging()
# Ensure at least one task is selected
if not (args.generate or args.verify or args.evaluate):
print("Please select at least one task (--generate, --verify, or --evaluate)")
return
# Step 0: Download datasets
if args.auto_dl:
from datasets import load_dataset
if args.datasets == "FomaMATH-All":
input_dataset_id = "SphereLab/FormalMATH-All"
elif args.datasets == "FomaMATH-Lite":
input_dataset_id = "SphereLab/FormalMATH-Lite"
else:
raise ValueError(f"Unknown dataset: {args.datasets}")
input_dataset_branch = "main"
local_dataset_path = "./data/"
os.makedirs(local_dataset_path, exist_ok=True)
try:
args.input_file = os.path.join(local_dataset_path, "FomaMATH.json")
args.generated_file = os.path.join(local_dataset_path, "FomaMATH_generated.json")
args.verification_file = os.path.join(local_dataset_path, "FomaMATH_verification.json")
args.evaluation_file = os.path.join(local_dataset_path, "FomaMATH_evaluation.json")
ds = load_dataset(input_dataset_id, split="train", revision=input_dataset_branch)
ds.to_json(args.input_file)
print(f"Dataset has been saved to: {local_dataset_path}")
except Exception as e:
print(f"Error occurred while downloading dataset: {e}")
# Step 1: Generate answers
if args.generate:
try:
print(f"Generating answers using model {args.model}")
process_data(
model_path=args.model,
input_file=args.input_file,
output_file=args.generated_file,
batch_size=args.n,
num_answers=args.nums_answer
)
print(f"Answers have been generated and saved to {args.generated_file}")
except Exception as e:
logging.error(f"Error during answer generation: {e}")
return
# Step 2: Verify answers
if args.verify:
try:
print("Starting verification of answers")
# Use the generated file as input if answers were generated, otherwise use the provided input file
verification_input = args.generated_file
# Check if the input file exists
if not os.path.exists(verification_input):
print(f"Verification input file {verification_input} does not exist. Please check the path or generate answers first.")
if args.evaluate:
print("Answer verification failed, proceeding to evaluation")
else:
return
verify_answers(
input_file=verification_input,
output_file=args.verification_file,
repl_path=args.repl_path,
lean_env_path=args.lean_env_path,
num_batches=args.num_batches,
session_timeout=args.session_timeout,
expect_timeout=args.expect_timeout
)
print(f"Verification complete. Results have been saved to {args.verification_file}")
except Exception as e:
logging.error(f"Error during answer verification: {e}")
return
# Step 3: Evaluate verification results
if args.evaluate:
try:
print("Starting evaluation of verification results")
# Check if the verification result file exists
if not os.path.exists(args.verification_file):
print(f"Verification result file {args.verification_file} does not exist. Please check the path or verify answers first.")
return
# Parse custom sampling sizes
sample_sizes = None
if args.custom_sample_sizes:
sample_sizes = [int(size) for size in args.custom_sample_sizes.split(',')]
monte_carlo_evaluate(
input_filepath=args.verification_file,
output_filepath=args.evaluation_file,
sample_sizes=sample_sizes,
n_simulations=args.n_simulations,
n_processes=args.n_processes
)
print(f"Evaluation complete. Results have been saved to {args.evaluation_file}")
except Exception as e:
logging.error(f"Error during evaluation: {e}")
return
print(f"All requested tasks have been completed successfully! You can check your successful rate in {args.evaluation_file}")
if __name__ == "__main__":
main()