Skip to content

Commit 85195c6

Browse files
committed
added script to show operating points
1 parent abcb39e commit 85195c6

1 file changed

Lines changed: 47 additions & 0 deletions

File tree

show_operating_points.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import argparse
2+
import pandas as pd
3+
4+
# show best performing parameters exceeding threshold
5+
6+
if __name__ == "__main__":
7+
parser = argparse.ArgumentParser()
8+
parser.add_argument(
9+
'--algorithm',)
10+
parser.add_argument(
11+
'--threshold',
12+
default=0.9,
13+
help='minimum recall',
14+
type=float)
15+
parser.add_argument(
16+
'csv',
17+
metavar='CSV',
18+
help='input csv')
19+
parser.add_argument(
20+
'--task',
21+
choices=['task1', 'task2'],
22+
)
23+
24+
args = parser.parse_args()
25+
df = pd.read_csv(args.csv)
26+
df = df[df.task == args.task]
27+
28+
if args.algorithm:
29+
algorithms = [args.algorithm]
30+
else:
31+
algorithms = set(df.algo.values)
32+
for algo in algorithms:
33+
print(f'show {algo}')
34+
if (len(df[(df.recall > args.threshold) & (df.algo == algo)].groupby(['algo', 'dataset']).min()[['querytime']])) == 0:
35+
print("didn't exceed recall, print highest recall:")
36+
print(df[(df.algo == algo)].groupby(['algo', 'dataset']).max()[['recall', 'querytime']])
37+
38+
else:
39+
print(df[(df.recall > args.threshold) & (df.algo == algo)].groupby(['algo', 'dataset']).min()[['querytime']])
40+
41+
print("Overview passing threshold")
42+
43+
print(df[(df.recall >= args.threshold - 1e-6)][['algo', 'dataset', 'querytime', 'params']].sort_values(by=['dataset', 'algo', 'querytime']))
44+
45+
print("Overview NOT passing threshold")
46+
47+
print(df[(df.recall < args.threshold - 1e-6)][['algo', 'dataset', 'querytime', 'params']].sort_values(by=['dataset', 'algo', 'querytime']))

0 commit comments

Comments
 (0)