55// Created by Afonso Bordado on 2026/02/19.
66//
77
8-
98#include "distance-rvv.h"
109#include "distance-cpu.h"
1110
1211#if defined(__riscv_vector )
1312#include <riscv_vector.h>
13+ #include <math.h>
14+ #include <stdio.h>
15+ #include <stdlib.h>
1416
1517extern distance_function_t dispatch_distance_table [VECTOR_DISTANCE_MAX ][VECTOR_TYPE_MAX ];
1618extern const char * distance_backend_name ;
1719
1820// MARK: - FLOAT32 -
1921
2022float float32_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
21- panic ("float32_distance_l2_rvv: unimplemented" );
23+ printf ("float32_distance_l2_rvv: unimplemented\n" );
24+ abort ();
2225 return 0.0f ;
2326}
2427
2528float float32_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
26- panic ("float32_distance_l2_squared_rvv: unimplemented" );
29+ printf ("float32_distance_l2_squared_rvv: unimplemented\n" );
30+ abort ();
2731 return 0.0f ;
2832}
2933
3034float float32_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
31- panic ("float32_distance_l1_rvv: unimplemented" );
35+ printf ("float32_distance_l1_rvv: unimplemented\n" );
36+ abort ();
3237 return 0.0f ;
3338}
3439
3540float float32_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
36- panic ("float32_distance_dot_rvv: unimplemented" );
41+ printf ("float32_distance_dot_rvv: unimplemented\n" );
42+ abort ();
3743 return 0.0f ;
3844}
3945
4046float float32_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
41- panic ("float32_distance_cosine_rvv: unimplemented" );
42- return 0.0f ;
47+ const float * a = (const float * )v1 ;
48+ const float * b = (const float * )v2 ;
49+
50+ // We accumulate the results into a vecto register
51+ size_t vl = __riscv_vsetvl_e32m1 (1 );
52+ vfloat32m1_t dot_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
53+ vfloat32m1_t magn_a_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
54+ vfloat32m1_t magn_b_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
55+
56+
57+ // Iterate by VL elements
58+ for (; n > 0 ; n -= vl ) {
59+ // Use LMUL=8, we have 4 registers to work with.
60+ // In practice we use 3, and the last register gets split for the reduction operations
61+ vl = __riscv_vsetvl_e32m8 (n );
62+
63+ // Load the vectors into the registers
64+ vfloat32m8_t va = __riscv_vle32_v_f32m8 (a , vl );
65+ vfloat32m8_t vb = __riscv_vle32_v_f32m8 (b , vl );
66+
67+ // Compute the dot product for the entire register, and sum the
68+ // results into the accumuating register
69+ vfloat32m8_t vdot = __riscv_vfmul_vv_f32m8 (va , vb , vl );
70+ dot_acc = __riscv_vfredusum_vs_f32m8_f32m1 (vdot , dot_acc , vl );
71+
72+ // Also calculate the magnitude value for both a and b
73+ vfloat32m8_t magn_a = __riscv_vfmul_vv_f32m8 (va , va , vl );
74+ magn_a_acc = __riscv_vfredusum_vs_f32m8_f32m1 (magn_a , magn_a_acc , vl );
75+ vfloat32m8_t magn_b = __riscv_vfmul_vv_f32m8 (vb , vb , vl );
76+ magn_b_acc = __riscv_vfredusum_vs_f32m8_f32m1 (magn_b , magn_b_acc , vl );
77+
78+ // Advance the a and b pointers to the next offset
79+ a = & a [vl ];
80+ b = & b [vl ];
81+ }
82+
83+ // Copy the accumulators back into a scalar register, to finalize the calculations
84+ // TODO: With default flags this does not use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that
85+ float dot = __riscv_vfmv_f_s_f32m1_f32 (dot_acc );
86+ float magn_a = sqrtf (__riscv_vfmv_f_s_f32m1_f32 (magn_a_acc ));
87+ float magn_b = sqrtf (__riscv_vfmv_f_s_f32m1_f32 (magn_b_acc ));
88+
89+ if (magn_a == 0.0f || magn_b == 0.0f ) return 1.0f ;
90+
91+ float cosine_similarity = dot / (magn_a * magn_b );
92+ if (cosine_similarity > 1.0f ) cosine_similarity = 1.0f ;
93+ if (cosine_similarity < -1.0f ) cosine_similarity = -1.0f ;
94+ return 1.0f - cosine_similarity ;
4395}
4496
4597
4698// MARK: - FLOAT16 -
4799
48100float float16_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
49- panic ("float16_distance_l2_rvv: unimplemented" );
101+ printf ("float16_distance_l2_rvv: unimplemented\n" );
102+ abort ();
50103 return 0.0f ;
51104}
52105
53106float float16_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
54- panic ("float16_distance_l2_squared_rvv: unimplemented" );
107+ printf ("float16_distance_l2_squared_rvv: unimplemented\n" );
108+ abort ();
55109 return 0.0f ;
56110}
57111
58112float float16_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
59- panic ("float16_distance_l1_rvv: unimplemented" );
113+ printf ("float16_distance_l1_rvv: unimplemented\n" );
114+ abort ();
60115 return 0.0f ;
61116}
62117
63118float float16_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
64- panic ("float16_distance_dot_rvv: unimplemented" );
119+ printf ("float16_distance_dot_rvv: unimplemented\n" );
120+ abort ();
65121 return 0.0f ;
66122}
67123
68124float float16_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
69- panic ("float16_distance_cosine_rvv: unimplemented" );
125+ printf ("float16_distance_cosine_rvv: unimplemented\n" );
126+ abort ();
70127 return 0.0f ;
71128}
72129
73130// MARK: - BFLOAT16 -
74131
75132float bfloat16_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
76- panic ("bfloat16_distance_l2_rvv: unimplemented" );
133+ printf ("bfloat16_distance_l2_rvv: unimplemented\n" );
134+ abort ();
77135 return 0.0f ;
78136}
79137
80138float bfloat16_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
81- panic ("bfloat16_distance_l2_squared_rvv: unimplemented" );
139+ printf ("bfloat16_distance_l2_squared_rvv: unimplemented\n" );
140+ abort ();
82141 return 0.0f ;
83142}
84143
85144float bfloat16_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
86- panic ("bfloat16_distance_l1_rvv: unimplemented" );
145+ printf ("bfloat16_distance_l1_rvv: unimplemented\n" );
146+ abort ();
87147 return 0.0f ;
88148}
89149
90150float bfloat16_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
91- panic ("bfloat16_distance_dot_rvv: unimplemented" );
151+ printf ("bfloat16_distance_dot_rvv: unimplemented\n" );
152+ abort ();
92153 return 0.0f ;
93154}
94155
95156float bfloat16_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
96- panic ("bfloat16_distance_cosine_rvv: unimplemented" );
157+ printf ("bfloat16_distance_cosine_rvv: unimplemented\n" );
158+ abort ();
97159 return 0.0f ;
98160}
99161
100162// MARK: - UINT8 -
101163
102164float uint8_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
103- panic ("uint8_distance_l2_rvv: unimplemented" );
165+ printf ("uint8_distance_l2_rvv: unimplemented\n" );
166+ abort ();
104167 return 0.0f ;
105168}
106169
107170float uint8_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
108- panic ("uint8_distance_l2_squared_rvv: unimplemented" );
171+ printf ("uint8_distance_l2_squared_rvv: unimplemented\n" );
172+ abort ();
109173 return 0.0f ;
110174}
111175
112176float uint8_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
113- panic ("uint8_distance_dot_rvv: unimplemented" );
177+ printf ("uint8_distance_dot_rvv: unimplemented\n" );
178+ abort ();
114179 return 0.0f ;
115180}
116181
117182float uint8_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
118- panic ("uint8_distance_l1_rvv: unimplemented" );
183+ printf ("uint8_distance_l1_rvv: unimplemented\n" );
184+ abort ();
119185 return 0.0f ;
120186}
121187
122188float uint8_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
123- panic ("uint8_distance_cosine_rvv: unimplemented" );
189+ printf ("uint8_distance_cosine_rvv: unimplemented\n" );
190+ abort ();
124191 return 0.0f ;
125192}
126193
127194// MARK: - INT8 -
128195
129196float int8_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
130- panic ("int8_distance_l2_rvv: unimplemented" );
197+ printf ("int8_distance_l2_rvv: unimplemented\n" );
198+ abort ();
131199 return 0.0f ;
132200}
133201
134202float int8_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
135- panic ("int8_distance_l2_squared_rvv: unimplemented" );
203+ printf ("int8_distance_l2_squared_rvv: unimplemented\n" );
204+ abort ();
136205 return 0.0f ;
137206}
138207
139208float int8_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
140- panic ("int8_distance_dot_rvv: unimplemented" );
209+ printf ("int8_distance_dot_rvv: unimplemented\n" );
210+ abort ();
141211 return 0.0f ;
142212}
143213
144214float int8_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
145- panic ("int8_distance_l1_rvv: unimplemented" );
215+ printf ("int8_distance_l1_rvv: unimplemented\n" );
216+ abort ();
146217 return 0.0f ;
147218}
148219
149220float int8_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
150- panic ("int8_distance_cosine_rvv: unimplemented" );
221+ printf ("int8_distance_cosine_rvv: unimplemented\n" );
222+ abort ();
151223 return 0.0f ;
152224}
153225
154226// MARK: - BIT -
155227
156228float bit1_distance_hamming_rvv (const void * v1 , const void * v2 , int n ) {
157- panic ("bit1_distance_hamming_rvv: unimplemented" );
229+ printf ("bit1_distance_hamming_rvv: unimplemented\n" );
230+ abort ();
158231 return 0.0f ;
159232}
160233
@@ -164,37 +237,37 @@ float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
164237
165238void init_distance_functions_rvv (void ) {
166239#if defined(__riscv_vector )
167- dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_rvv ;
168- dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_F16 ] = float16_distance_l2_rvv ;
169- dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_BF16 ] = bfloat16_distance_l2_rvv ;
170- dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_rvv ;
171- dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_I8 ] = int8_distance_l2_rvv ;
240+ // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
241+ // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
242+ // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
243+ // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
244+ // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
172245
173- dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_squared_rvv ;
174- dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F16 ] = float16_distance_l2_squared_rvv ;
175- dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_BF16 ] = bfloat16_distance_l2_squared_rvv ;
176- dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_squared_rvv ;
177- dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_I8 ] = int8_distance_l2_squared_rvv ;
246+ // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
247+ // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
248+ // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
249+ // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
250+ // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
178251
179252 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_F32 ] = float32_distance_cosine_rvv ;
180- dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_F16 ] = float16_distance_cosine_rvv ;
181- dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_BF16 ] = bfloat16_distance_cosine_rvv ;
182- dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_U8 ] = uint8_distance_cosine_rvv ;
183- dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_I8 ] = int8_distance_cosine_rvv ;
253+ // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
254+ // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
255+ // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
256+ // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
184257
185- dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F32 ] = float32_distance_dot_rvv ;
186- dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F16 ] = float16_distance_dot_rvv ;
187- dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_BF16 ] = bfloat16_distance_dot_rvv ;
188- dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_U8 ] = uint8_distance_dot_rvv ;
189- dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_I8 ] = int8_distance_dot_rvv ;
258+ // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
259+ // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
260+ // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
261+ // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
262+ // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
190263
191- dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F32 ] = float32_distance_l1_rvv ;
192- dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F16 ] = float16_distance_l1_rvv ;
193- dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_BF16 ] = bfloat16_distance_l1_rvv ;
194- dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_U8 ] = uint8_distance_l1_rvv ;
195- dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_I8 ] = int8_distance_l1_rvv ;
264+ // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
265+ // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
266+ // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
267+ // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
268+ // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
196269
197- dispatch_distance_table [VECTOR_DISTANCE_HAMMING ][VECTOR_TYPE_BIT ] = bit1_distance_hamming_rvv ;
270+ // dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_rvv;
198271
199272 distance_backend_name = "RVV" ;
200273#endif
0 commit comments