Skip to content

Commit 06c82ac

Browse files
authored
Improve performance and checking of recall and precision measurements (#647)
* Perf optimized version of AccuracyMetrics and removal of dead and noisy code. * AccuracyMetrics now treats duplicate ground truth and retrieved results as an error condition. * Added a check for null elements in the ground truth set.
1 parent 8c75f1b commit 06c82ac

2 files changed

Lines changed: 425 additions & 49 deletions

File tree

jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
package io.github.jbellis.jvector.example.util;
1818

1919
import io.github.jbellis.jvector.graph.SearchResult;
20-
21-
import java.util.Arrays;
20+
import java.util.HashSet;
2221
import java.util.List;
23-
import java.util.stream.Collectors;
24-
import java.util.stream.IntStream;
22+
import java.util.Set;
2523

2624
/**
2725
* Computes accuracy metrics, such as recall and mean average precision.
@@ -41,43 +39,54 @@ public static double recallFromSearchResults(List<? extends List<Integer>> gt, L
4139
if (gt.size() != retrieved.size()) {
4240
throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements");
4341
}
44-
Long correctCount = IntStream.range(0, gt.size())
45-
.mapToObj(i -> topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved))
46-
.reduce(0L, Long::sum);
42+
43+
long correctCount = 0;
44+
for (int i = 0; i < gt.size(); i++) {
45+
correctCount += topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved);
46+
}
47+
4748
return (double) correctCount / (kGT * gt.size());
4849
}
4950

50-
private static long topKCorrect(List<Integer> gt, List<Integer> retrieved, int kGT, int kRetrieved) {
51+
private static long topKCorrect(List<Integer> gt, SearchResult retrieved, int kGT, int kRetrieved) {
52+
// Exception validation
53+
var nodes = retrieved.getNodes();
5154
if (kGT > kRetrieved) {
5255
throw new IllegalArgumentException("kGT: " + kGT + " > kRetrieved: " + kRetrieved);
5356
}
5457
if (kGT > gt.size()) {
5558
throw new IllegalArgumentException("kGT: " + kGT + " > Gt size: " + gt.size());
5659
}
57-
if (kRetrieved > retrieved.size()) {
58-
throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + retrieved.size());
60+
if (kRetrieved > nodes.length) {
61+
throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + nodes.length);
5962
}
6063

61-
var gtView = crop(gt, kGT);
62-
var retrievedView = crop(retrieved, kRetrieved);
63-
64-
if (gtView.size() > retrieved.size()) {
65-
return gtView.stream().filter(retrievedView::contains).count();
66-
} else {
67-
return retrievedView.stream().filter(gtView::contains).count();
64+
// Build HashSet with explicit capacity to avoid rehashing.
65+
// Load factor is 0.75, so sized to kGT / 0.75.
66+
Set<Integer> gtSet = new HashSet<>((int) (kGT / 0.75f) + 1);
67+
for (int i = 0; i < kGT; i++) {
68+
Integer ord = gt.get(i);
69+
if (ord == null) {
70+
throw new IllegalArgumentException("Null ground truth ordinal in top-" + kGT + " at index " + i);
71+
}
72+
if (!gtSet.add(ord)) {
73+
throw new IllegalArgumentException("Duplicate ground truth ordinal in top-" + kGT + ": " + ord);
74+
}
6875
}
69-
}
7076

71-
private static long topKCorrect(List<Integer> gt, SearchResult retrieved, int kGT, int kRetrieved) {
72-
var temp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node)
73-
.boxed()
74-
.collect(Collectors.toList());
75-
return topKCorrect(gt, temp, kGT, kRetrieved);
76-
}
77+
Set<Integer> seenRetrieved = new HashSet<>((int) (kRetrieved / 0.75f) + 1);
78+
int hits = 0;
79+
for (int i = 0; i < kRetrieved; i++) {
80+
int p = nodes[i].node;
81+
if (!seenRetrieved.add(p)) {
82+
throw new IllegalArgumentException("Duplicate retrieved ordinal in top-" + kRetrieved + ": " + p);
83+
}
84+
if (gtSet.contains(p)) {
85+
hits++;
86+
}
87+
}
7788

78-
private static List<Integer> crop(List<Integer> list, int k) {
79-
int count = Math.min(list.size(), k);
80-
return list.subList(0, count);
89+
return hits;
8190
}
8291

8392
/**
@@ -89,33 +98,41 @@ private static List<Integer> crop(List<Integer> list, int k) {
8998
* @return the average precision
9099
*/
91100
public static double averagePrecisionAtK(List<Integer> gt, SearchResult retrieved, int k) {
92-
var retrievedTemp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node)
93-
.boxed()
94-
.collect(Collectors.toList());
95-
101+
var nodes = retrieved.getNodes();
96102
if (k > gt.size()) {
97103
throw new IllegalArgumentException("k: " + k + " > Gt size: " + gt.size());
98104
}
99-
if (k > retrievedTemp.size()) {
100-
throw new IllegalArgumentException("k: " + k + " > retrieved size: " + retrievedTemp.size());
105+
if (k > nodes.length) {
106+
throw new IllegalArgumentException("k: " + k + " > retrieved size: " + nodes.length);
101107
}
102108

103-
var gtView = crop(gt, k);
104-
var retrievedView = crop(retrievedTemp, k);
109+
// Sized hashset used for performance.
110+
Set<Integer> gtSet = new HashSet<>((int) (k / 0.75f) + 1);
111+
for (int i = 0; i < k; i++) {
112+
Integer ord = gt.get(i);
113+
if (ord == null) {
114+
throw new IllegalArgumentException("Null ground truth ordinal in top-" + k + " at index " + i);
115+
}
116+
if (!gtSet.add(ord)) {
117+
throw new IllegalArgumentException("Duplicate ground truth ordinal in top-" + k + ": " + ord);
118+
}
119+
}
105120

121+
Set<Integer> seenRetrieved = new HashSet<>((int) (k / 0.75f) + 1);
106122
double score = 0.;
107-
int num_hits = 0;
108-
int i = 0;
109-
110-
for (var p : retrievedView) {
111-
if (gtView.contains(p) && !retrievedView.subList(0, i).contains(p)) {
112-
num_hits += 1;
113-
score += num_hits / (i + 1.0);
123+
int hits = 0;
124+
for (int i = 0; i < k; i++) {
125+
int p = nodes[i].node;
126+
if (!seenRetrieved.add(p)) {
127+
throw new IllegalArgumentException("Duplicate retrieved ordinal in top-" + k + ": " + p);
128+
}
129+
if (gtSet.contains(p)) {
130+
hits++;
131+
score += (double) hits / (i + 1);
114132
}
115-
i++;
116133
}
117134

118-
return score / gtView.size();
135+
return score / k;
119136
}
120137

121138
/**
@@ -130,10 +147,12 @@ public static double meanAveragePrecisionAtK(List<? extends List<Integer>> gt, L
130147
if (gt.size() != retrieved.size()) {
131148
throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements");
132149
}
133-
Double apk = IntStream.range(0, gt.size())
134-
.mapToObj(i -> averagePrecisionAtK(gt.get(i), retrieved.get(i), k))
135-
.reduce(0., Double::sum);
136-
return apk / gt.size();
137-
}
138150

151+
double totalAp = 0;
152+
for (int i = 0; i < gt.size(); i++) {
153+
totalAp += averagePrecisionAtK(gt.get(i), retrieved.get(i), k);
154+
}
155+
156+
return totalAp / gt.size();
157+
}
139158
}

0 commit comments

Comments
 (0)