Skip to content

Commit ae023f6

Browse files
committed
distance-rvv: Implement all distance functions for f32
1 parent 8686ee2 commit ae023f6

1 file changed

Lines changed: 138 additions & 40 deletions

File tree

src/distance-rvv.c

Lines changed: 138 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "distance-rvv.h"
99
#include "distance-cpu.h"
1010

11-
#if defined(__riscv_vector)
11+
#if defined(__riscv_v_intrinsic)
1212
#include <riscv_vector.h>
1313
#include <math.h>
1414
#include <stdio.h>
@@ -19,72 +19,171 @@ extern const char *distance_backend_name;
1919

2020
// MARK: - FLOAT32 -
2121

22+
float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool use_sqrt) {
23+
const float *a = (const float *)v1;
24+
const float *b = (const float *)v2;
25+
26+
// We accumulate the results into a vector register
27+
size_t vl = __riscv_vsetvl_e32m8(n);
28+
vfloat32m8_t vl2 = __riscv_vfmv_v_f_f32m8(0.0f, vl);
29+
30+
// Iterate by VL elements
31+
for (size_t i = n; i > 0; i -= vl) {
32+
// Use LMUL=8, we have 4 registers to work with.
33+
vl = __riscv_vsetvl_e32m8(n);
34+
35+
// Load the vectors into the registers
36+
vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl);
37+
vfloat32m8_t vb = __riscv_vle32_v_f32m8(b, vl);
38+
39+
// L2 = (a[i] - b[i]) + acc
40+
vfloat32m8_t vdiff = __riscv_vfsub_vv_f32m8(va, vb, vl);
41+
vl2 = __riscv_vfmacc_vv_f32m8(vl2, vdiff, vdiff, vl);
42+
43+
// Advance the a and b pointers to the next offset
44+
a = &a[vl];
45+
b = &b[vl];
46+
}
47+
48+
// Copy the accumulators back into a scalar register
49+
vfloat32m1_t vl2_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
50+
vl = __riscv_vsetvl_e32m8(n);
51+
vl2_acc = __riscv_vfredusum_vs_f32m8_f32m1(vl2, vl2_acc, vl);
52+
53+
float l2 = __riscv_vfmv_f_s_f32m1_f32(vl2_acc);
54+
return use_sqrt ? sqrtf(l2) : l2;
55+
}
56+
57+
2258
float float32_distance_l2_rvv (const void *v1, const void *v2, int n) {
23-
printf("float32_distance_l2_rvv: unimplemented\n");
24-
abort();
25-
return 0.0f;
59+
return float32_distance_l2_impl_rvv(v1, v2, n, true);
2660
}
2761

62+
2863
float float32_distance_l2_squared_rvv (const void *v1, const void *v2, int n) {
29-
printf("float32_distance_l2_squared_rvv: unimplemented\n");
30-
abort();
31-
return 0.0f;
64+
return float32_distance_l2_impl_rvv(v1, v2, n, false);
3265
}
3366

3467
float float32_distance_l1_rvv (const void *v1, const void *v2, int n) {
35-
printf("float32_distance_l1_rvv: unimplemented\n");
36-
abort();
37-
return 0.0f;
38-
}
68+
const float *a = (const float *)v1;
69+
const float *b = (const float *)v2;
3970

40-
float float32_distance_dot_rvv (const void *v1, const void *v2, int n) {
41-
printf("float32_distance_dot_rvv: unimplemented\n");
42-
abort();
43-
return 0.0f;
71+
// We accumulate the results into a vector register
72+
size_t vl = __riscv_vsetvl_e32m8(n);
73+
vfloat32m8_t vsad = __riscv_vfmv_v_f_f32m8(0.0f, vl);
74+
75+
// Iterate by VL elements
76+
for (size_t i = n; i > 0; i -= vl) {
77+
// Use LMUL=8, we have 4 registers to work with.
78+
vl = __riscv_vsetvl_e32m8(n);
79+
80+
// Load the vectors into the registers
81+
vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl);
82+
vfloat32m8_t vb = __riscv_vle32_v_f32m8(b, vl);
83+
84+
85+
// SAD = abs(a[i] - b[i]) + acc
86+
vfloat32m8_t vdiff = __riscv_vfsub_vv_f32m8(va, vb, vl);
87+
vfloat32m8_t vabs = __riscv_vfabs_v_f32m8(vdiff, vl);
88+
vsad = __riscv_vfadd_vv_f32m8(vsad, vabs, vl);
89+
90+
// Advance the a and b pointers to the next offset
91+
a = &a[vl];
92+
b = &b[vl];
93+
}
94+
95+
// Copy the accumulators back into a scalar register
96+
vfloat32m1_t vsad_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
97+
vl = __riscv_vsetvl_e32m8(n);
98+
vsad_acc = __riscv_vfredusum_vs_f32m8_f32m1(vsad, vsad_acc, vl);
99+
100+
float sad = __riscv_vfmv_f_s_f32m1_f32(vsad_acc);
101+
return sad;
44102
}
45103

