Skip to content

Commit de5e6ea

Browse files
committed
clean up
1 parent c147438 commit de5e6ea

10 files changed

Lines changed: 78 additions & 117 deletions

File tree

bindings/python/module.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,22 +472,22 @@ PYBIND11_MODULE(_netgraph_core, m) {
472472
py::gil_scoped_release rel; auto total = fs.place_max_flow(src, dst, placement, shortest_path); py::gil_scoped_acquire acq; return total;
473473
}, py::arg("src"), py::arg("dst"), py::arg("flow_placement") = FlowPlacement::Proportional, py::arg("shortest_path") = false)
474474
.def("compute_min_cut", [](const FlowState& fs, std::int32_t src, py::object node_mask, py::object edge_mask){
475-
const bool* node_ptr = nullptr; const bool* edge_ptr = nullptr; py::array node_arr, edge_arr;
475+
std::span<const bool> node_span, edge_span; py::array node_arr, edge_arr;
476476
if (!node_mask.is_none()) {
477477
node_arr = py::cast<py::array>(node_mask);
478478
if (!(node_arr.flags() & py::array::c_style)) throw py::type_error("node_mask must be C-contiguous (np.ascontiguousarray)");
479479
auto b = node_arr.request();
480480
if (b.ndim != 1 || b.format != py::format_descriptor<bool>::format()) throw py::type_error("node_mask must be 1-D bool");
481-
node_ptr = static_cast<const bool*>(b.ptr);
481+
node_span = std::span<const bool>(static_cast<const bool*>(b.ptr), static_cast<std::size_t>(b.shape[0]));
482482
}
483483
if (!edge_mask.is_none()) {
484484
edge_arr = py::cast<py::array>(edge_mask);
485485
if (!(edge_arr.flags() & py::array::c_style)) throw py::type_error("edge_mask must be C-contiguous (np.ascontiguousarray)");
486486
auto b = edge_arr.request();
487487
if (b.ndim != 1 || b.format != py::format_descriptor<bool>::format()) throw py::type_error("edge_mask must be 1-D bool");
488-
edge_ptr = static_cast<const bool*>(b.ptr);
488+
edge_span = std::span<const bool>(static_cast<const bool*>(b.ptr), static_cast<std::size_t>(b.shape[0]));
489489
}
490-
py::gil_scoped_release rel; auto mc = fs.compute_min_cut(src, node_ptr, edge_ptr); py::gil_scoped_acquire acq; return mc;
490+
py::gil_scoped_release rel; auto mc = fs.compute_min_cut(src, node_span, edge_span); py::gil_scoped_acquire acq; return mc;
491491
}, py::arg("src"), py::kw_only(), py::arg("node_mask") = py::none(), py::arg("edge_mask") = py::none());
492492

493493
// FlowIndex and FlowGraph bindings

include/netgraph/core/flow_state.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class FlowState {
5656
// from source s on the residual graph (forward arcs: residual>MIN; reverse arcs:
5757
// positive flow). Honors optional masks.
5858
[[nodiscard]] MinCut compute_min_cut(NodeId src,
59-
const bool* node_mask = nullptr,
60-
const bool* edge_mask = nullptr) const;
59+
std::span<const bool> node_mask = {},
60+
std::span<const bool> edge_mask = {}) const;
6161

6262
// Apply or revert a set of edge flow deltas directly.
6363
// When add==true, treats each (eid, flow) as additional placed flow on the edge.

include/netgraph/core/k_shortest_paths.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace netgraph::core {
2020
const StrictMultiDiGraph& g, NodeId src, NodeId dst,
2121
int k, std::optional<double> max_cost_factor,
2222
bool unique,
23-
const bool* node_mask = nullptr,
24-
const bool* edge_mask = nullptr);
23+
std::span<const bool> node_mask = {},
24+
std::span<const bool> edge_mask = {});
2525

2626
} // namespace netgraph::core

include/netgraph/core/max_flow.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ calc_max_flow(const StrictMultiDiGraph& g, NodeId src, NodeId dst,
3232
bool with_edge_flows,
3333
bool with_reachable,
3434
bool with_residuals,
35-
const bool* node_mask = nullptr,
36-
const bool* edge_mask = nullptr);
35+
std::span<const bool> node_mask = {},
36+
std::span<const bool> edge_mask = {});
3737

3838
[[nodiscard]] std::vector<FlowSummary>
3939
batch_max_flow(const StrictMultiDiGraph& g,
@@ -42,7 +42,7 @@ batch_max_flow(const StrictMultiDiGraph& g,
4242
bool with_edge_flows,
4343
bool with_reachable,
4444
bool with_residuals,
45-
const std::vector<const bool*>& node_masks = {},
46-
const std::vector<const bool*>& edge_masks = {});
45+
const std::vector<std::span<const bool>>& node_masks = {},
46+
const std::vector<std::span<const bool>>& edge_masks = {});
4747

