1+ # This is based on https://github.com/matsui528/annbench/blob/main/plot.py
2+ import argparse
3+ import csv
4+ import matplotlib
5+ matplotlib .use ('agg' )
6+ import matplotlib .pyplot as plt
7+ import sys
8+ from itertools import cycle
9+
10+ from datasets import get_query_count
11+
12+
13+ marker = cycle (('p' , '^' , 'h' , 'x' , 'o' , 's' , '*' , '+' , 'D' , '1' , 'X' ))
14+ linestyle = cycle ((':' , '-' , '--' ))
15+
16+ def draw (lines , xlabel , ylabel , title , filename , with_ctrl , width , height ):
17+ """
18+ Visualize search results and save them as an image
19+ Args:
20+ lines (list): search results. list of dict.
21+ xlabel (str): label of x-axis, usually "recall"
22+ ylabel (str): label of y-axis, usually "query per sec"
23+ title (str): title of the result_img
24+ filename (str): output file name of image
25+ with_ctrl (bool): show control parameters or not
26+ width (int): width of the figure
27+ height (int): height of the figure
28+ """
29+ plt .figure (figsize = (width , height ))
30+
31+ for line in lines :
32+ for key in ["xs" , "ys" , "label" , "ctrls" ]:
33+ assert key in line
34+
35+ for line in lines :
36+ plt .plot (line ["xs" ], line ["ys" ], label = line ["label" ], marker = next (marker ), linestyle = next (linestyle ))
37+ if with_ctrl :
38+ for x , y , ctrl in zip (line ["xs" ], line ["ys" ], line ["ctrls" ]):
39+ plt .annotate (text = str (ctrl ), xy = (x , y ),
40+ xytext = (x , y + 50 ))
41+
42+ plt .xlabel (xlabel )
43+ plt .ylabel (ylabel )
44+ plt .grid (which = "both" )
45+ plt .yscale ("log" )
46+ plt .legend (bbox_to_anchor = (1.05 , 1.0 ), loc = "upper left" )
47+ plt .title (title )
48+ plt .savefig (filename , bbox_inches = 'tight' )
49+ plt .cla ()
50+
51+ def get_pareto_frontier (line ):
52+ data = sorted (zip (line ["ys" ], line ["xs" ], line ["ctrls" ]),reverse = True )
53+ line ["xs" ] = []
54+ line ["ys" ] = []
55+ line ["ctrls" ] = []
56+
57+ cur = 0
58+ for y , x , label in data :
59+ if x > cur :
60+ cur = x
61+ line ["xs" ].append (x )
62+ line ["ys" ].append (y )
63+ line ["ctrls" ].append (label )
64+
65+ return line
66+
67+ if __name__ == "__main__" :
68+ parser = argparse .ArgumentParser ()
69+ parser .add_argument (
70+ "--task" ,
71+ choices = ['task1' , 'task2' ]
72+ )
73+ parser .add_argument ("csvfile" )
74+ args = parser .parse_args ()
75+
76+ with open (args .csvfile , newline = "" ) as csvfile :
77+ reader = csv .DictReader (csvfile )
78+ data = list (reader )
79+
80+ lines = {}
81+ for res in data :
82+ if res ["task" ] != args .task :
83+ continue
84+ dataset = res ["dataset" ]
85+ algo = res ["algo" ]
86+ label = dataset + algo
87+ if label not in lines :
88+ lines [label ] = {
89+ "xs" : [],
90+ "ys" : [],
91+ "ctrls" : [],
92+ "label" : label ,
93+ }
94+ lines [label ]["xs" ].append (float (res ["recall" ]))
95+ lines [label ]["ys" ].append (get_query_count (dataset , args .task )/ float (res ["querytime" ])) # FIX query size hardcoded
96+ try :
97+ run_identifier = res ["params" ].split ("query=" )[1 ]
98+ except :
99+ run_identifier = res ["params" ]
100+ lines [label ]["ctrls" ].append (run_identifier )
101+
102+ draw ([get_pareto_frontier (line ) for line in lines .values ()],
103+ "Recall" , "QPS (1/s)" , "Result" , f"result_{ args .task } .png" , True , 10 , 8 )
0 commit comments