88#include "distance-rvv.h"
99#include "distance-cpu.h"
1010
11- #if defined(__riscv_vector )
11+ #if defined(__riscv_v_intrinsic )
1212#include <riscv_vector.h>
1313#include <math.h>
1414#include <stdio.h>
@@ -19,72 +19,171 @@ extern const char *distance_backend_name;
1919
2020// MARK: - FLOAT32 -
2121
22+ float float32_distance_l2_impl_rvv (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
23+ const float * a = (const float * )v1 ;
24+ const float * b = (const float * )v2 ;
25+
26+ // We accumulate the results into a vector register
27+ size_t vl = __riscv_vsetvl_e32m8 (n );
28+ vfloat32m8_t vl2 = __riscv_vfmv_v_f_f32m8 (0.0f , vl );
29+
30+ // Iterate by VL elements
31+ for (size_t i = n ; i > 0 ; i -= vl ) {
32+ // Use LMUL=8, we have 4 registers to work with.
33+ vl = __riscv_vsetvl_e32m8 (n );
34+
35+ // Load the vectors into the registers
36+ vfloat32m8_t va = __riscv_vle32_v_f32m8 (a , vl );
37+ vfloat32m8_t vb = __riscv_vle32_v_f32m8 (b , vl );
38+
39+ // L2 = (a[i] - b[i]) + acc
40+ vfloat32m8_t vdiff = __riscv_vfsub_vv_f32m8 (va , vb , vl );
41+ vl2 = __riscv_vfmacc_vv_f32m8 (vl2 , vdiff , vdiff , vl );
42+
43+ // Advance the a and b pointers to the next offset
44+ a = & a [vl ];
45+ b = & b [vl ];
46+ }
47+
48+ // Copy the accumulators back into a scalar register
49+ vfloat32m1_t vl2_acc = __riscv_vfmv_v_f_f32m1 (0.0f , 1 );
50+ vl = __riscv_vsetvl_e32m8 (n );
51+ vl2_acc = __riscv_vfredusum_vs_f32m8_f32m1 (vl2 , vl2_acc , vl );
52+
53+ float l2 = __riscv_vfmv_f_s_f32m1_f32 (vl2_acc );
54+ return use_sqrt ? sqrtf (l2 ) : l2 ;
55+ }
56+
57+
2258float float32_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
23- printf ("float32_distance_l2_rvv: unimplemented\n" );
24- abort ();
25- return 0.0f ;
59+ return float32_distance_l2_impl_rvv (v1 , v2 , n , true);
2660}
2761
62+
2863float float32_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
29- printf ("float32_distance_l2_squared_rvv: unimplemented\n" );
30- abort ();
31- return 0.0f ;
64+ return float32_distance_l2_impl_rvv (v1 , v2 , n , false);
3265}
3366
3467float float32_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
35- printf ("float32_distance_l1_rvv: unimplemented\n" );
36- abort ();
37- return 0.0f ;
38- }
68+ const float * a = (const float * )v1 ;
69+ const float * b = (const float * )v2 ;
3970
40- float float32_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
41- printf ("float32_distance_dot_rvv: unimplemented\n" );
42- abort ();
43- return 0.0f ;
71+ // We accumulate the results into a vector register
72+ size_t vl = __riscv_vsetvl_e32m8 (n );
73+ vfloat32m8_t vsad = __riscv_vfmv_v_f_f32m8 (0.0f , vl );
74+
75+ // Iterate by VL elements
76+ for (size_t i = n ; i > 0 ; i -= vl ) {
77+ // Use LMUL=8, we have 4 registers to work with.
78+ vl = __riscv_vsetvl_e32m8 (n );
79+
80+ // Load the vectors into the registers
81+ vfloat32m8_t va = __riscv_vle32_v_f32m8 (a , vl );
82+ vfloat32m8_t vb = __riscv_vle32_v_f32m8 (b , vl );
83+
84+
85+ // SAD = abs(a[i] - b[i]) + acc
86+ vfloat32m8_t vdiff = __riscv_vfsub_vv_f32m8 (va , vb , vl );
87+ vfloat32m8_t vabs = __riscv_vfabs_v_f32m8 (vdiff , vl );
88+ vsad = __riscv_vfadd_vv_f32m8 (vsad , vabs , vl );
89+
90+ // Advance the a and b pointers to the next offset
91+ a = & a [vl ];
92+ b = & b [vl ];
93+ }
94+
95+ // Copy the accumulators back into a scalar register
96+ vfloat32m1_t vsad_acc = __riscv_vfmv_v_f_f32m1 (0.0f , 1 );
97+ vl = __riscv_vsetvl_e32m8 (n );
98+ vsad_acc = __riscv_vfredusum_vs_f32m8_f32m1 (vsad , vsad_acc , vl );
99+
100+ float sad = __riscv_vfmv_f_s_f32m1_f32 (vsad_acc );
101+ return sad ;
44102}
45103
46- float float32_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
104+ float float32_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
47105 const float * a = (const float * )v1 ;
48106 const float * b = (const float * )v2 ;
49107
50- // We accumulate the results into a vecto register
51- size_t vl = __riscv_vsetvl_e32m1 (1 );
52- vfloat32m1_t dot_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
53- vfloat32m1_t magn_a_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
54- vfloat32m1_t magn_b_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
108+ // We accumulate the results into a vector register
109+ size_t vl = __riscv_vsetvl_e32m8 (n );
110+ vfloat32m8_t vdot = __riscv_vfmv_v_f_f32m8 (0.0f , vl );
55111
56-
57112 // Iterate by VL elements
58- for (; n > 0 ; n -= vl ) {
113+ for (size_t i = n ; i > 0 ; i -= vl ) {
59114 // Use LMUL=8, we have 4 registers to work with.
60- // In practice we use 3, and the last register gets split for the reduction operations
61- vl = __riscv_vsetvl_e32m8 (n );
115+ vl = __riscv_vsetvl_e32m8 (i );
62116
63117 // Load the vectors into the registers
64118 vfloat32m8_t va = __riscv_vle32_v_f32m8 (a , vl );
65119 vfloat32m8_t vb = __riscv_vle32_v_f32m8 (b , vl );
66120
67121 // Compute the dot product for the entire register, and sum the
68122 // results into the accumuating register
69- vfloat32m8_t vdot = __riscv_vfmul_vv_f32m8 (va , vb , vl );
70- dot_acc = __riscv_vfredusum_vs_f32m8_f32m1 (vdot , dot_acc , vl );
123+ vdot = __riscv_vfmacc_vv_f32m8 (vdot , va , vb , vl );
124+
125+ // Advance the a and b pointers to the next offset
126+ a = & a [vl ];
127+ b = & b [vl ];
128+ }
129+
130+ // Copy the accumulators back into a scalar register
131+ vfloat32m1_t vdot_acc = __riscv_vfmv_v_f_f32m1 (0.0f , 1 );
132+ vl = __riscv_vsetvl_e32m8 (n );
133+ vdot_acc = __riscv_vfredusum_vs_f32m8_f32m1 (vdot , vdot_acc , vl );
134+
135+ float dot = __riscv_vfmv_f_s_f32m1_f32 (vdot_acc );
136+ return - dot ;
137+ }
138+
139+ float float32_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
140+ const float * a = (const float * )v1 ;
141+ const float * b = (const float * )v2 ;
142+
143+ // Use LMUL=4, we have 8 registers to work with.
144+ size_t vl = __riscv_vsetvl_e32m4 (n );
145+
146+ // Zero out the starting registers
147+ vfloat32m4_t vdot = __riscv_vfmv_v_f_f32m4 (0.0f , vl );
148+ vfloat32m4_t vmagn_a = __riscv_vfmv_v_f_f32m4 (0.0f , vl );
149+ vfloat32m4_t vmagn_b = __riscv_vfmv_v_f_f32m4 (0.0f , vl );
150+
151+ // Iterate by VL elements
152+ for (size_t i = n ; i > 0 ; i -= vl ) {
153+ // Update VL with the remaining elements
154+ vl = __riscv_vsetvl_e32m4 (i );
155+
156+ // Load the vectors into the registers
157+ vfloat32m4_t va = __riscv_vle32_v_f32m4 (a , vl );
158+ vfloat32m4_t vb = __riscv_vle32_v_f32m4 (b , vl );
159+
160+ // Compute the dot product for the entire register
161+ vdot = __riscv_vfmacc_vv_f32m4 (vdot , va , vb , vl );
71162
72163 // Also calculate the magnitude value for both a and b
73- vfloat32m8_t magn_a = __riscv_vfmul_vv_f32m8 (va , va , vl );
74- magn_a_acc = __riscv_vfredusum_vs_f32m8_f32m1 (magn_a , magn_a_acc , vl );
75- vfloat32m8_t magn_b = __riscv_vfmul_vv_f32m8 (vb , vb , vl );
76- magn_b_acc = __riscv_vfredusum_vs_f32m8_f32m1 (magn_b , magn_b_acc , vl );
164+ vmagn_a = __riscv_vfmacc_vv_f32m4 (vmagn_a , va , va , vl );
165+ vmagn_b = __riscv_vfmacc_vv_f32m4 (vmagn_b , vb , vb , vl );
77166
78167 // Advance the a and b pointers to the next offset
79168 a = & a [vl ];
80169 b = & b [vl ];
81170 }
82171
172+ // Now do a final reduction on the registers to sum the remaining elements
173+ vfloat32m1_t vdot_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
174+ vdot_acc = __riscv_vfredusum_vs_f32m4_f32m1 (vdot , vdot_acc , vl );
175+
176+ vfloat32m1_t vmagn_a_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
177+ vmagn_a_acc = __riscv_vfredusum_vs_f32m4_f32m1 (vmagn_a , vmagn_a_acc , vl );
178+
179+ vfloat32m1_t vmagn_b_acc = __riscv_vfmv_v_f_f32m1 (0.0f , vl );
180+ vmagn_b_acc = __riscv_vfredusum_vs_f32m4_f32m1 (vmagn_b , vmagn_b_acc , vl );
181+
83182 // Copy the accumulators back into a scalar register, to finalize the calculations
84183 // TODO: With default flags this does not use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that
85- float dot = __riscv_vfmv_f_s_f32m1_f32 (dot_acc );
86- float magn_a = sqrtf (__riscv_vfmv_f_s_f32m1_f32 (magn_a_acc ));
87- float magn_b = sqrtf (__riscv_vfmv_f_s_f32m1_f32 (magn_b_acc ));
184+ float dot = __riscv_vfmv_f_s_f32m1_f32 (vdot_acc );
185+ float magn_a = sqrtf (__riscv_vfmv_f_s_f32m1_f32 (vmagn_a_acc ));
186+ float magn_b = sqrtf (__riscv_vfmv_f_s_f32m1_f32 (vmagn_b_acc ));
88187
89188 if (magn_a == 0.0f || magn_b == 0.0f ) return 1.0f ;
90189
@@ -94,7 +193,6 @@ float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) {
94193 return 1.0f - cosine_similarity ;
95194}
96195
97-
98196// MARK: - FLOAT16 -
99197
100198float float16_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
@@ -236,14 +334,14 @@ float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
236334// MARK: -
237335
238336void init_distance_functions_rvv (void ) {
239- #if defined(__riscv_vector )
240- // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
337+ #if defined(__riscv_v_intrinsic )
338+ dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_rvv ;
241339 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
242340 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
243341 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
244342 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
245343
246- // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
344+ dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_squared_rvv ;
247345 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
248346 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
249347 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
@@ -255,13 +353,13 @@ void init_distance_functions_rvv (void) {
255353 // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
256354 // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
257355
258- // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
356+ dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F32 ] = float32_distance_dot_rvv ;
259357 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
260358 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
261359 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
262360 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
263361
264- // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
362+ dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F32 ] = float32_distance_l1_rvv ;
265363 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
266364 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
267365 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
0 commit comments