Skip to content

Commit 0954b5a

Browse files
committed
distance-rvv: Add support for Hamming Distance
1 parent 0570383 commit 0954b5a

1 file changed

Lines changed: 67 additions & 5 deletions

File tree

src/distance-rvv.c

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) {
3535
return __riscv_vfmv_f_s_f32m1_f32(acc);
3636
}
3737

38+
// Reduces a vector by summing all of it's elements into a single scalar integer
39+
uint64_t uint64_sum_vector_u64m8(vuint64m8_t vec, size_t vl) {
40+
vuint64m1_t acc = __riscv_vmv_s_x_u64m1(0, 1);
41+
vl = __riscv_vsetvl_e32m8(vl);
42+
acc = __riscv_vredsum_vs_u64m8_u64m1(vec, acc, vl);
43+
return __riscv_vmv_x_s_u64m1_u64(acc);
44+
}
45+
3846
// MARK: - FLOAT32 -
3947

4048
float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool use_sqrt) {
@@ -317,12 +325,66 @@ float int8_distance_cosine_rvv (const void *v1, const void *v2, int n) {
317325

318326
// MARK: - BIT -
319327

320-
float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
321-
printf("bit1_distance_hamming_rvv: unimplemented\n");
322-
abort();
323-
return 0.0f;
328+
329+
// Counts the number of set bits on each element of a vector register
330+
//
331+
// TODO: RISC-V natively supports vcpop.v for population count, but only with the
332+
// Zvbb extension, which we don't support yet. For everyone else, do a fallback implemetation.
333+
vuint64m8_t vpopcnt_u64m8(vuint64m8_t v, size_t vl) {
334+
// v = v - ((v >> 1) & 0x5555555555555555ULL);
335+
vuint64m8_t shr1 = __riscv_vsrl_vx_u64m8(v, 1, vl);
336+
vuint64m8_t and1 = __riscv_vand_vx_u64m8(shr1, 0x5555555555555555ULL, vl);
337+
v = __riscv_vsub_vv_u64m8(v, and1, vl);
338+
339+
// v = (v & 0x3333333333333333ULL) + ((v >> 2) & 0x3333333333333333ULL);
340+
vuint64m8_t shr2 = __riscv_vsrl_vx_u64m8(v, 2, vl);
341+
vuint64m8_t and2 = __riscv_vand_vx_u64m8(shr2, 0x3333333333333333ULL, vl);
342+
vuint64m8_t and3 = __riscv_vand_vx_u64m8(v, 0x3333333333333333ULL, vl);
343+
v = __riscv_vadd_vv_u64m8(and2, and3, vl);
344+
345+
// v = (v + (v >> 4)) & 0x0f0f0f0f0f0f0f0fULL;
346+
vuint64m8_t shr4 = __riscv_vsrl_vx_u64m8(v, 4, vl);
347+
vuint64m8_t add = __riscv_vadd_vv_u64m8(v, shr4, vl);
348+
v = __riscv_vand_vx_u64m8(add, 0x0f0f0f0f0f0f0f0fULL, vl);
349+
350+
// v = (v * 0x0101010101010101ULL) >> 56;
351+
vuint64m8_t mul = __riscv_vmul_vx_u64m8(v, 0x0101010101010101ULL, vl);
352+
v = __riscv_vsrl_vx_u64m8(mul, 56, vl);
353+
354+
return v;
324355
}
325356

357+
358+
float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
359+
const uint8_t *a = (const uint8_t *)v1;
360+
const uint8_t *b = (const uint8_t *)v2;
361+
362+
// We accumulate the results into a vector register
363+
size_t vl = __riscv_vsetvl_e32m8(n);
364+
vuint64m8_t vdistance = __riscv_vmv_s_x_u64m8(0, vl);
365+
366+
// Iterate by VL elements
367+
for (size_t i = n; i > 0; i -= vl) {
368+
// Use LMUL=8, we have 4 registers to work with.
369+
vl = __riscv_vsetvl_e64m8(n);
370+
371+
// Load the vectors into the registers and cast them into a u64
372+
vuint64m8_t va = __riscv_vreinterpret_v_u8m8_u64m8(__riscv_vle8_v_u8m8(a, vl));
373+
vuint64m8_t vb = __riscv_vreinterpret_v_u8m8_u64m8(__riscv_vle8_v_u8m8(b, vl));
374+
375+
vuint64m8_t xor = __riscv_vxor_vv_u64m8(va, vb, vl);
376+
vuint64m8_t popcnt = vpopcnt_u64m8(xor, vl);
377+
vdistance = __riscv_vadd_vv_u64m8(vdistance, popcnt, vl);
378+
379+
// Advance the a and b pointers to the next offset. Here we multiply by 8 because
380+
// the vectors are defined as u8, but VL is defined in elements of 64bits.
381+
a = &a[vl * 8];
382+
b = &b[vl * 8];
383+
}
384+
385+
// Copy the accumulator back into a scalar register
386+
return (float) uint64_sum_vector_u64m8(vdistance, vl);
387+
}
326388
#endif
327389

328390
// MARK: -
@@ -359,7 +421,7 @@ void init_distance_functions_rvv (void) {
359421
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;
360422
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_I8] = int8_distance_l1_rvv;
361423

362-
// dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_rvv;
424+
dispatch_distance_table[VECTOR_DISTANCE_HAMMING][VECTOR_TYPE_BIT] = bit1_distance_hamming_rvv;
363425

364426
distance_backend_name = "RVV";
365427
#endif

0 commit comments

Comments
 (0)