Skip to content

Commit 687a357

Browse files
committed
distance-rvv: Add support for u8 distance functions
1 parent 500ba02 commit 687a357

1 file changed

Lines changed: 183 additions & 23 deletions

File tree

src/distance-rvv.c

Lines changed: 183 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,28 @@ float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) {
3838
// Reduces a vector by summing all of it's elements into a single scalar integer
3939
uint64_t uint64_sum_vector_u64m8(vuint64m8_t vec, size_t vl) {
4040
vuint64m1_t acc = __riscv_vmv_s_x_u64m1(0, 1);
41-
vl = __riscv_vsetvl_e32m8(vl);
41+
vl = __riscv_vsetvl_e64m8(vl);
4242
acc = __riscv_vredsum_vs_u64m8_u64m1(vec, acc, vl);
4343
return __riscv_vmv_x_s_u64m1_u64(acc);
4444
}
4545

46+
// Reduces a vector by summing all of it's elements into a single scalar integer
47+
uint32_t uint32_sum_vector_u32m8(vuint32m8_t vec, size_t vl) {
48+
vuint32m1_t acc = __riscv_vmv_s_x_u32m1(0, 1);
49+
vl = __riscv_vsetvl_e32m8(vl);
50+
acc = __riscv_vredsum_vs_u32m8_u32m1(vec, acc, vl);
51+
return __riscv_vmv_x_s_u32m1_u32(acc);
52+
}
53+
54+
// Reduces a vector by summing all of it's elements into a single scalar integer
55+
int32_t int32_sum_vector_i32m8(vint32m8_t vec, size_t vl) {
56+
vint32m1_t acc = __riscv_vmv_s_x_i32m1(0, 1);
57+
vl = __riscv_vsetvl_e32m8(vl);
58+
acc = __riscv_vredsum_vs_i32m8_i32m1(vec, acc, vl);
59+
return __riscv_vmv_x_s_i32m1_i32(acc);
60+
}
61+
62+
4663
// MARK: - FLOAT32 -
4764

4865
float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool use_sqrt) {
@@ -80,7 +97,6 @@ float float32_distance_l2_rvv (const void *v1, const void *v2, int n) {
8097
return float32_distance_l2_impl_rvv(v1, v2, n, true);
8198
}
8299

83-
84100
float float32_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
85101
return float32_distance_l2_impl_rvv(v1, v2, n, false);
86102
}
@@ -261,34 +277,179 @@ float bfloat16_distance_cosine_rvv (const void *v1, const void *v2, int n) {
261277

262278
// MARK: - UINT8 -
263279

280+
float uint8_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool use_sqrt) {
281+
const uint8_t *a = (const uint8_t *)v1;
282+
const uint8_t *b = (const uint8_t *)v2;
283+
284+
// We accumulate the results into a vector register
285+
size_t vl = __riscv_vsetvlmax_e32m8();
286+
vint32m8_t vl2 = __riscv_vmv_s_x_i32m8(0, vl);
287+
288+
// Iterate by VL elements
289+
for (size_t i = n; i > 0; i -= vl) {
290+
// Use LMUL=2 to start off, but we're going to widen this
291+
vl = __riscv_vsetvl_e8m2(i);
292+
293+
// Load the vectors into the registers
294+
vuint8m2_t va = __riscv_vle8_v_u8m2(a, vl);
295+
vuint8m2_t vb = __riscv_vle8_v_u8m2(b, vl);
296+
297+
// Widen these values to 16bit unsigned
298+
vuint16m4_t va_wide = __riscv_vwcvtu_x_x_v_u16m4(va, vl);
299+
vuint16m4_t vb_wide = __riscv_vwcvtu_x_x_v_u16m4(vb, vl);
300+
vl = __riscv_vsetvl_e16m4(i);
301+
302+
// Cast these to signed values
303+
vint16m4_t va_wides = __riscv_vreinterpret_v_u16m4_i16m4(va_wide);
304+
vint16m4_t vb_wides = __riscv_vreinterpret_v_u16m4_i16m4(vb_wide);
305+
306+
// L2 = (a[i] - b[i]) + acc
307+
// The subtract is signed, but the accumulate is unsigned
308+
vint32m8_t vdiff = __riscv_vwsub_vv_i32m8(va_wides, vb_wides, vl);
309+
vl2 = __riscv_vmacc_vv_i32m8(vl2, vdiff, vdiff, vl);
310+
311+
// Advance the a and b pointers to the next offset
312+
a = &a[vl];
313+
b = &b[vl];
314+
}
315+
316+
// Copy the accumulators back into a scalar register
317+
float l2 = (float) int32_sum_vector_i32m8(vl2, n);
318+
return use_sqrt ? sqrtf(l2) : l2;
319+
}
320+
264321
float uint8_distance_l2_rvv (const void *v1, const void *v2, int n) {
265-
printf("uint8_distance_l2_rvv: unimplemented\n");
266-
abort();
267-
return 0.0f;
322+
return uint8_distance_l2_impl_rvv(v1, v2, n, true);
268323
}
269324

