Skip to content

Commit 4864f78

Browse files
author
stephanie gross
committed
WIP
1 parent c31d427 commit 4864f78

1 file changed

Lines changed: 275 additions & 0 deletions

File tree

ragability/ragability_check2.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Module to check responses against target facts and assign scores. This creates a result file with the scores which
5+
can then be used to calculate summary statistics in various ways.
6+
"""
7+
8+
import sys
9+
import json
10+
import argparse
11+
import datetime
12+
import hjson
13+
from logging import DEBUG
14+
from ragability.data import read_input_file, read_prompt_file
15+
from llms_wrapper.config import read_config_file, update_llm_config
16+
from ragability.logging import logger, set_logging_level, add_logging_file
17+
from llms_wrapper.llms import LLMS, LLM
18+
from ragability.utils import pp_config
19+
from ragability.checks2 import CHECKS
20+
21+
DEFAULT_PROMPT = {
22+
"system": "You are an expert analyzing responses and how they related to desired facts or properties of the responses. You will be given the response following RESPONSE: and before QUERY:, and a query telling you what to analyze after QUERY:",
23+
"user": "RESPONSE: ${response} QUERY: ${query}",
24+
}
25+
26+
# TODO: allow ${fact0} to ${fact9} as substitution fields.
27+
28+
29+
def get_args():
30+
"""
31+
Get the command line arguments
32+
"""
33+
parser = argparse.ArgumentParser(description='Check responses against target facts and assign scores')
34+
parser.add_argument('--input', '-i', type=str, help='Input file with the responses from ragability_query (or from config), jsonl, json, yaml', required=False)
35+
parser.add_argument('--output', '-o', type=str, help='Output file with the checking results (default: $DATETIME.out.jsonl), jsonl, json, yaml', required=False)
36+
parser.add_argument("--config", "-c", type=str, help="Config file with the LLM and other info, json, jsonl, yaml", required=False)
37+
parser.add_argument('--usellm', '-u', type=str, help='The alias of the configured LLM to use (use first one found)', required=False)
38+
parser.add_argument("--promptfile", "-pf", type=str, help="File with the prompt to use for the checking queries (or use config), jsonl, json, yaml", required=False)
39+
parser.add_argument("--all", "-a", action="store_true", help="Run all queries, even if they have a response", required=False)
40+
parser.add_argument("--logfile", "-f", type=str, help="Log file", required=False)
41+
parser.add_argument("--dry-run", "-n", action="store_true", help="Dry run, do not actually run the queries", required=False)
42+
parser.add_argument("--verbose", "-v", action="store_true", help="Be more verbose and inform what is happening", required=False)
43+
parser.add_argument("--debug", "-d", action="store_true", help="Debug mode", required=False)
44+
args_tmp = parser.parse_args()
45+
tmp = {}
46+
tmp.update(vars(args_tmp))
47+
args: dict = tmp
48+
49+
# if a config file is specified, read the config file using our config reading function and update the arguments.
50+
# The config data may contain:
51+
# - input: used only if not specified in the command line arguments
52+
# - output: used only if not specified in the command line arguments
53+
# - llm: added to the ones specified in the command line arguments
54+
# - prompt: used to add config info to the llms specified in the command line arguments
55+
if args["config"]:
56+
config = read_config_file(args["config"])
57+
config.update(args)
58+
args = config
59+
if not args["input"]:
60+
print("Error: Missing input file")
61+
parser.print_help()
62+
sys.exit(1)
63+
update_llm_config(args)
64+
# read the prompt file into memory, add prompts to the "prompts" key in the config, raise an error if the
65+
# prompt id is already in the config
66+
if args["promptfile"]:
67+
prompts = read_prompt_file(args["promptfile"]) # this is a list of dicts with key "pid" containing the id
68+
if "prompts" not in args:
69+
args["prompts"] = []
70+
for prompt in prompts:
71+
if prompt["pid"] in args["prompts"]:
72+
raise ValueError(f"Error: Prompt id {prompt['pid']} already in config")
73+
args["prompts"].append(prompt)
74+
# create a "prompts_dict" key in the config which is a dict mapping the prompt id to the prompt dict
75+
if args.get("prompts") is None:
76+
args["prompts"] = []
77+
args["prompts_dict"] = {prompt["pid"]: prompt for prompt in args.get("prompts", [])}
78+
return args
79+
80+
81+
def check_check(check: dict, example: dict, config: dict) -> bool:
82+
"""
83+
This returns True if the check is correct, False if it can be skipped or raises an exception if the error cannot be
84+
skipped.
85+
"""
86+
# make sure the func field is present and that the func field is a string
87+
# now if the func is not LLM, we can use the function directly, otherwise we need to query the LLM
88+
if "func" not in check:
89+
logger.warning(f"Warning: Missing 'func' field in check in example {example['qid']}")
90+
return False
91+
if "metrics" not in check:
92+
logger.warning(f"Warning: Missing 'metrics' field in check in example {example['qid']}")
93+
return False
94+
if not isinstance(check["func"], str):
95+
logger.warning(f"Warning: 'func' field in check must be a string in example {example['qid']}")
96+
return False
97+
# make sure the function is in the CHECKS dictionary
98+
if check["func"] not in CHECKS:
99+
logger.warning(f"Warning: Check function {check['func']} not in CHECKS in example {example['qid']}")
100+
return False
101+
func = CHECKS[check["func"]]
102+
# check if the number of parameters defined with "parms" matches the number of parameters required by the function
103+
nargs = func["nargs"]
104+
args = check.get("args", [])
105+
if nargs != len(args):
106+
logger.warning(f"Warning: Wrong number of positional arguments in check for function {func['func']} in example {example['qid']}: {len(args)} instead of {nargs}")
107+
return False
108+
if not config['all'] and "result" in check:
109+
logger.debug(f"Skipping check {check['query']} with result")
110+
return False
111+
return True
112+
113+
114+
def run_check(check, llm: LLM, example, config, debug=False):
115+
llmname = llm["alias"]
116+
cid = check.get("cid", "NOID")
117+
# check the check
118+
if not check_check(check, example, config):
119+
logger.debug(f"Skipping check in example {example['qid']}")
120+
return
121+
122+
# if there is a query in the check, invoke the checker LLM and use the response from the checker
123+
# as the response to check. If there is no query, use the response from the example as the response to check
124+
response = None # this will hold the string to check
125+
if "query" in check and check["query"] is not None:
126+
query = check["query"]
127+
check_for = check.get("check_for")
128+
# get the prompt id from the check, if there is none, use the default prompt, otherwise use the prompt
129+
# with that id in the config. If a pid is specified which is not present, this is an error
130+
if "pid" in check:
131+
if check["pid"] in config["prompts_dict"]:
132+
theprompt = config["prompts_dict"][check["pid"]].copy()
133+
else:
134+
logger.warning(f"Error: Prompt id {check['pid']} not found for example {example['qid']}")
135+
logger.debug(f"Have prompt ids {config['prompts_dict'].keys()}")
136+
check["error"] = f"Prompt id {check['pid']} not found"
137+
check["result"] = None
138+
return
139+
else:
140+
theprompt = DEFAULT_PROMPT.copy()
141+
for role, text in theprompt.items():
142+
text = text.replace("${query}", query)
143+
text = text.replace("${answer}", example["response"])
144+
if check_for:
145+
text = text.replace("${check_for}", check_for)
146+
theprompt[role] = text
147+
# check if we have a dry run, if yes, just log what we would do, otherwise query the LLM
148+
messages = llm.make_messages(prompt=theprompt)
149+
if config['dry_run']:
150+
logger.info(f"Would query checker-LLM {llmname} with messages: {messages}")
151+
response = ""
152+
error = "NOT RUN: DRY-RUN"
153+
return
154+
if config['verbose']:
155+
logger.info(f"Querying checker-LLM {llmname} for example {example['qid']} and check {cid}")
156+
ret = llm.query(messages=messages, return_cost=True, debug=config['debug'])
157+
response = ret.get("answer", "")
158+
check["cost"] = ret.get("cost", 0)
159+
check["response"] = response
160+
error = ret.get("error", "")
161+
check["llm"] = llmname
162+
# if we had an error with the checker LLM, log it and return, we cannot check the response
163+
if error:
164+
logger.warning(f"Error from checking LLM, cannot check: {error}")
165+
check["error"] = error
166+
check["result"] = None
167+
return
168+
else:
169+
response = example["response"]
170+
func_config = CHECKS[check["func"]]
171+
func = func_config["func"]
172+
nargs = func_config["nargs"]
173+
args = check.get("args", [])
174+
assert len(args) == nargs, f"Error: Wrong number of positional arguments in check for function {func['func']}: {len(args)} instead of {nargs}"
175+
kwargs = check.get("kwargs", {})
176+
try:
177+
result = func(response, *args, **kwargs)
178+
error = ""
179+
except Exception as e:
180+
logger.error(f"Error in check function {func}: {e}")
181+
result = None
182+
error = f"Error in check function {func}: {e}"
183+
check["result"] = result
184+
check["error"] = error
185+
186+
187+
def run(config: dict):
188+
# check the configuration: for checkking, we want exactly one LLM to be configured and we want
189+
# to have a single prompt or no promot configured. If no prompt is configured, a default prompt will be used.
190+
if len(config["llms"]) < 1:
191+
raise ValueError(f"Error: at least one LLM must be configured")
192+
# if usellm is configured, we want to use the LLM with that alias, otherwise we use the first in the list
193+
llmname = ""
194+
if config["usellm"]:
195+
for llm in config["llm"]:
196+
if llm["alias"] == config["usellm"]:
197+
thellmname = llm["alias"]
198+
break
199+
if not llmname:
200+
raise ValueError(f"Error: LLM with alias {config['usellm']} not found")
201+
else:
202+
llmname = config["llms"][0]["alias"]
203+
if len(config["prompts"]) == 0:
204+
theprompt = DEFAULT_PROMPT
205+
logger.warning(f"Warning: No prompt configured, using default prompt")
206+
# read the input file into memory, we do not expect it to be too large and we want to check the format
207+
# of all json lines
208+
inputs = read_input_file(config["input"])
209+
logger.info(f"Loaded {len(inputs)} queries from {config['input']}")
210+
logger.info(f"LLM to use: {llmname}")
211+
logger.info(f"Prompts found: {len(config['prompts_dict'])}")
212+
213+
# initialize the LLMS object with the configuration
214+
llms = LLMS(config)
215+
llm: LLM = llms[llmname]
216+
217+
if not config['output']:
218+
config['output'] = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.checked.hsjon"
219+
# write either a jsonl or json file, depending on the file extension
220+
if not config['output'].endswith(".json") and not config['output'].endswith(".jsonl") and not config[
221+
'output'].endswith(".hjson"):
222+
print(f"Error: Output file must end in .json, .jsonl or .hjson, not {config['output']}")
223+
n_errors = 0
224+
n_outputs = 0
225+
total_cost = 0
226+
with open(config['output'], 'w') as f:
227+
if config['output'].endswith(".json") or config['output'].endswith(".hjson"):
228+
f.write("[\n")
229+
for example in inputs:
230+
# check if the example has checks at all, give a warning if not
231+
if not "checks" in example or len(example["checks"]) == 0:
232+
logger.warning(f"Warning: No checks in example {example['qid']}")
233+
continue
234+
# if the example has an error, we cannot check it, so we skip it
235+
if example.get("error"):
236+
logger.warning(f"Skipping example {example['qid']} with error: {example['error']}")
237+
continue
238+
# now go through each of the checks: if we already have a check result, skip unless the --all option is given
239+
# if the function is LLM, we need to run the function on the result of querying the LLM, otherwise
240+
# we directly run the function on the response from the query stage
241+
for check in example["checks"]:
242+
run_check(check, llm, example, config, debug=config["debug"])
243+
cost = check.get("cost", 0)
244+
total_cost += cost
245+
if check.get("error"):
246+
logger.warning(f"Error in check {check['query']}: {check['error']}")
247+
n_errors += 1
248+
# write the example to the output file
249+
towrite = example
250+
n_outputs += 1
251+
if config['output'].endswith(".json"):
252+
f.write(json.dumps(towrite, indent=2) + "\n")
253+
elif config['output'].endswith(".hjson"):
254+
f.write(hjson.dumps(towrite, indent=2) + "\n")
255+
else:
256+
f.write(json.dumps(towrite) + "\n")
257+
if config['output'].endswith(".json") or config['output'].endswith(".hjson"):
258+
f.write("]\n")
259+
logger.info(f"Wrote {n_outputs} examples to {config['output']}, {n_errors} errors")
260+
logger.info(f"Total cost: {total_cost}")
261+
262+
263+
def main():
264+
args = get_args()
265+
if args["logfile"]:
266+
add_logging_file(args["logfile"])
267+
if args["debug"]:
268+
set_logging_level(DEBUG)
269+
ppargs = pp_config(args)
270+
logger.debug(f"Effective arguments: {ppargs}")
271+
run(args)
272+
273+
274+
if __name__ == '__main__':
275+
main()

0 commit comments

Comments
 (0)