Skip to content

Commit abcb39e

Browse files
committed
working baseline for both task 1 and task 2
1 parent 06e5595 commit abcb39e

3 files changed

Lines changed: 42 additions & 62 deletions

File tree

datasets.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@ def download(src, dst):
88
print('downloading %s -> %s...' % (src, dst))
99
urlretrieve(src, dst)
1010

11-
def get_fn(kind):
12-
version = "ccnews-small"
13-
return os.path.join("data", kind, f"{version}.h5"), os.path.join('data', kind, 'gt', f'{version}.h5')
11+
def get_fn(dataset, task):
12+
return os.path.join("data", dataset, task, f"{dataset}.h5"), os.path.join('data', dataset, task, 'gt', f'gt_{dataset}.h5')
1413

15-
def prepare(kind):
16-
url = DATASETS['ccnews-small'][kind]['url']
17-
gt_url = DATASETS['ccnews-small'][kind]['gt_url']
18-
fn, gt_fn = get_fn(kind)
14+
def prepare(dataset, task):
15+
url = DATASETS[dataset][task]['url']
16+
gt_url = DATASETS[dataset][task]['gt_url']
17+
fn, gt_fn = get_fn(dataset, task)
1918

2019
download(url, fn)
2120
download(gt_url, gt_fn)
@@ -28,13 +27,15 @@ def prepare(kind):
2827
'data': lambda x: x['train'],
2928
'gt_url': 'https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/benchmark-dev-ccnews-fp16.h5?download=true',
3029
'gt_I': lambda x: x['itest']['knns'],
30+
'k': 30,
3131
},
3232
'task2': {
3333
'url': 'https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/benchmark-dev-ccnews-fp16.h5?download=true',
3434
'queries': lambda x: x['train'],
3535
'data': lambda x: x['train'],
3636
'gt_url': 'https://huggingface.co/datasets/sadit/SISAP2025/resolve/main/allknn-benchmark-dev-ccnews.h5?download=true',
37-
'gt_I': lambda x: x['knns']
37+
'gt_I': lambda x: x['knns'],
38+
'k': 15,
3839
}
3940
}
4041
}

eval.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,17 @@
55
import csv
66
import glob
77
from pathlib import Path
8-
from datasets import DATASETS
9-
10-
def get_groundtruth(size="100K", private=False):
11-
# test
12-
gt_f = h5py.File(out_fn, "r")
13-
true_I = np.array(gt_f['knns'])
14-
gt_f.close()
15-
return true_I
8+
from datasets import DATASETS, get_fn, prepare
169

1710
def get_all_results(dirname):
18-
mask = [dirname + "/**/*.h5"]
19-
print("search for results matching:")
11+
mask = [dirname + "/**/*.h5", dirname + "/**/*/*.h5"]
12+
print("Searching for results matching:")
2013
print("\n".join(mask))
2114
for m in mask:
2215
for fn in glob.iglob(m):
2316
print(fn)
2417
f = h5py.File(fn, "r")
25-
if "knns" not in f or not ("data" in f or "data" in f.attrs):
18+
if "knns" not in f or not ("dataset" in f or "dataset" in f.attrs):
2619
print("Ignoring " + fn)
2720
f.close()
2821
continue
@@ -45,36 +38,38 @@ def get_recall(I, gt, k):
4538
parser.add_argument(
4639
"--results",
4740
help='directory in which results are stored',
48-
default="result"
41+
default="results"
4942
)
5043
parser.add_argument(
5144
'--private',
5245
help="private queries held out for evaluation",
5346
action='store_true',
5447
default=False
5548
)
56-
parser.add_argument(
57-
'--dataset',
58-
choices = ['ccnews-small'],
59-
default='ccnews-small',
60-
)
6149

6250
parser.add_argument("csvfile")
6351
args = parser.parse_args()
6452
true_I_cache = {}
6553

6654

67-
columns = ["data", "kind", "algo", "buildtime", "querytime", "params", "recall"]
55+
columns = ["dataset", "task", "algo", "buildtime", "querytime", "params", "recall"]
6856

