Skip to content

Commit 0570383

Browse files
committed
distance-rvv: Cleanup reduction operations
1 parent ae023f6 commit 0570383

1 file changed

Lines changed: 25 additions & 31 deletions

File tree

src/distance-rvv.c

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@
1717
extern distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX];
1818
extern const char *distance_backend_name;
1919

20+
// MARK: - UTILS -
21+
22+
// Reduces a vector by summing all of it's elements into a single scalar float
23+
float float32_sum_vector_f32m8(vfloat32m8_t vec, size_t vl) {
24+
vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
25+
vl = __riscv_vsetvl_e32m8(vl);
26+
acc = __riscv_vfredusum_vs_f32m8_f32m1(vec, acc, vl);
27+
return __riscv_vfmv_f_s_f32m1_f32(acc);
28+
}
29+
30+
// Reduces a vector by summing all of it's elements into a single scalar float
31+
float float32_sum_vector_f32m4(vfloat32m4_t vec, size_t vl) {
32+
vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
33+
vl = __riscv_vsetvl_e32m4(vl);
34+
acc = __riscv_vfredusum_vs_f32m4_f32m1(vec, acc, vl);
35+
return __riscv_vfmv_f_s_f32m1_f32(acc);
36+
}
37+
2038
// MARK: - FLOAT32 -
2139

2240
float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool use_sqrt) {
@@ -46,15 +64,10 @@ float float32_distance_l2_impl_rvv (const void *v1, const void *v2, int n, bool
4664
}
4765

4866
// Copy the accumulators back into a scalar register
49-
vfloat32m1_t vl2_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
50-
vl = __riscv_vsetvl_e32m8(n);
51-
vl2_acc = __riscv_vfredusum_vs_f32m8_f32m1(vl2, vl2_acc, vl);
52-
53-
float l2 = __riscv_vfmv_f_s_f32m1_f32(vl2_acc);
67+
float l2 = float32_sum_vector_f32m8(vl2, n);
5468
return use_sqrt ? sqrtf(l2) : l2;
5569
}
5670

57-
5871
float float32_distance_l2_rvv (const void *v1, const void *v2, int n) {
5972
return float32_distance_l2_impl_rvv(v1, v2, n, true);
6073
}
@@ -93,12 +106,7 @@ float float32_distance_l1_rvv (const void *v1, const void *v2, int n) {
93106
}
94107

95108
// Copy the accumulators back into a scalar register
96-
vfloat32m1_t vsad_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
97-
vl = __riscv_vsetvl_e32m8(n);
98-
vsad_acc = __riscv_vfredusum_vs_f32m8_f32m1(vsad, vsad_acc, vl);
99-
100-
float sad = __riscv_vfmv_f_s_f32m1_f32(vsad_acc);
101-
return sad;
109+
return float32_sum_vector_f32m8(vsad, n);
102110
}
103111

104112
float float32_distance_dot_rvv (const void *v1, const void *v2, int n) {
@@ -128,11 +136,7 @@ float float32_distance_dot_rvv (const void *v1, const void *v2, int n) {
128136
}
129137

130138
// Copy the accumulators back into a scalar register
131-
vfloat32m1_t vdot_acc = __riscv_vfmv_v_f_f32m1(0.0f, 1);
132-
vl = __riscv_vsetvl_e32m8(n);
133-
vdot_acc = __riscv_vfredusum_vs_f32m8_f32m1(vdot, vdot_acc, vl);
134-
135-
float dot = __riscv_vfmv_f_s_f32m1_f32(vdot_acc);
139+
float dot = float32_sum_vector_f32m8(vdot, n);
136140
return -dot;
137141
}
138142

@@ -170,20 +174,10 @@ float float32_distance_cosine_rvv (const void *v1, const void *v2, int n) {
170174
}
171175

172176
// Now do a final reduction on the registers to sum the remaining elements
173-
vfloat32m1_t vdot_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
174-
vdot_acc = __riscv_vfredusum_vs_f32m4_f32m1(vdot, vdot_acc, vl);
175-
176-
vfloat32m1_t vmagn_a_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
177-
vmagn_a_acc = __riscv_vfredusum_vs_f32m4_f32m1(vmagn_a, vmagn_a_acc, vl);
178-
179-
vfloat32m1_t vmagn_b_acc = __riscv_vfmv_v_f_f32m1(0.0f, vl);
180-
vmagn_b_acc = __riscv_vfredusum_vs_f32m4_f32m1(vmagn_b, vmagn_b_acc, vl);
181-
182-
// Copy the accumulators back into a scalar register, to finalize the calculations
183-
// TODO: With default flags this does not use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that
184-
float dot = __riscv_vfmv_f_s_f32m1_f32(vdot_acc);
185-
float magn_a = sqrtf(__riscv_vfmv_f_s_f32m1_f32(vmagn_a_acc));
186-
float magn_b = sqrtf(__riscv_vfmv_f_s_f32m1_f32(vmagn_b_acc));
177+
// TODO: With default flags this does not always use the fsqrt.s/fmin.s/fmax.s instruction, we should fix that
178+
float dot = float32_sum_vector_f32m4(vdot, n);
179+
float magn_a = sqrtf(float32_sum_vector_f32m4(vmagn_a, n));
180+
float magn_b = sqrtf(float32_sum_vector_f32m4(vmagn_b, n));
187181

188182
if (magn_a == 0.0f || magn_b == 0.0f) return 1.0f;
189183

0 commit comments

Comments
 (0)