Skip to content

Commit 160ab32

Browse files
committed
distance-rvv: Implement cosine distance
1 parent 286afc1 commit 160ab32

1 file changed

Lines changed: 126 additions & 53 deletions

File tree

src/distance-rvv.c

Lines changed: 126 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,156 +5,229 @@
55
// Created by Afonso Bordado on 2026/02/19.
66
//
77

8-
98
#include "distance-rvv.h"
109
#include "distance-cpu.h"
1110

1211
#if defined(__riscv_vector)
1312
#include <riscv_vector.h>
13+
#include <math.h>
14+
#include <stdio.h>
15+
#include <stdlib.h>
1416

1517
extern distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX];
1618
extern const char *distance_backend_name;
1719

1820
// MARK: - FLOAT32 -
1921

2022
float float32_distance_l2_rvv (const void *v1, const void *v2, int n) {
21-
panic("float32_distance_l2_rvv: unimplemented");
23+
printf("float32_distance_l2_rvv: unimplemented\n");
24+
abort();
2225
return 0.0f;
2326
}
2427

2528
float float32_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
26-
panic("float32_distance_l2_squared_rvv: unimplemented");
29+
printf("float32_distance_l2_squared_rvv: unimplemented\n");
30+
abort();
2731
return 0.0f;
2832
}
2933

3034
float float32_distance_l1_rvv (const void *v1, const void *v2, int n) {
31-
panic("float32_distance_l1_rvv: unimplemented");
35+
printf("float32_distance_l1_rvv: unimplemented\n");
36+
abort();
3237
return 0.0f;
3338
}
3439

3540
float float32_distance_dot_rvv (const void *v1, const void *v2, int n) {
36-
panic("float32_distance_dot_rvv: unimplemented");
41+
printf("float32_distance_dot_rvv: unimplemented\n");
42+
abort();
3743
return 0.0f;
3844
}
3945

4046
float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) {
41-
panic("float32_distance_cosine_rvv: unimplemented");
42-
return 0.0f;
47+
const float *a = (const float *)v1;
48+
const float *b = (const float *)v2;
49+
50+
// We accumulate the results into a vecto register
51+
size_t vl = __riscv_vsetvl_e32m1(1);
52+
vfloat32m1_t dot_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
53+
vfloat32m1_t magn_a_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
54+
vfloat32m1_t magn_b_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
55+
56+
57+
// Iterate by VL elements
58+
for (; n > 0; n -= vl) {
59+
// Use LMUL=8, we have 4 registers to work with.
60+
// In practice we use 3, and the last register gets split for the reduction operations
61+
vl = __riscv_vsetvl_e32m8(n);
62+
63+
// Load the vectors into the registers
64+
vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl);
65+
vfloat32m8_t vb = __riscv_vle32_v_f32m8(b, vl);
66+
67+
// Compute the dot product for the entire register, and sum the
68+
// results into the accumuating register
69+
vfloat32m8_t vdot = __riscv_vfmul_vv_f32m8(va, vb, vl);
70+
dot_acc = __riscv_vfredusum_vs_f32m8_f32m1(vdot, dot_acc, vl);
71+
72+
// Also calculate the magnitude value for both a and b
73+
vfloat32m8_t magn_a = __riscv_vfmul_vv_f32m8(va, va, vl);
74+
magn_a_acc = __riscv_vfredusum_vs_f32m8_f32m1(magn_a, magn_a_acc, vl);
75+
vfloat32m8_t magn_b = __riscv_vfmul_vv_f32m8(vb, vb, vl);
76+
magn_b_acc = __riscv_vfredusum_vs_f32m8_f32m1(magn_b, magn_b_acc, vl);
77+
78+
// Advance the a and b pointers to the next offset
79+
a = &a[vl];
80+
b = &b[vl];
81+
}
82+
83+
// Copy the accumulators back into a scalar register, to finalize the calculations
84+
// TODO: With default flags this does not use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that
85+
float dot = __riscv_vfmv_f_s_f32m1_f32(dot_acc);
86+
float magn_a = sqrtf(__riscv_vfmv_f_s_f32m1_f32(magn_a_acc));
87+
float magn_b = sqrtf(__riscv_vfmv_f_s_f32m1_f32(magn_b_acc));
88+
89+
if (magn_a == 0.0f || magn_b == 0.0f) return 1.0f;
90+
91+
float cosine_similarity = dot / (magn_a * magn_b);
92+
if (cosine_similarity > 1.0f) cosine_similarity = 1.0f;
93+
if (cosine_similarity < -1.0f) cosine_similarity = -1.0f;
94+
return 1.0f - cosine_similarity;
4395
}
4496

