55#include " netgraph/core/k_shortest_paths.hpp"
66#include " netgraph/core/max_flow.hpp"
77#include " netgraph/core/shortest_paths.hpp"
8- #include < cassert>
98
109namespace netgraph ::core {
1110
@@ -24,69 +23,37 @@ class CpuBackend final : public Backend {
2423 std::pair<std::vector<Cost>, PredDAG> spf (
2524 const GraphHandle& gh, NodeId src, const SpfOptions& opts) override {
2625 const StrictMultiDiGraph& g = *gh.graph ;
27- // Precondition checks (debug mode only)
28- assert (src >= 0 && src < g.num_nodes ());
29- if (opts.dst .has_value ()) {
30- assert (*opts.dst >= 0 && *opts.dst < g.num_nodes ());
31- }
32- if (!opts.node_mask .empty ()) {
33- assert (opts.node_mask .size () == static_cast <size_t >(g.num_nodes ()));
34- }
35- if (!opts.edge_mask .empty ()) {
36- assert (opts.edge_mask .size () == static_cast <size_t >(g.num_edges ()));
37- }
38- if (!opts.residual .empty ()) {
39- assert (opts.residual .size () == static_cast <size_t >(g.num_edges ()));
40- }
41- // Convert spans to raw pointers for compatibility with existing free functions
42- const bool * node_ptr = opts.node_mask .empty () ? nullptr : opts.node_mask .data ();
43- const bool * edge_ptr = opts.edge_mask .empty () ? nullptr : opts.edge_mask .data ();
26+ // Forward spans directly
27+ std::span<const bool > node_span = (opts.node_mask .size () == static_cast <size_t >(g.num_nodes ())) ? opts.node_mask : std::span<const bool >{};
28+ std::span<const bool > edge_span = (opts.edge_mask .size () == static_cast <size_t >(g.num_edges ())) ? opts.edge_mask : std::span<const bool >{};
4429 return netgraph::core::shortest_paths (g, src, opts.dst , opts.multipath , opts.selection ,
45- opts.residual , node_ptr, edge_ptr );
30+ opts.residual , node_span, edge_span );
4631 }
4732
4833 std::pair<Flow, FlowSummary> max_flow (
4934 const GraphHandle& gh, NodeId src, NodeId dst, const MaxFlowOptions& opts) override {
5035 const StrictMultiDiGraph& g = *gh.graph ;
51- // Precondition checks (debug mode only)
52- assert (src >= 0 && src < g.num_nodes ());
53- assert (dst >= 0 && dst < g.num_nodes ());
54- if (!opts.node_mask .empty ()) {
55- assert (opts.node_mask .size () == static_cast <size_t >(g.num_nodes ()));
56- }
57- if (!opts.edge_mask .empty ()) {
58- assert (opts.edge_mask .size () == static_cast <size_t >(g.num_edges ()));
59- }
60- // Convert spans to raw pointers for compatibility with existing free functions
61- const bool * node_ptr = opts.node_mask .empty () ? nullptr : opts.node_mask .data ();
62- const bool * edge_ptr = opts.edge_mask .empty () ? nullptr : opts.edge_mask .data ();
36+ // Forward spans directly
37+ std::span<const bool > node_span = (opts.node_mask .size () == static_cast <size_t >(g.num_nodes ())) ? opts.node_mask : std::span<const bool >{};
38+ std::span<const bool > edge_span = (opts.edge_mask .size () == static_cast <size_t >(g.num_edges ())) ? opts.edge_mask : std::span<const bool >{};
6339 return netgraph::core::calc_max_flow (
6440 g, src, dst,
6541 opts.placement , opts.shortest_path ,
6642 opts.with_edge_flows ,
6743 opts.with_reachable ,
6844 opts.with_residuals ,
69- node_ptr, edge_ptr );
45+ node_span, edge_span );
7046 }
7147
7248 std::vector<std::pair<std::vector<Cost>, PredDAG>> ksp (
7349 const GraphHandle& gh, NodeId src, NodeId dst, const KspOptions& opts) override {
7450 const StrictMultiDiGraph& g = *gh.graph ;
75- // Precondition checks (debug mode only)
76- assert (src >= 0 && src < g.num_nodes ());
77- assert (dst >= 0 && dst < g.num_nodes ());
78- assert (opts.k > 0 );
79- if (!opts.node_mask .empty ()) {
80- assert (opts.node_mask .size () == static_cast <size_t >(g.num_nodes ()));
81- }
82- if (!opts.edge_mask .empty ()) {
83- assert (opts.edge_mask .size () == static_cast <size_t >(g.num_edges ()));
84- }
85- // Convert spans to raw pointers for compatibility with existing free functions
86- const bool * node_ptr = opts.node_mask .empty () ? nullptr : opts.node_mask .data ();
87- const bool * edge_ptr = opts.edge_mask .empty () ? nullptr : opts.edge_mask .data ();
51+ if (opts.k <= 0 ) { return {}; }
52+ // Forward spans directly
53+ std::span<const bool > node_span = (opts.node_mask .size () == static_cast <size_t >(g.num_nodes ())) ? opts.node_mask : std::span<const bool >{};
54+ std::span<const bool > edge_span = (opts.edge_mask .size () == static_cast <size_t >(g.num_edges ())) ? opts.edge_mask : std::span<const bool >{};
8855 return netgraph::core::k_shortest_paths (g, src, dst, opts.k , opts.max_cost_factor ,
89- opts.unique , node_ptr, edge_ptr );
56+ opts.unique , node_span, edge_span );
9057 }
9158
9259 std::vector<FlowSummary> batch_max_flow (
@@ -96,24 +63,16 @@ class CpuBackend final : public Backend {
9663 const std::vector<std::span<const bool >>& node_masks,
9764 const std::vector<std::span<const bool >>& edge_masks) override {
9865 const StrictMultiDiGraph& g = *gh.graph ;
99- // Precondition checks (debug mode only)
100- for (const auto & [src, dst] : pairs) {
101- assert (src >= 0 && src < g.num_nodes ());
102- assert (dst >= 0 && dst < g.num_nodes ());
103- }
104- assert (node_masks.size () == edge_masks.size () || node_masks.empty () || edge_masks.empty ());
66+ // Forward spans directly
67+ std::vector<std::span<const bool >> node_ptrs, edge_ptrs;
68+ node_ptrs.reserve (node_masks.size ());
69+ edge_ptrs.reserve (edge_masks.size ());
10570 for (const auto & span : node_masks) {
106- assert ( span.size () == static_cast <size_t >(g.num_nodes ()));
71+ node_ptrs. push_back (( span.size () == static_cast <size_t >(g.num_nodes ())) ? span : std::span< const bool >{} );
10772 }
10873 for (const auto & span : edge_masks) {
109- assert ( span.size () == static_cast <size_t >(g.num_edges ()));
74+ edge_ptrs. push_back (( span.size () == static_cast <size_t >(g.num_edges ())) ? span : std::span< const bool >{} );
11075 }
111- // Convert spans to raw pointers for compatibility with existing free functions
112- std::vector<const bool *> node_ptrs, edge_ptrs;
113- node_ptrs.reserve (node_masks.size ());
114- edge_ptrs.reserve (edge_masks.size ());
115- for (const auto & span : node_masks) node_ptrs.push_back (span.data ());
116- for (const auto & span : edge_masks) edge_ptrs.push_back (span.data ());
11776 return netgraph::core::batch_max_flow (g, pairs,
11877 opts.placement , opts.shortest_path ,
11978 opts.with_edge_flows , opts.with_reachable , opts.with_residuals ,
0 commit comments