270325
float uint8_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
271-
printf("uint8_distance_l2_squared_rvv: unimplemented\n");
272-
abort();
273-
return 0.0f;
326+
return uint8_distance_l2_impl_rvv(v1, v2, n, false);
274327
}
275328

276329
float uint8_distance_dot_rvv (const void *v1, const void *v2, int n) {
277-
printf("uint8_distance_dot_rvv: unimplemented\n");
278-
abort();
279-
return 0.0f;
330+
const uint8_t *a = (const uint8_t *)v1;
331+
const uint8_t *b = (const uint8_t *)v2;
332+
333+
// We accumulate the results into a vector register
334+
size_t vl = __riscv_vsetvlmax_e32m8();
335+
vuint32m8_t vdot = __riscv_vmv_s_x_u32m8(0, vl);
336+
337+
// Iterate by VL elements
338+
for (size_t i = n; i > 0; i -= vl) {
339+
// Use LMUL=2 to start off, but we're going to widen this
340+
vl = __riscv_vsetvl_e8m2(i);
341+
342+
// Load the vectors into the registers
343+
vuint8m2_t va = __riscv_vle8_v_u8m2(a, vl);
344+
vuint8m2_t vb = __riscv_vle8_v_u8m2(b, vl);
345+
346+
// Widen these vectors to 16bit
347+
vuint16m4_t va_wide = __riscv_vwcvtu_x_x_v_u16m4(va, vl);
348+
vuint16m4_t vb_wide = __riscv_vwcvtu_x_x_v_u16m4(vb, vl);
349+
350+
// Now we're operating on 16 bit elements
351+
vl = __riscv_vsetvl_e16m4(i);
352+
353+
// Do a widening multiply-accumulate to 32 bits
354+
vdot = __riscv_vwmaccu_vv_u32m8(vdot, va_wide, vb_wide, vl);
355+
356+
// Advance the a and b pointers to the next offset
357+
a = &a[vl];
358+
b = &b[vl];
359+
}
360+
361+
// Copy the accumulators back into a scalar register
362+
float dot = uint32_sum_vector_u32m8(vdot, n);
363+
return -dot;
280364
}
281365

282366
float uint8_distance_l1_rvv (const void *v1, const void *v2, int n) {
283-
printf("uint8_distance_l1_rvv: unimplemented\n");
284-
abort();
285-
return 0.0f;
367+
const uint8_t *a = (const uint8_t *)v1;
368+
const uint8_t *b = (const uint8_t *)v2;
369+
370+
// We accumulate the results into a vector register
371+
size_t vl = __riscv_vsetvlmax_e32m8();
372+
vuint32m8_t vl1 = __riscv_vmv_s_x_u32m8(0, vl);
373+
374+
// Iterate by VL elements
375+
for (size_t i = n; i > 0; i -= vl) {
376+
// Use LMUL=2 to start off, but we're going to widen this
377+
vl = __riscv_vsetvl_e8m2(i);
378+
379+
// Load the vectors into the registers
380+
vuint8m2_t va = __riscv_vle8_v_u8m2(a, vl);
381+
vuint8m2_t vb = __riscv_vle8_v_u8m2(b, vl);
382+
383+
// Compute the absolute difference by getting the min and max and subtracting them.
384+
vuint8m2_t vmin = __riscv_vminu_vv_u8m2(va, vb, vl);
385+
vuint8m2_t vmax = __riscv_vmaxu_vv_u8m2(va, vb, vl);
386+
vuint16m4_t vabs = __riscv_vwsubu_vv_u16m4(vmax, vmin, vl);
387+
vl = __riscv_vsetvl_e16m4(i);
388+
389+
// Now widen it to 32bits and add to the accumulator
390+
vuint32m8_t vwide = __riscv_vwcvtu_x_x_v_u32m8(vabs, vl);
391+
vl1 = __riscv_vadd_vv_u32m8(vl1, vwide, vl);
392+
393+
// Advance the a and b pointers to the next offset
394+
a = &a[vl];
395+
b = &b[vl];
396+
}
397+
398+
// Copy the accumulators back into a scalar register
399+
float l1 = uint32_sum_vector_u32m8(vl1, n);
400+
return l1;
286401
}
287402