4597

4698
// MARK: - FLOAT16 -
4799

48100
float float16_distance_l2_rvv (const void *v1, const void *v2, int n) {
49-
panic("float16_distance_l2_rvv: unimplemented");
101+
printf("float16_distance_l2_rvv: unimplemented\n");
102+
abort();
50103
return 0.0f;
51104
}
52105

53106
float float16_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
54-
panic("float16_distance_l2_squared_rvv: unimplemented");
107+
printf("float16_distance_l2_squared_rvv: unimplemented\n");
108+
abort();
55109
return 0.0f;
56110
}
57111

58112
float float16_distance_l1_rvv (const void *v1, const void *v2, int n) {
59-
panic("float16_distance_l1_rvv: unimplemented");
113+
printf("float16_distance_l1_rvv: unimplemented\n");
114+
abort();
60115
return 0.0f;
61116
}
62117

63118
float float16_distance_dot_rvv (const void *v1, const void *v2, int n) {
64-
panic("float16_distance_dot_rvv: unimplemented");
119+
printf("float16_distance_dot_rvv: unimplemented\n");
120+
abort();
65121
return 0.0f;
66122
}
67123

68124
float float16_distance_cosine_rvv (const void *v1, const void *v2, int n) {
69-
panic("float16_distance_cosine_rvv: unimplemented");
125+
printf("float16_distance_cosine_rvv: unimplemented\n");
126+
abort();
70127
return 0.0f;
71128
}
72129

73130
// MARK: - BFLOAT16 -
74131

75132
float bfloat16_distance_l2_rvv (const void *v1, const void *v2, int n) {
76-
panic("bfloat16_distance_l2_rvv: unimplemented");
133+
printf("bfloat16_distance_l2_rvv: unimplemented\n");
134+
abort();
77135
return 0.0f;
78136
}
79137

80138
float bfloat16_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
81-
panic("bfloat16_distance_l2_squared_rvv: unimplemented");
139+
printf("bfloat16_distance_l2_squared_rvv: unimplemented\n");
140+
abort();
82141
return 0.0f;
83142
}
84143

85144
float bfloat16_distance_l1_rvv (const void *v1, const void *v2, int n) {
86-
panic("bfloat16_distance_l1_rvv: unimplemented");
145+
printf("bfloat16_distance_l1_rvv: unimplemented\n");
146+
abort();
87147
return 0.0f;
88148
}
89149

90150
float bfloat16_distance_dot_rvv (const void *v1, const void *v2, int n) {
91-
panic("bfloat16_distance_dot_rvv: unimplemented");
151+
printf("bfloat16_distance_dot_rvv: unimplemented\n");
152+
abort();
92153
return 0.0f;
93154
}
94155

95156
float bfloat16_distance_cosine_rvv (const void *v1, const void *v2, int n) {
96-
panic("bfloat16_distance_cosine_rvv: unimplemented");
157+
printf("bfloat16_distance_cosine_rvv: unimplemented\n");
158+
abort();
97159
return 0.0f;
98160
}
99161

100162
// MARK: - UINT8 -
101163

102164
float uint8_distance_l2_rvv (const void *v1, const void *v2, int n) {
103-
panic("uint8_distance_l2_rvv: unimplemented");
165+
printf("uint8_distance_l2_rvv: unimplemented\n");
166+
abort();
104167
return 0.0f;
105168
}
106169

