Skip to content

Commit e30816a

Browse files
committed
update graph to RAVV mapping for buildScoreProvider
Signed-off-by: Samuel Herman <sherman8915@gmail.com>
1 parent f6d0b97 commit e30816a

4 files changed

Lines changed: 171 additions & 42 deletions

File tree

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -826,18 +826,12 @@ public static OnHeapGraphIndex buildAndMergeNewNodes(OnDiskGraphIndex onDiskGrap
826826
PhysicalCoreExecutor.pool(),
827827
ForkJoinPool.commonPool())) {
828828

829-
// Add each new vector incrementally
830-
final List<ForkJoinTask<?>> forkJoinTask = new ArrayList<>(newVectors.size());
831-
for (int i = startingNodeOffset; i < newVectors.size(); i++) {
832-
final int nodeId = i;
833-
final VectorFloat<?> vector = newVectors.getVector(graphToRavvOrdMap[nodeId]);
834-
835-
// The GraphIndexBuilder can add nodes to an existing index
836-
forkJoinTask.add(PhysicalCoreExecutor.pool().submit(() -> builder.addGraphNode(nodeId, vector)));
837-
}
838-
for (ForkJoinTask<?> task : forkJoinTask) {
839-
task.join();
840-
}
829+
var vv = newVectors.threadLocalSupplier();
830+
831+
// parallel graph construction from the merge documents Ids
832+
PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
833+
builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord]));
834+
})).join();
841835

842836
builder.cleanup();
843837
return builder.getGraph();

jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,58 @@ public SearchScoreProvider diversityProviderFor(int node1) {
133133
};
134134
}
135135

136+
/**
137+
* Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction.
138+
*/
139+
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) {
140+
// We need two sources of vectors in order to perform diversity check comparisons without
141+
// colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared.
142+
var vectors = ravv.threadLocalSupplier();
143+
var vectorsCopy = ravv.threadLocalSupplier();
144+
145+
return new BuildScoreProvider() {
146+
@Override
147+
public boolean isExact() {
148+
return true;
149+
}
150+
151+
@Override
152+
public VectorFloat<?> approximateCentroid() {
153+
var vv = vectors.get();
154+
var centroid = vts.createFloatVector(vv.dimension());
155+
for (int i = 0; i < vv.size(); i++) {
156+
var v = vv.getVector(i);
157+
if (v != null) { // MapRandomAccessVectorValues is not necessarily dense
158+
VectorUtil.addInPlace(centroid, v);
159+
}
160+
}
161+
VectorUtil.scale(centroid, 1.0f / vv.size());
162+
return centroid;
163+
}
164+
165+
@Override
166+
public SearchScoreProvider searchProviderFor(VectorFloat<?> vector) {
167+
var vc = vectorsCopy.get();
168+
return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc);
169+
}
170+
171+
@Override
172+
public SearchScoreProvider searchProviderFor(int node1) {
173+
RandomAccessVectorValues randomAccessVectorValues = vectors.get();
174+
var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]);
175+
return searchProviderFor(v);
176+
}
177+
178+
@Override
179+
public SearchScoreProvider diversityProviderFor(int node1) {
180+
RandomAccessVectorValues randomAccessVectorValues = vectors.get();
181+
var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]);
182+
var vc = vectorsCopy.get();
183+
return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc);
184+
}
185+
};
186+
}
187+
136188
/**
137189
* Returns a BSP that performs approximate score comparisons using the given PQVectors,
138190
* with reranking performed using RandomAccessVectorValues (which is intended to be

jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,20 @@ public float similarityTo(int node2) {
7878
};
7979
return new DefaultSearchScoreProvider(sf);
8080
}
81+
82+
/**
83+
* A SearchScoreProvider for a single-pass search based on exact similarity.
84+
* Generally only suitable when your RandomAccessVectorValues is entirely in-memory,
85+
* e.g. during construction.
86+
*/
87+
public static DefaultSearchScoreProvider exact(VectorFloat<?> v, int[] graphToRavvOrdMap ,VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) {
88+
// don't use ESF.reranker, we need thread safety here
89+
var sf = new ScoreFunction.ExactScoreFunction() {
90+
@Override
91+
public float similarityTo(int node2) {
92+
return vsf.compare(v, ravv.getVector(graphToRavvOrdMap[node2]));
93+
}
94+
};
95+
return new DefaultSearchScoreProvider(sf);
96+
}
8197
}

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java

Lines changed: 97 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache;
2525
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
2626
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
27+
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
2728
import io.github.jbellis.jvector.util.Bits;
2829
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
2930
import io.github.jbellis.jvector.vector.VectorizationProvider;
3031
import io.github.jbellis.jvector.vector.types.VectorFloat;
3132
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
3233
import org.apache.logging.log4j.Logger;
3334
import org.junit.After;
35+
import org.junit.Assert;
3436
import org.junit.Before;
3537
import org.junit.Test;
3638

