Skip to content

Commit 851d71a

Browse files
committed
Fix/change options save_*
1 parent fa0d1d4 commit 851d71a

1 file changed

Lines changed: 102 additions & 70 deletions

File tree

ragability/ragability_eval.py

Lines changed: 102 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@
2323
def make_grouping_func(
2424
df: pd.DataFrame,
2525
tags: Optional[List[str]] = None,
26-
fields: Optional[Dict[str,str]] = None ):
26+
fields: Optional[List[str]] = None):
2727
"""
2828
Create a function which can be used as an argument to the pandas group_by method on the given dataframe.
2929
This creates a binary grouping where one group consists of all the rows that match the given tags and field values,
3030
and another group which does not.
3131
"""
32+
# NOTE: for now this must be used for either tags or fields, not both
33+
assert not (tags and fields)
34+
3235
# if both tags and fields are None or empty, raise and Exception
3336
if tags is None and fields is None:
3437
raise Exception("No grouping criteria")
@@ -41,11 +44,15 @@ def the_groupby_func(index):
4144
if t not in tag_values:
4245
logger.debug(f"Tag {t} not in {tag_values} in {row} for groupby {tags}")
4346
return False
44-
if fields:
45-
for fname, fval in fields.items():
46-
if row[fname] != fval:
47-
return False
48-
return True
47+
return True
48+
elif fields:
49+
keys = []
50+
for fname in fields:
51+
keys.append(row[fname])
52+
groupname = ",".join(keys)
53+
return groupname
54+
else:
55+
raise Exception("No grouping criteria")
4956
return the_groupby_func
5057

5158