6957
with open(args.csvfile, 'w', newline='') as csvfile:
7058
writer = csv.DictWriter(csvfile, fieldnames=columns)
7159
writer.writeheader()
7260
for res in get_all_results(args.results):
73-
data = res.attrs["data"]
61+
dataset = res.attrs["dataset"]
62+
task = res.attrs["task"]
63+
assert dataset in DATASETS and task in DATASETS[dataset]
64+
prepare(dataset, task)
7465
d = dict(res.attrs)
75-
print(d)
76-
gt_I = np.array(DATASETS['ccnews-small'][data]['gt_I'](res))
77-
recall = get_recall(np.array(res["knns"]), gt_I, 10)
66+
# print(d)
67+
_, gt_f = get_fn(dataset, task)
68+
print(f"Using groundtruth in {gt_f}")
69+
f = h5py.File(gt_f)
70+
gt_I = np.array(DATASETS[dataset][task]['gt_I'](f))
71+
f.close()
72+
recall = get_recall(np.array(res["knns"]), gt_I, DATASETS[dataset][task]['k'])
7873
d['recall'] = recall
79-
print(d["data"], d["algo"], d["params"], "=>", recall)
74+
print(d["dataset"], d["task"], d["algo"], d["params"], "=>", recall)
8075
writer.writerow(d)

search.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,36 @@
77
import time
88
from datasets import DATASETS, prepare, get_fn
99

10-
def store_results(dst, algo, kind, D, I, buildtime, querytime, params):
10+
def store_results(dst, algo, dataset, task, D, I, buildtime, querytime, params):
1111
os.makedirs(Path(dst).parent, exist_ok=True)
1212
f = h5py.File(dst, 'w')
1313
f.attrs['algo'] = algo
14-
f.attrs['data'] = kind
14+
f.attrs['dataset'] = dataset
15+
f.attrs['task'] = task
1516
f.attrs['buildtime'] = buildtime
1617
f.attrs['querytime'] = querytime
1718
f.attrs['params'] = params
1819
f.create_dataset('knns', I.shape, dtype=I.dtype)[:] = I
1920
f.create_dataset('dists', D.shape, dtype=D.dtype)[:] = D
2021
f.close()
2122

22-
def run(kind, params):
23-
print("Running", kind)
23+
def run(dataset, task, k):
24+
print(f'Running {task} on {dataset}')
2425

25-
prepare(kind)
26+
prepare(dataset, task)
2627

27-
fn, _ = get_fn(kind)
28+
fn, _ = get_fn(dataset, task)
2829
f = h5py.File(fn)
29-
data = np.array(DATASETS['ccnews-small'][kind]['data'](f))
30-
queries = np.array(DATASETS['ccnews-small'][kind]['queries'](f))
30+
data = np.array(DATASETS[dataset][task]['data'](f))
31+
queries = np.array(DATASETS[dataset][task]['queries'](f))
3132
f.close()
3233

3334
n, d = data.shape
34-
k = params['k']
35+
if task == 'task2':
36+
k = k + 1 # need to search for one more NN since we cannot remove self-loop
3537

3638
nlist = 1024 # number of clusters/centroids to build the IVF from
37-
if kind == 'task1':
38-
index_identifier = f"IVF{nlist},SQfp16"
39-
elif kind == 'task2':
40-
index_identifier = f"IVF{nlist},PQ{d//2}x4fs"
39+
index_identifier = f"IVF{nlist},SQfp16"
4140

4241
index = faiss.index_factory(d, index_identifier)
4342

@@ -49,10 +48,6 @@ def run(kind, params):
4948
print(f"Done training in {elapsed_build}s.")
5049
assert index.is_trained
5150

52-
if kind == "task2":
53-
index = faiss.IndexRefineFlat(index, faiss.swig_ptr(data.astype('float32')))
54-
index.k_factor = 200
55-
5651
for nprobe in [1, 2, 5, 10, 20, 50, 100]:
5752
print(f"Starting search on {queries.shape} with nprobe={nprobe}")
5853
start = time.time()
@@ -65,7 +60,7 @@ def run(kind, params):
6560

6661
identifier = f"index=({index_identifier}),query=(nprobe={nprobe})"
6762

68-
store_results(os.path.join("result/", kind, f"{identifier}.h5"), "faissIVF", kind, D, I, elapsed_build, elapsed_search, identifier)
63+
store_results(os.path.join("results/", dataset, task, f"{identifier}.h5"), "faissIVF", dataset, task, D, I, elapsed_build, elapsed_search, identifier)
6964

7065
if __name__ == "__main__":
7166

@@ -78,22 +73,11 @@ def run(kind, params):
7873

7974
parser.add_argument(
8075
'--dataset',
81-
choices=[
82-
'ccnews-small',
83-
],
76+
choices=DATASETS.keys(),
8477
default='ccnews-small'
8578
)
8679

87-
params = {
88-
'task1': {
89-
"k": 30,
90-
},
91-
'task2': {
92-
"k": 15,
93-
}
94-
}
9580

9681
args = parser.parse_args()
97-
98-
run(args.task, params[args.task])
82+
run(args.dataset, args.task, DATASETS[args.dataset][args.task]['k'])
9983

0 commit comments

Comments
 (0)