288403
float uint8_distance_cosine_rvv (const void *v1, const void *v2, int n) {
289-
printf("uint8_distance_cosine_rvv: unimplemented\n");
290-
abort();
291-
return 0.0f;
404+
const uint8_t *a = (const uint8_t *)v1;
405+
const uint8_t *b = (const uint8_t *)v2;
406+
407+
// We accumulate the results into a vector register
408+
size_t vl = __riscv_vsetvlmax_e32m8();
409+
410+
// Zero out the starting registers
411+
vuint32m8_t vdot = __riscv_vmv_s_x_u32m8(0, vl);
412+
vuint32m8_t vmagn_a = __riscv_vmv_s_x_u32m8(0, vl);
413+
vuint32m8_t vmagn_b = __riscv_vmv_s_x_u32m8(0, vl);
414+
415+
// Iterate by VL elements
416+
for (size_t i = n; i > 0; i -= vl) {
417+
// Use LMUL=2 to start off, but we're going to widen this
418+
vl = __riscv_vsetvl_e8m2(i);
419+
420+
// Load the vectors into the registers
421+
vuint8m2_t va = __riscv_vle8_v_u8m2(a, vl);
422+
vuint8m2_t vb = __riscv_vle8_v_u8m2(b, vl);
423+
424+
// Widen these values to 16bit unsigned
425+
vuint16m4_t va_wide = __riscv_vwcvtu_x_x_v_u16m4(va, vl);
426+
vuint16m4_t vb_wide = __riscv_vwcvtu_x_x_v_u16m4(vb, vl);
427+
vl = __riscv_vsetvl_e16m4(i);
428+
429+
// Compute the dot product for the entire register (widening madd)
430+
vdot = __riscv_vwmaccu_vv_u32m8(vdot, va_wide, vb_wide, vl);
431+
432+
// Also calculate the magnitude value for both a and b (widening madd)
433+
vmagn_a = __riscv_vwmaccu_vv_u32m8(vmagn_a, va_wide, va_wide, vl);
434+
vmagn_b = __riscv_vwmaccu_vv_u32m8(vmagn_b, vb_wide, vb_wide, vl);
435+
436+
// Advance the a and b pointers to the next offset
437+
a = &a[vl];
438+
b = &b[vl];
439+
}
440+
441+
// Now do a final reduction on the registers to sum the remaining elements
442+
// TODO: With default flags this does not always use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that
443+
float dot = uint32_sum_vector_u32m8(vdot, n);
444+
float magn_a = sqrtf(uint32_sum_vector_u32m8(vmagn_a, n));
445+
float magn_b = sqrtf(uint32_sum_vector_u32m8(vmagn_b, n));
446+
447+
if (magn_a == 0.0f || magn_b == 0.0f) return 1.0f;
448+
449+
float cosine_similarity = dot / (magn_a * magn_b);
450+
if (cosine_similarity > 1.0f) cosine_similarity = 1.0f;
451+
if (cosine_similarity < -1.0f) cosine_similarity = -1.0f;
452+
return 1.0f - cosine_similarity;
292453
}
293454

294455
// MARK: - INT8 -
@@ -354,7 +515,6 @@ vuint64m8_t vpopcnt_u64m8(vuint64m8_t v, size_t vl) {
354515
return v;
355516
}
356517

357-
358518
float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
359519
const uint8_t *a = (const uint8_t *)v1;
360520
const uint8_t *b = (const uint8_t *)v2;
@@ -394,31 +554,31 @@ void init_distance_functions_rvv (void) {
394554
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
395555
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
396556
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
397-
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
557+
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
398558
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
399559

400560
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
401561
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
402562
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
403-
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
563+
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
404564
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_I8] = int8_distance_l2_squared_rvv;
405565

406566
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F32] = float32_distance_cosine_rvv;
407567
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_F16] = float16_distance_cosine_rvv;
408568
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_BF16] = bfloat16_distance_cosine_rvv;
409-
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
569+
dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
410570
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
411571

412572
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
413573
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
414574
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
415-
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
575+
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
416576
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
417577

418578
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
419579
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
420580
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
421-
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
581+
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
422582
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
423583

424584
dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_rvv;

0 commit comments

Comments
 (0)