@@ -454,34 +454,173 @@ float uint8_distance_cosine_rvv (const void *v1, const void *v2, int n) {
454454
455455// MARK: - INT8 -
456456
457+ float int8_distance_l2_impl_rvv (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
458+ const int8_t * a = (const int8_t * )v1 ;
459+ const int8_t * b = (const int8_t * )v2 ;
460+
461+ // We accumulate the results into a vector register
462+ size_t vl = __riscv_vsetvlmax_e32m8 ();
463+ vint32m8_t vl2 = __riscv_vmv_s_x_i32m8 (0 , vl );
464+
465+ // Iterate by VL elements
466+ for (size_t i = n ; i > 0 ; i -= vl ) {
467+ // Use LMUL=2 to start off, but we're going to widen this
468+ vl = __riscv_vsetvl_e8m2 (i );
469+
470+ // Load the vectors into the registers
471+ vint8m2_t va = __riscv_vle8_v_i8m2 (a , vl );
472+ vint8m2_t vb = __riscv_vle8_v_i8m2 (b , vl );
473+
474+ // Widen these values to 16bit signed
475+ vint16m4_t va_wide = __riscv_vwcvt_x_x_v_i16m4 (va , vl );
476+ vint16m4_t vb_wide = __riscv_vwcvt_x_x_v_i16m4 (vb , vl );
477+ vl = __riscv_vsetvl_e16m4 (i );
478+
479+ // L2 = (a[i] - b[i]) + acc
480+ vint32m8_t vdiff = __riscv_vwsub_vv_i32m8 (va_wide , vb_wide , vl );
481+ vl2 = __riscv_vmacc_vv_i32m8 (vl2 , vdiff , vdiff , vl );
482+
483+ // Advance the a and b pointers to the next offset
484+ a = & a [vl ];
485+ b = & b [vl ];
486+ }
487+
488+ // Copy the accumulators back into a scalar register
489+ float l2 = (float ) int32_sum_vector_i32m8 (vl2 , n );
490+ return use_sqrt ? sqrtf (l2 ) : l2 ;
491+ }
492+
457493float int8_distance_l2_rvv (const void * v1 , const void * v2 , int n ) {
458- printf ("int8_distance_l2_rvv: unimplemented\n" );
459- abort ();
460- return 0.0f ;
494+ return int8_distance_l2_impl_rvv (v1 , v2 , n , true);
461495}
462496
463497float int8_distance_l2_squared_rvv (const void * v1 , const void * v2 , int n ) {
464- printf ("int8_distance_l2_squared_rvv: unimplemented\n" );
465- abort ();
466- return 0.0f ;
498+ return int8_distance_l2_impl_rvv (v1 , v2 , n , false);
467499}
468500
469501float int8_distance_dot_rvv (const void * v1 , const void * v2 , int n ) {
470- printf ("int8_distance_dot_rvv: unimplemented\n" );
471- abort ();
472- return 0.0f ;
502+ const int8_t * a = (const int8_t * )v1 ;
503+ const int8_t * b = (const int8_t * )v2 ;
504+
505+ // We accumulate the results into a vector register
506+ size_t vl = __riscv_vsetvlmax_e32m8 ();
507+ vint32m8_t vdot = __riscv_vmv_s_x_i32m8 (0 , vl );
508+
509+ // Iterate by VL elements
510+ for (size_t i = n ; i > 0 ; i -= vl ) {
511+ // Use LMUL=2 to start off, but we're going to widen this
512+ vl = __riscv_vsetvl_e8m2 (i );
513+
514+ // Load the vectors into the registers
515+ vint8m2_t va = __riscv_vle8_v_i8m2 (a , vl );
516+ vint8m2_t vb = __riscv_vle8_v_i8m2 (b , vl );
517+
518+ // Widen these vectors to 16bit
519+ vint16m4_t va_wide = __riscv_vwcvt_x_x_v_i16m4 (va , vl );
520+ vint16m4_t vb_wide = __riscv_vwcvt_x_x_v_i16m4 (vb , vl );
521+
522+ // Now we're operating on 16 bit elements
523+ vl = __riscv_vsetvl_e16m4 (i );
524+
525+ // Do a widening multiply-accumulate to 32 bits
526+ vdot = __riscv_vwmacc_vv_i32m8 (vdot , va_wide , vb_wide , vl );
527+
528+ // Advance the a and b pointers to the next offset
529+ a = & a [vl ];
530+ b = & b [vl ];
531+ }
532+
533+ // Copy the accumulators back into a scalar register
534+ float dot = (float ) int32_sum_vector_i32m8 (vdot , n );
535+ return - dot ;
473536}
474537
475538float int8_distance_l1_rvv (const void * v1 , const void * v2 , int n ) {
476- printf ("int8_distance_l1_rvv: unimplemented\n" );
477- abort ();
478- return 0.0f ;
539+ const int8_t * a = (const int8_t * )v1 ;
540+ const int8_t * b = (const int8_t * )v2 ;
541+
542+ // We accumulate the results into a vector register
543+ size_t vl = __riscv_vsetvlmax_e32m8 ();
544+ vint32m8_t vl1 = __riscv_vmv_s_x_i32m8 (0 , vl );
545+
546+ // Iterate by VL elements
547+ for (size_t i = n ; i > 0 ; i -= vl ) {
548+ // Use LMUL=2 to start off, but we're going to widen this
549+ vl = __riscv_vsetvl_e8m2 (i );
550+
551+ // Load the vectors into the registers
552+ vint8m2_t va = __riscv_vle8_v_i8m2 (a , vl );
553+ vint8m2_t vb = __riscv_vle8_v_i8m2 (b , vl );
554+
555+ // Compute the absolute difference by getting the min and max and subtracting them.
556+ vint8m2_t vmin = __riscv_vmin_vv_i8m2 (va , vb , vl );
557+ vint8m2_t vmax = __riscv_vmax_vv_i8m2 (va , vb , vl );
558+ vint16m4_t vabs = __riscv_vwsub_vv_i16m4 (vmax , vmin , vl );
559+ vl = __riscv_vsetvl_e16m4 (i );
560+
561+ // Now widen it to 32bits and add to the accumulator
562+ vint32m8_t vwide = __riscv_vwcvt_x_x_v_i32m8 (vabs , vl );
563+ vl1 = __riscv_vadd_vv_i32m8 (vl1 , vwide , vl );
564+
565+ // Advance the a and b pointers to the next offset
566+ a = & a [vl ];
567+ b = & b [vl ];
568+ }
569+
570+ // Copy the accumulators back into a scalar register
571+ float l1 = (float ) int32_sum_vector_i32m8 (vl1 , n );
572+ return l1 ;
479573}
480574
481575float int8_distance_cosine_rvv (const void * v1 , const void * v2 , int n ) {
482- printf ("int8_distance_cosine_rvv: unimplemented\n" );
483- abort ();
484- return 0.0f ;
576+ const int8_t * a = (const int8_t * )v1 ;
577+ const int8_t * b = (const int8_t * )v2 ;
578+
579+ // We accumulate the results into a vector register
580+ size_t vl = __riscv_vsetvlmax_e32m8 ();
581+
582+ // Zero out the starting registers
583+ vint32m8_t vdot = __riscv_vmv_s_x_i32m8 (0 , vl );
584+ vint32m8_t vmagn_a = __riscv_vmv_s_x_i32m8 (0 , vl );
585+ vint32m8_t vmagn_b = __riscv_vmv_s_x_i32m8 (0 , vl );
586+
587+ // Iterate by VL elements
588+ for (size_t i = n ; i > 0 ; i -= vl ) {
589+ // Use LMUL=2 to start off, but we're going to widen this
590+ vl = __riscv_vsetvl_e8m2 (i );
591+
592+ // Load the vectors into the registers
593+ vint8m2_t va = __riscv_vle8_v_i8m2 (a , vl );
594+ vint8m2_t vb = __riscv_vle8_v_i8m2 (b , vl );
595+
596+ // Widen these values to 16bit signed
597+ vint16m4_t va_wide = __riscv_vwcvt_x_x_v_i16m4 (va , vl );
598+ vint16m4_t vb_wide = __riscv_vwcvt_x_x_v_i16m4 (vb , vl );
599+ vl = __riscv_vsetvl_e16m4 (i );
600+
601+ // Compute the dot product for the entire register (widening madd)
602+ vdot = __riscv_vwmacc_vv_i32m8 (vdot , va_wide , vb_wide , vl );
603+
604+ // Also calculate the magnitude value for both a and b (widening madd)
605+ vmagn_a = __riscv_vwmacc_vv_i32m8 (vmagn_a , va_wide , va_wide , vl );
606+ vmagn_b = __riscv_vwmacc_vv_i32m8 (vmagn_b , vb_wide , vb_wide , vl );
607+
608+ // Advance the a and b pointers to the next offset
609+ a = & a [vl ];
610+ b = & b [vl ];
611+ }
612+
613+ // Now do a final reduction on the registers to sum the remaining elements
614+ float dot = (float ) int32_sum_vector_i32m8 (vdot , n );
615+ float magn_a = sqrtf ((float ) int32_sum_vector_i32m8 (vmagn_a , n ));
616+ float magn_b = sqrtf ((float ) int32_sum_vector_i32m8 (vmagn_b , n ));
617+
618+ if (magn_a == 0.0f || magn_b == 0.0f ) return 1.0f ;
619+
620+ float cosine_similarity = dot / (magn_a * magn_b );
621+ if (cosine_similarity > 1.0f ) cosine_similarity = 1.0f ;
622+ if (cosine_similarity < -1.0f ) cosine_similarity = -1.0f ;
623+ return 1.0f - cosine_similarity ;
485624}
486625
487626// MARK: - BIT -
@@ -555,31 +694,31 @@ void init_distance_functions_rvv (void) {
555694 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
556695 // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
557696 dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_rvv ;
558- // dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
697+ dispatch_distance_table [VECTOR_DISTANCE_L2 ][VECTOR_TYPE_I8 ] = int8_distance_l2_rvv ;
559698
560699 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_F32 ] = float32_distance_l2_squared_rvv ;
561700 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
562701 // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
563702 dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_U8 ] = uint8_distance_l2_squared_rvv ;
564- // dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
703+ dispatch_distance_table [VECTOR_DISTANCE_SQUARED_L2 ][VECTOR_TYPE_I8 ] = int8_distance_l2_squared_rvv ;
565704
566705 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_F32 ] = float32_distance_cosine_rvv ;
567706 // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
568707 // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
569708 dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_U8 ] = uint8_distance_cosine_rvv ;
570- // dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
709+ dispatch_distance_table [VECTOR_DISTANCE_COSINE ][VECTOR_TYPE_I8 ] = int8_distance_cosine_rvv ;
571710
572711 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_F32 ] = float32_distance_dot_rvv ;
573712 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
574713 // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
575714 dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_U8 ] = uint8_distance_dot_rvv ;
576- // dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
715+ dispatch_distance_table [VECTOR_DISTANCE_DOT ][VECTOR_TYPE_I8 ] = int8_distance_dot_rvv ;
577716
578717 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_F32 ] = float32_distance_l1_rvv ;
579718 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
580719 // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
581720 dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_U8 ] = uint8_distance_l1_rvv ;
582- // dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
721+ dispatch_distance_table [VECTOR_DISTANCE_L1 ][VECTOR_TYPE_I8 ] = int8_distance_l1_rvv ;
583722
584723 dispatch_distance_table [VECTOR_DISTANCE_HAMMING ][VECTOR_TYPE_BIT ] = bit1_distance_hamming_rvv ;
585724
0 commit comments