46-
float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) {
104+
float float32_distance_dot_rvv (const void *v1, const void *v2, int n) {
47105
const float *a = (const float *)v1;
48106
const float *b = (const float *)v2;
49107

50-
// We accumulate the results into a vecto register
51-
size_t vl = __riscv_vsetvl_e32m1(1);
52-
vfloat32m1_t dot_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
53-
vfloat32m1_t magn_a_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
54-
vfloat32m1_t magn_b_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
108+
// We accumulate the results into a vector register
109+
size_t vl = __riscv_vsetvl_e32m8(n);
110+
vfloat32m8_t vdot = __riscv_vfmv_v_f_f32m8(0.0f, vl);
55111

56-
57112
// Iterate by VL elements
58-
for (; n > 0; n -= vl) {
113+
for (size_t i = n; i > 0; i -= vl) {
59114
// Use LMUL=8, we have 4 registers to work with.
60-
// In practice we use 3, and the last register gets split for the reduction operations
61-
vl = __riscv_vsetvl_e32m8(n);
115+
vl = __riscv_vsetvl_e32m8(i);
62116

63117
// Load the vectors into the registers
64118
vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl);
65119
vfloat32m8_t vb = __riscv_vle32_v_f32m8(b, vl);
66120

67121
// Compute the dot product for the entire register, and sum the
68122
// results into the accumuating register
69-
vfloat32m8_t vdot = __riscv_vfmul_vv_f32m8(va, vb, vl);
70-
dot_acc = __riscv_vfredusum_vs_f32m8_f32m1(vdot, dot_acc, vl);
123+
vdot = __riscv_vfmacc_vv_f32m8(vdot, va, vb, vl);
124+
125+
// Advance the a and b pointers to the next offset
126+
a = &a[vl];
127+
b = &b[vl];
128+
}
129+
130+
// Copy the accumulators back into a scalar register
131+
vfloat32m1_t vdot_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
132+
vl = __riscv_vsetvl_e32m8(n);
133+
vdot_acc = __riscv_vfredusum_vs_f32m8_f32m1(vdot, vdot_acc, vl);
134+
135+
float dot = __riscv_vfmv_f_s_f32m1_f32(vdot_acc);
136+
return -dot;
137+
}
138+
139+
float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) {
140+
const float *a = (const float *)v1;
141+
const float *b = (const float *)v2;
142+
143+
// Use LMUL=4, we have 8 registers to work with.
144+
size_t vl = __riscv_vsetvl_e32m4(n);
145+
146+
// Zero out the starting registers
147+
vfloat32m4_t vdot = __riscv_vfmv_v_f_f32m4(0.0f, vl);
148+
vfloat32m4_t vmagn_a = __riscv_vfmv_v_f_f32m4(0.0f, vl);
149+
vfloat32m4_t vmagn_b = __riscv_vfmv_v_f_f32m4(0.0f, vl);
150+
151+
// Iterate by VL elements
152+
for (size_t i = n; i > 0; i -= vl) {
153+
// Update VL with the remaining elements
154+
vl = __riscv_vsetvl_e32m4(i);
155+
156+
// Load the vectors into the registers
157+
vfloat32m4_t va = __riscv_vle32_v_f32m4(a, vl);
158+
vfloat32m4_t vb = __riscv_vle32_v_f32m4(b, vl);
159+
160+
// Compute the dot product for the entire register
161+
vdot = __riscv_vfmacc_vv_f32m4(vdot, va, vb, vl);
71162

72163
// Also calculate the magnitude value for both a and b
73-
vfloat32m8_t magn_a = __riscv_vfmul_vv_f32m8(va, va, vl);
74-
magn_a_acc = __riscv_vfredusum_vs_f32m8_f32m1(magn_a, magn_a_acc, vl);
75-
vfloat32m8_t magn_b = __riscv_vfmul_vv_f32m8(vb, vb, vl);
76-
magn_b_acc = __riscv_vfredusum_vs_f32m8_f32m1(magn_b, magn_b_acc, vl);
164+
vmagn_a = __riscv_vfmacc_vv_f32m4(vmagn_a, va, va, vl);
165+
vmagn_b = __riscv_vfmacc_vv_f32m4(vmagn_b, vb, vb, vl);
77166

78167
// Advance the a and b pointers to the next offset
79168
a = &a[vl];
80169
b = &b[vl];
81170
}
82171

172+
// Now do a final reduction on the registers to sum the remaining elements
173+
vfloat32m1_t vdot_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
174+
vdot_acc = __riscv_vfredusum_vs_f32m4_f32m1(vdot, vdot_acc, vl);
175+
176+
vfloat32m1_t vmagn_a_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
177+
vmagn_a_acc = __riscv_vfredusum_vs_f32m4_f32m1(vmagn_a, vmagn_a_acc, vl);
178+
179+
vfloat32m1_t vmagn_b_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
180+
vmagn_b_acc = __riscv_vfredusum_vs_f32m4_f32m1(vmagn_b, vmagn_b_acc, vl);
181+
83182
// Copy the accumulators back into a scalar register, to finalize the calculations
84183
// TODO: With default flags this does not use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that
85-
float dot = __riscv_vfmv_f_s_f32m1_f32(dot_acc);
86-
float magn_a = sqrtf(__riscv_vfmv_f_s_f32m1_f32(magn_a_acc));
87-
float magn_b = sqrtf(__riscv_vfmv_f_s_f32m1_f32(magn_b_acc));
184+
float dot = __riscv_vfmv_f_s_f32m1_f32(vdot_acc);
185+
float magn_a = sqrtf(__riscv_vfmv_f_s_f32m1_f32(vmagn_a_acc));
186+
float magn_b = sqrtf(__riscv_vfmv_f_s_f32m1_f32(vmagn_b_acc));
88187

89188
if (magn_a == 0.0f || magn_b == 0.0f) return 1.0f;
90189

@@ -94,7 +193,6 @@ float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) {
94193
return 1.0f - cosine_similarity;
95194
}
96195

