Skip to content

Commit 93239c8

Browse files
committed
task 1
0 parents  commit 93239c8

1 file changed

Lines changed: 103 additions & 0 deletions

File tree

search/search.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import argparse
2+
import faiss
3+
import h5py
4+
import numpy as np
5+
import os
6+
from pathlib import Path
7+
from urllib.request import urlretrieve
8+
import time
9+
10+
def download(src, dst):
11+
if not os.path.exists(dst):
12+
os.makedirs(Path(dst).parent, exist_ok=True)
13+
print('downloading %s -> %s...' % (src, dst))
14+
urlretrieve(src, dst)
15+
16+
def get_fn(kind):
17+
version = "ccnews-small"
18+
return os.path.join("data", kind, f"{version}.h5")
19+
20+
def prepare(kind):
21+
if kind == 'task2':
22+
url = "https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/allknn-benchmark-dev-ccnews.h5?download=true"
23+
if kind == 'task1':
24+
url = "https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/benchmark-dev-ccnews-fp16.h5?download=true"
25+
fn = get_fn(kind)
26+
27+
download(url, fn)
28+
29+
def store_results(dst, algo, kind, D, I, buildtime, querytime, params):
30+
os.makedirs(Path(dst).parent, exist_ok=True)
31+
f = h5py.File(dst, 'w')
32+
f.attrs['algo'] = algo
33+
f.attrs['data'] = kind
34+
f.attrs['buildtime'] = buildtime
35+
f.attrs['querytime'] = querytime
36+
f.attrs['params'] = params
37+
f.create_dataset('knns', I.shape, dtype=I.dtype)[:] = I
38+
f.create_dataset('dists', D.shape, dtype=D.dtype)[:] = D
39+
f.close()
40+
41+
def run(kind, params):
42+
print("Running", kind)
43+
44+
prepare(kind)
45+
46+
fn = get_fn(kind)
47+
f = h5py.File(fn)
48+
data = np.array(f['train'])
49+
queries = np.array(f['itest']['queries'])
50+
f.close()
51+
52+
n, d = data.shape
53+
k = params['k']
54+
55+
nlist = 1024 # number of clusters/centroids to build the IVF from
56+
index_identifier = f"IVF{nlist},SQfp16"
57+
index = faiss.index_factory(d, index_identifier)
58+
59+
print(f"Training index on {data.shape}")
60+
start = time.time()
61+
index.train(data)
62+
index.add(data)
63+
elapsed_build = time.time() - start
64+
print(f"Done training in {elapsed_build}s.")
65+
assert index.is_trained
66+
67+
for nprobe in [1, 2, 5, 10, 20, 50, 100]:
68+
print(f"Starting search on {queries.shape} with nprobe={nprobe}")
69+
start = time.time()
70+
index.nprobe = nprobe
71+
D, I = index.search(queries, k)
72+
elapsed_search = time.time() - start
73+
print(f"Done searching in {elapsed_search}s.")
74+
75+
I = I + 1 # FAISS is 0-indexed, groundtruth is 1-indexed
76+
77+
identifier = f"index=({index_identifier}),query=(nprobe={nprobe})"
78+
79+
store_results(os.path.join("result/", kind, f"{identifier}.h5"), "faissIVF", kind, D, I, elapsed_build, elapsed_search, identifier)
80+
81+
if __name__ == "__main__":
82+
83+
parser = argparse.ArgumentParser()
84+
parser.add_argument(
85+
"--task",
86+
choices=['task1', 'task2'],
87+
default='task2'
88+
)
89+
90+
91+
params = {
92+
'task1': {
93+
"k": 30,
94+
},
95+
'task2': {
96+
"k": 15,
97+
}
98+
}
99+
100+
args = parser.parse_args()
101+
102+
run(args.task, params[args.task])
103+

0 commit comments

Comments
 (0)