|
17 | 17 | extern distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX]; |
18 | 18 | extern const char *distance_backend_name; |
19 | 19 |
|
| 20 | +// MARK: - UTILS - |
| 21 | + |
| 22 | +// Reduces a vector by summing all of it's elements into a single scalar float |
| 23 | +float float32_sum_vector_f32m8(vfloat32m8_t vec, size_t vl) { |
| 24 | + vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, 1); |
| 25 | + vl = __riscv_vsetvl_e32m8(vl); |
| 26 | + acc = __riscv_vfredusum_vs_f32m8_f32m1(vec, acc, vl); |
| 27 | + return __riscv_vfmv_f_s_f32m1_f32(acc); |
| 28 | +} |
| 29 | + |
| 30 | +// Reduces a vector by summing all of it's elements into a single scalar float |
| 31 | +float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) { |
| 32 | + vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, 1); |
| 33 | + vl = __riscv_vsetvl_e32m4(vl); |
| 34 | + acc = __riscv_vfredusum_vs_f32m4_f32m1(vec, acc, vl); |
| 35 | + return __riscv_vfmv_f_s_f32m1_f32(acc); |
| 36 | +} |
| 37 | + |
20 | 38 | // MARK: - FLOAT32 - |
21 | 39 |
|
22 | 40 | float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool use_sqrt) { |
@@ -46,15 +64,10 @@ float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool |
46 | 64 | } |
47 | 65 |
|
48 | 66 | // Copy the accumulators back into a scalar register |
49 | | - vfloat32m1_t vl2_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1); |
50 | | - vl = __riscv_vsetvl_e32m8(n); |
51 | | - vl2_acc = __riscv_vfredusum_vs_f32m8_f32m1(vl2, vl2_acc, vl); |
52 | | - |
53 | | - float l2 = __riscv_vfmv_f_s_f32m1_f32(vl2_acc); |
| 67 | + float l2 = float32_sum_vector_f32m8(vl2, n); |
54 | 68 | return use_sqrt ? sqrtf(l2) : l2; |
55 | 69 | } |
56 | 70 |
|
57 | | - |
58 | 71 | float float32_distance_l2_rvv (const void *v1, const void *v2, int n) { |
59 | 72 | return float32_distance_l2_impl_rvv(v1, v2, n, true); |
60 | 73 | } |
@@ -93,12 +106,7 @@ float float32_distance_l1_rvv (const void *v1, const void *v2, int n) { |
93 | 106 | } |
94 | 107 |
|
95 | 108 | // Copy the accumulators back into a scalar register |
96 | | - vfloat32m1_t vsad_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1); |
97 | | - vl = __riscv_vsetvl_e32m8(n); |
98 | | - vsad_acc = __riscv_vfredusum_vs_f32m8_f32m1(vsad, vsad_acc, vl); |
99 | | - |
100 | | - float sad = __riscv_vfmv_f_s_f32m1_f32(vsad_acc); |
101 | | - return sad; |
| 109 | + return float32_sum_vector_f32m8(vsad, n); |
102 | 110 | } |
103 | 111 |
|
104 | 112 | float float32_distance_dot_rvv (const void *v1, const void *v2, int n) { |
@@ -128,11 +136,7 @@ float float32_distance_dot_rvv (const void *v1, const void *v2, int n) { |
128 | 136 | } |
129 | 137 |
|
130 | 138 | // Copy the accumulators back into a scalar register |
131 | | - vfloat32m1_t vdot_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1); |
132 | | - vl = __riscv_vsetvl_e32m8(n); |
133 | | - vdot_acc = __riscv_vfredusum_vs_f32m8_f32m1(vdot, vdot_acc, vl); |
134 | | - |
135 | | - float dot = __riscv_vfmv_f_s_f32m1_f32(vdot_acc); |
| 139 | + float dot = float32_sum_vector_f32m8(vdot, n); |
136 | 140 | return -dot; |
137 | 141 | } |
138 | 142 |
|
@@ -170,20 +174,10 @@ float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) { |
170 | 174 | } |
171 | 175 |
|
172 | 176 | // Now do a final reduction on the registers to sum the remaining elements |
173 | | - vfloat32m1_t vdot_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl); |
174 | | - vdot_acc = __riscv_vfredusum_vs_f32m4_f32m1(vdot, vdot_acc, vl); |
175 | | - |
176 | | - vfloat32m1_t vmagn_a_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl); |
177 | | - vmagn_a_acc = __riscv_vfredusum_vs_f32m4_f32m1(vmagn_a, vmagn_a_acc, vl); |
178 | | - |
179 | | - vfloat32m1_t vmagn_b_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl); |
180 | | - vmagn_b_acc = __riscv_vfredusum_vs_f32m4_f32m1(vmagn_b, vmagn_b_acc, vl); |
181 | | - |
182 | | - // Copy the accumulators back into a scalar register, to finalize the calculations |
183 | | - // TODO: With default flags this does not use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that |
184 | | - float dot = __riscv_vfmv_f_s_f32m1_f32(vdot_acc); |
185 | | - float magn_a = sqrtf(__riscv_vfmv_f_s_f32m1_f32(vmagn_a_acc)); |
186 | | - float magn_b = sqrtf(__riscv_vfmv_f_s_f32m1_f32(vmagn_b_acc)); |
| 177 | + // TODO: With default flags this does not always use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that |
| 178 | + float dot = float32_sum_vector_f32m4(vdot, n); |
| 179 | + float magn_a = sqrtf(float32_sum_vector_f32m4(vmagn_a, n)); |
| 180 | + float magn_b = sqrtf(float32_sum_vector_f32m4(vmagn_b, n)); |
187 | 181 |
|
188 | 182 | if (magn_a == 0.0f || magn_b == 0.0f) return 1.0f; |
189 | 183 |
|
|
0 commit comments