@@ -50,15 +52,16 @@
5052
public class OnHeapGraphIndexTest extends RandomizedTest {
5153
private final static Logger log = org.apache.logging.log4j.LogManager.getLogger(OnHeapGraphIndexTest.class);
5254
private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport();
53-
private static final int numBaseVectors = 100;
54-
private static final int numNewVectors = 100;
55-
private static final int numAllVectors = numBaseVectors + numNewVectors;
56-
private static final int dimension = 16;
55+
private static final int NUM_BASE_VECTORS = 100;
56+
private static final int NUM_NEW_VECTORS = 100;
57+
private static final int NUM_ALL_VECTORS = NUM_BASE_VECTORS + NUM_NEW_VECTORS;
58+
private static final int DIMENSION = 16;
5759
private static final int M = 8;
58-
private static final int beamWidth = 100;
59-
private static final float alpha = 1.2f;
60-
private static final float neighborOverflow = 1.2f;
61-
private static final boolean addHierarchy = false;
60+
private static final int BEAM_WIDTH = 100;
61+
private static final float ALPHA = 1.2f;
62+
private static final float NEIGHBOR_OVERFLOW = 1.2f;
63+
private static final boolean ADD_HIERARCHY = false;
64+
private static final int TOP_K = 10;
6265

6366
private Path testDirectory;
6467

@@ -68,6 +71,9 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
6871
private RandomAccessVectorValues baseVectorsRavv;
6972
private RandomAccessVectorValues newVectorsRavv;
7073
private RandomAccessVectorValues allVectorsRavv;
74+
private VectorFloat<?> queryVector;
75+
private int[] groundTruthBaseVectors;
76+
private int[] groundTruthAllVectors;
7177
private BuildScoreProvider baseBuildScoreProvider;
7278
private BuildScoreProvider newBuildScoreProvider;
7379
private BuildScoreProvider allBuildScoreProvider;
@@ -78,42 +84,47 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
7884
@Before
7985
public void setup() throws IOException {
8086
testDirectory = Files.createTempDirectory(this.getClass().getSimpleName());
81-
baseVectors = new ArrayList<>(numBaseVectors);
82-
newVectors = new ArrayList<>(numNewVectors);
83-
allVectors = new ArrayList<>(numAllVectors);
84-
for (int i = 0; i < numBaseVectors; i++) {
85-
VectorFloat<?> vector = createRandomVector(dimension);
87+
baseVectors = new ArrayList<>(NUM_BASE_VECTORS);
88+
newVectors = new ArrayList<>(NUM_NEW_VECTORS);
89+
allVectors = new ArrayList<>(NUM_ALL_VECTORS);
90+
for (int i = 0; i < NUM_BASE_VECTORS; i++) {
91+
VectorFloat<?> vector = createRandomVector(DIMENSION);
8692
baseVectors.add(vector);
8793
allVectors.add(vector);
8894
}
89-
for (int i = 0; i < numNewVectors; i++) {
90-
VectorFloat<?> vector = createRandomVector(dimension);
95+
for (int i = 0; i < NUM_NEW_VECTORS; i++) {
96+
VectorFloat<?> vector = createRandomVector(DIMENSION);
9197
newVectors.add(vector);
9298
allVectors.add(vector);
9399
}
94100

95101
// wrap the raw vectors in a RandomAccessVectorValues
96-
baseVectorsRavv = new ListRandomAccessVectorValues(baseVectors, dimension);
97-
newVectorsRavv = new ListRandomAccessVectorValues(newVectors, dimension);
98-
allVectorsRavv = new ListRandomAccessVectorValues(allVectors, dimension);
102+
baseVectorsRavv = new ListRandomAccessVectorValues(baseVectors, DIMENSION);
103+
newVectorsRavv = new ListRandomAccessVectorValues(newVectors, DIMENSION);
104+
allVectorsRavv = new ListRandomAccessVectorValues(allVectors, DIMENSION);
99105

106+
queryVector = createRandomVector(DIMENSION);
107+
groundTruthBaseVectors = getGroundTruth(baseVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN);
108+
groundTruthAllVectors = getGroundTruth(allVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN);
109+
110+
// score provider using the raw, in-memory vectors
100111
baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, VectorSimilarityFunction.EUCLIDEAN);
101112
newBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(newVectorsRavv, VectorSimilarityFunction.EUCLIDEAN);
102113
allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, VectorSimilarityFunction.EUCLIDEAN);
103114
var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider,
104115
baseVectorsRavv.dimension(),
105116
M, // graph degree
106-
beamWidth, // construction search depth
107-
neighborOverflow, // allow degree overflow during construction by this factor
108-
alpha, // relax neighbor diversity requirement by this factor
109-
addHierarchy); // add the hierarchy
117+
BEAM_WIDTH, // construction search depth
118+
NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor
119+
ALPHA, // relax neighbor diversity requirement by this factor
120+
ADD_HIERARCHY); // add the hierarchy
110121
var allGraphIndexBuilder = new GraphIndexBuilder(allBuildScoreProvider,
111122
allVectorsRavv.dimension(),
112123
M, // graph degree
113-
beamWidth, // construction search depth
114-
neighborOverflow, // allow degree overflow during construction by this factor
115-
alpha, // relax neighbor diversity requirement by this factor
116-
addHierarchy); // add the hierarchy
124+
BEAM_WIDTH, // construction search depth
125+
NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor
126+
ALPHA, // relax neighbor diversity requirement by this factor
127+
ADD_HIERARCHY); // add the hierarchy
117128