97-
98196
// MARK: - FLOAT16 -
99197

100198
float float16_distance_l2_rvv (const void *v1, const void *v2, int n) {
@@ -236,14 +334,14 @@ float bit1_distance_hamming_rvv (const void *v1, const void *v2, int n) {
236334
// MARK: -
237335

238336
void init_distance_functions_rvv (void) {
239-
#if defined(__riscv_vector)
240-
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
337+
#if defined(__riscv_v_intrinsic)
338+
dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F32] = float32_distance_l2_rvv;
241339
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_F16] = float16_distance_l2_rvv;
242340
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_rvv;
243341
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_U8] = uint8_distance_l2_rvv;
244342
// dispatch_distance_table[VECTOR_DISTANCE_L2][VECTOR_TYPE_I8] = int8_distance_l2_rvv;
245343

246-
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
344+
dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F32] = float32_distance_l2_squared_rvv;
247345
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_F16] = float16_distance_l2_squared_rvv;
248346
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_BF16] = bfloat16_distance_l2_squared_rvv;
249347
// dispatch_distance_table[VECTOR_DISTANCE_SQUARED_L2][VECTOR_TYPE_U8] = uint8_distance_l2_squared_rvv;
@@ -255,13 +353,13 @@ void init_distance_functions_rvv (void) {
255353
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_U8] = uint8_distance_cosine_rvv;
256354
// dispatch_distance_table[VECTOR_DISTANCE_COSINE][VECTOR_TYPE_I8] = int8_distance_cosine_rvv;
257355

258-
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
356+
dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F32] = float32_distance_dot_rvv;
259357
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_F16] = float16_distance_dot_rvv;
260358
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_BF16] = bfloat16_distance_dot_rvv;
261359
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_U8] = uint8_distance_dot_rvv;
262360
// dispatch_distance_table[VECTOR_DISTANCE_DOT][VECTOR_TYPE_I8] = int8_distance_dot_rvv;
263361

264-
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
362+
dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F32] = float32_distance_l1_rvv;
265363
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_F16] = float16_distance_l1_rvv;
266364
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_BF16] = bfloat16_distance_l1_rvv;
267365
// dispatch_distance_table[VECTOR_DISTANCE_L1][VECTOR_TYPE_U8] = uint8_distance_l1_rvv;

0 commit comments

Comments
 (0)