@@ -35,6 +35,14 @@ float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) {
3535 return __riscv_vfmv_f_s_f32m1_f32 (acc );
3636}
3737
38+ // Reduces a vector by summing all of it's elements into a single scalar integer
39+ uint64_t uint64_sum_vector_u64m8 (vuint64m8_t vec , size_t vl ) {
40+ vuint64m1_t acc = __riscv_vmv_s_x_u64m1 (0 , 1 );
41+ vl = __riscv_vsetvl_e32m8 (vl );
42+ acc = __riscv_vredsum_vs_u64m8_u64m1 (vec , acc , vl );
43+ return __riscv_vmv_x_s_u64m1_u64 (acc );
44+ }
45+
3846// MARK: - FLOAT32 -
3947
4048float float32_distance_l2_impl_rvv (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
@@ -317,12 +325,66 @@ float int8_distance_cosine_rvv (const void *v1, const void *v2, int n) {
317325
318326// MARK: - BIT -
319327
320- float bit1_distance_hamming_rvv (const void * v1 , const void * v2 , int n ) {
321- printf ("bit1_distance_hamming_rvv: unimplemented\n" );
322- abort ();
323- return 0.0f ;
328+
329+ // Counts the number of set bits on each element of a vector register
330+ //
331+ // TODO: RISC-V natively supports vcpop.v for population count, but only with the
332+ // Zvbb extension, which we don't support yet. For everyone else, do a fallback implemetation.
333+ vuint64m8_t vpopcnt_u64m8 (vuint64m8_t v , size_t vl ) {
334+ // v = v - ((v >> 1) & 0x5555555555555555ULL);
335+ vuint64m8_t shr1 = __riscv_vsrl_vx_u64m8 (v , 1 , vl );
336+ vuint64m8_t and1 = __riscv_vand_vx_u64m8 (shr1 , 0x5555555555555555ULL , vl );
337+ v = __riscv_vsub_vv_u64m8 (v , and1 , vl );
338+
339+ // v = (v & 0x3333333333333333ULL) + ((v >> 2) & 0x3333333333333333ULL);
340+ vuint64m8_t shr2 = __riscv_vsrl_vx_u64m8 (v , 2 , vl );
341+ vuint64m8_t and2 = __riscv_vand_vx_u64m8 (shr2 , 0x3333333333333333ULL , vl );
342+ vuint64m8_t and3 = __riscv_vand_vx_u64m8 (v , 0x3333333333333333ULL , vl );
343+ v = __riscv_vadd_vv_u64m8 (and2 , and3 , vl );
344+
345+ // v = (v + (v >> 4)) & 0x0f0f0f0f0f0f0f0fULL;
346+ vuint64m8_t shr4 = __riscv_vsrl_vx_u64m8 (v , 4 , vl );
347+ vuint64m8_t add = __riscv_vadd_vv_u64m8 (v , shr4 , vl );
348+ v = __riscv_vand_vx_u64m8 (add , 0x0f0f0f0f0f0f0f0fULL , vl );
349+
350+ // v = (v * 0x0101010101010101ULL) >> 56;
351+ vuint64m8_t mul = __riscv_vmul_vx_u64m8 (v , 0x0101010101010101ULL , vl );
352+ v = __riscv_vsrl_vx_u64m8 (mul , 56 , vl );
353+
354+ return v ;
324355}
325356
357+
358+ float bit1_distance_hamming_rvv (const void * v1 , const void * v2 , int n ) {
359+ const uint8_t * a = (const uint8_t * )v1 ;
360+ const uint8_t * b = (const uint8_t * )v2 ;
361+
362+ // We accumulate the results into a vector register
363+ size_t vl = __riscv_vsetvl_e32m8 (n );
364+ vuint64m8_t vdistance = __riscv_vmv_s_x_u64m8 (0 , vl );
365+
366+ // Iterate by VL elements
367+ for (size_t i = n ; i > 0 ; i -= vl ) {
368+ // Use LMUL=8, we have 4 registers to work with.
369+ vl = __riscv_vsetvl_e64m8 (n );
370+
371+ // Load the vectors into the registers and cast them into a u64
372+ vuint64m8_t va = __riscv_vreinterpret_v_u8m8_u64m8 (__riscv_vle8_v_u8m8 (a , vl ));
373+ vuint64m8_t vb = __riscv_vreinterpret_v_u8m8_u64m8 (__riscv_vle8_v_u8m8 (b , vl ));
374+
375+ vuint64m8_t xor = __riscv_vxor_vv_u64m8 (va , vb , vl );
376+ vuint64m8_t popcnt = vpopcnt_u64m8 (xor , vl );
377+ vdistance = __riscv_vadd_vv_u64m8 (vdistance , popcnt , vl );
378+
379+ // Advance the a and b pointers to the next offset. Here we multiply by 8 because
380+ // the vectors are defined as u8, but VL is defined in elements of 64bits.
381+ a = & a [vl * 8 ];
382+ b = & b [vl * 8 ];
383+ }
384+
385+ // Copy the accumulator back into a scalar register
386+ return (float ) uint64_sum_vector_u64m8 (vdistance , vl );
387+ }
326388#endif
327389
328390// MARK: -
@@ -359,7 +421,7 @@ void init_distance_functions_rvv (void) {
359421 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
360422 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
361423
362- // dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_rvv;
424+ dispatch_distance_table [VECTOR_DISTANCE_HAMMING ][VECTOR_TYPE_BIT ] = bit1_distance_hamming_rvv ;
363425
364426 distance_backend_name = "RVV" ;
365427#endif
0 commit comments