@@ -20,45 +20,79 @@ extern const char *distance_backend_name;
2020// MARK: - UTILS -
2121
2222// Reduces a vector by summing all of it's elements into a single scalar float
23- float float32_sum_vector_f32m8 (vfloat32m8_t vec , size_t vl ) {
23+ static inline float float32_sum_vector_f32m8 (vfloat32m8_t vec , size_t vl ) {
2424 vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1 (0.0f , 1 );
2525 vl = __riscv_vsetvl_e32m8 (vl );
2626 acc = __riscv_vfredusum_vs_f32m8_f32m1 (vec , acc , vl );
2727 return __riscv_vfmv_f_s_f32m1_f32 (acc );
2828}
2929
3030// Reduces a vector by summing all of it's elements into a single scalar float
31- float float32_sum_vector_f32m4 (vfloat32m4_t vec , size_t vl ) {
31+ static inline float float32_sum_vector_f32m4 (vfloat32m4_t vec , size_t vl ) {
3232 vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1 (0.0f , 1 );
3333 vl = __riscv_vsetvl_e32m4 (vl );
3434 acc = __riscv_vfredusum_vs_f32m4_f32m1 (vec , acc , vl );
3535 return __riscv_vfmv_f_s_f32m1_f32 (acc );
3636}
3737
38+ // Reduces a vector by summing all of it's elements into a single scalar double
39+ static inline double float64_sum_vector_f64m4 (vfloat64m4_t vec , size_t vl ) {
40+ vfloat64m1_t acc = __riscv_vfmv_v_f_f64m1 (0.0 , 1 );
41+ vl = __riscv_vsetvl_e64m4 (vl );
42+ acc = __riscv_vfredusum_vs_f64m4_f64m1 (vec , acc , vl );
43+ return __riscv_vfmv_f_s_f64m1_f64 (acc );
44+ }
45+
3846// Reduces a vector by summing all of it's elements into a single scalar integer
39- uint64_t uint64_sum_vector_u64m8 (vuint64m8_t vec , size_t vl ) {
47+ static inline uint64_t uint64_sum_vector_u64m8 (vuint64m8_t vec , size_t vl ) {
4048 vuint64m1_t acc = __riscv_vmv_s_x_u64m1 (0 , 1 );
4149 vl = __riscv_vsetvl_e64m8 (vl );
4250 acc = __riscv_vredsum_vs_u64m8_u64m1 (vec , acc , vl );
4351 return __riscv_vmv_x_s_u64m1_u64 (acc );
4452}
4553
4654// Reduces a vector by summing all of it's elements into a single scalar integer
47- uint32_t uint32_sum_vector_u32m8 (vuint32m8_t vec , size_t vl ) {
55+ static inline uint32_t uint32_sum_vector_u32m8 (vuint32m8_t vec , size_t vl ) {
4856 vuint32m1_t acc = __riscv_vmv_s_x_u32m1 (0 , 1 );
4957 vl = __riscv_vsetvl_e32m8 (vl );
5058 acc = __riscv_vredsum_vs_u32m8_u32m1 (vec , acc , vl );
5159 return __riscv_vmv_x_s_u32m1_u32 (acc );
5260}
5361
5462// Reduces a vector by summing all of it's elements into a single scalar integer
55- int32_t int32_sum_vector_i32m8 (vint32m8_t vec , size_t vl ) {
63+ static inline int32_t int32_sum_vector_i32m8 (vint32m8_t vec , size_t vl ) {
5664 vint32m1_t acc = __riscv_vmv_s_x_i32m1 (0 , 1 );
5765 vl = __riscv_vsetvl_e32m8 (vl );
5866 acc = __riscv_vredsum_vs_i32m8_i32m1 (vec , acc , vl );
5967 return __riscv_vmv_x_s_i32m1_i32 (acc );
6068}
6169
70+ // Scalar-load fp16 payloads, convert to fp32, and pack as an f32m2 vector.
71+ static inline vfloat32m2_t rvv_load_f16_as_f32m2 (const uint16_t * src , size_t n ) {
72+ size_t vl = __riscv_vsetvl_e32m2 (n );
73+ float lanes [vl ];
74+ for (size_t i = 0 ; i < vl ; ++ i ) lanes [i ] = float16_to_float32 (src [i ]);
75+ return __riscv_vle32_v_f32m2 (lanes , vl );
76+ }
77+
78+ // Returns true if any lane has an fp16-style infinity mismatch:
79+ // one side is Inf and the other is not, or both are Inf with different signs.
80+ static inline bool rvv_has_f16_inf_mismatch_f64m4 (vfloat64m4_t va , vfloat64m4_t vb , size_t vl ) {
81+ vuint64m4_t a_class = __riscv_vfclass_v_u64m4 (va , vl );
82+ vuint64m4_t b_class = __riscv_vfclass_v_u64m4 (vb , vl );
83+ vuint64m4_t a_inf_bits = __riscv_vand_vx_u64m4 (a_class , 0x81u , vl );
84+ vuint64m4_t b_inf_bits = __riscv_vand_vx_u64m4 (b_class , 0x81u , vl );
85+ vbool16_t inf_mismatch = __riscv_vmsne_vv_u64m4_b16 (a_inf_bits , b_inf_bits , vl );
86+ return __riscv_vfirst_m_b16 (inf_mismatch , vl ) >= 0 ;
87+ }
88+
89+ // Returns mask of lanes where both vectors are not NaN.
90+ static inline vbool16_t rvv_both_not_nan_f64m4 (vfloat64m4_t va , vfloat64m4_t vb , size_t vl ) {
91+ vbool16_t a_not_nan = __riscv_vmfeq_vv_f64m4_b16 (va , va , vl );
92+ vbool16_t b_not_nan = __riscv_vmfeq_vv_f64m4_b16 (vb , vb , vl );
93+ return __riscv_vmand_mm_b16 (a_not_nan , b_not_nan , vl );
94+ }
95+
6296
6397// MARK: - FLOAT32 -
6498
@@ -213,34 +247,182 @@ float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) {
213247
214248// MARK: - FLOAT16 -
215249
250+ static inline float float16_distance_l2_impl_rvv (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
251+ const uint16_t * a = (const uint16_t * )v1 ;
252+ const uint16_t * b = (const uint16_t * )v2 ;
253+
254+ size_t vl = __riscv_vsetvlmax_e64m4 ();
255+ vfloat64m4_t vsum = __riscv_vfmv_v_f_f64m4 (0.0 , vl );
256+
257+ for (size_t i = n ; i > 0 ;) {
258+ // Scalar-load fp16, convert to f32, then widen to f64.
259+ vl = __riscv_vsetvl_e32m2 (i );
260+ vfloat32m2_t va32 = rvv_load_f16_as_f32m2 (a , vl );
261+ vfloat32m2_t vb32 = rvv_load_f16_as_f32m2 (b , vl );
262+ vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4 (va32 , vl );
263+ vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4 (vb32 , vl );
264+
265+ vl = __riscv_vsetvl_e64m4 (vl );
266+
267+ // Return +Inf if there is an infinity mismatch.
268+ if (rvv_has_f16_inf_mismatch_f64m4 (va , vb , vl )) return INFINITY ;
269+
270+ // Skip NaN lanes in accumulation path.
271+ vbool16_t not_nan = rvv_both_not_nan_f64m4 (va , vb , vl );
272+
273+ vfloat64m4_t vdiff = __riscv_vfsub_vv_f64m4 (va , vb , vl );
274+ vsum = __riscv_vfmacc_vv_f64m4_m (not_nan , vsum , vdiff , vdiff , vl );
275+
276+ a += vl ;
277+ b += vl ;
278+ i -= vl ;
279+ }
280+
281+ double l2sq = float64_sum_vector_f64m4 (vsum , n );
282+ return use_sqrt ? sqrtf ((float )l2sq ) : (float )l2sq ;
283+ }
284+
216285float float16_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
217- printf ("float16_distance_l2_rvv: unimplemented\n" );
218- abort ();
219- return 0.0f ;
286+ return float16_distance_l2_impl_rvv (v1 , v2 , n , true);
220287}
221288
222289float float16_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
223- printf ("float16_distance_l2_squared_rvv: unimplemented\n" );
224- abort ();
225- return 0.0f ;
290+ return float16_distance_l2_impl_rvv (v1 , v2 , n , false);
226291}
227292
228293float float16_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
229- printf ("float16_distance_l1_rvv: unimplemented\n" );
230- abort ();
231- return 0.0f ;
294+ const uint16_t * a = (const uint16_t * )v1 ;
295+ const uint16_t * b = (const uint16_t * )v2 ;
296+
297+ size_t vl = __riscv_vsetvlmax_e64m4 ();
298+ vfloat64m4_t vsum = __riscv_vfmv_v_f_f64m4 (0.0 , vl );
299+
300+ for (size_t i = n ; i > 0 ;) {
301+ // Scalar-load fp16, convert to f32, then widen to f64.
302+ vl = __riscv_vsetvl_e32m2 (i );
303+ vfloat32m2_t va32 = rvv_load_f16_as_f32m2 (a , vl );
304+ vfloat32m2_t vb32 = rvv_load_f16_as_f32m2 (b , vl );
305+ vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4 (va32 , vl );
306+ vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4 (vb32 , vl );
307+
308+ vl = __riscv_vsetvl_e64m4 (vl );
309+
310+ // Return +Inf if there is an infinity mismatch.
311+ if (rvv_has_f16_inf_mismatch_f64m4 (va , vb , vl )) return INFINITY ;
312+
313+ // Skip NaN lanes in accumulation path.
314+ vbool16_t not_nan = rvv_both_not_nan_f64m4 (va , vb , vl );
315+
316+ vfloat64m4_t vdiff = __riscv_vfsub_vv_f64m4 (va , vb , vl );
317+ vfloat64m4_t vabs = __riscv_vfabs_v_f64m4 (vdiff , vl );
318+ vsum = __riscv_vfadd_vv_f64m4_m (not_nan , vsum , vabs , vl );
319+
320+ a += vl ;
321+ b += vl ;
322+ i -= vl ;
323+ }
324+
325+ return (float )float64_sum_vector_f64m4 (vsum , n );
232326}
233327
234328float float16_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
235- printf ("float16_distance_dot_rvv: unimplemented\n" );
236- abort ();
237- return 0.0f ;
329+ const uint16_t * a = (const uint16_t * )v1 ;
330+ const uint16_t * b = (const uint16_t * )v2 ;
331+
332+ // Keep accumulation vectorized while preserving CPU NaN/Inf semantics.
333+ size_t vl = __riscv_vsetvlmax_e64m4 ();
334+ vfloat64m4_t vdot = __riscv_vfmv_v_f_f64m4 (0.0 , vl );
335+
336+ for (size_t i = n ; i > 0 ;) {
337+ // Scalar-load fp16, convert to f32, then widen to f64.
338+ vl = __riscv_vsetvl_e32m2 (i );
339+ vfloat32m2_t va32 = rvv_load_f16_as_f32m2 (a , vl );
340+ vfloat32m2_t vb32 = rvv_load_f16_as_f32m2 (b , vl );
341+ vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4 (va32 , vl );
342+ vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4 (vb32 , vl );
343+
344+ vl = __riscv_vsetvl_e64m4 (vl );
345+
346+ // not_nan = lanes where both sides are not NaN.
347+ vbool16_t not_nan = rvv_both_not_nan_f64m4 (va , vb , vl );
348+
349+ // Multiply once, then classify the product only.
350+ vfloat64m4_t vprod = __riscv_vfmul_vv_f64m4 (va , vb , vl );
351+
352+ // Try to find infinite values, if there are any, exit early
353+ vuint64m4_t p_class = __riscv_vfclass_v_u64m4 (vprod , vl );
354+ vbool16_t inf_pos = __riscv_vmsne_vx_u64m4_b16_m (not_nan , __riscv_vand_vx_u64m4_m (not_nan , p_class , 0x80u , vl ), 0u , vl );
355+ vbool16_t inf_neg = __riscv_vmsne_vx_u64m4_b16_m (not_nan , __riscv_vand_vx_u64m4_m (not_nan , p_class , 0x01u , vl ), 0u , vl );
356+ long first_pos = __riscv_vfirst_m_b16 (inf_pos , vl );
357+ long first_neg = __riscv_vfirst_m_b16 (inf_neg , vl );
358+ if (first_pos >= 0 || first_neg >= 0 ) {
359+ if (first_pos >= 0 && (first_neg < 0 || first_pos < first_neg )) return - INFINITY ;
360+ return INFINITY ;
361+ }
362+
363+ // Accumulate only valid lanes; NaN lanes are skipped.
364+ vdot = __riscv_vfadd_vv_f64m4_m (not_nan , vdot , vprod , vl );
365+
366+ a += vl ;
367+ b += vl ;
368+ i -= vl ;
369+ }
370+
371+ double dot = float64_sum_vector_f64m4 (vdot , n );
372+ return (float )(- dot );
238373}
239374
240375float float16_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
241- printf ("float16_distance_cosine_rvv: unimplemented\n" );
242- abort ();
243- return 0.0f ;
376+ const uint16_t * a = (const uint16_t * )v1 ;
377+ const uint16_t * b = (const uint16_t * )v2 ;
378+
379+ size_t vl = __riscv_vsetvlmax_e64m4 ();
380+ vfloat64m4_t vdot = __riscv_vfmv_v_f_f64m4 (0.0 , vl );
381+ vfloat64m4_t vnx = __riscv_vfmv_v_f_f64m4 (0.0 , vl );
382+ vfloat64m4_t vny = __riscv_vfmv_v_f_f64m4 (0.0 , vl );
383+
384+ for (size_t i = n ; i > 0 ;) {
385+ // Scalar-load fp16, convert to f32, then widen to f64.
386+ vl = __riscv_vsetvl_e32m2 (i );
387+ vfloat32m2_t va32 = rvv_load_f16_as_f32m2 (a , vl );
388+ vfloat32m2_t vb32 = rvv_load_f16_as_f32m2 (b , vl );
389+ vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4 (va32 , vl );
390+ vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4 (vb32 , vl );
391+
392+ vl = __riscv_vsetvl_e64m4 (vl );
393+
394+ // Keep only lanes where both values are not NaN.
395+ vbool16_t not_nan = rvv_both_not_nan_f64m4 (va , vb , vl );
396+
397+ // Any infinity on a valid lane returns 1.0f.
398+ vuint64m4_t a_class = __riscv_vfclass_v_u64m4 (va , vl );
399+ vuint64m4_t b_class = __riscv_vfclass_v_u64m4 (vb , vl );
400+ vuint64m4_t ab_class = __riscv_vor_vv_u64m4 (a_class , b_class , vl );
401+ vbool16_t ab_inf = __riscv_vmsne_vx_u64m4_b16 (__riscv_vand_vx_u64m4 (ab_class , 0x81u , vl ), 0u , vl );
402+ vbool16_t any_inf = __riscv_vmand_mm_b16 (not_nan , ab_inf , vl );
403+ if (__riscv_vfirst_m_b16 (any_inf , vl ) >= 0 ) return 1.0f ;
404+
405+ // Accumulate dot and squared norms on valid lanes.
406+ vfloat64m4_t vprod = __riscv_vfmul_vv_f64m4 (va , vb , vl );
407+ vdot = __riscv_vfadd_vv_f64m4_m (not_nan , vdot , vprod , vl );
408+ vnx = __riscv_vfmacc_vv_f64m4_m (not_nan , vnx , va , va , vl );
409+ vny = __riscv_vfmacc_vv_f64m4_m (not_nan , vny , vb , vb , vl );
410+
411+ a += vl ;
412+ b += vl ;
413+ i -= vl ;
414+ }
415+
416+ double dot = float64_sum_vector_f64m4 (vdot , n );
417+ double nx = float64_sum_vector_f64m4 (vnx , n );
418+ double ny = float64_sum_vector_f64m4 (vny , n );
419+ double denom = sqrt (nx ) * sqrt (ny );
420+ if (!(denom > 0.0 ) || !isfinite (denom ) || !isfinite (dot )) return 1.0f ;
421+
422+ double cosv = dot / denom ;
423+ if (cosv > 1.0 ) cosv = 1.0 ;
424+ if (cosv < -1.0 ) cosv = -1.0 ;
425+ return (float )(1.0 - cosv );
244426}
245427
246428// MARK: - BFLOAT16 -
@@ -691,31 +873,31 @@ float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
691873void init_distance_functions_rvv (void ) {
692874#if defined(__riscv_v_intrinsic )
693875 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_rvv ;
694- // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
876+ dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_F16 ] = float16_distance_l2_rvv ;
695877 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
696878 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_rvv ;
697879 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_I8 ] = int8_distance_l2_rvv ;
698880
699881 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_squared_rvv ;
700- // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
882+ dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F16 ] = float16_distance_l2_squared_rvv ;
701883 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
702884 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_squared_rvv ;
703885 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_I8 ] = int8_distance_l2_squared_rvv ;
704886
705887 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_F32 ] = float32_distance_cosine_rvv ;
706- // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
888+ dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_F16 ] = float16_distance_cosine_rvv ;
707889 // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
708890 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_U8 ] = uint8_distance_cosine_rvv ;
709891 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_I8 ] = int8_distance_cosine_rvv ;
710892
711893 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F32 ] = float32_distance_dot_rvv ;
712- // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
894+ dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F16 ] = float16_distance_dot_rvv ;
713895 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
714896 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_U8 ] = uint8_distance_dot_rvv ;
715897 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_I8 ] = int8_distance_dot_rvv ;
716898
717899 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F32 ] = float32_distance_l1_rvv ;
718- // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
900+ dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F16 ] = float16_distance_l1_rvv ;
719901 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
720902 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_U8 ] = uint8_distance_l1_rvv ;
721903 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_I8 ] = int8_distance_l1_rvv ;
0 commit comments