Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions src/backend/linalg_internal_cpu/Sum_internal.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "Sum_internal.hpp"

#include <span>

#include "boost/smart_ptr/intrusive_ptr.hpp"

#include "backend/Storage.hpp"
#include "backend/linalg_internal_cpu/pairwise_sum.hpp"
#include "cytnx_error.hpp"
#include "Type.hpp"

Expand Down Expand Up @@ -88,42 +91,30 @@ namespace cytnx {
cytnx_double *_ten = (cytnx_double *)ten->data();
cytnx_double *_out = (cytnx_double *)out->data();

_out[0] = 0;
for (cytnx_uint64 n = 0; n < Nelem; n++) {
_out[0] += _ten[n];
}
_out[0] = PairwiseSum(std::span<const cytnx_double>(_ten, Nelem));
}

void Sum_internal_f(boost::intrusive_ptr<Storage_base> &out,
const boost::intrusive_ptr<Storage_base> &ten, const cytnx_uint64 &Nelem) {
cytnx_float *_ten = (cytnx_float *)ten->data();
cytnx_float *_out = (cytnx_float *)out->data();

_out[0] = 0;
for (cytnx_uint64 n = 0; n < Nelem; n++) {
_out[0] += _ten[n];
}
_out[0] = PairwiseSum(std::span<const cytnx_float>(_ten, Nelem));
}
void Sum_internal_cd(boost::intrusive_ptr<Storage_base> &out,
const boost::intrusive_ptr<Storage_base> &ten, const cytnx_uint64 &Nelem) {
cytnx_complex128 *_ten = (cytnx_complex128 *)ten->data();
cytnx_complex128 *_out = (cytnx_complex128 *)out->data();

_out[0] = 0;
for (cytnx_uint64 n = 0; n < Nelem; n++) {
_out[0] += _ten[n];
}
_out[0] = PairwiseSum(std::span<const cytnx_complex128>(_ten, Nelem));
}

void Sum_internal_cf(boost::intrusive_ptr<Storage_base> &out,
const boost::intrusive_ptr<Storage_base> &ten, const cytnx_uint64 &Nelem) {
cytnx_complex64 *_ten = (cytnx_complex64 *)ten->data();
cytnx_complex64 *_out = (cytnx_complex64 *)out->data();

_out[0] = 0;
for (cytnx_uint64 n = 0; n < Nelem; n++) {
_out[0] += _ten[n];
}
_out[0] = PairwiseSum(std::span<const cytnx_complex64>(_ten, Nelem));
}

