Skip to content

Commit ce51bf3

Browse files
committed
added plot.py
1 parent 3df147d commit ce51bf3

1 file changed

Lines changed: 103 additions & 0 deletions

File tree

plot.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# This is based on https://github.com/matsui528/annbench/blob/main/plot.py
2+
import argparse
3+
import csv
4+
import matplotlib
5+
matplotlib.use('agg')
6+
import matplotlib.pyplot as plt
7+
import sys
8+
from itertools import cycle
9+
10+
from datasets import get_query_count
11+
12+
13+
marker = cycle(('p', '^', 'h', 'x', 'o', 's', '*', '+', 'D', '1', 'X'))
14+
linestyle = cycle((':', '-', '--'))
15+
16+
def draw(lines, xlabel, ylabel, title, filename, with_ctrl, width, height):
17+
"""
18+
Visualize search results and save them as an image
19+
Args:
20+
lines (list): search results. list of dict.
21+
xlabel (str): label of x-axis, usually "recall"
22+
ylabel (str): label of y-axis, usually "query per sec"
23+
title (str): title of the result_img
24+
filename (str): output file name of image
25+
with_ctrl (bool): show control parameters or not
26+
width (int): width of the figure
27+
height (int): height of the figure
28+
"""
29+
plt.figure(figsize=(width, height))
30+
31+
for line in lines:
32+
for key in ["xs", "ys", "label", "ctrls"]:
33+
assert key in line
34+
35+
for line in lines:
36+
plt.plot(line["xs"], line["ys"], label=line["label"], marker=next(marker), linestyle=next(linestyle))
37+
if with_ctrl:
38+
for x, y, ctrl in zip(line["xs"], line["ys"], line["ctrls"]):
39+
plt.annotate(text=str(ctrl), xy=(x, y),
40+
xytext=(x, y+50))
41+
42+
plt.xlabel(xlabel)
43+
plt.ylabel(ylabel)
44+
plt.grid(which="both")
45+
plt.yscale("log")
46+
plt.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
47+
plt.title(title)
48+
plt.savefig(filename, bbox_inches='tight')
49+
plt.cla()
50+
51+
def get_pareto_frontier(line):
52+
data = sorted(zip(line["ys"], line["xs"], line["ctrls"]),reverse=True)
53+
line["xs"] = []
54+
line["ys"] = []
55+
line["ctrls"] = []
56+
57+
cur = 0
58+
for y, x, label in data:
59+
if x > cur:
60+
cur = x
61+
line["xs"].append(x)
62+
line["ys"].append(y)
63+
line["ctrls"].append(label)
64+
65+
return line
66+
67+
if __name__ == "__main__":
68+
parser = argparse.ArgumentParser()
69+
parser.add_argument(
70+
"--task",
71+
choices=['task1', 'task2']
72+
)
73+
parser.add_argument("csvfile")
74+
args = parser.parse_args()
75+
76+
with open(args.csvfile, newline="") as csvfile:
77+
reader = csv.DictReader(csvfile)
78+
data = list(reader)
79+
80+
lines = {}
81+
for res in data:
82+
if res["task"] != args.task:
83+
continue
84+
dataset = res["dataset"]
85+
algo = res["algo"]
86+
label = dataset + algo
87+
if label not in lines:
88+
lines[label] = {
89+
"xs": [],
90+
"ys": [],
91+
"ctrls": [],
92+
"label": label,
93+
}
94+
lines[label]["xs"].append(float(res["recall"]))
95+
lines[label]["ys"].append(get_query_count(dataset, args.task)/float(res["querytime"])) # FIX query size hardcoded
96+
try:
97+
run_identifier = res["params"].split("query=")[1]
98+
except:
99+
run_identifier = res["params"]
100+
lines[label]["ctrls"].append(run_identifier)
101+
102+
draw([get_pareto_frontier(line) for line in lines.values()],
103+
"Recall", "QPS (1/s)", "Result", f"result_{args.task}.png", True, 10, 8)

0 commit comments

Comments
 (0)