Skip to content

Commit 344b8c6

Browse files
committed
distance-rvv: Add support for bf16
1 parent 47100b1 commit 344b8c6

1 file changed

Lines changed: 151 additions & 32 deletions

File tree

src/distance-rvv.c

Lines changed: 151 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,64 +20,88 @@ extern const char *distance_backend_name;
2020
// MARK: - UTILS -
2121

2222
// Reduces a vector by summing all of it's elements into a single scalar float
23-
static inline float float32_sum_vector_f32m8(vfloat32m8_t vec, size_t vl) {
23+
static inline float float32_sum_vector_f32m8 (vfloat32m8_t vec, size_t vl) {
2424
vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
2525
vl = __riscv_vsetvl_e32m8(vl);
2626
acc = __riscv_vfredusum_vs_f32m8_f32m1(vec, acc, vl);
2727
return __riscv_vfmv_f_s_f32m1_f32(acc);
2828
}
2929

3030
// Reduces a vector by summing all of it's elements into a single scalar float
31-
static inline float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) {
31+
static inline float float32_sum_vector_f32m4 (vfloat32m4_t vec, size_t vl) {
3232
vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
3333
vl = __riscv_vsetvl_e32m4(vl);
3434
acc = __riscv_vfredusum_vs_f32m4_f32m1(vec, acc, vl);
3535
return __riscv_vfmv_f_s_f32m1_f32(acc);
3636
}
3737

3838
// Reduces a vector by summing all of it's elements into a single scalar double
39-
static inline double float64_sum_vector_f64m4(vfloat64m4_t vec, size_t vl) {
39+
static inline double float64_sum_vector_f64m4 (vfloat64m4_t vec, size_t vl) {
4040
vfloat64m1_t acc = __riscv_vfmv_v_f_f64m1(0.0, 1);
4141
vl = __riscv_vsetvl_e64m4(vl);
4242
acc = __riscv_vfredusum_vs_f64m4_f64m1(vec, acc, vl);
4343
return __riscv_vfmv_f_s_f64m1_f64(acc);
4444
}
4545

4646
// Reduces a vector by summing all of it's elements into a single scalar integer
47-
static inline uint64_t uint64_sum_vector_u64m8(vuint64m8_t vec, size_t vl) {
47+
static inline uint64_t uint64_sum_vector_u64m8 (vuint64m8_t vec, size_t vl) {
4848
vuint64m1_t acc = __riscv_vmv_s_x_u64m1(0, 1);
4949
vl = __riscv_vsetvl_e64m8(vl);
5050
acc = __riscv_vredsum_vs_u64m8_u64m1(vec, acc, vl);
5151
return __riscv_vmv_x_s_u64m1_u64(acc);
5252
}
5353

5454
// Reduces a vector by summing all of it's elements into a single scalar integer
55-
static inline uint32_t uint32_sum_vector_u32m8(vuint32m8_t vec, size_t vl) {
55+
static inline uint32_t uint32_sum_vector_u32m8 (vuint32m8_t vec, size_t vl) {
5656
vuint32m1_t acc = __riscv_vmv_s_x_u32m1(0, 1);
5757
vl = __riscv_vsetvl_e32m8(vl);
5858
acc = __riscv_vredsum_vs_u32m8_u32m1(vec, acc, vl);
5959
return __riscv_vmv_x_s_u32m1_u32(acc);
6060
}
6161

6262
// Reduces a vector by summing all of it's elements into a single scalar integer
63-
static inline int32_t int32_sum_vector_i32m8(vint32m8_t vec, size_t vl) {
63+
static inline int32_t int32_sum_vector_i32m8 (vint32m8_t vec, size_t vl) {
6464
vint32m1_t acc = __riscv_vmv_s_x_i32m1(0, 1);
6565
vl = __riscv_vsetvl_e32m8(vl);
6666
acc = __riscv_vredsum_vs_i32m8_i32m1(vec, acc, vl);
6767
return __riscv_vmv_x_s_i32m1_i32(acc);
6868
}
6969

7070
// Scalar-load fp16 payloads, convert to fp32, and pack as an f32m2 vector.
71-
static inline vfloat32m2_t rvv_load_f16_as_f32m2(const uint16_t *src, size_t n) {
71+
static inline vfloat32m2_t rvv_load_f16_as_f32m2 (const uint16_t *src, size_t n) {
7272
size_t vl = __riscv_vsetvl_e32m2(n);
7373
float lanes[vl];
7474
for (size_t i = 0; i < vl; ++i) lanes[i] = float16_to_float32(src[i]);
7575
return __riscv_vle32_v_f32m2(lanes, vl);
7676
}
7777

78+
// Scalar-load bf16 payloads, convert to fp32, and pack as an f32m8 vector.
79+
static inline vfloat32m8_t rvv_load_bf16_as_f32m8 (const uint16_t *src, size_t n) {
80+
size_t vl = __riscv_vsetvl_e32m8(n);
81+
float lanes[vl];
82+
for (size_t i = 0; i < vl; ++i) lanes[i] = bfloat16_to_float32(src[i]);
83+
return __riscv_vle32_v_f32m8(lanes, vl);
84+
}
85+
86+
// Scalar-load bf16 payloads, convert to fp32, and pack as an f32m4 vector.
87+
static inline vfloat32m4_t rvv_load_bf16_as_f32m4 (const uint16_t *src, size_t n) {
88+
size_t vl = __riscv_vsetvl_e32m4(n);
89+
float lanes[vl];
90+
for (size_t i = 0; i < vl; ++i) lanes[i] = bfloat16_to_float32(src[i]);
91+
return __riscv_vle32_v_f32m4(lanes, vl);
92+
}
93+
94+
// Scalar-load bf16 payloads, convert to fp32, and pack as an f32m2 vector.
95+
static inline vfloat32m2_t rvv_load_bf16_as_f32m2 (const uint16_t *src, size_t n) {
96+
size_t vl = __riscv_vsetvl_e32m2(n);
97+
float lanes[vl];
98+
for (size_t i = 0; i < vl; ++i) lanes[i] = bfloat16_to_float32(src[i]);
99+
return __riscv_vle32_v_f32m2(lanes, vl);
100+
}
101+
78102
// Returns true if any lane has an fp16-style infinity mismatch:
79103
// one side is Inf and the other is not, or both are Inf with different signs.
80-
static inline bool rvv_has_f16_inf_mismatch_f64m4(vfloat64m4_t va, vfloat64m4_t vb, size_t vl) {
104+
static inline bool rvv_has_f16_inf_mismatch_f64m4 (vfloat64m4_t va, vfloat64m4_t vb, size_t vl) {
81105
vuint64m4_t a_class = __riscv_vfclass_v_u64m4(va, vl);
82106
vuint64m4_t b_class = __riscv_vfclass_v_u64m4(vb, vl);
83107
vuint64m4_t a_inf_bits = __riscv_vand_vx_u64m4(a_class, 0x81u, vl);
@@ -87,7 +111,7 @@ static inline bool rvv_has_f16_inf_mismatch_f64m4(vfloat64m4_t va, vfloat64m4_t
87111
}
88112

89113
// Returns mask of lanes where both vectors are not NaN.
90-
static inline vbool16_t rvv_both_not_nan_f64m4(vfloat64m4_t va, vfloat64m4_t vb, size_t vl) {
114+
static inline vbool16_t rvv_both_not_nan_f64m4 (vfloat64m4_t va, vfloat64m4_t vb, size_t vl) {
91115
vbool16_t a_not_nan = __riscv_vmfeq_vv_f64m4_b16(va, va, vl);
92116
vbool16_t b_not_nan = __riscv_vmfeq_vv_f64m4_b16(vb, vb, vl);
93117
return __riscv_vmand_mm_b16(a_not_nan, b_not_nan, vl);
@@ -107,7 +131,7 @@ float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool
107131
// Iterate by VL elements
108132
for (size_t i = n; i > 0; i -= vl) {
109133
// Use LMUL=8, we have 4 registers to work with.
110-
vl = __riscv_vsetvl_e32m8(n);
134+
vl = __riscv_vsetvl_e32m8(i);
111135

112136
// Load the vectors into the registers
113137
vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl);
@@ -146,7 +170,7 @@ float float32_distance_l1_rvv (const void *v1, const void *v2, int n) {
146170
// Iterate by VL elements
147171
for (size_t i = n; i > 0; i -= vl) {
148172
// Use LMUL=8, we have 4 registers to work with.
149-
vl = __riscv_vsetvl_e32m8(n);
173+
vl = __riscv_vsetvl_e32m8(i);
150174

151175
// Load the vectors into the registers
152176
vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl);
@@ -427,34 +451,129 @@ float float16_distance_cosine_rvv (const void *v1, const void *v2, int n) {
427451

428452
// MARK: - BFLOAT16 -
429453

454+
static inline float bfloat16_distance_l2_impl_rvv(const void *v1, const void *v2, int n, bool use_sqrt) {
455+
const uint16_t *a = (const uint16_t *)v1;
456+
const uint16_t *b = (const uint16_t *)v2;
457+
458+
size_t vl = __riscv_vsetvlmax_e64m4();
459+
vfloat64m4_t vsum = __riscv_vfmv_v_f_f64m4(0.0, vl);
460+
461+
for (size_t i = n; i > 0;) {
462+
// Load as f32m2 and widen to f64m4 to avoid overflow in accumulation.
463+
vl = __riscv_vsetvl_e32m2(i);
464+
vfloat32m2_t va32 = rvv_load_bf16_as_f32m2(a, vl);
465+
vfloat32m2_t vb32 = rvv_load_bf16_as_f32m2(b, vl);
466+
vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4(va32, vl);
467+
vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4(vb32, vl);
468+
469+
vl = __riscv_vsetvl_e64m4(vl);
470+
471+
vfloat64m4_t vdiff = __riscv_vfsub_vv_f64m4(va, vb, vl);
472+
473+
// If any diff lane is infinite, return +INFINITY.
474+
vuint64m4_t d_class = __riscv_vfclass_v_u64m4(vdiff, vl);
475+
vbool16_t d_inf = __riscv_vmsne_vx_u64m4_b16(__riscv_vand_vx_u64m4(d_class, 0x81u, vl), 0u, vl);
476+
if (__riscv_vfirst_m_b16(d_inf, vl) >= 0) return INFINITY;
477+
478+
// Skip NaN diff lanes.
479+
vbool16_t not_nan = __riscv_vmfeq_vv_f64m4_b16(vdiff, vdiff, vl);
480+
vsum = __riscv_vfmacc_vv_f64m4_m(not_nan, vsum, vdiff, vdiff, vl);
481+
482+
a += vl;
483+
b += vl;
484+
i -= vl;
485+
}
486+
487+
double l2sq = float64_sum_vector_f64m4(vsum, n);
488+
return use_sqrt ? sqrtf((float)l2sq) : (float)l2sq;
489+
}
490+
430491
float bfloat16_distance_l2_rvv (const void *v1, const void *v2, int n) {
431-
printf("bfloat16_distance_l2_rvv: unimplemented\n");
432-
abort();
433-
return 0.0f;
492+
return bfloat16_distance_l2_impl_rvv(v1, v2, n, true);
434493
}
435494

436495
float bfloat16_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
437-
printf("bfloat16_distance_l2_squared_rvv: unimplemented\n");
438-
abort();
439-
return 0.0f;
496+
return bfloat16_distance_l2_impl_rvv(v1, v2, n, false);
440497
}
441498

442499
float bfloat16_distance_l1_rvv (const void *v1, const void *v2, int n) {
443-
printf("bfloat16_distance_l1_rvv: unimplemented\n");
444-
abort();
445-
return 0.0f;
500+
const uint16_t *a = (const uint16_t *)v1;
501+
const uint16_t *b = (const uint16_t *)v2;
502+
503+
size_t vl = __riscv_vsetvlmax_e32m8();
504+
vfloat32m8_t vsum = __riscv_vfmv_v_f_f32m8(0.0f, vl);
505+
506+
for (size_t i = n; i > 0;) {
507+
vl = __riscv_vsetvl_e32m8(i);
508+
vfloat32m8_t va = rvv_load_bf16_as_f32m8(a, vl);
509+
vfloat32m8_t vb = rvv_load_bf16_as_f32m8(b, vl);
510+
511+
vfloat32m8_t vdiff = __riscv_vfsub_vv_f32m8(va, vb, vl);
512+
vfloat32m8_t vabs = __riscv_vfabs_v_f32m8(vdiff, vl);
513+
vsum = __riscv_vfadd_vv_f32m8(vsum, vabs, vl);
514+
515+
a += vl;
516+
b += vl;
517+
i -= vl;
518+
}
519+
520+
return float32_sum_vector_f32m8(vsum, n);
446521
}
447522

448523
float bfloat16_distance_dot_rvv (const void *v1, const void *v2, int n) {
449-
printf("bfloat16_distance_dot_rvv: unimplemented\n");
450-
abort();
451-
return 0.0f;
524+
const uint16_t *a = (const uint16_t *)v1;
525+
const uint16_t *b = (const uint16_t *)v2;
526+
527+
size_t vl = __riscv_vsetvlmax_e32m8();
528+
vfloat32m8_t vdot = __riscv_vfmv_v_f_f32m8(0.0f, vl);
529+
530+
for (size_t i = n; i > 0;) {
531+
vl = __riscv_vsetvl_e32m8(i);
532+
vfloat32m8_t va = rvv_load_bf16_as_f32m8(a, vl);
533+
vfloat32m8_t vb = rvv_load_bf16_as_f32m8(b, vl);
534+
vdot = __riscv_vfmacc_vv_f32m8(vdot, va, vb, vl);
535+
536+
a += vl;
537+
b += vl;
538+
i -= vl;
539+
}
540+
541+
float dot = float32_sum_vector_f32m8(vdot, n);
542+
return -dot;
452543
}
453544

454545
float bfloat16_distance_cosine_rvv (const void *v1, const void *v2, int n) {
455-
printf("bfloat16_distance_cosine_rvv: unimplemented\n");
456-
abort();
457-
return 0.0f;
546+
const uint16_t *a = (const uint16_t *)v1;
547+
const uint16_t *b = (const uint16_t *)v2;
548+
549+
size_t vl = __riscv_vsetvlmax_e32m4();
550+
vfloat32m4_t vdot = __riscv_vfmv_v_f_f32m4(0.0f, vl);
551+
vfloat32m4_t vnx = __riscv_vfmv_v_f_f32m4(0.0f, vl);
552+
vfloat32m4_t vny = __riscv_vfmv_v_f_f32m4(0.0f, vl);
553+
554+
for (size_t i = n; i > 0;) {
555+
vl = __riscv_vsetvl_e32m4(i);
556+
vfloat32m4_t va = rvv_load_bf16_as_f32m4(a, vl);
557+
vfloat32m4_t vb = rvv_load_bf16_as_f32m4(b, vl);
558+
559+
vdot = __riscv_vfmacc_vv_f32m4(vdot, va, vb, vl);
560+
vnx = __riscv_vfmacc_vv_f32m4(vnx, va, va, vl);
561+
vny = __riscv_vfmacc_vv_f32m4(vny, vb, vb, vl);
562+
563+
a += vl;
564+
b += vl;
565+
i -= vl;
566+
}
567+
568+
float dot = float32_sum_vector_f32m4(vdot, n);
569+
float norm_x = float32_sum_vector_f32m4(vnx, n);
570+
float norm_y = float32_sum_vector_f32m4(vny, n);
571+
if (norm_x == 0.0f || norm_y == 0.0f) return 1.0f;
572+
573+
float cosine_similarity = dot / (sqrtf(norm_x) * sqrtf(norm_y));
574+
if (cosine_similarity > 1.0f) cosine_similarity = 1.0f;
575+
if (cosine_similarity < -1.0f) cosine_similarity = -1.0f;
576+
return 1.0f - cosine_similarity;
458577
}
459578

460579
// MARK: - UINT8 -
@@ -847,7 +966,7 @@ float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
847966
// Iterate by VL elements
848967
for (size_t i = n; i > 0; i -= vl) {
849968
// Use LMUL=8, we have 4 registers to work with.
850-
vl = __riscv_vsetvl_e64m8(n);
969+
vl = __riscv_vsetvl_e64m8(i);
851970

852971
// Load the vectors into the registers and cast them into a u64 inplace
853972
vuint64m8_t va = __riscv_vreinterpret_v_u8m8_u64m8(__riscv_vle8_v_u8m8(a, vl));
@@ -874,31 +993,31 @@ void init_distance_functions_rvv (void) {
874993
#if defined(__riscv_v_intrinsic)
875994
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
876995
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
877-
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
996+
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
878997
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
879998
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
880999

8811000
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
8821001
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
883-
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
1002+
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
8841003
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
8851004
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
8861005

8871006
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F32] = float32_distance_cosine_rvv;
8881007
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
889-
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
1008+
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
8901009
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
8911010
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
8921011

8931012
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
8941013
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
895-
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
1014+
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
8961015
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
8971016
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
8981017

8991018
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
9001019
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
901-
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
1020+
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
9021021
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
9031022
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
9041023

0 commit comments

Comments
 (0)