void Sum_internal_b(boost::intrusive_ptr<Storage_base> &out,
Expand Down
62 changes: 62 additions & 0 deletions src/backend/linalg_internal_cpu/pairwise_sum.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#ifndef CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_
#define CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_

#include <cstddef>
#include <iterator>
#include <ranges>

namespace cytnx {
namespace linalg_internal {

// Recursive (divide-and-conquer) core of the pairwise summation, matching
// NumPy's np.add.reduce: a straight loop for the smallest blocks, an
// eight-accumulator unrolled loop up to 128 elements, and a split into two
// halves (rounded to a multiple of eight) above that. Worst-case rounding
// error grows as O(log N * eps) instead of the O(N * eps) of a naive serial
// accumulation, at essentially the same cost.
template <class T, std::random_access_iterator It>
T PairwiseSumBlocks(It first, std::size_t n) {
if (n < 8) {
T res = T(0);
for (std::size_t i = 0; i < n; ++i) res += first[static_cast<std::ptrdiff_t>(i)];
return res;
}
if (n <= 128) {
T r0 = first[0], r1 = first[1], r2 = first[2], r3 = first[3];
T r4 = first[4], r5 = first[5], r6 = first[6], r7 = first[7];
std::size_t i = 8;
for (; i + 8 <= n; i += 8) {
auto p = first + static_cast<std::ptrdiff_t>(i);
r0 += p[0];
r1 += p[1];
r2 += p[2];
r3 += p[3];
r4 += p[4];
r5 += p[5];
r6 += p[6];
r7 += p[7];
}
T res = ((r0 + r1) + (r2 + r3)) + ((r4 + r5) + (r6 + r7));
for (; i < n; ++i) res += first[static_cast<std::ptrdiff_t>(i)];
return res;
}
std::size_t half = n / 2;
half -= half % 8;
return PairwiseSumBlocks<T>(first, half) +
PairwiseSumBlocks<T>(first + static_cast<std::ptrdiff_t>(half), n - half);
}

// Pairwise sum over a random-access range. The element type is deduced from
// the range. A contiguous std::span sums every element; pass a strided view
// (see stride_view.hpp) to sum a strided sequence such as a matrix diagonal.
template <std::ranges::random_access_range R>
std::ranges::range_value_t<R> PairwiseSum(R&& range) {
using T = std::ranges::range_value_t<R>;
return PairwiseSumBlocks<T>(std::ranges::begin(range),
static_cast<std::size_t>(std::ranges::size(range)));
}

} // namespace linalg_internal
} // namespace cytnx

#endif // CYTNX_BACKEND_LINALG_INTERNAL_CPU_PAIRWISE_SUM_H_
79 changes: 78 additions & 1 deletion tests/linalg_test/sum_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace cytnx {
* Note: `cytnx_bool` is not supported for the `linalg::Sum()` function.
* This test also assesses the accuracy of summing floating-point numbers.
*/
TYPED_TEST(LinalgSumHomogeneousValuesTest, DISABLED_Accuracy) {
TYPED_TEST(LinalgSumHomogeneousValuesTest, Accuracy) {
Comment thread
IvanaGyro marked this conversation as resolved.
TypeParam value = LinalgSumHomogeneousValuesTest<TypeParam>::value;
int element_number = 10000;
Comment thread
IvanaGyro marked this conversation as resolved.
unsigned int dtype = Type_class().cy_typeid(value);
Expand All @@ -64,4 +64,81 @@ namespace cytnx {

EXPECT_NUMBER_EQ(sum_result.at<TypeParam>({0}), value * static_cast<TypeParam>(element_number));
}

/**
* Exercises every branch of `PairwiseSumBlocks` -- the recursive core that
* `linalg::Sum` dispatches floating-point reductions through.
*
* * n < 8 : straight serial loop
* * 8 <= n <= 128 : 8-accumulator unrolled body, optionally with a scalar
* tail when n % 8 != 0
* * n > 128 : recursive split into two halves rounded to a multiple
* of 8
*
* The original Accuracy test only covers the n = 10000 recursive-split case;
* a regression in either of the small-n branches would not be caught there.
* Sizes 7/8/9/15/128/129/137 straddle the thresholds (including the off-by-one
* tail cases). The expected result is exact, so any branch that drops or
* double-counts an element fails immediately.
*/
TEST(LinalgSumBoundaryTest, EachPairwiseSumBranch) {
const cytnx_double value = 1.0;
for (int n : {1, 7, 8, 9, 15, 128, 129, 137, 1024}) {
Tensor tensor(/* shape */ {static_cast<unsigned long>(n)}, Type.Double, Device.cpu,
/* init_zero */ false);
tensor.fill(value);
Tensor sum_result = linalg::Sum(tensor);
EXPECT_EQ(sum_result.shape().size(), 1);
EXPECT_EQ(sum_result.shape()[0], 1);
EXPECT_DOUBLE_EQ(sum_result.at<cytnx_double>({0}), value * static_cast<cytnx_double>(n))
<< "n=" << n;
}
}

/**
* The dynamic-range case that motivates pairwise summation, and the one
* input distribution where naive serial accumulation visibly fails.
*
* Both arrays are [+L, 1, 1, ..., 1, -L] with L far above the precision
* threshold (2^53 for double, 2^23 for float); the exact sum is N - 2. Under
* naive accumulation the running total reaches L on the first element and
* the subsequent +1's vanish in IEEE 754 rounding -- 1.0 is below the unit
* in the last place at that magnitude -- so naive returns ~0. Pairwise
* keeps small values together in the tree until the cancellation of +/-L at
* the top, so the small terms survive (modulo a handful of 1's in the
* unrolled-accumulator blocks that hold L itself).
*
* The contract this asserts is qualitative on purpose: any reasonable
* pairwise implementation must land much closer to N - 2 than to 0. A
* serial-accumulation regression would collapse to ~0 and fail
* `EXPECT_GT(result, N / 2)` immediately; the exact pairwise result depends
* on which accumulators receive +/-L (a few small terms are lost there).
*/
TEST(LinalgSumHeterogeneousMagnitudeTest, RecoversTermsLostByNaiveAccumulation_Double) {
constexpr int N = 1024;
Tensor tensor(/* shape */ {static_cast<unsigned long>(N)}, Type.Double, Device.cpu,
/* init_zero */ false);
tensor.fill(static_cast<cytnx_double>(1));
tensor.at<cytnx_double>({0}) = 1e16;
tensor.at<cytnx_double>({N - 1}) = -1e16;
Tensor sum_result = linalg::Sum(tensor);
const cytnx_double result = sum_result.at<cytnx_double>({0});
EXPECT_GT(result, static_cast<cytnx_double>(N / 2));
EXPECT_LE(result, static_cast<cytnx_double>(N));
}

TEST(LinalgSumHeterogeneousMagnitudeTest, RecoversTermsLostByNaiveAccumulation_Float) {
constexpr int N = 1024;
Tensor tensor(/* shape */ {static_cast<unsigned long>(N)}, Type.Float, Device.cpu,
/* init_zero */ false);
tensor.fill(static_cast<cytnx_float>(1));
// float has ~7.2 decimal digits of precision; 1e8 already exceeds 2^23, so
// a serial `1e8 + 1` collapses to 1e8 and the unit terms are lost.
tensor.at<cytnx_float>({0}) = 1e8f;
tensor.at<cytnx_float>({N - 1}) = -1e8f;
Tensor sum_result = linalg::Sum(tensor);
const cytnx_float result = sum_result.at<cytnx_float>({0});
EXPECT_GT(result, static_cast<cytnx_float>(N / 2));
EXPECT_LE(result, static_cast<cytnx_float>(N));
}
} // namespace cytnx
Loading