2323def make_grouping_func (
2424 df : pd .DataFrame ,
2525 tags : Optional [List [str ]] = None ,
26- fields : Optional [Dict [str , str ]] = None ):
26+ fields : Optional [List [str ]] = None ):
2727 """
2828 Create a function which can be used as an argument to the pandas group_by method on the given dataframe.
2929 This creates a binary grouping where one group consists of all the rows that match the given tags and field values,
3030 and another group which does not.
3131 """
32+ # NOTE: for now this must be used for either tags or fields, not both
33+ assert not (tags and fields )
34+
3235 # if both tags and fields are None or empty, raise and Exception
3336 if tags is None and fields is None :
3437 raise Exception ("No grouping criteria" )
@@ -41,11 +44,15 @@ def the_groupby_func(index):
4144 if t not in tag_values :
4245 logger .debug (f"Tag { t } not in { tag_values } in { row } for groupby { tags } " )
4346 return False
44- if fields :
45- for fname , fval in fields .items ():
46- if row [fname ] != fval :
47- return False
48- return True
47+ return True
48+ elif fields :
49+ keys = []
50+ for fname in fields :
51+ keys .append (row [fname ])
52+ groupname = "," .join (keys )
53+ return groupname
54+ else :
55+ raise Exception ("No grouping criteria" )
4956 return the_groupby_func
5057
5158
@@ -55,11 +62,11 @@ def get_args():
5562 """
5663 parser = argparse .ArgumentParser (description = 'Evaluation of a ragability_check output file' )
5764 parser .add_argument ('--input' , '-i' , type = str , help = 'Input ragability_check output file' , required = True )
58- parser .add_argument ('--save-json ' , '-o' , type = str ,
65+ parser .add_argument ('--save_json ' , '-o' , type = str ,
5966 help = 'Output json or hjson' , required = False )
6067 parser .add_argument ('--config' , '-c' , type = str , help = 'Configuration file' , required = False )
61- parser .add_argument ("--save-longdf " , type = str , help = "Save the long format dataframe to a csv or tsv file" , required = False )
62- parser .add_argument ("--save-widedf " , type = str , help = "Save the wide format dataframe to a csv or tsv file" , required = False )
68+ parser .add_argument ("--save_longdf " , type = str , help = "Save the long format dataframe to a csv or tsv file" , required = False )
69+ parser .add_argument ("--save_widedf " , type = str , help = "Save the wide format dataframe to a csv or tsv file" , required = False )
6370 parser .add_argument ('--verbose' , '-v' , action = "store_true" ,
6471 help = 'Be more verbose' , required = False )
6572 parser .add_argument ('--by_tags' , nargs = "+" , type = str ,
@@ -173,85 +180,110 @@ def run(config: dict):
173180 ))
174181 logger .debug (f"Generated { len (dfrows )} rows for all LLMs" )
175182
183+ # for eachof the tag names mentioned in the config "by_tags" parameter, create a group for all rows
184+ # which do have the tag, labeling with the group name "tagname:yes" and for all rows which do not have the tag
185+ # labeling with the group name "tagname:no". For each of these groups, create the same metrics as for the "all"
186+ # group.
187+ if config .get ("by_tags" ):
188+ for tagname in config .get ("by_tags" ):
189+ logger .debug (f"Generating rows for grouping by tag { tagname } " )
190+ for key , df in checkdfs .items ():
191+ grouping_func = make_grouping_func (df , tags = [tagname ])
192+ kind , metric = key .split (":" )
193+ grouped = df .groupby (grouping_func )
194+ for group , groupdf in grouped :
195+ logger .debug (f"Grouping { key } by tag { tagname } and group { group } " )
196+ if group :
197+ groupname = f"{ tagname } :yes"
198+ else :
199+ groupname = f"{ tagname } :no"
200+ for llm , llmgroup in groupdf .groupby ("llm" ):
201+ dfrows .append (dict (
202+ group = groupname ,
203+ llm = llm ,
204+ metric = f"{ metric } :accuracy" ,
205+ value = sk .metrics .accuracy_score (llmgroup ["target" ].values , llmgroup ["result" ].values )
206+ ))
207+ dfrows .append (dict (
208+ group = groupname ,
209+ llm = llm ,
210+ metric = f"{ metric } :n" ,
211+ value = len (llmgroup )
212+ ))
176213
177- # now if we have grouping criteria, do the following: for each of the by_tags or by_qfields criteria,
178- # create a grouping function to split the df into two groups, one that matches the criteria and one that does not.
179- # Create the corresponding dataframes with the rows matching the criteria and the other with the rows not
180- # matching the criteria. Then group each of these dataframes by LLM and calculate the accuracy and number of rows
181- # for each metric.
182- if config .get ("by_tags" ) or config .get ("by_qfields" ):
183- for groupbyname in ["by_tags" , "by_qfields" ]:
184- groupbyvalues = config .get (groupbyname )
185- logger .debug (f"Grouping by { groupbyname } with values { groupbyvalues } " )
186- if not groupbyvalues :
187- continue
188- n_rows4group = 0
189- for groupbyvalue in groupbyvalues :
190- logger .debug (f"Generating rows for grouping by { groupbyname } with value { groupbyvalue } " )
191- for key , df in checkdfs .items ():
192- if groupbyname == "by_tags" :
193- grouping_func = make_grouping_func (df , tags = [groupbyvalue ])
214+ # for each of the field names mentioned in the config "by_qfields" parameter, find all the different
215+ # values of the field in the dataframe and create a group for each of these values, labeling with the group name
216+ # "fieldname:value" for all rows which have the value.
217+ # For each of these groups, create the same metrics as for the "all" group.
218+ if config .get ("by_qfields" ):
219+ for fieldname in config .get ("by_qfields" ):
220+ logger .debug (f"Generating rows for grouping by field { fieldname } " )
221+ for key , df in checkdfs .items ():
222+ grouping_func = make_grouping_func (df , fields = [fieldname ])
223+ kind , metric = key .split (":" )
224+ grouped = df .groupby (grouping_func )
225+ for group , groupdf in grouped :
226+ logger .debug (f"Grouping { key } by field { fieldname } and group { group } " )
227+ if group :
228+ groupname = f"{ fieldname } :{ group } "
194229 else :
195- # find all possible values of the field in the df
196- fields = {groupbyvalue : v for v in df [groupbyvalue ].unique ()}
197- grouping_func = make_grouping_func (df , fields = fields )
198- kind , metric = key .split (":" )
199- grouped = df .groupby (grouping_func )
200- for group , groupdf in grouped :
201- logger .debug (f"Grouping { key } by { groupbyname } with value { groupbyvalue } and group { group } " )
202- if group :
203- groupname = f"{ groupbyvalue } :yes"
204- else :
205- groupname = f"{ groupbyvalue } :no"
206- for llm , llmgroup in groupdf .groupby ("llm" ):
207- dfrows .append (dict (
208- group = groupname ,
209- llm = llm ,
210- metric = f"{ metric } :accuracy" ,
211- value = sk .metrics .accuracy_score (llmgroup ["target" ].values , llmgroup ["result" ].values )
212- ))
213- dfrows .append (dict (
214- group = groupname ,
215- llm = llm ,
216- metric = f"{ metric } :n" ,
217- value = len (llmgroup )
218- ))
219- n_rows4group += 2
220- logger .debug (f"Generated { n_rows4group } rows for grouping by { groupbyname } " )
230+ groupname = f"{ fieldname } :no"
231+ for llm , llmgroup in groupdf .groupby ("llm" ):
232+ dfrows .append (dict (
233+ group = groupname ,
234+ llm = llm ,
235+ metric = f"{ metric } :accuracy" ,
236+ value = sk .metrics .accuracy_score (llmgroup ["target" ].values , llmgroup ["result" ].values )
237+ ))
238+ dfrows .append (dict (
239+ group = groupname ,
240+ llm = llm ,
241+ metric = f"{ metric } :n" ,
242+ value = len (llmgroup )
243+ ))
244+
221245 logger .debug (f"Generated { len (dfrows )} rows in total" )
246+ # re-order the rows and sort by group, llm, metric
247+ dfrows = sorted (dfrows , key = lambda x : (x ["group" ], x ["llm" ], x ["metric" ]))
248+
222249 # create the long format dataframe from the list of rows
223250 dfout_long = pd .DataFrame (dfrows )
224- if config .get ("save-longdf" ):
225- if config ["save-longdf" ].endswith (".csv" ):
226- dfout_long .to_csv (config ["save-longdf" ], index = False )
227- elif config ["save-longdf" ].endswith (".tsv" ):
228- dfout_long .to_csv (config ["save-longdf" ], index = False , sep = "\t " )
251+ logger .debug (f"Generated long format dataframe with { len (dfout_long )} rows and { len (dfout_long .columns )} columns" )
252+ if config .get ("save_longdf" ):
253+ if config ["save_longdf" ].endswith (".csv" ):
254+ dfout_long .to_csv (config ["save_longdf" ], index = False )
255+ elif config ["save_longdf" ].endswith (".tsv" ):
256+ dfout_long .to_csv (config ["save_longdf" ], index = False , sep = "\t " )
229257 else :
230258 raise Exception (f"Error: Output file must end in .csv or .tsv, not { config ['save-longdf' ]} " )
231259 # now pivot the long format dataframe to the wide format
232260 dfout = dfout_long .pivot_table (index = ["group" , "llm" ], columns = "metric" , values = "value" )
233261 dfout .reset_index (inplace = True )
234- if config .get ("save-widedf " ):
235- if config ["save-widedf " ].endswith (".csv" ):
236- dfout .to_csv (config ["save-widedf " ], index = False )
237- elif config ["save-widedf " ].endswith (".tsv" ):
238- dfout .to_csv (config ["save-widedf " ], index = False , sep = "\t " )
262+ if config .get ("save_widedf " ):
263+ if config ["save_widedf " ].endswith (".csv" ):
264+ dfout .to_csv (config ["save_widedf " ], index = False )
265+ elif config ["save_widedf " ].endswith (".tsv" ):
266+ dfout .to_csv (config ["save_widedf " ], index = False , sep = "\t " )
239267 else :
240- raise Exception (f"Error: Output file must end in .csv or .tsv, not { config ['save-widedf ' ]} " )
268+ raise Exception (f"Error: Output file must end in .csv or .tsv, not { config ['save_widedf ' ]} " )
241269 # if the output file is specified, save the dataframe as csv or tsv depending on the extension or
242270 # save a dictionary representation of the dataframe as json or hjson
243- if config .get ("save-json " ):
244- if config ["output " ].endswith (".json" ):
245- dfout .to_json (config ["output " ], orient = "records" )
246- elif config ["output " ].endswith (".hjson" ):
247- with open (config ["output " ], "wt" ) as outfp :
271+ if config .get ("save_json " ):
272+ if config ["save_json " ].endswith (".json" ):
273+ dfout .to_json (config ["save_json " ], orient = "records" )
274+ elif config ["save_json " ].endswith (".hjson" ):
275+ with open (config ["save_json " ], "wt" ) as outfp :
248276 hjson .dump (dfout .to_dict (orient = "records" ), outfp )
249277 else :
250278 raise Exception (f"Error: Output file must end in .csv, .tsv, .json or .hjson, not { config ['output' ]} " )
251279 # if verbose is set, or no output file is specified, write the results to stdout using textual formattign of
252280 # the dataframe
253- if config .get ("verbose" ) or not config .get ("output" ):
254- print (dfout_long .to_string ())
281+ if config .get ("verbose" ) or not config .get ("save-json" ):
282+ # createa copy of the dataframe with all the rows where the column "metric" has a value
283+ # which ends with :n removed
284+ dfout_long_metrics = dfout_long [~ dfout_long ["metric" ].str .endswith (":n" )]
285+ print (dfout_long_metrics .to_string ())
286+
255287
256288
257289def main ():
0 commit comments