@@ -20,64 +20,88 @@ extern const char *distance_backend_name;
2020// MARK: - UTILS -
2121
2222// Reduces a vector by summing all of it's elements into a single scalar float
23- static inline float float32_sum_vector_f32m8 (vfloat32m8_t vec , size_t vl ) {
23+ static inline float float32_sum_vector_f32m8 (vfloat32m8_t vec , size_t vl ) {
2424 vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1 (0.0f , 1 );
2525 vl = __riscv_vsetvl_e32m8 (vl );
2626 acc = __riscv_vfredusum_vs_f32m8_f32m1 (vec , acc , vl );
2727 return __riscv_vfmv_f_s_f32m1_f32 (acc );
2828}
2929
3030// Reduces a vector by summing all of it's elements into a single scalar float
31- static inline float float32_sum_vector_f32m4 (vfloat32m4_t vec , size_t vl ) {
31+ static inline float float32_sum_vector_f32m4 (vfloat32m4_t vec , size_t vl ) {
3232 vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1 (0.0f , 1 );
3333 vl = __riscv_vsetvl_e32m4 (vl );
3434 acc = __riscv_vfredusum_vs_f32m4_f32m1 (vec , acc , vl );
3535 return __riscv_vfmv_f_s_f32m1_f32 (acc );
3636}
3737
3838// Reduces a vector by summing all of it's elements into a single scalar double
39- static inline double float64_sum_vector_f64m4 (vfloat64m4_t vec , size_t vl ) {
39+ static inline double float64_sum_vector_f64m4 (vfloat64m4_t vec , size_t vl ) {
4040 vfloat64m1_t acc = __riscv_vfmv_v_f_f64m1 (0.0 , 1 );
4141 vl = __riscv_vsetvl_e64m4 (vl );
4242 acc = __riscv_vfredusum_vs_f64m4_f64m1 (vec , acc , vl );
4343 return __riscv_vfmv_f_s_f64m1_f64 (acc );
4444}
4545
4646// Reduces a vector by summing all of it's elements into a single scalar integer
47- static inline uint64_t uint64_sum_vector_u64m8 (vuint64m8_t vec , size_t vl ) {
47+ static inline uint64_t uint64_sum_vector_u64m8 (vuint64m8_t vec , size_t vl ) {
4848 vuint64m1_t acc = __riscv_vmv_s_x_u64m1 (0 , 1 );
4949 vl = __riscv_vsetvl_e64m8 (vl );
5050 acc = __riscv_vredsum_vs_u64m8_u64m1 (vec , acc , vl );
5151 return __riscv_vmv_x_s_u64m1_u64 (acc );
5252}
5353
5454// Reduces a vector by summing all of it's elements into a single scalar integer
55- static inline uint32_t uint32_sum_vector_u32m8 (vuint32m8_t vec , size_t vl ) {
55+ static inline uint32_t uint32_sum_vector_u32m8 (vuint32m8_t vec , size_t vl ) {
5656 vuint32m1_t acc = __riscv_vmv_s_x_u32m1 (0 , 1 );
5757 vl = __riscv_vsetvl_e32m8 (vl );
5858 acc = __riscv_vredsum_vs_u32m8_u32m1 (vec , acc , vl );
5959 return __riscv_vmv_x_s_u32m1_u32 (acc );
6060}
6161
6262// Reduces a vector by summing all of it's elements into a single scalar integer
63- static inline int32_t int32_sum_vector_i32m8 (vint32m8_t vec , size_t vl ) {
63+ static inline int32_t int32_sum_vector_i32m8 (vint32m8_t vec , size_t vl ) {
6464 vint32m1_t acc = __riscv_vmv_s_x_i32m1 (0 , 1 );
6565 vl = __riscv_vsetvl_e32m8 (vl );
6666 acc = __riscv_vredsum_vs_i32m8_i32m1 (vec , acc , vl );
6767 return __riscv_vmv_x_s_i32m1_i32 (acc );
6868}
6969
7070// Scalar-load fp16 payloads, convert to fp32, and pack as an f32m2 vector.
71- static inline vfloat32m2_t rvv_load_f16_as_f32m2 (const uint16_t * src , size_t n ) {
71+ static inline vfloat32m2_t rvv_load_f16_as_f32m2 (const uint16_t * src , size_t n ) {
7272 size_t vl = __riscv_vsetvl_e32m2 (n );
7373 float lanes [vl ];
7474 for (size_t i = 0 ; i < vl ; ++ i ) lanes [i ] = float16_to_float32 (src [i ]);
7575 return __riscv_vle32_v_f32m2 (lanes , vl );
7676}
7777
78+ // Scalar-load bf16 payloads, convert to fp32, and pack as an f32m8 vector.
79+ static inline vfloat32m8_t rvv_load_bf16_as_f32m8 (const uint16_t * src , size_t n ) {
80+ size_t vl = __riscv_vsetvl_e32m8 (n );
81+ float lanes [vl ];
82+ for (size_t i = 0 ; i < vl ; ++ i ) lanes [i ] = bfloat16_to_float32 (src [i ]);
83+ return __riscv_vle32_v_f32m8 (lanes , vl );
84+ }
85+
86+ // Scalar-load bf16 payloads, convert to fp32, and pack as an f32m4 vector.
87+ static inline vfloat32m4_t rvv_load_bf16_as_f32m4 (const uint16_t * src , size_t n ) {
88+ size_t vl = __riscv_vsetvl_e32m4 (n );
89+ float lanes [vl ];
90+ for (size_t i = 0 ; i < vl ; ++ i ) lanes [i ] = bfloat16_to_float32 (src [i ]);
91+ return __riscv_vle32_v_f32m4 (lanes , vl );
92+ }
93+
94+ // Scalar-load bf16 payloads, convert to fp32, and pack as an f32m2 vector.
95+ static inline vfloat32m2_t rvv_load_bf16_as_f32m2 (const uint16_t * src , size_t n ) {
96+ size_t vl = __riscv_vsetvl_e32m2 (n );
97+ float lanes [vl ];
98+ for (size_t i = 0 ; i < vl ; ++ i ) lanes [i ] = bfloat16_to_float32 (src [i ]);
99+ return __riscv_vle32_v_f32m2 (lanes , vl );
100+ }
101+
78102// Returns true if any lane has an fp16-style infinity mismatch:
79103// one side is Inf and the other is not, or both are Inf with different signs.
80- static inline bool rvv_has_f16_inf_mismatch_f64m4 (vfloat64m4_t va , vfloat64m4_t vb , size_t vl ) {
104+ static inline bool rvv_has_f16_inf_mismatch_f64m4 (vfloat64m4_t va , vfloat64m4_t vb , size_t vl ) {
81105 vuint64m4_t a_class = __riscv_vfclass_v_u64m4 (va , vl );
82106 vuint64m4_t b_class = __riscv_vfclass_v_u64m4 (vb , vl );
83107 vuint64m4_t a_inf_bits = __riscv_vand_vx_u64m4 (a_class , 0x81u , vl );
@@ -87,7 +111,7 @@ static inline bool rvv_has_f16_inf_mismatch_f64m4(vfloat64m4_t va, vfloat64m4_t
87111}
88112
89113// Returns mask of lanes where both vectors are not NaN.
90- static inline vbool16_t rvv_both_not_nan_f64m4 (vfloat64m4_t va , vfloat64m4_t vb , size_t vl ) {
114+ static inline vbool16_t rvv_both_not_nan_f64m4 (vfloat64m4_t va , vfloat64m4_t vb , size_t vl ) {
91115 vbool16_t a_not_nan = __riscv_vmfeq_vv_f64m4_b16 (va , va , vl );
92116 vbool16_t b_not_nan = __riscv_vmfeq_vv_f64m4_b16 (vb , vb , vl );
93117 return __riscv_vmand_mm_b16 (a_not_nan , b_not_nan , vl );
@@ -107,7 +131,7 @@ float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool
107131 // Iterate by VL elements
108132 for (size_t i = n ; i > 0 ; i -= vl ) {
109133 // Use LMUL=8, we have 4 registers to work with.
110- vl = __riscv_vsetvl_e32m8 (n );
134+ vl = __riscv_vsetvl_e32m8 (i );
111135
112136 // Load the vectors into the registers
113137 vfloat32m8_t va = __riscv_vle32_v_f32m8 (a , vl );
@@ -146,7 +170,7 @@ float float32_distance_l1_rvv (const void *v1, const void *v2, int n) {
146170 // Iterate by VL elements
147171 for (size_t i = n ; i > 0 ; i -= vl ) {
148172 // Use LMUL=8, we have 4 registers to work with.
149- vl = __riscv_vsetvl_e32m8 (n );
173+ vl = __riscv_vsetvl_e32m8 (i );
150174
151175 // Load the vectors into the registers
152176 vfloat32m8_t va = __riscv_vle32_v_f32m8 (a , vl );
@@ -427,34 +451,129 @@ float float16_distance_cosine_rvv (const void *v1, const void *v2, int n) {
427451
428452// MARK: - BFLOAT16 -
429453
454+ static inline float bfloat16_distance_l2_impl_rvv (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
455+ const uint16_t * a = (const uint16_t * )v1 ;
456+ const uint16_t * b = (const uint16_t * )v2 ;
457+
458+ size_t vl = __riscv_vsetvlmax_e64m4 ();
459+ vfloat64m4_t vsum = __riscv_vfmv_v_f_f64m4 (0.0 , vl );
460+
461+ for (size_t i = n ; i > 0 ;) {
462+ // Load as f32m2 and widen to f64m4 to avoid overflow in accumulation.
463+ vl = __riscv_vsetvl_e32m2 (i );
464+ vfloat32m2_t va32 = rvv_load_bf16_as_f32m2 (a , vl );
465+ vfloat32m2_t vb32 = rvv_load_bf16_as_f32m2 (b , vl );
466+ vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4 (va32 , vl );
467+ vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4 (vb32 , vl );
468+
469+ vl = __riscv_vsetvl_e64m4 (vl );
470+
471+ vfloat64m4_t vdiff = __riscv_vfsub_vv_f64m4 (va , vb , vl );
472+
473+ // If any diff lane is infinite, return +INFINITY.
474+ vuint64m4_t d_class = __riscv_vfclass_v_u64m4 (vdiff , vl );
475+ vbool16_t d_inf = __riscv_vmsne_vx_u64m4_b16 (__riscv_vand_vx_u64m4 (d_class , 0x81u , vl ), 0u , vl );
476+ if (__riscv_vfirst_m_b16 (d_inf , vl ) >= 0 ) return INFINITY ;
477+
478+ // Skip NaN diff lanes.
479+ vbool16_t not_nan = __riscv_vmfeq_vv_f64m4_b16 (vdiff , vdiff , vl );
480+ vsum = __riscv_vfmacc_vv_f64m4_m (not_nan , vsum , vdiff , vdiff , vl );
481+
482+ a += vl ;
483+ b += vl ;
484+ i -= vl ;
485+ }
486+
487+ double l2sq = float64_sum_vector_f64m4 (vsum , n );
488+ return use_sqrt ? sqrtf ((float )l2sq ) : (float )l2sq ;
489+ }
490+
430491float bfloat16_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
431- printf ("bfloat16_distance_l2_rvv: unimplemented\n" );
432- abort ();
433- return 0.0f ;
492+ return bfloat16_distance_l2_impl_rvv (v1 , v2 , n , true);
434493}
435494
436495float bfloat16_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
437- printf ("bfloat16_distance_l2_squared_rvv: unimplemented\n" );
438- abort ();
439- return 0.0f ;
496+ return bfloat16_distance_l2_impl_rvv (v1 , v2 , n , false);
440497}
441498
442499float bfloat16_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
443- printf ("bfloat16_distance_l1_rvv: unimplemented\n" );
444- abort ();
445- return 0.0f ;
500+ const uint16_t * a = (const uint16_t * )v1 ;
501+ const uint16_t * b = (const uint16_t * )v2 ;
502+
503+ size_t vl = __riscv_vsetvlmax_e32m8 ();
504+ vfloat32m8_t vsum = __riscv_vfmv_v_f_f32m8 (0.0f , vl );
505+
506+ for (size_t i = n ; i > 0 ;) {
507+ vl = __riscv_vsetvl_e32m8 (i );
508+ vfloat32m8_t va = rvv_load_bf16_as_f32m8 (a , vl );
509+ vfloat32m8_t vb = rvv_load_bf16_as_f32m8 (b , vl );
510+
511+ vfloat32m8_t vdiff = __riscv_vfsub_vv_f32m8 (va , vb , vl );
512+ vfloat32m8_t vabs = __riscv_vfabs_v_f32m8 (vdiff , vl );
513+ vsum = __riscv_vfadd_vv_f32m8 (vsum , vabs , vl );
514+
515+ a += vl ;
516+ b += vl ;
517+ i -= vl ;
518+ }
519+
520+ return float32_sum_vector_f32m8 (vsum , n );
446521}
447522
448523float bfloat16_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
449- printf ("bfloat16_distance_dot_rvv: unimplemented\n" );
450- abort ();
451- return 0.0f ;
524+ const uint16_t * a = (const uint16_t * )v1 ;
525+ const uint16_t * b = (const uint16_t * )v2 ;
526+
527+ size_t vl = __riscv_vsetvlmax_e32m8 ();
528+ vfloat32m8_t vdot = __riscv_vfmv_v_f_f32m8 (0.0f , vl );
529+
530+ for (size_t i = n ; i > 0 ;) {
531+ vl = __riscv_vsetvl_e32m8 (i );
532+ vfloat32m8_t va = rvv_load_bf16_as_f32m8 (a , vl );
533+ vfloat32m8_t vb = rvv_load_bf16_as_f32m8 (b , vl );
534+ vdot = __riscv_vfmacc_vv_f32m8 (vdot , va , vb , vl );
535+
536+ a += vl ;
537+ b += vl ;
538+ i -= vl ;
539+ }
540+
541+ float dot = float32_sum_vector_f32m8 (vdot , n );
542+ return - dot ;
452543}
453544
454545float bfloat16_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
455- printf ("bfloat16_distance_cosine_rvv: unimplemented\n" );
456- abort ();
457- return 0.0f ;
546+ const uint16_t * a = (const uint16_t * )v1 ;
547+ const uint16_t * b = (const uint16_t * )v2 ;
548+
549+ size_t vl = __riscv_vsetvlmax_e32m4 ();
550+ vfloat32m4_t vdot = __riscv_vfmv_v_f_f32m4 (0.0f , vl );
551+ vfloat32m4_t vnx = __riscv_vfmv_v_f_f32m4 (0.0f , vl );
552+ vfloat32m4_t vny = __riscv_vfmv_v_f_f32m4 (0.0f , vl );
553+
554+ for (size_t i = n ; i > 0 ;) {
555+ vl = __riscv_vsetvl_e32m4 (i );
556+ vfloat32m4_t va = rvv_load_bf16_as_f32m4 (a , vl );
557+ vfloat32m4_t vb = rvv_load_bf16_as_f32m4 (b , vl );
558+
559+ vdot = __riscv_vfmacc_vv_f32m4 (vdot , va , vb , vl );
560+ vnx = __riscv_vfmacc_vv_f32m4 (vnx , va , va , vl );
561+ vny = __riscv_vfmacc_vv_f32m4 (vny , vb , vb , vl );
562+
563+ a += vl ;
564+ b += vl ;
565+ i -= vl ;
566+ }
567+
568+ float dot = float32_sum_vector_f32m4 (vdot , n );
569+ float norm_x = float32_sum_vector_f32m4 (vnx , n );
570+ float norm_y = float32_sum_vector_f32m4 (vny , n );
571+ if (norm_x == 0.0f || norm_y == 0.0f ) return 1.0f ;
572+
573+ float cosine_similarity = dot / (sqrtf (norm_x ) * sqrtf (norm_y ));
574+ if (cosine_similarity > 1.0f ) cosine_similarity = 1.0f ;
575+ if (cosine_similarity < -1.0f ) cosine_similarity = -1.0f ;
576+ return 1.0f - cosine_similarity ;
458577}
459578
460579// MARK: - UINT8 -
@@ -847,7 +966,7 @@ float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
847966 // Iterate by VL elements
848967 for (size_t i = n ; i > 0 ; i -= vl ) {
849968 // Use LMUL=8, we have 4 registers to work with.
850- vl = __riscv_vsetvl_e64m8 (n );
969+ vl = __riscv_vsetvl_e64m8 (i );
851970
852971 // Load the vectors into the registers and cast them into a u64 inplace
853972 vuint64m8_t va = __riscv_vreinterpret_v_u8m8_u64m8 (__riscv_vle8_v_u8m8 (a , vl ));
@@ -874,31 +993,31 @@ void init_distance_functions_rvv (void) {
874993#if defined(__riscv_v_intrinsic )
875994 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_rvv ;
876995 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_F16 ] = float16_distance_l2_rvv ;
877- // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
996+ dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_BF16 ] = bfloat16_distance_l2_rvv ;
878997 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_rvv ;
879998 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_I8 ] = int8_distance_l2_rvv ;
880999
8811000 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_squared_rvv ;
8821001 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F16 ] = float16_distance_l2_squared_rvv ;
883- // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
1002+ dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_BF16 ] = bfloat16_distance_l2_squared_rvv ;
8841003 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_squared_rvv ;
8851004 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_I8 ] = int8_distance_l2_squared_rvv ;
8861005
8871006 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_F32 ] = float32_distance_cosine_rvv ;
8881007 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_F16 ] = float16_distance_cosine_rvv ;
889- // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
1008+ dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_BF16 ] = bfloat16_distance_cosine_rvv ;
8901009 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_U8 ] = uint8_distance_cosine_rvv ;
8911010 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_I8 ] = int8_distance_cosine_rvv ;
8921011
8931012 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F32 ] = float32_distance_dot_rvv ;
8941013 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F16 ] = float16_distance_dot_rvv ;
895- // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
1014+ dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_BF16 ] = bfloat16_distance_dot_rvv ;
8961015 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_U8 ] = uint8_distance_dot_rvv ;
8971016 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_I8 ] = int8_distance_dot_rvv ;
8981017
8991018 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F32 ] = float32_distance_l1_rvv ;
9001019 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F16 ] = float16_distance_l1_rvv ;
901- // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
1020+ dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_BF16 ] = bfloat16_distance_l1_rvv ;
9021021 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_U8 ] = uint8_distance_l1_rvv ;
9031022 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_I8 ] = int8_distance_l1_rvv ;
9041023
0 commit comments