-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathknn_utils.py
More file actions
283 lines (255 loc) · 12.2 KB
/
knn_utils.py
File metadata and controls
283 lines (255 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import os
import sys
import numpy as np
import faiss
import argparse
import struct
def read_fvecs(fname):
"""
Reads an fvec file and returns a numpy array of shape (n, d).
The file is assumed to be in the format where each vector is stored
as: [d, float, float, ..., float], with 1 integer dimension d and d floats.
"""
with open(fname, "rb") as f:
data = np .fromfile(f, dtype=np.int32)
dim = data[0].item()
f.seek(0) # Reset file pointer to re-read data correctly
data = np.fromfile(f, dtype=np.float32) # Read full data as float32
num_vectors = len(data) // (dim + 1)
return data.reshape(num_vectors, dim + 1)[:, 1:] # Remove first column (dimension)
def read_hdf5(fname, key="data"):
"""
Reads an HDF5 file and returns a numpy array from the dataset with the given key.
"""
import h5py
with h5py.File(fname, 'r') as hf:
if key not in hf:
raise ValueError(f"Key '{key}' not found in HDF5 file: {fname}")
dset = hf[key]
return np.array(dset)
def read_hdf5_tensor(fname, key="data"):
"""
Reads an HDF5 file and returns a numpy array from the dataset with the given key.
"""
import h5py
with h5py.File(fname, 'r') as hf:
if key not in hf:
raise ValueError(f"Key '{key}' not found in HDF5 file: {fname}")
tensor_arr = np.array(hf[key])
return tensor_arr.reshape(-1, tensor_arr.shape[-1])
def read_vectors(fname):
"""
Determines whether the file is HDF5 or fvec format.
For HDF5 files (extension .h5 or .hdf5), it checks for a colon.
If a colon is found, splits the string into filename and key.
Otherwise, uses the default key "data".
"""
fname = os.path.expanduser(fname)
# If a colon is present, split into file_path and key.
if ':' in fname:
file_path, key = fname.split(':', 1)
if file_path.endswith('.h5') or file_path.endswith('.hdf5'):
return read_hdf5(file_path, key)
else:
raise ValueError("For HDF5, use the format 'file.h5:key'")
else:
return read_fvecs(fname)
def write_fvecs(fname, arr):
"""
Write a numpy array (shape: n x d) to an fvec file.
Each vector is stored as: [d (int32), float, float, ..., float]
"""
n, d = arr.shape
fname = os.path.expanduser(fname) # Expand tilde to full home directory path
with open(fname, "wb") as f:
d_repr = struct.unpack("<f", np.uint32(d))[0]
# fvecs format: [[dim, vec1...], [dim, vec2...]]
formatted = np.concatenate((np.full((n, 1), d_repr, dtype=np.float32), arr.astype(np.float32)), axis=1)
assert(struct.unpack("<I", formatted[0][0]) == (d,))
formatted.tofile(f)
def write_ivecs(fname, ivecs):
"""
Writes an array of integer vectors to an ivec file.
Each vector is written as: [k, int, int, ..., int] where k is the number
of elements in the vector.
"""
n, k = ivecs.shape
fname = os.path.expanduser(fname) # Expand tilde to full home directory path
with open(fname, 'wb') as f:
# ivecs format: [[k, int...(k times)], [k, int...(k times)]]
formatted = np.concatenate((np.full((n, 1), k, dtype=np.int32), ivecs.astype(np.int32)), axis=1)
formatted.tofile(f)
def count_zero_vectors(vecs, eps=0.0):
"""
Count vectors with L2 norm <= eps. Use eps=0.0 for exact zeros.
"""
norms = np.linalg.norm(vecs, axis=1)
return int(np.sum(norms <= eps))
def remove_zero_vectors(arr, name, eps=0.0):
# eps lets you treat "near-zero" as zero if desired; keep eps=0.0 for exact zeros
norms = np.linalg.norm(arr, axis=1)
keep = norms > eps
removed = int((~keep).sum())
if removed:
print(f"Removed {removed} zero vectors from {name} (kept {keep.sum()} / {arr.shape[0]}).")
else:
print(f"Removed 0 zero vectors from {name}.")
return arr[keep]
def check_normalization(vecs, tol=1e-3):
"""
Returns True if all vectors in the array are approximately normalized
(L2 norm close to 1 within the specified tolerance).
"""
norms = np.linalg.norm(vecs, axis=1)
return np.all(np.abs(norms - 1) < tol)
def build_index(base, d, metric, gpu_ids):
"""
Build a FAISS index for the given base vectors, dimension and metric.
gpu_ids should be a list of integers.
"""
if metric == "l2":
cpu_index = faiss.IndexFlatL2(d)
elif metric == "ip":
cpu_index = faiss.IndexFlatIP(d)
else:
raise ValueError("Unsupported metric: " + metric)
if gpu_ids[0] < 0:
print("Using device: cpu")
index = cpu_index
elif len(gpu_ids) == 1:
print("Using device: cuda({})".format(gpu_ids[0]))
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, gpu_ids[0], cpu_index)
else:
print("Using devices:", ", ".join("cuda({})".format(g) for g in gpu_ids))
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.devices = gpu_ids
index = faiss.index_cpu_to_all_gpus(cpu_index, co=co)
index.add(base)
return index
def main():
parser = argparse.ArgumentParser(
description='Compute ground truth for nearest neighbor search using a GPU.')
parser.add_argument('--base', type=str, required=True,
help='Path to the base vectors file (fvec or HDF5). For HDF5, use the format "file.h5:key".')
parser.add_argument('--query', type=str, required=True,
help='Path to the query vectors file (fvec or HDF5). For HDF5, use the format "file.h5:key".')
parser.add_argument('--output', type=str, required=True,
help='Output ivec file to write ground truth indices.')
parser.add_argument('--num_base', type=int, default=0,
help='Number of base vectors for truncated dataset (if 0, skip truncation).')
parser.add_argument('--num_query', type=int, default=0,
help='Number of query vectors for truncated dataset (if 0, skip truncation).')
parser.add_argument('--remove_zeros', action='store_true', default=False,
help='If set, remove zero-norm vectors from both base and query.')
parser.add_argument('--shuffle', action='store_true', default=False,
help='If set, shuffle both base and query vectors.')
parser.add_argument('--normalize', action='store_true', default=False,
help='If set, normalize both base and query vectors.')
parser.add_argument('--processed_base_out', type=str, default="",
help='Output file for processed base vectors (fvec file) if truncation or normalization is applied.')
parser.add_argument('--processed_query_out', type=str, default="",
help='Output file for processed query vectors (fvec file) if truncation or normalization is applied.')
parser.add_argument('--k', type=int, required=True,
help='Number of nearest neighbors to compute ground truth indices for.')
parser.add_argument('--gpus', type=str, default="-1",
help='Comma-separated list of GPU ids to use. Use "-1" for CPU.')
parser.add_argument('--metric', type=str, default='l2', choices=['l2', 'ip'],
help='Distance metric to use: "l2" or "ip".')
if len(sys.argv) == 1:
parser.print_help(sys.stderr) # prints "usage:" line + options
sys.exit(2)
args = parser.parse_args()
gpu_ids = [int(x) for x in args.gpus.split(',')]
# Load base and query vectors from files
print("Loading base vectors from:", args.base)
base = read_vectors(args.base)
print(f"Loaded {base.shape[0]} base vectors of dimension {base.shape[1]}.")
print("Loading query vectors from:", args.query)
query = read_vectors(args.query)
print(f"Loaded {query.shape[0]} query vectors of dimension {query.shape[1]}.")
# Ensure dimensions match
d = base.shape[1]
if query.shape[1] != d:
raise ValueError("Dimension mismatch: base vectors have dimension {} but query vectors have dimension {}."
.format(d, query.shape[1]))
# Check for zero vectors (before any shuffling/truncation/normalization)
base_zero = count_zero_vectors(base) # exact zeros
query_zero = count_zero_vectors(query) # exact zeros
print(f"Base zero vectors: {base_zero} / {base.shape[0]}")
print(f"Query zero vectors: {query_zero} / {query.shape[0]}")
if args.remove_zeros:
print("Removing zero vectors from both base and query.")
base = remove_zero_vectors(base, "base")
query = remove_zero_vectors(query, "query")
if base.shape[0] == 0:
raise ValueError("All base vectors were zero after removal.")
if query.shape[0] == 0:
raise ValueError("All query vectors were zero after removal.")
# Check normalization of base and query vectors.
base_normalized = check_normalization(base)
query_normalized = check_normalization(query)
print("Base vectors normalized:", "Yes" if base_normalized else "No")
print("Query vectors normalized:", "Yes" if query_normalized else "No")
# Optionally shuffle both base and query vectors.
# Shuffle the full dataset before truncation.
if args.shuffle:
print("Shuffling both base and query vectors.")
np.random.seed(42) # For reproducibility
np.random.shuffle(base)
np.random.shuffle(query)
# Process any truncations.
if args.num_base > 0:
if args.num_base > base.shape[0]:
raise ValueError("Truncated base size exceeds full dataset size.")
base = base[:args.num_base]
print(f"Using truncated base: {args.num_base} vectors.")
if args.num_query > 0:
if args.num_query > query.shape[0]:
raise ValueError("Truncated query size exceeds full dataset size.")
query = query[:args.num_query]
print(f"Using truncated query: {args.num_query} vectors.")
# Apply normalization if requested (to both base and query).
if args.normalize:
def normalize_vectors(arr):
norms = np.linalg.norm(arr, axis=1, keepdims=True)
norms[norms == 0] = 1 # Prevent division by zero.
return arr / norms
base = normalize_vectors(base)
query = normalize_vectors(query)
print("Normalized both base and query vectors.")
# Require processed output filenames when processing is applied.
if args.remove_zeros or args.normalize or args.shuffle or args.num_base > 0 and args.num_query > 0:
if not args.processed_base_out or not args.processed_query_out:
raise ValueError(
"When removing zeros, normalization, shuffling, or truncation is applied, processed_base_out and processed_query_out must be provided. ")
print("Writing processed base vectors to:", args.processed_base_out)
write_fvecs(args.processed_base_out, base)
print("Writing processed query vectors to:", args.processed_query_out)
write_fvecs(args.processed_query_out, query)
elif args.num_base > 0:
if not args.processed_base_out:
raise ValueError(
"When truncation is applied, processed_base_out must be provided.")
print("Writing processed base vectors to:", args.processed_base_out)
write_fvecs(args.processed_base_out, base)
elif args.num_query > 0:
if not args.processed_query_out:
raise ValueError(
"When truncation is applied, processed_query_out must be provided.")
print("Writing processed query vectors to:", args.processed_query_out)
write_fvecs(args.processed_query_out, query)
# Creating index and adding base vectors.
print("Adding base vectors to the index...")
index = build_index(base, d, args.metric, gpu_ids)
# Perform the search for each query.
print("Performing nearest neighbor search for k =", args.k)
distances, indices = index.search(query, args.k)
print("Search completed.")
# Write the ground truth indices to the output ivec file.
print("Writing results to output file:", args.output)
write_ivecs(args.output, indices)
print("Done.")
if __name__ == '__main__':
main()