Skip to content

Commit 47100b1

Browse files
committed
distance-rvv: Add support for f16 distance functions
1 parent 3ee8612 commit 47100b1

1 file changed

Lines changed: 207 additions & 25 deletions

File tree

src/distance-rvv.c

Lines changed: 207 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,45 +20,79 @@ extern const char *distance_backend_name;
2020
// MARK: - UTILS -
2121

2222
// Reduces a vector by summing all of it's elements into a single scalar float
23-
float float32_sum_vector_f32m8(vfloat32m8_t vec, size_t vl) {
23+
static inline float float32_sum_vector_f32m8(vfloat32m8_t vec, size_t vl) {
2424
vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
2525
vl = __riscv_vsetvl_e32m8(vl);
2626
acc = __riscv_vfredusum_vs_f32m8_f32m1(vec, acc, vl);
2727
return __riscv_vfmv_f_s_f32m1_f32(acc);
2828
}
2929

3030
// Reduces a vector by summing all of it's elements into a single scalar float
31-
float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) {
31+
static inline float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) {
3232
vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
3333
vl = __riscv_vsetvl_e32m4(vl);
3434
acc = __riscv_vfredusum_vs_f32m4_f32m1(vec, acc, vl);
3535
return __riscv_vfmv_f_s_f32m1_f32(acc);
3636
}
3737