107170
float uint8_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
108-
panic("uint8_distance_l2_squared_rvv: unimplemented");
171+
printf("uint8_distance_l2_squared_rvv: unimplemented\n");
172+
abort();
109173
return 0.0f;
110174
}
111175

112176
float uint8_distance_dot_rvv (const void *v1, const void *v2, int n) {
113-
panic("uint8_distance_dot_rvv: unimplemented");
177+
printf("uint8_distance_dot_rvv: unimplemented\n");
178+
abort();
114179
return 0.0f;
115180
}
116181

117182
float uint8_distance_l1_rvv (const void *v1, const void *v2, int n) {
118-
panic("uint8_distance_l1_rvv: unimplemented");
183+
printf("uint8_distance_l1_rvv: unimplemented\n");
184+
abort();
119185
return 0.0f;
120186
}
121187

122188
float uint8_distance_cosine_rvv (const void *v1, const void *v2, int n) {
123-
panic("uint8_distance_cosine_rvv: unimplemented");
189+
printf("uint8_distance_cosine_rvv: unimplemented\n");
190+
abort();
124191
return 0.0f;
125192
}
126193

127194
// MARK: - INT8 -
128195

129196
float int8_distance_l2_rvv (const void *v1, const void *v2, int n) {
130-
panic("int8_distance_l2_rvv: unimplemented");
197+
printf("int8_distance_l2_rvv: unimplemented\n");
198+
abort();
131199
return 0.0f;
132200
}
133201

134202
float int8_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
135-
panic("int8_distance_l2_squared_rvv: unimplemented");
203+
printf("int8_distance_l2_squared_rvv: unimplemented\n");
204+
abort();
136205
return 0.0f;
137206
}
138207

139208
float int8_distance_dot_rvv (const void *v1, const void *v2, int n) {
140-
panic("int8_distance_dot_rvv: unimplemented");
209+
printf("int8_distance_dot_rvv: unimplemented\n");
210+
abort();
141211
return 0.0f;
142212
}
143213

144214
float int8_distance_l1_rvv (const void *v1, const void *v2, int n) {
145-
panic("int8_distance_l1_rvv: unimplemented");
215+
printf("int8_distance_l1_rvv: unimplemented\n");
216+
abort();
146217
return 0.0f;
147218
}
148219

149220
float int8_distance_cosine_rvv (const void *v1, const void *v2, int n) {
150-
panic("int8_distance_cosine_rvv: unimplemented");
221+
printf("int8_distance_cosine_rvv: unimplemented\n");
222+
abort();
151223
return 0.0f;
152224
}
153225

154226
// MARK: - BIT -
155227

156228
float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
157-
panic("bit1_distance_hamming_rvv: unimplemented");
229+
printf("bit1_distance_hamming_rvv: unimplemented\n");
230+
abort();
158231
return 0.0f;
159232
}
160233

@@ -164,37 +237,37 @@ float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
164237

165238
void init_distance_functions_rvv (void) {
166239
#if defined(__riscv_vector)
167-
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
168-
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
169-
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
170-
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
171-
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
240+
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
241+
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
242+
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
243+
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
244+
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
172245

173-
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
174-
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
175-
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
176-
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
177-
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
246+
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
247+
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
248+
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
249+
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
250+
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
178251

179252
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F32] = float32_distance_cosine_rvv;
180-
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
181-
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
182-
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
183-
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
253+
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
254+
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
255+
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
256+
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
184257

185-
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
186-
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
187-
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
188-
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
189-
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
258+
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
259+
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
260+
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
261+
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
262+
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
190263

191-
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
192-
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
193-
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
194-
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
195-
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
264+
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
265+
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
266+
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
267+
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
268+
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
196269

197-
dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_rvv;
270+
// dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_rvv;
198271

199272
distance_backend_name = "RVV";
200273
#endif

0 commit comments

Comments
 (0)