-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy patheval_plan_solve.py
More file actions
217 lines (187 loc) · 10.1 KB
/
eval_plan_solve.py
File metadata and controls
217 lines (187 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import argparse, config, openai, os, random, re
from intercode.envs import (
BashEnv, SqlEnv, ACTION_EXEC
)
import simplejson as json
from tqdm import tqdm
from typing import Dict, List
from experiments.utils import TemplatePlanSolve, ACTION_PARSER_MAP
parser = argparse.ArgumentParser(description='Plan & Solve evaluation for Intercode environment')
parser.add_argument('--data_path', type=str, help='path to dataset to evaluate on')
parser.add_argument('--env', choices=['sql', 'bash'], help='Intercode environment to run eval on')
parser.add_argument('--image_name', type=str, help='name of docker image to build environment with')
parser.add_argument('--log_dir', type=str, help='folder to save experiment run log file to')
parser.add_argument('--proportion', type=float, help="proportion of the dataset to use for evaluation")
parser.add_argument('--refine', action='store_true', help="whether to run refine step")
parser.add_argument('--refine_turns', type=int, help="number of turns to run refine step for")
parser.add_argument('--seed', type=int, help="seed for randomness")
parser.add_argument('--verbose', action='store_true', help="print out logs")
args = parser.parse_args()
SETTING_MAP = {
"sql": "MySQL Database",
"bash": "Bourne Shell"
}
def preprocess_sql(record: Dict) -> List:
db = record["db"]
return [f"use {db}"]
# Set OpenAPI key from environment or config file
api_key = os.environ.get("OPENAI_API_KEY")
if (api_key is None or api_key == "") and os.path.isfile(os.path.join(os.getcwd(), "keys.cfg")):
cfg = config.Config('keys.cfg')
api_key = cfg["OPENAI_API_KEY"]
openai.api_key = api_key
def llm(messages, stop=["\n"]):
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0,
top_p=1,
max_tokens=512,
n=1,
)
return response.choices[0].message
class ExperimentWrapper():
def __init__(self, args):
self.args = args
# Set environment (No logging for env)
self.env = None
if args.env == 'sql':
self.env = SqlEnv(image_name=args.image_name,
data_path=args.data_path, preprocess=preprocess_sql)
elif args.env == 'bash':
self.env = BashEnv(image_name=args.image_name,
data_path=args.data_path)
else:
raise ValueError(f'Environment {args.env} not recognized')
# Define log file name and path
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir, exist_ok=True)
log_file_name = f"{self.env.name}_plan_solve.json"
if args.refine and args.refine_turns:
log_file_name = f"{self.env.name}_plan_solve_refine_{args.refine_turns}_turns.json"
self.log_path = os.path.join(args.log_dir, log_file_name)
self.log_data = {}
# Define dialogue, template, parser
self.dialogue = []
self.template = TemplatePlanSolve(self.args.env.upper(), SETTING_MAP[self.args.env])
self.parser = ACTION_PARSER_MAP[self.args.env]
def run_expr(self):
try:
indices = range(len(self.env.data_loader))
if self.args.seed and self.args.proportion:
indices = random.Random(self.args.seed).choices(list(indices),
k=int(len(indices) * self.args.proportion))[6:]
for idx in tqdm(indices, disable=self.args.verbose):
# Reset variables per task
self.env.reset(idx)
observation, self.dialogue = None, []
turn_history = {"actions": [], "observations": [], "rewards": [], "steps": [], "valid_action": []}
record = self.env.data_loader.get(idx)
if self.args.verbose:
print(f'------\nQuery {idx}: {self.env.query}')
# Get plan
self.dialogue.append({"role": "system", "content": self.template.get_init_msg()})
self.dialogue.append({"role": "user", "content": self.template.get_query_msg(self.env.query)})
action = llm(self.dialogue)
plan = re.findall("\d+\.\s(.*?)(?=\s\d+\.|\Z)", action.content, re.DOTALL)
if self.args.verbose:
print(f"Plan: {plan}")
self.dialogue.append({"role": "system", "content": self.template.get_execute_plan_msg()})
observation, info = None, None
for idx_plan in range(len(plan)):
step = plan[idx_plan]
if isinstance(observation, str) and len(observation) > 1000:
observation = observation[:1000]
elif isinstance(observation, list) and len(observation) > 25:
observation = observation[:25]
observation = f"Step: {step}" if observation is None else f"Observation: {observation}\nStep: {step}"
self.dialogue.append({"role": "user", "content": str(observation)})
action = llm(self.dialogue)
action_parsed, is_code = self.parser(action.content)
if not is_code:
observation = self.template.get_retry_msg()
valid_action, reward, done = False, 0, False
else:
observation, _, _, info = self.env.step(action_parsed)
valid_action = info[ACTION_EXEC]
_, reward, done, info = self.env.step("submit")
self.dialogue.append(action)
if self.args.verbose:
print(f"- Step {idx_plan}: {step}")
print(f"-- Action: {action_parsed}")
if isinstance(observation, str) and observation.startswith(f'No {self.args.env.upper()} code'):
print(f"-- Observation: (meta) No code output, policy's template's retry message was invoked")
else:
print(f"-- Observation: {observation}")
turn_history["steps"].append(step)
turn_history["actions"].append(action_parsed)
turn_history["rewards"].append(reward)
turn_history["observations"].append(str(observation)) # To avoid serialization issues
turn_history["valid_action"].append(valid_action)
# Enter self-refine mode if enabled
if reward != 1 and self.args.refine and self.args.refine_turns:
if self.args.verbose:
print("Plan finished execution, entering refine mode")
self.dialogue.append({"role": "system", "content": self.template.get_after_plan_msg()})
for turn in range(self.args.refine_turns):
if isinstance(observation, str) and len(observation) > 250:
observation = observation[:250]
elif isinstance(observation, list) and len(observation) > 25:
observation = observation[:25]
self.dialogue.append({"role": "user", "content": str(observation)})
action = llm(self.dialogue)
action_parsed, is_code = self.parser(action.content)
if not is_code:
observation = self.template.get_retry_msg()
else:
observation, _, _, info = self.env.step(action_parsed)
valid_action = info[ACTION_EXEC]
_, reward, done, info = self.env.step("submit")
self.dialogue.append(action)
if self.args.verbose:
print(f"- Step (Refine): {turn + 1}")
print(f"-- Action: {action_parsed}")
if isinstance(observation, str) and observation.startswith(f'No {self.args.env.upper()} code'):
print(f"-- Observation: (meta) No code output, policy's template's retry message was invoked")
else:
print(f"-- Observation: {observation}")
turn_history["actions"].append(action_parsed)
turn_history["rewards"].append(reward)
turn_history["observations"].append(str(observation)) # To avoid serialization issues
turn_history["valid_action"].append(valid_action)
if reward == 1:
break
# Logging
max_reward, max_reward_idx = 0, -1
if len(turn_history["rewards"]) > 0:
max_reward = max(turn_history["rewards"])
max_reward_idx = turn_history["rewards"].index(max_reward)
log_episode = {
"environment": self.env.name,
"dataset": self.args.data_path,
"task_id": idx,
"query": self.env.query,
"turn_history": turn_history,
"info": info,
"summary": {
"max_reward": max_reward,
"max_reward_idx": max_reward_idx,
}
}
if "hardness" in record:
log_episode["hardness"] = record["hardness"]
self.log_data[idx] = log_episode
if self.args.verbose:
print(f"Query {idx} Finished\n-Reward: {max_reward}")
except KeyboardInterrupt:
print("Keyboard interrupt detected")
finally:
with open(self.log_path, "w") as fp:
json.dump({
"meta": vars(self.args),
"logs": self.log_data
}, fp, indent=2)
self.env.close()
if __name__ == '__main__':
expr_wrapper = ExperimentWrapper(args)
expr_wrapper.run_expr()