4848
} // namespace netgraph::core

include/netgraph/core/shortest_paths.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ struct PredDAG {
2525
// Optional node/edge masks:
2626
// - node_mask[v] == true means node v is allowed; false excludes it from search.
2727
// - edge_mask[e] == true means edge e is allowed; false excludes it from search.
28-
// If masks are nullptr, they are ignored.
28+
// Empty mask spans are ignored.
2929
[[nodiscard]] std::pair<std::vector<Cost>, PredDAG>
3030
shortest_paths(const StrictMultiDiGraph& g, NodeId src,
3131
std::optional<NodeId> dst,
3232
bool multipath,
3333
const EdgeSelection& selection,
3434
std::span<const Cap> residual = {},
35-
const bool* node_mask = nullptr,
36-
const bool* edge_mask = nullptr);
35+
std::span<const bool> node_mask = {},
36+
std::span<const bool> edge_mask = {});
3737

3838
// Enumerate concrete paths represented by a PredDAG from src to dst.
3939
// Each path is returned as a sequence of (node_id, (edge_ids...)) pairs ending with (dst, ()).

src/cpu_backend.cpp

Lines changed: 19 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "netgraph/core/k_shortest_paths.hpp"
66
#include "netgraph/core/max_flow.hpp"
77
#include "netgraph/core/shortest_paths.hpp"
8-
#include <cassert>
98

109
namespace netgraph::core {
1110

@@ -24,69 +23,37 @@ class CpuBackend final : public Backend {
2423
std::pair<std::vector<Cost>, PredDAG> spf(
2524
const GraphHandle& gh, NodeId src, const SpfOptions& opts) override {
2625
const StrictMultiDiGraph& g = *gh.graph;
27-
// Precondition checks (debug mode only)
28-
assert(src >= 0 && src < g.num_nodes());
29-
if (opts.dst.has_value()) {
30-
assert(*opts.dst >= 0 && *opts.dst < g.num_nodes());
31-
}
32-
if (!opts.node_mask.empty()) {
33-
assert(opts.node_mask.size() == static_cast<size_t>(g.num_nodes()));
34-
}
35-
if (!opts.edge_mask.empty()) {
36-
assert(opts.edge_mask.size() == static_cast<size_t>(g.num_edges()));
37-
}
38-
if (!opts.residual.empty()) {
39-
assert(opts.residual.size() == static_cast<size_t>(g.num_edges()));
40-
}
41-
// Convert spans to raw pointers for compatibility with existing free functions
42-
const bool* node_ptr = opts.node_mask.empty() ? nullptr : opts.node_mask.data();
43-
const bool* edge_ptr = opts.edge_mask.empty() ? nullptr : opts.edge_mask.data();
26+
// Forward spans directly
27+
std::span<const bool> node_span = (opts.node_mask.size() == static_cast<size_t>(g.num_nodes())) ? opts.node_mask : std::span<const bool>{};
28+
std::span<const bool> edge_span = (opts.edge_mask.size() == static_cast<size_t>(g.num_edges())) ? opts.edge_mask : std::span<const bool>{};
4429
return netgraph::core::shortest_paths(g, src, opts.dst, opts.multipath, opts.selection,
45-
opts.residual, node_ptr, edge_ptr);
30+
opts.residual, node_span, edge_span);
4631
}
4732

4833
std::pair<Flow, FlowSummary> max_flow(
4934
const GraphHandle& gh, NodeId src, NodeId dst, const MaxFlowOptions& opts) override {
5035
const StrictMultiDiGraph& g = *gh.graph;
51-
// Precondition checks (debug mode only)
52-
assert(src >= 0 && src < g.num_nodes());
53-
assert(dst >= 0 && dst < g.num_nodes());
54-
if (!opts.node_mask.empty()) {
55-
assert(opts.node_mask.size() == static_cast<size_t>(g.num_nodes()));
56-
}
57-
if (!opts.edge_mask.empty()) {
58-
assert(opts.edge_mask.size() == static_cast<size_t>(g.num_edges()));
59-
}
60-
// Convert spans to raw pointers for compatibility with existing free functions
61-
const bool* node_ptr = opts.node_mask.empty() ? nullptr : opts.node_mask.data();
62-
const bool* edge_ptr = opts.edge_mask.empty() ? nullptr : opts.edge_mask.data();
36+
// Forward spans directly
37+
std::span<const bool> node_span = (opts.node_mask.size() == static_cast<size_t>(g.num_nodes())) ? opts.node_mask : std::span<const bool>{};
38+
std::span<const bool> edge_span = (opts.edge_mask.size() == static_cast<size_t>(g.num_edges())) ? opts.edge_mask : std::span<const bool>{};
6339
return netgraph::core::calc_max_flow(
6440
g, src, dst,
6541
opts.placement, opts.shortest_path,
6642
opts.with_edge_flows,
6743
opts.with_reachable,
6844
opts.with_residuals,
69-
node_ptr, edge_ptr);
45+
node_span, edge_span);
7046
}
7147

