Skip to content

Commit 77d4945

Browse files
committed
JMH: Add benchmark for querying a FusedPQ index
1 parent f5849d6 commit 77d4945

1 file changed

Lines changed: 204 additions & 0 deletions

File tree

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.bench;
18+
19+
import io.github.jbellis.jvector.graph.GraphIndexBuilder;
20+
import io.github.jbellis.jvector.graph.GraphSearcher;
21+
import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
22+
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
23+
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
24+
import io.github.jbellis.jvector.graph.SearchResult;
25+
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
26+
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
27+
import io.github.jbellis.jvector.quantization.PQVectors;
28+
import io.github.jbellis.jvector.quantization.ProductQuantization;
29+
import io.github.jbellis.jvector.util.Bits;
30+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
31+
import io.github.jbellis.jvector.vector.VectorizationProvider;
32+
import io.github.jbellis.jvector.vector.types.VectorFloat;
33+
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
34+
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
35+
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter;
36+
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
37+
import io.github.jbellis.jvector.graph.disk.feature.Feature;
38+
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
39+
import io.github.jbellis.jvector.graph.disk.feature.FusedPQ;
40+
import io.github.jbellis.jvector.graph.disk.feature.InlineVectors;
41+
import io.github.jbellis.jvector.disk.ReaderSupplierFactory;
42+
import java.util.*;
43+
import org.openjdk.jmh.annotations.*;
44+
import org.openjdk.jmh.infra.Blackhole;
45+
import org.slf4j.Logger;
46+
import org.slf4j.LoggerFactory;
47+
48+
import java.io.IOException;
49+
import java.nio.file.Files;
50+
import java.nio.file.Path;
51+
import java.util.*;
52+
import java.util.concurrent.TimeUnit;
53+
import java.util.function.IntFunction;
54+
import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED;
55+
56+
@BenchmarkMode(Mode.AverageTime)
57+
@OutputTimeUnit(TimeUnit.MICROSECONDS)
58+
@State(Scope.Benchmark)
59+
@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=true"})
60+
@Warmup(iterations = 3)
61+
@Measurement(iterations = 5)
62+
@Threads(1)
63+
public class FusedPQQueryBenchmark {
64+
private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport();
65+
66+
private OnDiskGraphIndex index;
67+
private ArrayList<VectorFloat<?>> queryVectors;
68+
private Path indexPath;
69+
private Path tempDir;
70+
71+
@Param({"1536"})
72+
int dimension;
73+
74+
@Param({"96"})
75+
int pqM;
76+
77+
@Param({"100000"})
78+
int numBaseVectors;
79+
80+
@Param({"100"})
81+
int numQueryVectors;
82+
83+
@Param({"10"})
84+
int topK;
85+
86+
@Param({"100"})
87+
int efSearch;
88+
89+
@Setup(Level.Trial)
90+
public void setup() throws IOException {
91+
System.out.println("Setting up FusedPQ index...");
92+
93+
// 1. Create base vectors
94+
var baseVectors = new ArrayList<VectorFloat<?>>(numBaseVectors);
95+
for (int i = 0; i < numBaseVectors; i++) {
96+
baseVectors.add(createRandomVector(dimension));
97+
}
98+
RandomAccessVectorValues floatVectors = new ListRandomAccessVectorValues(baseVectors, dimension);
99+
100+
// 2. Create query vectors
101+
queryVectors = new ArrayList<>(numQueryVectors);
102+
for (int i = 0; i < numQueryVectors; i++) {
103+
queryVectors.add(createRandomVector(dimension));
104+
}
105+
106+
// 3. Compute PQ compression
107+
System.out.println("Computing PQ compression...");
108+
boolean centerData = false; // false for DOT_PRODUCT/COSINE
109+
var pq = ProductQuantization.compute(floatVectors, pqM, 256, centerData, UNWEIGHTED);
110+
var pqVectors = (PQVectors) pq.encodeAll(floatVectors);
111+
System.out.printf("PQ: %d subspaces, 256 clusters%n", pqM);
112+
113+
// 4. Build graph with PQ-compressed vectors
114+
System.out.println("Building graph...");
115+
int M = 16;
116+
int efConstruction = 100;
117+
float neighborOverflow = 1.2f;
118+
float alpha = 1.2f;
119+
boolean addHierarchy = true;
120+
boolean refineFinalGraph = true;
121+
122+
var bsp = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.DOT_PRODUCT, pqVectors);
123+
var builder = new GraphIndexBuilder(bsp, dimension, M, efConstruction,
124+
neighborOverflow, alpha, addHierarchy, refineFinalGraph);
125+
var graph = builder.build(floatVectors);
126+
System.out.printf("Graph built: %d nodes%n", graph.size(0));
127+
128+
// 5. Write FusedPQ index to disk
129+
System.out.println("Writing FusedPQ index to disk...");
130+
tempDir = Files.createTempDirectory("fusedpq-bench");
131+
indexPath = tempDir.resolve("fusedpq-index");
132+
133+
var fusedPQFeature = new FusedPQ(graph.maxDegree(), pq);
134+
var inlineVectors = new InlineVectors(dimension);
135+
136+
try (var writer = new OnDiskGraphIndexWriter.Builder(graph, indexPath)
137+
.with(fusedPQFeature)
138+
.with(inlineVectors)
139+
.withMapper(new OrdinalMapper.IdentityMapper(floatVectors.size() - 1))
140+
.build()) {
141+
142+
var view = graph.getView();
143+
Map<FeatureId, IntFunction<Feature.State>> suppliers = new EnumMap<>(FeatureId.class);
144+
suppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(view, pqVectors, ordinal));
145+
suppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(floatVectors.getVector(ordinal)));
146+
147+
writer.write(suppliers);
148+
view.close();
149+
}
150+
151+
builder.close();
152+
System.out.printf("Index written: %.2f MB%n", Files.size(indexPath) / 1024.0 / 1024.0);
153+
154+
// 6. Load the index
155+
System.out.println("Loading index...");
156+
index = OnDiskGraphIndex.load(ReaderSupplierFactory.open(indexPath));
157+
System.out.println("Setup complete!");
158+
}
159+
160+
@TearDown(Level.Trial)
161+
public void tearDown() throws IOException {
162+
if (index != null) {
163+
index.close();
164+
}
165+
if (indexPath != null && Files.exists(indexPath)) {
166+
Files.deleteIfExists(indexPath);
167+
}
168+
if (tempDir != null && Files.exists(tempDir)) {
169+
Files.deleteIfExists(tempDir);
170+
}
171+
if (queryVectors != null) {
172+
queryVectors.clear();
173+
}
174+
}
175+
176+
@Benchmark
177+
public void queryFusedPQ(Blackhole blackhole) throws IOException {
178+
// Perform queries on all query vectors
179+
for (VectorFloat<?> queryVector : queryVectors) {
180+
try (var view = index.getView()) {
181+
var scoringView = (ImmutableGraphIndex.ScoringView) view;
182+
183+
// Get score functions - FusedPQ for approximate, then rerank
184+
var asf = scoringView.approximateScoreFunctionFor(queryVector, VectorSimilarityFunction.DOT_PRODUCT);
185+
var reranker = scoringView.rerankerFor(queryVector, VectorSimilarityFunction.DOT_PRODUCT);
186+
var ssp = new io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider(asf, reranker);
187+
188+
// Search
189+
var searcher = new GraphSearcher(index);
190+
SearchResult result = searcher.search(ssp, topK, efSearch, 1.0f, 0.0f, io.github.jbellis.jvector.util.Bits.ALL);
191+
192+
blackhole.consume(result);
193+
}
194+
}
195+
}
196+
197+
private VectorFloat<?> createRandomVector(int dimension) {
198+
VectorFloat<?> vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension);
199+
for (int i = 0; i < dimension; i++) {
200+
vector.set(i, (float) Math.random());
201+
}
202+
return vector;
203+
}
204+
}

0 commit comments

Comments
 (0)