Skip to content

Commit 3ee8612

Browse files
committed
distance-rvv: Add support for i8 distance functions
1 parent 687a357 commit 3ee8612

1 file changed

Lines changed: 159 additions & 20 deletions

File tree

src/distance-rvv.c

Lines changed: 159 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -454,34 +454,173 @@ float uint8_distance_cosine_rvv (const void *v1, const void *v2, int n) {
454454

455455
// MARK: - INT8 -
456456

457+
float int8_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool use_sqrt) {
458+
const int8_t *a = (const int8_t *)v1;
459+
const int8_t *b = (const int8_t *)v2;
460+
461+
// We accumulate the results into a vector register
462+
size_t vl = __riscv_vsetvlmax_e32m8();
463+
vint32m8_t vl2 = __riscv_vmv_s_x_i32m8(0, vl);
464+
465+
// Iterate by VL elements
466+
for (size_t i = n; i > 0; i -= vl) {
467+
// Use LMUL=2 to start off, but we're going to widen this
468+
vl = __riscv_vsetvl_e8m2(i);
469+
470+
// Load the vectors into the registers
471+
vint8m2_t va = __riscv_vle8_v_i8m2(a, vl);
472+
vint8m2_t vb = __riscv_vle8_v_i8m2(b, vl);
473+
474+
// Widen these values to 16bit signed
475+
vint16m4_t va_wide = __riscv_vwcvt_x_x_v_i16m4(va, vl);
476+
vint16m4_t vb_wide = __riscv_vwcvt_x_x_v_i16m4(vb, vl);
477+
vl = __riscv_vsetvl_e16m4(i);
478+
479+
// L2 = (a[i] - b[i]) + acc
480+
vint32m8_t vdiff = __riscv_vwsub_vv_i32m8(va_wide, vb_wide, vl);
481+
vl2 = __riscv_vmacc_vv_i32m8(vl2, vdiff, vdiff, vl);
482+
483+
// Advance the a and b pointers to the next offset
484+
a = &a[vl];
485+
b = &b[vl];
486+
}
487+
488+
// Copy the accumulators back into a scalar register
489+
float l2 = (float) int32_sum_vector_i32m8(vl2, n);
490+
return use_sqrt ? sqrtf(l2) : l2;
491+
}
492+
457493
float int8_distance_l2_rvv (const void *v1, const void *v2, int n) {
458-
printf("int8_distance_l2_rvv: unimplemented\n");
459-
abort();
460-
return 0.0f;
494+
return int8_distance_l2_impl_rvv(v1, v2, n, true);
461495
}
462496

463497
float int8_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
464-
printf("int8_distance_l2_squared_rvv: unimplemented\n");
465-
abort();
466-
return 0.0f;
498+
return int8_distance_l2_impl_rvv(v1, v2, n, false);
467499
}
468500

469501
float int8_distance_dot_rvv (const void *v1, const void *v2, int n) {
470-
printf("int8_distance_dot_rvv: unimplemented\n");
471-
abort();
472-
return 0.0f;
502+
const int8_t *a = (const int8_t *)v1;
503+
const int8_t *b = (const int8_t *)v2;
504+
505+
// We accumulate the results into a vector register
506+
size_t vl = __riscv_vsetvlmax_e32m8();
507+
vint32m8_t vdot = __riscv_vmv_s_x_i32m8(0, vl);
508+
509+
// Iterate by VL elements
510+
for (size_t i = n; i > 0; i -= vl) {
511+
// Use LMUL=2 to start off, but we're going to widen this
512+
vl = __riscv_vsetvl_e8m2(i);
513+
514+
// Load the vectors into the registers
515+
vint8m2_t va = __riscv_vle8_v_i8m2(a, vl);
516+
vint8m2_t vb = __riscv_vle8_v_i8m2(b, vl);
517+
518+
// Widen these vectors to 16bit
519+
vint16m4_t va_wide = __riscv_vwcvt_x_x_v_i16m4(va, vl);
520+
vint16m4_t vb_wide = __riscv_vwcvt_x_x_v_i16m4(vb, vl);
521+
522+
// Now we're operating on 16 bit elements
523+
vl = __riscv_vsetvl_e16m4(i);
524+
525+
// Do a widening multiply-accumulate to 32 bits
526+
vdot = __riscv_vwmacc_vv_i32m8(vdot, va_wide, vb_wide, vl);
527+
528+
// Advance the a and b pointers to the next offset
529+
a = &a[vl];
530+
b = &b[vl];
531+
}
532+
533+
// Copy the accumulators back into a scalar register
534+
float dot = (float) int32_sum_vector_i32m8(vdot, n);
535+
return -dot;
473536
}
474537

