@@ -38,11 +38,28 @@ float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) {
3838// Reduces a vector by summing all of it's elements into a single scalar integer
3939uint64_t uint64_sum_vector_u64m8 (vuint64m8_t vec , size_t vl ) {
4040 vuint64m1_t acc = __riscv_vmv_s_x_u64m1 (0 , 1 );
41- vl = __riscv_vsetvl_e32m8 (vl );
41+ vl = __riscv_vsetvl_e64m8 (vl );
4242 acc = __riscv_vredsum_vs_u64m8_u64m1 (vec , acc , vl );
4343 return __riscv_vmv_x_s_u64m1_u64 (acc );
4444}
4545
46+ // Reduces a vector by summing all of it's elements into a single scalar integer
47+ uint32_t uint32_sum_vector_u32m8 (vuint32m8_t vec , size_t vl ) {
48+ vuint32m1_t acc = __riscv_vmv_s_x_u32m1 (0 , 1 );
49+ vl = __riscv_vsetvl_e32m8 (vl );
50+ acc = __riscv_vredsum_vs_u32m8_u32m1 (vec , acc , vl );
51+ return __riscv_vmv_x_s_u32m1_u32 (acc );
52+ }
53+
54+ // Reduces a vector by summing all of it's elements into a single scalar integer
55+ int32_t int32_sum_vector_i32m8 (vint32m8_t vec , size_t vl ) {
56+ vint32m1_t acc = __riscv_vmv_s_x_i32m1 (0 , 1 );
57+ vl = __riscv_vsetvl_e32m8 (vl );
58+ acc = __riscv_vredsum_vs_i32m8_i32m1 (vec , acc , vl );
59+ return __riscv_vmv_x_s_i32m1_i32 (acc );
60+ }
61+
62+
4663// MARK: - FLOAT32 -
4764
4865float float32_distance_l2_impl_rvv (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
@@ -80,7 +97,6 @@ float float32_distance_l2_rvv (const void *v1, const void *v2, int n) {
8097 return float32_distance_l2_impl_rvv (v1 , v2 , n , true);
8198}
8299
83-
84100float float32_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
85101 return float32_distance_l2_impl_rvv (v1 , v2 , n , false);
86102}
@@ -261,34 +277,179 @@ float bfloat16_distance_cosine_rvv (const void *v1, const void *v2, int n) {
261277
262278// MARK: - UINT8 -
263279
280+ float uint8_distance_l2_impl_rvv (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
281+ const uint8_t * a = (const uint8_t * )v1 ;
282+ const uint8_t * b = (const uint8_t * )v2 ;
283+
284+ // We accumulate the results into a vector register
285+ size_t vl = __riscv_vsetvlmax_e32m8 ();
286+ vint32m8_t vl2 = __riscv_vmv_s_x_i32m8 (0 , vl );
287+
288+ // Iterate by VL elements
289+ for (size_t i = n ; i > 0 ; i -= vl ) {
290+ // Use LMUL=2 to start off, but we're going to widen this
291+ vl = __riscv_vsetvl_e8m2 (i );
292+
293+ // Load the vectors into the registers
294+ vuint8m2_t va = __riscv_vle8_v_u8m2 (a , vl );
295+ vuint8m2_t vb = __riscv_vle8_v_u8m2 (b , vl );
296+
297+ // Widen these values to 16bit unsigned
298+ vuint16m4_t va_wide = __riscv_vwcvtu_x_x_v_u16m4 (va , vl );
299+ vuint16m4_t vb_wide = __riscv_vwcvtu_x_x_v_u16m4 (vb , vl );
300+ vl = __riscv_vsetvl_e16m4 (i );
301+
302+ // Cast these to signed values
303+ vint16m4_t va_wides = __riscv_vreinterpret_v_u16m4_i16m4 (va_wide );
304+ vint16m4_t vb_wides = __riscv_vreinterpret_v_u16m4_i16m4 (vb_wide );
305+
306+ // L2 = (a[i] - b[i]) + acc
307+ // The subtract is signed, but the accumulate is unsigned
308+ vint32m8_t vdiff = __riscv_vwsub_vv_i32m8 (va_wides , vb_wides , vl );
309+ vl2 = __riscv_vmacc_vv_i32m8 (vl2 , vdiff , vdiff , vl );
310+
311+ // Advance the a and b pointers to the next offset
312+ a = & a [vl ];
313+ b = & b [vl ];
314+ }
315+
316+ // Copy the accumulators back into a scalar register
317+ float l2 = (float ) int32_sum_vector_i32m8 (vl2 , n );
318+ return use_sqrt ? sqrtf (l2 ) : l2 ;
319+ }
320+
264321float uint8_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
265- printf ("uint8_distance_l2_rvv: unimplemented\n" );
266- abort ();
267- return 0.0f ;
322+ return uint8_distance_l2_impl_rvv (v1 , v2 , n , true);
268323}
269324
270325float uint8_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
271- printf ("uint8_distance_l2_squared_rvv: unimplemented\n" );
272- abort ();
273- return 0.0f ;
326+ return uint8_distance_l2_impl_rvv (v1 , v2 , n , false);
274327}
275328
276329float uint8_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
277- printf ("uint8_distance_dot_rvv: unimplemented\n" );
278- abort ();
279- return 0.0f ;
330+ const uint8_t * a = (const uint8_t * )v1 ;
331+ const uint8_t * b = (const uint8_t * )v2 ;
332+
333+ // We accumulate the results into a vector register
334+ size_t vl = __riscv_vsetvlmax_e32m8 ();
335+ vuint32m8_t vdot = __riscv_vmv_s_x_u32m8 (0 , vl );
336+
337+ // Iterate by VL elements
338+ for (size_t i = n ; i > 0 ; i -= vl ) {
339+ // Use LMUL=2 to start off, but we're going to widen this
340+ vl = __riscv_vsetvl_e8m2 (i );
341+
342+ // Load the vectors into the registers
343+ vuint8m2_t va = __riscv_vle8_v_u8m2 (a , vl );
344+ vuint8m2_t vb = __riscv_vle8_v_u8m2 (b , vl );
345+
346+ // Widen these vectors to 16bit
347+ vuint16m4_t va_wide = __riscv_vwcvtu_x_x_v_u16m4 (va , vl );
348+ vuint16m4_t vb_wide = __riscv_vwcvtu_x_x_v_u16m4 (vb , vl );
349+
350+ // Now we're operating on 16 bit elements
351+ vl = __riscv_vsetvl_e16m4 (i );
352+
353+ // Do a widening multiply-accumulate to 32 bits
354+ vdot = __riscv_vwmaccu_vv_u32m8 (vdot , va_wide , vb_wide , vl );
355+
356+ // Advance the a and b pointers to the next offset
357+ a = & a [vl ];
358+ b = & b [vl ];
359+ }
360+
361+ // Copy the accumulators back into a scalar register
362+ float dot = uint32_sum_vector_u32m8 (vdot , n );
363+ return - dot ;
280364}
281365
282366float uint8_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
283- printf ("uint8_distance_l1_rvv: unimplemented\n" );
284- abort ();
285- return 0.0f ;
367+ const uint8_t * a = (const uint8_t * )v1 ;
368+ const uint8_t * b = (const uint8_t * )v2 ;
369+
370+ // We accumulate the results into a vector register
371+ size_t vl = __riscv_vsetvlmax_e32m8 ();
372+ vuint32m8_t vl1 = __riscv_vmv_s_x_u32m8 (0 , vl );
373+
374+ // Iterate by VL elements
375+ for (size_t i = n ; i > 0 ; i -= vl ) {
376+ // Use LMUL=2 to start off, but we're going to widen this
377+ vl = __riscv_vsetvl_e8m2 (i );
378+
379+ // Load the vectors into the registers
380+ vuint8m2_t va = __riscv_vle8_v_u8m2 (a , vl );
381+ vuint8m2_t vb = __riscv_vle8_v_u8m2 (b , vl );
382+
383+ // Compute the absolute difference by getting the min and max and subtracting them.
384+ vuint8m2_t vmin = __riscv_vminu_vv_u8m2 (va , vb , vl );
385+ vuint8m2_t vmax = __riscv_vmaxu_vv_u8m2 (va , vb , vl );
386+ vuint16m4_t vabs = __riscv_vwsubu_vv_u16m4 (vmax , vmin , vl );
387+ vl = __riscv_vsetvl_e16m4 (i );
388+
389+ // Now widen it to 32bits and add to the accumulator
390+ vuint32m8_t vwide = __riscv_vwcvtu_x_x_v_u32m8 (vabs , vl );
391+ vl1 = __riscv_vadd_vv_u32m8 (vl1 , vwide , vl );
392+
393+ // Advance the a and b pointers to the next offset
394+ a = & a [vl ];
395+ b = & b [vl ];
396+ }
397+
398+ // Copy the accumulators back into a scalar register
399+ float l1 = uint32_sum_vector_u32m8 (vl1 , n );
400+ return l1 ;
286401}
287402
288403float uint8_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
289- printf ("uint8_distance_cosine_rvv: unimplemented\n" );
290- abort ();
291- return 0.0f ;
404+ const uint8_t * a = (const uint8_t * )v1 ;
405+ const uint8_t * b = (const uint8_t * )v2 ;
406+
407+ // We accumulate the results into a vector register
408+ size_t vl = __riscv_vsetvlmax_e32m8 ();
409+
410+ // Zero out the starting registers
411+ vuint32m8_t vdot = __riscv_vmv_s_x_u32m8 (0 , vl );
412+ vuint32m8_t vmagn_a = __riscv_vmv_s_x_u32m8 (0 , vl );
413+ vuint32m8_t vmagn_b = __riscv_vmv_s_x_u32m8 (0 , vl );
414+
415+ // Iterate by VL elements
416+ for (size_t i = n ; i > 0 ; i -= vl ) {
417+ // Use LMUL=2 to start off, but we're going to widen this
418+ vl = __riscv_vsetvl_e8m2 (i );
419+
420+ // Load the vectors into the registers
421+ vuint8m2_t va = __riscv_vle8_v_u8m2 (a , vl );
422+ vuint8m2_t vb = __riscv_vle8_v_u8m2 (b , vl );
423+
424+ // Widen these values to 16bit unsigned
425+ vuint16m4_t va_wide = __riscv_vwcvtu_x_x_v_u16m4 (va , vl );
426+ vuint16m4_t vb_wide = __riscv_vwcvtu_x_x_v_u16m4 (vb , vl );
427+ vl = __riscv_vsetvl_e16m4 (i );
428+
429+ // Compute the dot product for the entire register (widening madd)
430+ vdot = __riscv_vwmaccu_vv_u32m8 (vdot , va_wide , vb_wide , vl );
431+
432+ // Also calculate the magnitude value for both a and b (widening madd)
433+ vmagn_a = __riscv_vwmaccu_vv_u32m8 (vmagn_a , va_wide , va_wide , vl );
434+ vmagn_b = __riscv_vwmaccu_vv_u32m8 (vmagn_b , vb_wide , vb_wide , vl );
435+
436+ // Advance the a and b pointers to the next offset
437+ a = & a [vl ];
438+ b = & b [vl ];
439+ }
440+
441+ // Now do a final reduction on the registers to sum the remaining elements
442+ // TODO: With default flags this does not always use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that
443+ float dot = uint32_sum_vector_u32m8 (vdot , n );
444+ float magn_a = sqrtf (uint32_sum_vector_u32m8 (vmagn_a , n ));
445+ float magn_b = sqrtf (uint32_sum_vector_u32m8 (vmagn_b , n ));
446+
447+ if (magn_a == 0.0f || magn_b == 0.0f ) return 1.0f ;
448+
449+ float cosine_similarity = dot / (magn_a * magn_b );
450+ if (cosine_similarity > 1.0f ) cosine_similarity = 1.0f ;
451+ if (cosine_similarity < -1.0f ) cosine_similarity = -1.0f ;
452+ return 1.0f - cosine_similarity ;
292453}
293454
294455// MARK: - INT8 -
@@ -354,7 +515,6 @@ vuint64m8_t vpopcnt_u64m8(vuint64m8_t v, size_t vl) {
354515 return v ;
355516}
356517
357-
358518float bit1_distance_hamming_rvv (const void * v1 , const void * v2 , int n ) {
359519 const uint8_t * a = (const uint8_t * )v1 ;
360520 const uint8_t * b = (const uint8_t * )v2 ;
@@ -394,31 +554,31 @@ void init_distance_functions_rvv (void) {
394554 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_rvv ;
395555 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
396556 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
397- // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
557+ dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_rvv ;
398558 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
399559
400560 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_squared_rvv ;
401561 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
402562 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
403- // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
563+ dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_squared_rvv ;
404564 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
405565
406566 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_F32 ] = float32_distance_cosine_rvv ;
407567 // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
408568 // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
409- // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
569+ dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_U8 ] = uint8_distance_cosine_rvv ;
410570 // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
411571
412572 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F32 ] = float32_distance_dot_rvv ;
413573 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
414574 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
415- // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
575+ dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_U8 ] = uint8_distance_dot_rvv ;
416576 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
417577
418578 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F32 ] = float32_distance_l1_rvv ;
419579 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
420580 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
421- // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
581+ dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_U8 ] = uint8_distance_l1_rvv ;
422582 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
423583
424584 dispatch_distance_table [VECTOR_DISTANCE_HAMMING ][VECTOR_TYPE_BIT ] = bit1_distance_hamming_rvv ;
0 commit comments