1717package io .github .jbellis .jvector .example .util ;
1818
1919import io .github .jbellis .jvector .graph .SearchResult ;
20-
21- import java .util .Arrays ;
20+ import java .util .HashSet ;
2221import java .util .List ;
23- import java .util .stream .Collectors ;
24- import java .util .stream .IntStream ;
22+ import java .util .Set ;
2523
2624/**
2725 * Computes accuracy metrics, such as recall and mean average precision.
@@ -41,43 +39,54 @@ public static double recallFromSearchResults(List<? extends List<Integer>> gt, L
4139 if (gt .size () != retrieved .size ()) {
4240 throw new IllegalArgumentException ("Insufficient ground truth for the number of retrieved elements" );
4341 }
44- Long correctCount = IntStream .range (0 , gt .size ())
45- .mapToObj (i -> topKCorrect (gt .get (i ), retrieved .get (i ), kGT , kRetrieved ))
46- .reduce (0L , Long ::sum );
42+
43+ long correctCount = 0 ;
44+ for (int i = 0 ; i < gt .size (); i ++) {
45+ correctCount += topKCorrect (gt .get (i ), retrieved .get (i ), kGT , kRetrieved );
46+ }
47+
4748 return (double ) correctCount / (kGT * gt .size ());
4849 }
4950
50- private static long topKCorrect (List <Integer > gt , List <Integer > retrieved , int kGT , int kRetrieved ) {
51+ private static long topKCorrect (List <Integer > gt , SearchResult retrieved , int kGT , int kRetrieved ) {
52+ // Exception validation
53+ var nodes = retrieved .getNodes ();
5154 if (kGT > kRetrieved ) {
5255 throw new IllegalArgumentException ("kGT: " + kGT + " > kRetrieved: " + kRetrieved );
5356 }
5457 if (kGT > gt .size ()) {
5558 throw new IllegalArgumentException ("kGT: " + kGT + " > Gt size: " + gt .size ());
5659 }
57- if (kRetrieved > retrieved . size () ) {
58- throw new IllegalArgumentException ("kRetrieved: " + kRetrieved + " > retrieved size: " + retrieved . size () );
60+ if (kRetrieved > nodes . length ) {
61+ throw new IllegalArgumentException ("kRetrieved: " + kRetrieved + " > retrieved size: " + nodes . length );
5962 }
6063
61- var gtView = crop (gt , kGT );
62- var retrievedView = crop (retrieved , kRetrieved );
63-
64- if (gtView .size () > retrieved .size ()) {
65- return gtView .stream ().filter (retrievedView ::contains ).count ();
66- } else {
67- return retrievedView .stream ().filter (gtView ::contains ).count ();
64+ // Build HashSet with explicit capacity to avoid rehashing.
65+ // Load factor is 0.75, so sized to kGT / 0.75.
66+ Set <Integer > gtSet = new HashSet <>((int ) (kGT / 0.75f ) + 1 );
67+ for (int i = 0 ; i < kGT ; i ++) {
68+ Integer ord = gt .get (i );
69+ if (ord == null ) {
70+ throw new IllegalArgumentException ("Null ground truth ordinal in top-" + kGT + " at index " + i );
71+ }
72+ if (!gtSet .add (ord )) {
73+ throw new IllegalArgumentException ("Duplicate ground truth ordinal in top-" + kGT + ": " + ord );
74+ }
6875 }
69- }
7076
71- private static long topKCorrect (List <Integer > gt , SearchResult retrieved , int kGT , int kRetrieved ) {
72- var temp = Arrays .stream (retrieved .getNodes ()).mapToInt (nodeScore -> nodeScore .node )
73- .boxed ()
74- .collect (Collectors .toList ());
75- return topKCorrect (gt , temp , kGT , kRetrieved );
76- }
77+ Set <Integer > seenRetrieved = new HashSet <>((int ) (kRetrieved / 0.75f ) + 1 );
78+ int hits = 0 ;
79+ for (int i = 0 ; i < kRetrieved ; i ++) {
80+ int p = nodes [i ].node ;
81+ if (!seenRetrieved .add (p )) {
82+ throw new IllegalArgumentException ("Duplicate retrieved ordinal in top-" + kRetrieved + ": " + p );
83+ }
84+ if (gtSet .contains (p )) {
85+ hits ++;
86+ }
87+ }
7788
78- private static List <Integer > crop (List <Integer > list , int k ) {
79- int count = Math .min (list .size (), k );
80- return list .subList (0 , count );
89+ return hits ;
8190 }
8291
8392 /**
@@ -89,33 +98,41 @@ private static List<Integer> crop(List<Integer> list, int k) {
8998 * @return the average precision
9099 */
91100 public static double averagePrecisionAtK (List <Integer > gt , SearchResult retrieved , int k ) {
92- var retrievedTemp = Arrays .stream (retrieved .getNodes ()).mapToInt (nodeScore -> nodeScore .node )
93- .boxed ()
94- .collect (Collectors .toList ());
95-
101+ var nodes = retrieved .getNodes ();
96102 if (k > gt .size ()) {
97103 throw new IllegalArgumentException ("k: " + k + " > Gt size: " + gt .size ());
98104 }
99- if (k > retrievedTemp . size () ) {
100- throw new IllegalArgumentException ("k: " + k + " > retrieved size: " + retrievedTemp . size () );
105+ if (k > nodes . length ) {
106+ throw new IllegalArgumentException ("k: " + k + " > retrieved size: " + nodes . length );
101107 }
102108
103- var gtView = crop (gt , k );
104- var retrievedView = crop (retrievedTemp , k );
109+ // Sized hashset used for performance.
110+ Set <Integer > gtSet = new HashSet <>((int ) (k / 0.75f ) + 1 );
111+ for (int i = 0 ; i < k ; i ++) {
112+ Integer ord = gt .get (i );
113+ if (ord == null ) {
114+ throw new IllegalArgumentException ("Null ground truth ordinal in top-" + k + " at index " + i );
115+ }
116+ if (!gtSet .add (ord )) {
117+ throw new IllegalArgumentException ("Duplicate ground truth ordinal in top-" + k + ": " + ord );
118+ }
119+ }
105120
121+ Set <Integer > seenRetrieved = new HashSet <>((int ) (k / 0.75f ) + 1 );
106122 double score = 0. ;
107- int num_hits = 0 ;
108- int i = 0 ;
109-
110- for (var p : retrievedView ) {
111- if (gtView .contains (p ) && !retrievedView .subList (0 , i ).contains (p )) {
112- num_hits += 1 ;
113- score += num_hits / (i + 1.0 );
123+ int hits = 0 ;
124+ for (int i = 0 ; i < k ; i ++) {
125+ int p = nodes [i ].node ;
126+ if (!seenRetrieved .add (p )) {
127+ throw new IllegalArgumentException ("Duplicate retrieved ordinal in top-" + k + ": " + p );
128+ }
129+ if (gtSet .contains (p )) {
130+ hits ++;
131+ score += (double ) hits / (i + 1 );
114132 }
115- i ++;
116133 }
117134
118- return score / gtView . size () ;
135+ return score / k ;
119136 }
120137
121138 /**
@@ -130,10 +147,12 @@ public static double meanAveragePrecisionAtK(List<? extends List<Integer>> gt, L
130147 if (gt .size () != retrieved .size ()) {
131148 throw new IllegalArgumentException ("Insufficient ground truth for the number of retrieved elements" );
132149 }
133- Double apk = IntStream .range (0 , gt .size ())
134- .mapToObj (i -> averagePrecisionAtK (gt .get (i ), retrieved .get (i ), k ))
135- .reduce (0. , Double ::sum );
136- return apk / gt .size ();
137- }
138150
151+ double totalAp = 0 ;
152+ for (int i = 0 ; i < gt .size (); i ++) {
153+ totalAp += averagePrecisionAtK (gt .get (i ), retrieved .get (i ), k );
154+ }
155+
156+ return totalAp / gt .size ();
157+ }
139158}
0 commit comments