Skip to content

Commit f5849d6

Browse files
committed
JMH: Add benchmark for querying a PQ graph index
1 parent 18488b8 commit f5849d6

1 file changed

Lines changed: 152 additions & 0 deletions

File tree

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
2+
/*
3+
* Copyright DataStax, Inc.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package io.github.jbellis.jvector.bench;
18+
19+
import io.github.jbellis.jvector.graph.GraphIndexBuilder;
20+
import io.github.jbellis.jvector.graph.GraphSearcher;
21+
import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
22+
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
23+
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
24+
import io.github.jbellis.jvector.graph.SearchResult;
25+
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
26+
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
27+
import io.github.jbellis.jvector.quantization.PQVectors;
28+
import io.github.jbellis.jvector.quantization.ProductQuantization;
29+
import io.github.jbellis.jvector.util.Bits;
30+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
31+
import io.github.jbellis.jvector.vector.VectorizationProvider;
32+
import io.github.jbellis.jvector.vector.types.VectorFloat;
33+
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
34+
import org.openjdk.jmh.annotations.*;
35+
import org.openjdk.jmh.infra.Blackhole;
36+
import org.slf4j.Logger;
37+
import org.slf4j.LoggerFactory;
38+
39+
import java.io.IOException;
40+
import java.util.ArrayList;
41+
import java.util.concurrent.TimeUnit;
42+
43+
/**
44+
* Benchmarks per-query search latency on a pre-built in-memory index with random vectors.
45+
* Index construction happens once per trial in @Setup; only the search is measured.
46+
*/
47+
@BenchmarkMode(Mode.AverageTime)
48+
@OutputTimeUnit(TimeUnit.MICROSECONDS)
49+
@State(Scope.Thread)
50+
@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=false"})
51+
@Warmup(iterations = 3)
52+
@Measurement(iterations = 5)
53+
@Threads(1)
54+
public class QueryTimeBenchmark {
55+
private static final Logger log = LoggerFactory.getLogger(QueryTimeBenchmark.class);
56+
private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport();
57+
58+
@Param({"768", "1536"})
59+
private int originalDimension;
60+
61+
@Param({"100000"})
62+
private int numBaseVectors;
63+
64+
@Param({"0", "16"})
65+
private int numberOfPQSubspaces;
66+
67+
@Param({"10"})
68+
private int topK;
69+
70+
private RandomAccessVectorValues ravv;
71+
private ImmutableGraphIndex graphIndex;
72+
private PQVectors pqVectors;
73+
74+
/** Query vectors rotated through on each invocation to avoid caching effects. */
75+
private VectorFloat<?>[] queryVectors;
76+
private int queryIndex;
77+
78+
private static final int NUM_QUERY_VECTORS = 1000;
79+
private static final int M = 32;
80+
private static final int BEAM_WIDTH = 100;
81+
82+
@Setup(Level.Trial)
83+
public void setup() throws IOException {
84+
// Build base vectors
85+
var baseVectors = new ArrayList<VectorFloat<?>>(numBaseVectors);
86+
for (int i = 0; i < numBaseVectors; i++) {
87+
baseVectors.add(createRandomVector(originalDimension));
88+
}
89+
ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension);
90+
91+
// Build index once — not measured
92+
final BuildScoreProvider buildScoreProvider;
93+
if (numberOfPQSubspaces > 0) {
94+
log.info("Building with PQ ({} subspaces), dim={}", numberOfPQSubspaces, originalDimension);
95+
ProductQuantization pq = ProductQuantization.compute(ravv, numberOfPQSubspaces, 256, true);
96+
pqVectors = (PQVectors) pq.encodeAll(ravv);
97+
buildScoreProvider = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.EUCLIDEAN, pqVectors);
98+
} else {
99+
log.info("Building with exact scorer, dim={}", originalDimension);
100+
pqVectors = null;
101+
buildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN);
102+
}
103+
104+
try (var builder = new GraphIndexBuilder(buildScoreProvider, ravv.dimension(), M, BEAM_WIDTH, 1.2f, 1.2f, true)) {
105+
graphIndex = builder.build(ravv);
106+
}
107+
108+
// Pre-generate query vectors so vector creation is not part of the measurement
109+
queryVectors = new VectorFloat<?>[NUM_QUERY_VECTORS];
110+
for (int i = 0; i < NUM_QUERY_VECTORS; i++) {
111+
queryVectors[i] = createRandomVector(originalDimension);
112+
}
113+
queryIndex = 0;
114+
}
115+
116+
@TearDown(Level.Trial)
117+
public void tearDown() {
118+
// graphIndex is AutoCloseable only if wrapped; nothing to do for ImmutableGraphIndex
119+
}
120+
121+
/**
122+
* Measures the time to execute a single query against the pre-built index.
123+
* A pool of pre-generated query vectors is cycled through
124+
*/
125+
@Benchmark
126+
public void queryBenchmark(Blackhole blackhole) throws IOException {
127+
VectorFloat<?> queryVector = queryVectors[queryIndex];
128+
queryIndex = (queryIndex + 1) % NUM_QUERY_VECTORS;
129+
130+
try (GraphSearcher searcher = new GraphSearcher(graphIndex)) {
131+
final SearchResult result;
132+
if (pqVectors != null) {
133+
var asf = pqVectors.precomputedScoreFunctionFor(queryVector, VectorSimilarityFunction.EUCLIDEAN);
134+
var reranker = ravv.rerankerFor(queryVector, VectorSimilarityFunction.EUCLIDEAN);
135+
var ssp = new DefaultSearchScoreProvider(asf, reranker);
136+
result = searcher.search(ssp, topK, topK * 2, 0.0f, 0.0f, Bits.ALL);
137+
} else {
138+
var ssp = DefaultSearchScoreProvider.exact(queryVector, VectorSimilarityFunction.EUCLIDEAN, ravv);
139+
result = searcher.search(ssp, topK, Bits.ALL);
140+
}
141+
blackhole.consume(result);
142+
}
143+
}
144+
145+
private VectorFloat<?> createRandomVector(int dimension) {
146+
VectorFloat<?> vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension);
147+
for (int i = 0; i < dimension; i++) {
148+
vector.set(i, (float) Math.random());
149+
}
150+
return vector;
151+
}
152+
}

0 commit comments

Comments
 (0)