475538
float int8_distance_l1_rvv (const void *v1, const void *v2, int n) {
476-
printf("int8_distance_l1_rvv: unimplemented\n");
477-
abort();
478-
return 0.0f;
539+
const int8_t *a = (const int8_t *)v1;
540+
const int8_t *b = (const int8_t *)v2;
541+
542+
// We accumulate the results into a vector register
543+
size_t vl = __riscv_vsetvlmax_e32m8();
544+
vint32m8_t vl1 = __riscv_vmv_s_x_i32m8(0, vl);
545+
546+
// Iterate by VL elements
547+
for (size_t i = n; i > 0; i -= vl) {
548+
// Use LMUL=2 to start off, but we're going to widen this
549+
vl = __riscv_vsetvl_e8m2(i);
550+
551+
// Load the vectors into the registers
552+
vint8m2_t va = __riscv_vle8_v_i8m2(a, vl);
553+
vint8m2_t vb = __riscv_vle8_v_i8m2(b, vl);
554+
555+
// Compute the absolute difference by getting the min and max and subtracting them.
556+
vint8m2_t vmin = __riscv_vmin_vv_i8m2(va, vb, vl);
557+
vint8m2_t vmax = __riscv_vmax_vv_i8m2(va, vb, vl);
558+
vint16m4_t vabs = __riscv_vwsub_vv_i16m4(vmax, vmin, vl);
559+
vl = __riscv_vsetvl_e16m4(i);
560+
561+
// Now widen it to 32bits and add to the accumulator
562+
vint32m8_t vwide = __riscv_vwcvt_x_x_v_i32m8(vabs, vl);
563+
vl1 = __riscv_vadd_vv_i32m8(vl1, vwide, vl);
564+
565+
// Advance the a and b pointers to the next offset
566+
a = &a[vl];
567+
b = &b[vl];
568+
}
569+
570+
// Copy the accumulators back into a scalar register
571+
float l1 = (float) int32_sum_vector_i32m8(vl1, n);
572+
return l1;
479573
}
480574

481575
float int8_distance_cosine_rvv (const void *v1, const void *v2, int n) {
482-
printf("int8_distance_cosine_rvv: unimplemented\n");
483-
abort();
484-
return 0.0f;
576+
const int8_t *a = (const int8_t *)v1;
577+
const int8_t *b = (const int8_t *)v2;
578+
579+
// We accumulate the results into a vector register
580+
size_t vl = __riscv_vsetvlmax_e32m8();
581+
582+
// Zero out the starting registers
583+
vint32m8_t vdot = __riscv_vmv_s_x_i32m8(0, vl);
584+
vint32m8_t vmagn_a = __riscv_vmv_s_x_i32m8(0, vl);
585+
vint32m8_t vmagn_b = __riscv_vmv_s_x_i32m8(0, vl);
586+
587+
// Iterate by VL elements
588+
for (size_t i = n; i > 0; i -= vl) {
589+
// Use LMUL=2 to start off, but we're going to widen this
590+
vl = __riscv_vsetvl_e8m2(i);
591+
592+
// Load the vectors into the registers
593+
vint8m2_t va = __riscv_vle8_v_i8m2(a, vl);
594+
vint8m2_t vb = __riscv_vle8_v_i8m2(b, vl);
595+
596+
// Widen these values to 16bit signed
597+
vint16m4_t va_wide = __riscv_vwcvt_x_x_v_i16m4(va, vl);
598+
vint16m4_t vb_wide = __riscv_vwcvt_x_x_v_i16m4(vb, vl);
599+
vl = __riscv_vsetvl_e16m4(i);
600+
601+
// Compute the dot product for the entire register (widening madd)
602+
vdot = __riscv_vwmacc_vv_i32m8(vdot, va_wide, vb_wide, vl);
603+
604+
// Also calculate the magnitude value for both a and b (widening madd)
605+
vmagn_a = __riscv_vwmacc_vv_i32m8(vmagn_a, va_wide, va_wide, vl);
606+
vmagn_b = __riscv_vwmacc_vv_i32m8(vmagn_b, vb_wide, vb_wide, vl);
607+
608+
// Advance the a and b pointers to the next offset
609+
a = &a[vl];
610+
b = &b[vl];
611+
}
612+
613+
// Now do a final reduction on the registers to sum the remaining elements
614+
float dot = (float) int32_sum_vector_i32m8(vdot, n);
615+
float magn_a = sqrtf((float) int32_sum_vector_i32m8(vmagn_a, n));
616+
float magn_b = sqrtf((float) int32_sum_vector_i32m8(vmagn_b, n));
617+
618+
if (magn_a == 0.0f || magn_b == 0.0f) return 1.0f;
619+
620+
float cosine_similarity = dot / (magn_a * magn_b);
621+
if (cosine_similarity > 1.0f) cosine_similarity = 1.0f;
622+
if (cosine_similarity < -1.0f) cosine_similarity = -1.0f;
623+
return 1.0f - cosine_similarity;
485624
}
486625

487626
// MARK: - BIT -
@@ -555,31 +694,31 @@ void init_distance_functions_rvv (void) {
555694
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
556695
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
557696
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
558-
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
697+
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
559698

560699
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
561700
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
562701
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
563702
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
564-
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
703+
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
565704

566705
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F32] = float32_distance_cosine_rvv;
567706
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
568707
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
569708
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
570-
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
709+
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
571710

572711
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
573712
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
574713
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
575714
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
576-
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
715+
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
577716

578717
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
579718
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
580719
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
581720
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
582-
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
721+
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
583722

584723
dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_rvv;
585724

0 commit comments

Comments
 (0)