Skip to content

Commit f318bce

Browse files
Fix large sparse conversion
1 parent fd0579a commit f318bce

1 file changed

Lines changed: 52 additions & 18 deletions

File tree

src/actionet/wp_utils.cpp

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Conversion functions between Python and C++ data structures
33

44
#include "wp_utils.h"
5+
#include <limits>
56

67
// Convert NumPy array to Armadillo dense matrix
78
arma::mat numpy_to_arma_mat(py::array_t<double> arr) {
@@ -27,28 +28,61 @@ arma::sp_mat scipy_to_arma_sparse(py::object scipy_sparse) {
2728
py::object csc = scipy_sparse.attr("tocsc")();
2829

2930
auto data = csc.attr("data").cast<py::array_t<double>>();
30-
auto indices = csc.attr("indices").cast<py::array_t<int>>();
31-
auto indptr = csc.attr("indptr").cast<py::array_t<int>>();
32-
auto shape = csc.attr("shape").cast<std::pair<int, int>>();
31+
auto indices = csc.attr("indices").cast<py::array_t<py::ssize_t, py::array::forcecast>>();
32+
auto indptr = csc.attr("indptr").cast<py::array_t<py::ssize_t, py::array::forcecast>>();
33+
auto shape = csc.attr("shape").cast<std::pair<py::ssize_t, py::ssize_t>>();
3334

3435
auto data_ptr = data.data();
3536
auto indices_ptr = indices.data();
3637
auto indptr_ptr = indptr.data();
3738

38-
arma::umat locations(2, data.size());
39-
arma::vec values(data.size());
39+
if (shape.first < 0 || shape.second < 0) {
40+
throw std::runtime_error("Sparse matrix shape must be non-negative");
41+
}
42+
43+
const size_t n_rows = static_cast<size_t>(shape.first);
44+
const size_t n_cols = static_cast<size_t>(shape.second);
45+
const size_t nnz = static_cast<size_t>(data.size());
46+
47+
if (indptr.size() < 1 || static_cast<size_t>(indptr.size()) != (n_cols + 1)) {
48+
throw std::runtime_error("Invalid CSC indptr length");
49+
}
50+
if (indptr_ptr[indptr.size() - 1] != static_cast<py::ssize_t>(nnz)) {
51+
throw std::runtime_error("CSC indptr does not match data length");
52+
}
4053

41-
size_t idx = 0;
42-
for (int col = 0; col < shape.second; ++col) {
43-
for (int j = indptr_ptr[col]; j < indptr_ptr[col + 1]; ++j) {
44-
locations(0, idx) = indices_ptr[j]; // row
45-
locations(1, idx) = col; // col
54+
auto to_uword = [](py::ssize_t v, const char* name) -> arma::uword {
55+
if (v < 0) {
56+
throw std::runtime_error(std::string("Negative index in sparse matrix: ") + name);
57+
}
58+
if (static_cast<unsigned long long>(v) > std::numeric_limits<arma::uword>::max()) {
59+
throw std::runtime_error(std::string("Index too large for Armadillo uword: ") + name);
60+
}
61+
return static_cast<arma::uword>(v);
62+
};
63+
64+
arma::umat locations(2, nnz);
65+
arma::vec values(nnz);
66+
67+
for (py::ssize_t col = 0; col < shape.second; ++col) {
68+
const py::ssize_t start = indptr_ptr[col];
69+
const py::ssize_t end = indptr_ptr[col + 1];
70+
if (start < 0 || end < start) {
71+
throw std::runtime_error("Invalid CSC indptr range");
72+
}
73+
for (py::ssize_t j = start; j < end; ++j) {
74+
const size_t idx = static_cast<size_t>(j);
75+
const py::ssize_t row = indices_ptr[j];
76+
if (static_cast<size_t>(row) >= n_rows) {
77+
throw std::runtime_error("Row index out of bounds in sparse matrix");
78+
}
79+
locations(0, idx) = to_uword(row, "row");
80+
locations(1, idx) = to_uword(col, "col");
4681
values(idx) = data_ptr[j];
47-
++idx;
4882
}
4983
}
5084

51-
return arma::sp_mat(locations, values, shape.first, shape.second);
85+
return arma::sp_mat(locations, values, n_rows, n_cols);
5286
}
5387

5488
// Convert Armadillo dense matrix to NumPy array
@@ -70,21 +104,21 @@ py::object arma_sparse_to_scipy(const arma::sp_mat& sp_mat) {
70104
py::module scipy_sparse = py::module::import("scipy.sparse");
71105

72106
std::vector<double> data;
73-
std::vector<int> rows;
74-
std::vector<int> cols;
107+
std::vector<py::ssize_t> rows;
108+
std::vector<py::ssize_t> cols;
75109

76110
for (arma::sp_mat::const_iterator it = sp_mat.begin(); it != sp_mat.end(); ++it) {
77111
data.push_back(*it);
78-
rows.push_back(it.row());
79-
cols.push_back(it.col());
112+
rows.push_back(static_cast<py::ssize_t>(it.row()));
113+
cols.push_back(static_cast<py::ssize_t>(it.col()));
80114
}
81115

82116
return scipy_sparse.attr("coo_matrix")(
83117
py::make_tuple(
84118
py::array_t<double>(data.size(), data.data()),
85119
py::make_tuple(
86-
py::array_t<int>(rows.size(), rows.data()),
87-
py::array_t<int>(cols.size(), cols.data())
120+
py::array_t<py::ssize_t>(rows.size(), rows.data()),
121+
py::array_t<py::ssize_t>(cols.size(), cols.data())
88122
)
89123
),
90124
py::make_tuple(sp_mat.n_rows, sp_mat.n_cols)

0 commit comments

Comments
 (0)