38+
// Reduces a vector by summing all of it's elements into a single scalar double
39+
static inline double float64_sum_vector_f64m4(vfloat64m4_t vec, size_t vl) {
40+
vfloat64m1_t acc = __riscv_vfmv_v_f_f64m1(0.0, 1);
41+
vl = __riscv_vsetvl_e64m4(vl);
42+
acc = __riscv_vfredusum_vs_f64m4_f64m1(vec, acc, vl);
43+
return __riscv_vfmv_f_s_f64m1_f64(acc);
44+
}
45+
3846
// Reduces a vector by summing all of it's elements into a single scalar integer
39-
uint64_t uint64_sum_vector_u64m8(vuint64m8_t vec, size_t vl) {
47+
static inline uint64_t uint64_sum_vector_u64m8(vuint64m8_t vec, size_t vl) {
4048
vuint64m1_t acc = __riscv_vmv_s_x_u64m1(0, 1);
4149
vl = __riscv_vsetvl_e64m8(vl);
4250
acc = __riscv_vredsum_vs_u64m8_u64m1(vec, acc, vl);
4351
return __riscv_vmv_x_s_u64m1_u64(acc);
4452
}
4553

4654
// Reduces a vector by summing all of it's elements into a single scalar integer
47-
uint32_t uint32_sum_vector_u32m8(vuint32m8_t vec, size_t vl) {
55+
static inline uint32_t uint32_sum_vector_u32m8(vuint32m8_t vec, size_t vl) {
4856
vuint32m1_t acc = __riscv_vmv_s_x_u32m1(0, 1);
4957
vl = __riscv_vsetvl_e32m8(vl);
5058
acc = __riscv_vredsum_vs_u32m8_u32m1(vec, acc, vl);
5159
return __riscv_vmv_x_s_u32m1_u32(acc);
5260
}
5361

5462
// Reduces a vector by summing all of it's elements into a single scalar integer
55-
int32_t int32_sum_vector_i32m8(vint32m8_t vec, size_t vl) {
63+
static inline int32_t int32_sum_vector_i32m8(vint32m8_t vec, size_t vl) {
5664
vint32m1_t acc = __riscv_vmv_s_x_i32m1(0, 1);
5765
vl = __riscv_vsetvl_e32m8(vl);
5866
acc = __riscv_vredsum_vs_i32m8_i32m1(vec, acc, vl);
5967
return __riscv_vmv_x_s_i32m1_i32(acc);
6068
}
6169

70+
// Scalar-load fp16 payloads, convert to fp32, and pack as an f32m2 vector.
71+
static inline vfloat32m2_t rvv_load_f16_as_f32m2(const uint16_t *src, size_t n) {
72+
size_t vl = __riscv_vsetvl_e32m2(n);
73+
float lanes[vl];
74+
for (size_t i = 0; i < vl; ++i) lanes[i] = float16_to_float32(src[i]);
75+
return __riscv_vle32_v_f32m2(lanes, vl);
76+
}
77+
78+
// Returns true if any lane has an fp16-style infinity mismatch:
79+
// one side is Inf and the other is not, or both are Inf with different signs.
80+
static inline bool rvv_has_f16_inf_mismatch_f64m4(vfloat64m4_t va, vfloat64m4_t vb, size_t vl) {
81+
vuint64m4_t a_class = __riscv_vfclass_v_u64m4(va, vl);
82+
vuint64m4_t b_class = __riscv_vfclass_v_u64m4(vb, vl);
83+
vuint64m4_t a_inf_bits = __riscv_vand_vx_u64m4(a_class, 0x81u, vl);
84+
vuint64m4_t b_inf_bits = __riscv_vand_vx_u64m4(b_class, 0x81u, vl);
85+
vbool16_t inf_mismatch = __riscv_vmsne_vv_u64m4_b16(a_inf_bits, b_inf_bits, vl);
86+
return __riscv_vfirst_m_b16(inf_mismatch, vl) >= 0;
87+
}
88+
89+
// Returns mask of lanes where both vectors are not NaN.
90+
static inline vbool16_t rvv_both_not_nan_f64m4(vfloat64m4_t va, vfloat64m4_t vb, size_t vl) {
91+
vbool16_t a_not_nan = __riscv_vmfeq_vv_f64m4_b16(va, va, vl);
92+
vbool16_t b_not_nan = __riscv_vmfeq_vv_f64m4_b16(vb, vb, vl);
93+
return __riscv_vmand_mm_b16(a_not_nan, b_not_nan, vl);
94+
}
95+
6296

6397
// MARK: - FLOAT32 -
6498

@@ -213,34 +247,182 @@ float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) {
213247

214248
// MARK: - FLOAT16 -
215249

250+
static inline float float16_distance_l2_impl_rvv(const void *v1, const void *v2, int n, bool use_sqrt) {
251+
const uint16_t *a = (const uint16_t *)v1;
252+
const uint16_t *b = (const uint16_t *)v2;
253+
254+
size_t vl = __riscv_vsetvlmax_e64m4();
255+
vfloat64m4_t vsum = __riscv_vfmv_v_f_f64m4(0.0, vl);
256+
257+
for (size_t i = n; i > 0;) {
258+
// Scalar-load fp16, convert to f32, then widen to f64.
259+
vl = __riscv_vsetvl_e32m2(i);
260+
vfloat32m2_t va32 = rvv_load_f16_as_f32m2(a, vl);
261+
vfloat32m2_t vb32 = rvv_load_f16_as_f32m2(b, vl);
262+
vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4(va32, vl);
263+
vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4(vb32, vl);
264+
265+
vl = __riscv_vsetvl_e64m4(vl);
266+
267+
// Return +Inf if there is an infinity mismatch.
268+
if (rvv_has_f16_inf_mismatch_f64m4(va, vb, vl)) return INFINITY;
269+
270+
// Skip NaN lanes in accumulation path.
271+
vbool16_t not_nan = rvv_both_not_nan_f64m4(va, vb, vl);
272+
273+
vfloat64m4_t vdiff = __riscv_vfsub_vv_f64m4(va, vb, vl);
274+
vsum = __riscv_vfmacc_vv_f64m4_m(not_nan, vsum, vdiff, vdiff, vl);
275+
276+
a += vl;
277+
b += vl;
278+
i -= vl;
279+
}
280+
281+
double l2sq = float64_sum_vector_f64m4(vsum, n);
282+
return use_sqrt ? sqrtf((float)l2sq) : (float)l2sq;
283+
}
284+
216285
float float16_distance_l2_rvv (const void *v1, const void *v2, int n) {
217-
printf("float16_distance_l2_rvv: unimplemented\n");
218-
abort();
219-
return 0.0f;
286+
return float16_distance_l2_impl_rvv(v1, v2, n, true);
220287
}
221288

222289
float float16_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
223-
printf("float16_distance_l2_squared_rvv: unimplemented\n");
224-
abort();
225-
return 0.0f;
290+
return float16_distance_l2_impl_rvv(v1, v2, n, false);
226291
}
227292

228293
float float16_distance_l1_rvv (const void *v1, const void *v2, int n) {
229-
printf("float16_distance_l1_rvv: unimplemented\n");
230-
abort();
231-
return 0.0f;
294+
const uint16_t *a = (const uint16_t *)v1;
295+
const uint16_t *b = (const uint16_t *)v2;
296+
297+
size_t vl = __riscv_vsetvlmax_e64m4();
298+
vfloat64m4_t vsum = __riscv_vfmv_v_f_f64m4(0.0, vl);
299+
300+
for (size_t i = n; i > 0;) {
301+
// Scalar-load fp16, convert to f32, then widen to f64.
302+
vl = __riscv_vsetvl_e32m2(i);
303+
vfloat32m2_t va32 = rvv_load_f16_as_f32m2(a, vl);
304+
vfloat32m2_t vb32 = rvv_load_f16_as_f32m2(b, vl);
305+
vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4(va32, vl);
306+
vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4(vb32, vl);
307+
308+
vl = __riscv_vsetvl_e64m4(vl);
309+
310+
// Return +Inf if there is an infinity mismatch.
311+
if (rvv_has_f16_inf_mismatch_f64m4(va, vb, vl)) return INFINITY;
312+
313+
// Skip NaN lanes in accumulation path.
314+
vbool16_t not_nan = rvv_both_not_nan_f64m4(va, vb, vl);
315+
316+
vfloat64m4_t vdiff = __riscv_vfsub_vv_f64m4(va, vb, vl);
317+
vfloat64m4_t vabs = __riscv_vfabs_v_f64m4(vdiff, vl);
318+
vsum = __riscv_vfadd_vv_f64m4_m(not_nan, vsum, vabs, vl);
319+
320+
a += vl;
321+
b += vl;
322+
i -= vl;
323+
}
324+
325+
return (float)float64_sum_vector_f64m4(vsum, n);
232326
}
233327

234328
float float16_distance_dot_rvv (const void *v1, const void *v2, int n) {
235-
printf("float16_distance_dot_rvv: unimplemented\n");
236-
abort();
237-
return 0.0f;
329+
const uint16_t *a = (const uint16_t *)v1;
330+
const uint16_t *b = (const uint16_t *)v2;
331+
332+
// Keep accumulation vectorized while preserving CPU NaN/Inf semantics.
333+
size_t vl = __riscv_vsetvlmax_e64m4();
334+
vfloat64m4_t vdot = __riscv_vfmv_v_f_f64m4(0.0, vl);
335+
336+
for (size_t i = n; i > 0;) {
337+
// Scalar-load fp16, convert to f32, then widen to f64.
338+
vl = __riscv_vsetvl_e32m2(i);
339+
vfloat32m2_t va32 = rvv_load_f16_as_f32m2(a, vl);
340+
vfloat32m2_t vb32 = rvv_load_f16_as_f32m2(b, vl);
341+
vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4(va32, vl);
342+
vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4(vb32, vl);
343+
344+
vl = __riscv_vsetvl_e64m4(vl);
345+
346+
// not_nan = lanes where both sides are not NaN.
347+
vbool16_t not_nan = rvv_both_not_nan_f64m4(va, vb, vl);
348+
349+
// Multiply once, then classify the product only.
350+
vfloat64m4_t vprod = __riscv_vfmul_vv_f64m4(va, vb, vl);
351+
352+
// Try to find infinite values, if there are any, exit early
353+
vuint64m4_t p_class = __riscv_vfclass_v_u64m4(vprod, vl);
354+
vbool16_t inf_pos = __riscv_vmsne_vx_u64m4_b16_m(not_nan, __riscv_vand_vx_u64m4_m(not_nan, p_class, 0x80u, vl), 0u, vl);
355+
vbool16_t inf_neg = __riscv_vmsne_vx_u64m4_b16_m(not_nan, __riscv_vand_vx_u64m4_m(not_nan, p_class, 0x01u, vl), 0u, vl);
356+
long first_pos = __riscv_vfirst_m_b16(inf_pos, vl);
357+
long first_neg = __riscv_vfirst_m_b16(inf_neg, vl);
358+
if (first_pos >= 0 || first_neg >= 0) {
359+
if (first_pos >= 0 && (first_neg < 0 || first_pos < first_neg)) return -INFINITY;
360+
return INFINITY;
361+
}
362+
363+
// Accumulate only valid lanes; NaN lanes are skipped.
364+
vdot = __riscv_vfadd_vv_f64m4_m(not_nan, vdot, vprod, vl);
365+
366+
a += vl;
367+
b += vl;
368+
i -= vl;
369+
}
370+
371+
double dot = float64_sum_vector_f64m4(vdot, n);
372+
return (float)(-dot);
238373
}
239374

240375
float float16_distance_cosine_rvv (const void *v1, const void *v2, int n) {
241-
printf("float16_distance_cosine_rvv: unimplemented\n");
242-
abort();
243-
return 0.0f;
376+
const uint16_t *a = (const uint16_t *)v1;
377+
const uint16_t *b = (const uint16_t *)v2;
378+
379+
size_t vl = __riscv_vsetvlmax_e64m4();
380+
vfloat64m4_t vdot = __riscv_vfmv_v_f_f64m4(0.0, vl);
381+
vfloat64m4_t vnx = __riscv_vfmv_v_f_f64m4(0.0, vl);
382+
vfloat64m4_t vny = __riscv_vfmv_v_f_f64m4(0.0, vl);
383+
384+
for (size_t i = n; i > 0;) {
385+
// Scalar-load fp16, convert to f32, then widen to f64.
386+
vl = __riscv_vsetvl_e32m2(i);
387+
vfloat32m2_t va32 = rvv_load_f16_as_f32m2(a, vl);
388+
vfloat32m2_t vb32 = rvv_load_f16_as_f32m2(b, vl);
389+
vfloat64m4_t va = __riscv_vfwcvt_f_f_v_f64m4(va32, vl);
390+
vfloat64m4_t vb = __riscv_vfwcvt_f_f_v_f64m4(vb32, vl);
391+
392+
vl = __riscv_vsetvl_e64m4(vl);
393+
394+
// Keep only lanes where both values are not NaN.
395+
vbool16_t not_nan = rvv_both_not_nan_f64m4(va, vb, vl);
396+
397+
// Any infinity on a valid lane returns 1.0f.
398+
vuint64m4_t a_class = __riscv_vfclass_v_u64m4(va, vl);
399+
vuint64m4_t b_class = __riscv_vfclass_v_u64m4(vb, vl);
400+
vuint64m4_t ab_class = __riscv_vor_vv_u64m4(a_class, b_class, vl);
401+
vbool16_t ab_inf = __riscv_vmsne_vx_u64m4_b16(__riscv_vand_vx_u64m4(ab_class, 0x81u, vl), 0u, vl);
402+
vbool16_t any_inf = __riscv_vmand_mm_b16(not_nan, ab_inf, vl);
403+
if (__riscv_vfirst_m_b16(any_inf, vl) >= 0) return 1.0f;
404+
405+
// Accumulate dot and squared norms on valid lanes.
406+
vfloat64m4_t vprod = __riscv_vfmul_vv_f64m4(va, vb, vl);
407+
vdot = __riscv_vfadd_vv_f64m4_m(not_nan, vdot, vprod, vl);
408+
vnx = __riscv_vfmacc_vv_f64m4_m(not_nan, vnx, va, va, vl);
409+
vny = __riscv_vfmacc_vv_f64m4_m(not_nan, vny, vb, vb, vl);
410+
411+
a += vl;
412+
b += vl;
413+
i -= vl;
414+
}
415+
416+
double dot = float64_sum_vector_f64m4(vdot, n);
417+
double nx = float64_sum_vector_f64m4(vnx, n);
418+
double ny = float64_sum_vector_f64m4(vny, n);
419+
double denom = sqrt(nx) * sqrt(ny);
420+
if (!(denom > 0.0) || !isfinite(denom) || !isfinite(dot)) return 1.0f;
421+
422+
double cosv = dot / denom;
423+
if (cosv > 1.0) cosv = 1.0;
424+
if (cosv < -1.0) cosv = -1.0;
425+
return (float)(1.0 - cosv);
244426
}
245427

246428
// MARK: - BFLOAT16 -
@@ -691,31 +873,31 @@ float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
691873
void init_distance_functions_rvv (void) {
692874
#if defined(__riscv_v_intrinsic)
693875
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
694-
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
876+
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
695877
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
696878
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
697879
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
698880

699881
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
700-
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
882+
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
701883
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
702884
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
703885
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
704886

705887
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F32] = float32_distance_cosine_rvv;
706-
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
888+
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
707889
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
708890
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
709891
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
710892

711893
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
712-
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
894+
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
713895
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
714896
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
715897
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
716898

717899
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
718-
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
900+
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
719901
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
720902
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
721903
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;

0 commit comments

Comments
 (0)