118129
baseGraphIndex = baseGraphIndexBuilder.build(baseVectorsRavv);
119130
allGraphIndex = allGraphIndexBuilder.build(allVectorsRavv);
@@ -156,7 +167,7 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException {
156167
validateVectors(onDiskView, baseVectorsRavv);
157168
}
158169

159-
OnHeapGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex.convertToHeap(onDiskGraph, neighborsScoreCacheRead, baseBuildScoreProvider, neighborOverflow, alpha);
170+
OnHeapGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex.convertToHeap(onDiskGraph, neighborsScoreCacheRead, baseBuildScoreProvider, NEIGHBOR_OVERFLOW, ALPHA);
160171
TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex);
161172
TestUtil.assertGraphEquals(onDiskGraph, reconstructedOnHeapGraphIndex);
162173

@@ -178,12 +189,18 @@ public void testIncrementalInsertionFromOnDiskIndex() throws IOException {
178189
TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph);
179190
// We will create a trivial 1:1 mapping between the new graph and the ravv
180191
final int[] graphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).toArray();
181-
OnHeapGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, neighborsScoreCache, allVectorsRavv, allBuildScoreProvider, numBaseVectors, graphToRavvOrdMap, beamWidth, neighborOverflow, alpha, addHierarchy);
192+
OnHeapGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(onDiskGraph, neighborsScoreCache, allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY);
193+
194+
// Verify that the recall is similar
195+
float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K);
196+
float recallFromAllGraphIndex = calculateRecall(allGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K);
197+
Assert.assertEquals(recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex, 0.01f);
182198

199+
// Verify that the result sets overlap
183200
try (GraphSearcher reconstructedAllGraphSearcher = new GraphSearcher(reconstructedAllNodeOnHeapGraphIndex);
184201
GraphSearcher allGraphSearcher = new GraphSearcher(allGraphIndex)) {
185-
final int topK = 10;
186-
VectorFloat<?> queryVector = createRandomVector(dimension);
202+
final int topK = TOP_K;
203+
VectorFloat<?> queryVector = createRandomVector(DIMENSION);
187204
var resultFromReconstructed = reconstructedAllGraphSearcher.search(allBuildScoreProvider.searchProviderFor(queryVector), topK, Bits.ALL);
188205
var resultFromAll = allGraphSearcher.search(allBuildScoreProvider.searchProviderFor(queryVector), topK, Bits.ALL);
189206
log.info("Reconstructed result: {}, all result: {}", resultFromReconstructed, resultFromAll);
@@ -210,4 +227,54 @@ private VectorFloat<?> createRandomVector(int dimension) {
210227
}
211228
return vector;
212229
}
230+
231+
/**
232+
* Get the ground truth for a query vector
233+
* @param ravv the vectors to search
234+
* @param queryVector the query vector
235+
* @param topK the number of results to return
236+
* @param similarityFunction the similarity function to use
237+
238+
* @return the ground truth
239+
*/
240+
private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat<?> queryVector, int topK, VectorSimilarityFunction similarityFunction) {
241+
var exactResults = new ArrayList<SearchResult.NodeScore>();
242+
for (int i = 0; i < ravv.size(); i++) {
243+
float similarityScore = similarityFunction.compare(queryVector, ravv.getVector(i));
244+
exactResults.add(new SearchResult.NodeScore(i, similarityScore));
245+
}
246+
exactResults.sort((a, b) -> Float.compare(b.score, a.score));
247+
return exactResults.stream().limit(topK).mapToInt(nodeScore -> nodeScore.node).toArray();
248+
}
249+
250+
private static float calculateRecall(OnHeapGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat<?> queryVector, int[] groundTruth, int k) throws IOException {
251+
try (GraphSearcher graphSearcher = new GraphSearcher(graphIndex)){
252+
SearchScoreProvider ssp = buildScoreProvider.searchProviderFor(queryVector);
253+
var searchResults = graphSearcher.search(ssp, k, Bits.ALL);
254+
var predicted = Arrays.stream(searchResults.getNodes()).mapToInt(nodeScore -> nodeScore.node).boxed().collect(Collectors.toSet());
255+
return calculateRecall(predicted, groundTruth, k);
256+
}
257+
}
258+
/**
259+
* Calculate the recall for a set of predicted results
260+
* @param predicted the predicted results
261+
* @param groundTruth the ground truth
262+
* @param k the number of results to consider
263+
* @return the recall
264+
*/
265+
private static float calculateRecall(Set<Integer> predicted, int[] groundTruth, int k) {
266+
int hits = 0;
267+
int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length));
268+
269+
for (int i = 0; i < actualK; i++) {
270+
for (int j = 0; j < actualK; j++) {
271+
if (predicted.contains(groundTruth[j])) {
272+
hits++;
273+
break;
274+
}
275+
}
276+
}
277+
278+
return ((float) hits) / (float) actualK;
279+
}
213280
}

0 commit comments

Comments
 (0)