7248
std::vector<std::pair<std::vector<Cost>, PredDAG>> ksp(
7349
const GraphHandle& gh, NodeId src, NodeId dst, const KspOptions& opts) override {
7450
const StrictMultiDiGraph& g = *gh.graph;
75-
// Precondition checks (debug mode only)
76-
assert(src >= 0 && src < g.num_nodes());
77-
assert(dst >= 0 && dst < g.num_nodes());
78-
assert(opts.k > 0);
79-
if (!opts.node_mask.empty()) {
80-
assert(opts.node_mask.size() == static_cast<size_t>(g.num_nodes()));
81-
}
82-
if (!opts.edge_mask.empty()) {
83-
assert(opts.edge_mask.size() == static_cast<size_t>(g.num_edges()));
84-
}
85-
// Convert spans to raw pointers for compatibility with existing free functions
86-
const bool* node_ptr = opts.node_mask.empty() ? nullptr : opts.node_mask.data();
87-
const bool* edge_ptr = opts.edge_mask.empty() ? nullptr : opts.edge_mask.data();
51+
if (opts.k <= 0) { return {}; }
52+
// Forward spans directly
53+
std::span<const bool> node_span = (opts.node_mask.size() == static_cast<size_t>(g.num_nodes())) ? opts.node_mask : std::span<const bool>{};
54+
std::span<const bool> edge_span = (opts.edge_mask.size() == static_cast<size_t>(g.num_edges())) ? opts.edge_mask : std::span<const bool>{};
8855
return netgraph::core::k_shortest_paths(g, src, dst, opts.k, opts.max_cost_factor,
89-
opts.unique, node_ptr, edge_ptr);
56+
opts.unique, node_span, edge_span);
9057
}
9158