@@ -55,11 +62,11 @@ def get_args():
5562
"""
5663
parser = argparse.ArgumentParser(description='Evaluation of a ragability_check output file')
5764
parser.add_argument('--input', '-i', type=str, help='Input ragability_check output file', required=True)
58-
parser.add_argument('--save-json', '-o', type=str,
65+
parser.add_argument('--save_json', '-o', type=str,
5966
help='Output json or hjson', required=False)
6067
parser.add_argument('--config', '-c', type=str, help='Configuration file', required=False)
61-
parser.add_argument("--save-longdf", type=str, help="Save the long format dataframe to a csv or tsv file", required=False)
62-
parser.add_argument("--save-widedf", type=str, help="Save the wide format dataframe to a csv or tsv file", required=False)
68+
parser.add_argument("--save_longdf", type=str, help="Save the long format dataframe to a csv or tsv file", required=False)
69+
parser.add_argument("--save_widedf", type=str, help="Save the wide format dataframe to a csv or tsv file", required=False)
6370
parser.add_argument('--verbose', '-v', action="store_true",
6471
help='Be more verbose', required=False)
6572
parser.add_argument('--by_tags', nargs="+", type=str,
@@ -173,85 +180,110 @@ def run(config: dict):
173180
))
174181
logger.debug(f"Generated {len(dfrows)} rows for all LLMs")
175182

183+
# for eachof the tag names mentioned in the config "by_tags" parameter, create a group for all rows
184+
# which do have the tag, labeling with the group name "tagname:yes" and for all rows which do not have the tag
185+
# labeling with the group name "tagname:no". For each of these groups, create the same metrics as for the "all"
186+
# group.
187+
if config.get("by_tags"):
188+
for tagname in config.get("by_tags"):
189+
logger.debug(f"Generating rows for grouping by tag {tagname}")
190+
for key, df in checkdfs.items():
191+
grouping_func = make_grouping_func(df, tags=[tagname])
192+
kind, metric = key.split(":")
193+
grouped = df.groupby(grouping_func)
194+
for group, groupdf in grouped:
195+
logger.debug(f"Grouping {key} by tag {tagname} and group {group}")
196+
if group:
197+
groupname = f"{tagname}:yes"
198+
else:
199+
groupname = f"{tagname}:no"
200+
for llm, llmgroup in groupdf.groupby("llm"):
201+
dfrows.append(dict(
202+
group=groupname,
203+
llm=llm,
204+
metric=f"{metric}:accuracy",
205+
value=sk.metrics.accuracy_score(llmgroup["target"].values, llmgroup["result"].values)
206+
))
207+
dfrows.append(dict(
208+
group=groupname,
209+
llm=llm,
210+
metric=f"{metric}:n",
211+
value=len(llmgroup)
212+
))
176213

177-
# now if we have grouping criteria, do the following: for each of the by_tags or by_qfields criteria,
178-
# create a grouping function to split the df into two groups, one that matches the criteria and one that does not.
179-
# Create the corresponding dataframes with the rows matching the criteria and the other with the rows not
180-
# matching the criteria. Then group each of these dataframes by LLM and calculate the accuracy and number of rows
181-
# for each metric.
182-
if config.get("by_tags") or config.get("by_qfields"):
183-
for groupbyname in ["by_tags", "by_qfields"]:
184-
groupbyvalues = config.get(groupbyname)
185-
logger.debug(f"Grouping by {groupbyname} with values {groupbyvalues}")
186-
if not groupbyvalues:
187-
continue
188-
n_rows4group = 0
189-
for groupbyvalue in groupbyvalues:
190-
logger.debug(f"Generating rows for grouping by {groupbyname} with value {groupbyvalue}")
191-
for key, df in checkdfs.items():
192-
if groupbyname == "by_tags":
193-
grouping_func = make_grouping_func(df, tags=[groupbyvalue])
214+
# for each of the field names mentioned in the config "by_qfields" parameter, find all the different
215+
# values of the field in the dataframe and create a group for each of these values, labeling with the group name
216+
# "fieldname:value" for all rows which have the value.
217+
# For each of these groups, create the same metrics as for the "all" group.
218+
if config.get("by_qfields"):
219+
for fieldname in config.get("by_qfields"):
220+
logger.debug(f"Generating rows for grouping by field {fieldname}")
221+
for key, df in checkdfs.items():
222+
grouping_func = make_grouping_func(df, fields=[fieldname])
223+
kind, metric = key.split(":")
224+
grouped = df.groupby(grouping_func)
225+
for group, groupdf in grouped:
226+
logger.debug(f"Grouping {key} by field {fieldname} and group {group}")
227+
if group:
228+
groupname = f"{fieldname}:{group}"
194229
else:
195-
# find all possible values of the field in the df
196-
fields = {groupbyvalue: v for v in df[groupbyvalue].unique()}
197-
grouping_func = make_grouping_func(df, fields=fields)
198-
kind, metric = key.split(":")
199-
grouped = df.groupby(grouping_func)
200-
for group, groupdf in grouped:
201-
logger.debug(f"Grouping {key} by {groupbyname} with value {groupbyvalue} and group {group}")
202-
if group:
203-
groupname = f"{groupbyvalue}:yes"
204-
else:
205-
groupname = f"{groupbyvalue}:no"
206-
for llm, llmgroup in groupdf.groupby("llm"):
207-
dfrows.append(dict(
208-
group=groupname,
209-
llm=llm,
210-
metric=f"{metric}:accuracy",
211-
value=sk.metrics.accuracy_score(llmgroup["target"].values, llmgroup["result"].values)
212-
))
213-
dfrows.append(dict(
214-
group=groupname,
215-
llm=llm,
216-
metric=f"{metric}:n",
217-
value=len(llmgroup)
218-
))
219-
n_rows4group += 2
220-
logger.debug(f"Generated {n_rows4group} rows for grouping by {groupbyname}")
230+
groupname = f"{fieldname}:no"
231+
for llm, llmgroup in groupdf.groupby("llm"):
232+
dfrows.append(dict(
233+
group=groupname,
234+
llm=llm,
235+
metric=f"{metric}:accuracy",
236+
value=sk.metrics.accuracy_score(llmgroup["target"].values, llmgroup["result"].values)
237+
))
238+
dfrows.append(dict(
239+
group=groupname,
240+
llm=llm,
241+
metric=f"{metric}:n",
242+
value=len(llmgroup)
243+
))
244+
221245
logger.debug(f"Generated {len(dfrows)} rows in total")
246+
# re-order the rows and sort by group, llm, metric
247+
dfrows = sorted(dfrows, key=lambda x: (x["group"], x["llm"], x["metric"]))
248+
222249
# create the long format dataframe from the list of rows
223250
dfout_long = pd.DataFrame(dfrows)
224-
if config.get("save-longdf"):
225-
if config["save-longdf"].endswith(".csv"):
226-
dfout_long.to_csv(config["save-longdf"], index=False)
227-
elif config["save-longdf"].endswith(".tsv"):
228-
dfout_long.to_csv(config["save-longdf"], index=False, sep="\t")
251+
logger.debug(f"Generated long format dataframe with {len(dfout_long)} rows and {len(dfout_long.columns)} columns")
252+
if config.get("save_longdf"):
253+
if config["save_longdf"].endswith(".csv"):
254+
dfout_long.to_csv(config["save_longdf"], index=False)
255+
elif config["save_longdf"].endswith(".tsv"):
256+
dfout_long.to_csv(config["save_longdf"], index=False, sep="\t")
229257
else:
230258
raise Exception(f"Error: Output file must end in .csv or .tsv, not {config['save-longdf']}")
231259
# now pivot the long format dataframe to the wide format
232260
dfout = dfout_long.pivot_table(index=["group", "llm"], columns="metric", values="value")
233261
dfout.reset_index(inplace=True)
234-
if config.get("save-widedf"):
235-
if config["save-widedf"].endswith(".csv"):
236-
dfout.to_csv(config["save-widedf"], index=False)
237-
elif config["save-widedf"].endswith(".tsv"):
238-
dfout.to_csv(config["save-widedf"], index=False, sep="\t")
262+
if config.get("save_widedf"):
263+
if config["save_widedf"].endswith(".csv"):
264+
dfout.to_csv(config["save_widedf"], index=False)
265+
elif config["save_widedf"].endswith(".tsv"):
266+
dfout.to_csv(config["save_widedf"], index=False, sep="\t")
239267
else:
240-
raise Exception(f"Error: Output file must end in .csv or .tsv, not {config['save-widedf']}")
268+
raise Exception(f"Error: Output file must end in .csv or .tsv, not {config['save_widedf']}")
241269
# if the output file is specified, save the dataframe as csv or tsv depending on the extension or
242270
# save a dictionary representation of the dataframe as json or hjson
243-
if config.get("save-json"):
244-
if config["output"].endswith(".json"):
245-
dfout.to_json(config["output"], orient="records")
246-
elif config["output"].endswith(".hjson"):
247-
with open(config["output"], "wt") as outfp:
271+
if config.get("save_json"):
272+
if config["save_json"].endswith(".json"):
273+
dfout.to_json(config["save_json"], orient="records")
274+
elif config["save_json"].endswith(".hjson"):
275+
with open(config["save_json"], "wt") as outfp:
248276
hjson.dump(dfout.to_dict(orient="records"), outfp)
249277
else:
250278
raise Exception(f"Error: Output file must end in .csv, .tsv, .json or .hjson, not {config['output']}")
251279
# if verbose is set, or no output file is specified, write the results to stdout using textual formattign of
252280
# the dataframe
253-
if config.get("verbose") or not config.get("output"):
254-
print(dfout_long.to_string())
281+
if config.get("verbose") or not config.get("save-json"):
282+
# createa copy of the dataframe with all the rows where the column "metric" has a value
283+
# which ends with :n removed
284+
dfout_long_metrics = dfout_long[~dfout_long["metric"].str.endswith(":n")]
285+
print(dfout_long_metrics.to_string())
286+
255287

256288

257289
def main():

0 commit comments

Comments
 (0)