Skip to content

Commit 06e5595

Browse files
committed
split up functionality
1 parent 93239c8 commit 06e5595

3 files changed

Lines changed: 140 additions & 24 deletions

File tree

datasets.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
from urllib.request import urlretrieve
3+
from pathlib import Path
4+
5+
def download(src, dst):
6+
if not os.path.exists(dst):
7+
os.makedirs(Path(dst).parent, exist_ok=True)
8+
print('downloading %s -> %s...' % (src, dst))
9+
urlretrieve(src, dst)
10+
11+
def get_fn(kind):
12+
version = "ccnews-small"
13+
return os.path.join("data", kind, f"{version}.h5"), os.path.join('data', kind, 'gt', f'{version}.h5')
14+
15+
def prepare(kind):
16+
url = DATASETS['ccnews-small'][kind]['url']
17+
gt_url = DATASETS['ccnews-small'][kind]['gt_url']
18+
fn, gt_fn = get_fn(kind)
19+
20+
download(url, fn)
21+
download(gt_url, gt_fn)
22+
23+
DATASETS = {
24+
'ccnews-small': {
25+
'task1': {
26+
'url': 'https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/benchmark-dev-ccnews-fp16.h5?download=true',
27+
'queries': lambda x: x['itest']['queries'],
28+
'data': lambda x: x['train'],
29+
'gt_url': 'https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/benchmark-dev-ccnews-fp16.h5?download=true',
30+
'gt_I': lambda x: x['itest']['knns'],
31+
},
32+
'task2': {
33+
'url': 'https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/benchmark-dev-ccnews-fp16.h5?download=true',
34+
'queries': lambda x: x['train'],
35+
'data': lambda x: x['train'],
36+
'gt_url': 'https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/allknn-benchmark-dev-ccnews.h5?download=true',
37+
'gt_I': lambda x: x['knns']
38+
}
39+
}
40+
}

eval.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import argparse
2+
import h5py
3+
import numpy as np
4+
import os
5+
import csv
6+
import glob
7+
from pathlib import Path
8+
from datasets import DATASETS
9+
10+
def get_groundtruth(size="100K", private=False):
11+
# test
12+
gt_f = h5py.File(out_fn, "r")
13+
true_I = np.array(gt_f['knns'])
14+
gt_f.close()
15+
return true_I
16+
17+
def get_all_results(dirname):
18+
mask = [dirname + "/**/*.h5"]
19+
print("search for results matching:")
20+
print("\n".join(mask))
21+
for m in mask:
22+
for fn in glob.iglob(m):
23+
print(fn)
24+
f = h5py.File(fn, "r")
25+
if "knns" not in f or not ("data" in f or "data" in f.attrs):
26+
print("Ignoring " + fn)
27+
f.close()
28+
continue
29+
yield f
30+
f.close()
31+
32+
def get_recall(I, gt, k):
33+
assert k <= I.shape[1]
34+
assert len(I) == len(gt)
35+
36+
n = len(I)
37+
recall = 0
38+
for i in range(n):
39+
recall += len(set(I[i, :k]) & set(gt[i, :k]))
40+
return recall / (n * k)
41+
42+
43+
if __name__ == "__main__":
44+
parser = argparse.ArgumentParser()
45+
parser.add_argument(
46+
"--results",
47+
help='directory in which results are stored',
48+
default="result"
49+
)
50+
parser.add_argument(
51+
'--private',
52+
help="private queries held out for evaluation",
53+
action='store_true',
54+
default=False
55+
)
56+
parser.add_argument(
57+
'--dataset',
58+
choices = ['ccnews-small'],
59+
default='ccnews-small',
60+
)
61+
62+
parser.add_argument("csvfile")
63+
args = parser.parse_args()
64+
true_I_cache = {}
65+
66+
67+
columns = ["data", "kind", "algo", "buildtime", "querytime", "params", "recall"]
68+
69+
with open(args.csvfile, 'w', newline='') as csvfile:
70+
writer = csv.DictWriter(csvfile, fieldnames=columns)
71+
writer.writeheader()
72+
for res in get_all_results(args.results):
73+
data = res.attrs["data"]
74+
d = dict(res.attrs)
75+
print(d)
76+
gt_I = np.array(DATASETS['ccnews-small'][data]['gt_I'](res))
77+
recall = get_recall(np.array(res["knns"]), gt_I, 10)
78+
d['recall'] = recall
79+
print(d["data"], d["algo"], d["params"], "=>", recall)
80+
writer.writerow(d)

search/search.py renamed to search.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,8 @@
44
import numpy as np
55
import os
66
from pathlib import Path
7-
from urllib.request import urlretrieve
87
import time
9-
10-
def download(src, dst):
11-
if not os.path.exists(dst):
12-
os.makedirs(Path(dst).parent, exist_ok=True)
13-
print('downloading %s -> %s...' % (src, dst))
14-
urlretrieve(src, dst)
15-
16-
def get_fn(kind):
17-
version = "ccnews-small"
18-
return os.path.join("data", kind, f"{version}.h5")
19-
20-
def prepare(kind):
21-
if kind == 'task2':
22-
url = "https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/allknn-benchmark-dev-ccnews.h5?download=true"
23-
if kind == 'task1':
24-
url = "https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/benchmark-dev-ccnews-fp16.h5?download=true"
25-
fn = get_fn(kind)
26-
27-
download(url, fn)
8+
from datasets import DATASETS, prepare, get_fn
289

2910
def store_results(dst, algo, kind, D, I, buildtime, querytime, params):
3011
os.makedirs(Path(dst).parent, exist_ok=True)
@@ -43,17 +24,21 @@ def run(kind, params):
4324

4425
prepare(kind)
4526

46-
fn = get_fn(kind)
27+
fn, _ = get_fn(kind)
4728
f = h5py.File(fn)
48-
data = np.array(f['train'])
49-
queries = np.array(f['itest']['queries'])
29+
data = np.array(DATASETS['ccnews-small'][kind]['data'](f))
30+
queries = np.array(DATASETS['ccnews-small'][kind]['queries'](f))
5031
f.close()
5132

5233
n, d = data.shape
5334
k = params['k']
5435

5536
nlist = 1024 # number of clusters/centroids to build the IVF from
56-
index_identifier = f"IVF{nlist},SQfp16"
37+
if kind == 'task1':
38+
index_identifier = f"IVF{nlist},SQfp16"
39+
elif kind == 'task2':
40+
index_identifier = f"IVF{nlist},PQ{d//2}x4fs"
41+
5742
index = faiss.index_factory(d, index_identifier)
5843

5944
print(f"Training index on {data.shape}")
@@ -64,6 +49,10 @@ def run(kind, params):
6449
print(f"Done training in {elapsed_build}s.")
6550
assert index.is_trained
6651

52+
if kind == "task2":
53+
index = faiss.IndexRefineFlat(index, faiss.swig_ptr(data.astype('float32')))
54+
index.k_factor = 200
55+
6756
for nprobe in [1, 2, 5, 10, 20, 50, 100]:
6857
print(f"Starting search on {queries.shape} with nprobe={nprobe}")
6958
start = time.time()
@@ -87,6 +76,13 @@ def run(kind, params):
8776
default='task2'
8877
)
8978

79+
parser.add_argument(
80+
'--dataset',
81+
choices=[
82+
'ccnews-small',
83+
],
84+
default='ccnews-small'
85+
)
9086

9187
params = {
9288
'task1': {

0 commit comments

Comments
 (0)