9259
std::vector<FlowSummary> batch_max_flow(
@@ -96,24 +63,16 @@ class CpuBackend final : public Backend {
9663
const std::vector<std::span<const bool>>& node_masks,
9764
const std::vector<std::span<const bool>>& edge_masks) override {
9865
const StrictMultiDiGraph& g = *gh.graph;
99-
// Precondition checks (debug mode only)
100-
for (const auto& [src, dst] : pairs) {
101-
assert(src >= 0 && src < g.num_nodes());
102-
assert(dst >= 0 && dst < g.num_nodes());
103-
}
104-
assert(node_masks.size() == edge_masks.size() || node_masks.empty() || edge_masks.empty());
66+
// Forward spans directly
67+
std::vector<std::span<const bool>> node_ptrs, edge_ptrs;
68+
node_ptrs.reserve(node_masks.size());
69+
edge_ptrs.reserve(edge_masks.size());
10570
for (const auto& span : node_masks) {
106-
assert(span.size() == static_cast<size_t>(g.num_nodes()));
71+
node_ptrs.push_back((span.size() == static_cast<size_t>(g.num_nodes())) ? span : std::span<const bool>{});
10772
}
10873
for (const auto& span : edge_masks) {
109-
assert(span.size() == static_cast<size_t>(g.num_edges()));
74+
edge_ptrs.push_back((span.size() == static_cast<size_t>(g.num_edges())) ? span : std::span<const bool>{});
11075
}
111-
// Convert spans to raw pointers for compatibility with existing free functions
112-
std::vector<const bool*> node_ptrs, edge_ptrs;
113-
node_ptrs.reserve(node_masks.size());
114-
edge_ptrs.reserve(edge_masks.size());
115-
for (const auto& span : node_masks) node_ptrs.push_back(span.data());
116-
for (const auto& span : edge_masks) edge_ptrs.push_back(span.data());
11776
return netgraph::core::batch_max_flow(g, pairs,
11877
opts.placement, opts.shortest_path,
11978
opts.with_edge_flows, opts.with_reachable, opts.with_residuals,

src/flow_state.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ Flow FlowState::place_max_flow(NodeId src, NodeId dst, FlowPlacement placement,
336336
return total;
337337
}
338338

339-
MinCut FlowState::compute_min_cut(NodeId src, const bool* node_mask, const bool* edge_mask) const {
339+
MinCut FlowState::compute_min_cut(NodeId src, std::span<const bool> node_mask, std::span<const bool> edge_mask) const {
340340
MinCut out;
341341
const auto N = g_->num_nodes();
342342
const auto row = g_->row_offsets_view();
@@ -348,17 +348,19 @@ MinCut FlowState::compute_min_cut(NodeId src, const bool* node_mask, const bool*
348348
std::vector<char> visited(static_cast<std::size_t>(N), 0);
349349
std::queue<std::int32_t> q;
350350
if (src >= 0 && src < N) { visited[static_cast<std::size_t>(src)] = 1; q.push(src); }
351+
const bool use_node_mask = (node_mask.size() == static_cast<std::size_t>(N));
352+
const bool use_edge_mask = (edge_mask.size() == static_cast<std::size_t>(g_->num_edges()));
351353
while (!q.empty()) {
352354
auto u = q.front(); q.pop();
353-
if (node_mask && !node_mask[static_cast<std::size_t>(u)]) continue;
355+
if (use_node_mask && !node_mask[static_cast<std::size_t>(u)]) continue;
354356
// Forward residual arcs
355357
auto start = static_cast<std::size_t>(row[static_cast<std::size_t>(u)]);
356358
auto end = static_cast<std::size_t>(row[static_cast<std::size_t>(u)+1]);
357359
for (std::size_t j = start; j < end; ++j) {
358360
auto v = static_cast<std::int32_t>(col[j]);
359361
auto eid = static_cast<std::size_t>(aei[j]);
360-
if (edge_mask && !edge_mask[eid]) continue;
361-
if (node_mask && !node_mask[static_cast<std::size_t>(v)]) continue;
362+
if (use_edge_mask && !edge_mask[eid]) continue;
363+
if (use_node_mask && !node_mask[static_cast<std::size_t>(v)]) continue;
362364
if (residual_[eid] > kMinCap && !visited[static_cast<std::size_t>(v)]) {
363365
visited[static_cast<std::size_t>(v)] = 1;
364366
q.push(v);
@@ -370,8 +372,8 @@ MinCut FlowState::compute_min_cut(NodeId src, const bool* node_mask, const bool*
370372
for (std::size_t j = rs; j < re; ++j) {
371373
auto w = static_cast<std::int32_t>(in_col[j]);
372374
auto eid = static_cast<std::size_t>(in_aei[j]);
373-
if (edge_mask && !edge_mask[eid]) continue;
374-
if (node_mask && !node_mask[static_cast<std::size_t>(w)]) continue;
375+
if (use_edge_mask && !edge_mask[eid]) continue;
376+
if (use_node_mask && !node_mask[static_cast<std::size_t>(w)]) continue;
375377
double flow_e = g_->capacity_view()[eid] - residual_[eid];
376378
if (flow_e > kMinFlow && !visited[static_cast<std::size_t>(w)]) {
377379
visited[static_cast<std::size_t>(w)] = 1;
@@ -382,15 +384,15 @@ MinCut FlowState::compute_min_cut(NodeId src, const bool* node_mask, const bool*
382384
// Collect cut edges
383385
for (std::int32_t u = 0; u < N; ++u) {
384386
if (!visited[static_cast<std::size_t>(u)]) continue;
385-
if (node_mask && !node_mask[static_cast<std::size_t>(u)]) continue;
387+
if (use_node_mask && !node_mask[static_cast<std::size_t>(u)]) continue;
386388
auto s3 = static_cast<std::size_t>(row[static_cast<std::size_t>(u)]);
387389
auto e3 = static_cast<std::size_t>(row[static_cast<std::size_t>(u)+1]);
388390
for (std::size_t j = s3; j < e3; ++j) {
389391
auto v = static_cast<std::int32_t>(col[j]);
390-
if (node_mask && !node_mask[static_cast<std::size_t>(v)]) continue;
392+
if (use_node_mask && !node_mask[static_cast<std::size_t>(v)]) continue;
391393
if (visited[static_cast<std::size_t>(v)]) continue;
392394
auto eid = static_cast<std::size_t>(aei[j]);
393-
if (edge_mask && !edge_mask[eid]) continue;
395+
if (use_edge_mask && !edge_mask[eid]) continue;
394396
if (residual_[eid] <= kMinCap) {
395397
out.edges.push_back(static_cast<EdgeId>(eid));
396398
}

src/k_shortest_paths.cpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ std::vector<std::pair<std::vector<Cost>, PredDAG>> k_shortest_paths(
149149
const StrictMultiDiGraph& g, NodeId src, NodeId dst,
150150
int k, std::optional<double> max_cost_factor,
151151
bool unique,
152-
const bool* node_mask,
153-
const bool* edge_mask) {
152+
std::span<const bool> node_mask,
153+
std::span<const bool> edge_mask) {
154154
std::vector<Path> paths;
155155
if (k <= 0) return {};
156156
if (src < 0 || dst < 0 || src >= g.num_nodes() || dst >= g.num_nodes()) return {};
@@ -160,12 +160,12 @@ std::vector<std::pair<std::vector<Cost>, PredDAG>> k_shortest_paths(
160160
std::vector<unsigned char> edge_mask_vec;
161161
const std::vector<unsigned char>* nm_ptr = nullptr;
162162
const std::vector<unsigned char>* em_ptr = nullptr;
163-
if (node_mask) {
163+
if (node_mask.size() == static_cast<std::size_t>(g.num_nodes())) {
164164
node_mask_vec.assign(static_cast<std::size_t>(g.num_nodes()), static_cast<unsigned char>(1));
165165
for (std::size_t i=0;i<node_mask_vec.size();++i) node_mask_vec[i] = node_mask[i] ? 1u : 0u;
166166
nm_ptr = &node_mask_vec;
167167
}
168-
if (edge_mask) {
168+
if (edge_mask.size() == static_cast<std::size_t>(g.num_edges())) {
169169
edge_mask_vec.assign(static_cast<std::size_t>(g.num_edges()), static_cast<unsigned char>(1));
170170
for (std::size_t i=0;i<edge_mask_vec.size();++i) edge_mask_vec[i] = edge_mask[i] ? 1u : 0u;
171171
em_ptr = &edge_mask_vec;
@@ -255,23 +255,18 @@ std::vector<std::pair<std::vector<Cost>, PredDAG>> k_shortest_paths(
255255
for (std::size_t idx=0; idx<edge_mask_local.size(); ++idx) edge_mask_local[idx] = (edge_mask_local[idx] && (*em_ptr)[idx]) ? 1u : 0u;
256256
}
257257
// Multipath spur PredDAG from spur_node -> t
258-
std::unique_ptr<bool[]> nm_buf;
259-
std::unique_ptr<bool[]> em_buf;
260-
const bool* nm = nullptr;
261-
const bool* em = nullptr;
258+
std::unique_ptr<bool[]> nm_buf2;
259+
std::unique_ptr<bool[]> em_buf2;
260+
std::span<const bool> nm, em;
262261
if (!node_mask_local.empty()) {
263-
nm_buf = std::unique_ptr<bool[]>(new bool[static_cast<std::size_t>(g.num_nodes())]);
264-
for (std::size_t idx = 0; idx < static_cast<std::size_t>(g.num_nodes()); ++idx) {
265-
nm_buf[idx] = node_mask_local[idx] != 0;
266-
}
267-
nm = nm_buf.get();
262+
nm_buf2 = std::unique_ptr<bool[]>(new bool[static_cast<std::size_t>(g.num_nodes())]);
263+
for (std::size_t idx = 0; idx < static_cast<std::size_t>(g.num_nodes()); ++idx) nm_buf2[idx] = (node_mask_local[idx] != 0);
264+
nm = std::span<const bool>(nm_buf2.get(), static_cast<std::size_t>(g.num_nodes()));
268265
}
269266
if (!edge_mask_local.empty()) {
270-
em_buf = std::unique_ptr<bool[]>(new bool[static_cast<std::size_t>(g.num_edges())]);
271-
for (std::size_t idx = 0; idx < static_cast<std::size_t>(g.num_edges()); ++idx) {
272-
em_buf[idx] = edge_mask_local[idx] != 0;
273-
}
274-
em = em_buf.get();
267+
em_buf2 = std::unique_ptr<bool[]>(new bool[static_cast<std::size_t>(g.num_edges())]);
268+
for (std::size_t idx = 0; idx < static_cast<std::size_t>(g.num_edges()); ++idx) em_buf2[idx] = (edge_mask_local[idx] != 0);
269+
em = std::span<const bool>(em_buf2.get(), static_cast<std::size_t>(g.num_edges()));
275270
}
276271
EdgeSelection sel; sel.multi_edge = true; sel.require_capacity = false; sel.tie_break = EdgeTieBreak::Deterministic;
277272
auto [dist_spur, dag_spur] = shortest_paths(g, spur_node, dst, /*multipath=*/true, sel, std::span<const Cap>(), nm, em);

0 commit comments

Comments
 (0)