Skip to content

Commit 150fd12

Browse files
committed
improved type consistency
1 parent de5e6ea commit 150fd12

7 files changed

Lines changed: 61 additions & 68 deletions

File tree

bindings/python/module.cpp

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
/*
22
Pybind11 module exposing NetGraph-Core C++ APIs to Python.
3-
4-
Notes:
5-
- Accepts NumPy arrays (C-contiguous) and converts to spans for zero-copy
6-
views where possible.
7-
- Distances returned as float64 arrays with inf for unreachable.
8-
- Edge/Node masks are validated for dtype=bool and length.
3+
- Uses NumPy arrays with zero-copy spans where possible
4+
- Returns distances as float64 arrays (inf for unreachable)
5+
- Validates edge/node masks for bool dtype and length
6+
- Array-returning view methods expose non-owning buffers over internal state; treat as read-only
97
*/
108
#include <pybind11/pybind11.h>
119
#include <pybind11/numpy.h>
@@ -82,13 +80,12 @@ PYBIND11_MODULE(_netgraph_core, m) {
8280
auto dst_s = as_span<std::int32_t>(dst, "dst");
8381
if (src_s.size() != dst_s.size()) throw py::type_error("src and dst must have the same length");
8482
auto cap_s = as_span<double>(capacity, "capacity");
85-
// Require cost dtype to be int64 for consistency
83+
// Cost dtype must be int64 to match internal Cost type
8684
auto cost_s = as_span<std::int64_t>(cost, "cost");
87-
std::span<const Cost> cost_cs(reinterpret_cast<const Cost*>(cost_s.data()), cost_s.size());
8885
return StrictMultiDiGraph::from_arrays(num_nodes,
8986
src_s,
9087
dst_s,
91-
cap_s, cost_cs, add_reverse);
88+
cap_s, cost_s, add_reverse);
9289
},
9390
py::arg("num_nodes"), py::arg("src"), py::arg("dst"), py::arg("capacity"), py::arg("cost"),
9491
py::kw_only(), py::arg("add_reverse") = false)
@@ -97,7 +94,7 @@ PYBIND11_MODULE(_netgraph_core, m) {
9794
// external link ids removed; EdgeId is the canonical id
9895
.def("capacity_view", [](py::object self_obj, const StrictMultiDiGraph& g){
9996
auto s = g.capacity_view();
100-
return py::array(
97+
py::array out(
10198
py::buffer_info(
10299
const_cast<double*>(s.data()),
103100
sizeof(double),
@@ -108,10 +105,11 @@ PYBIND11_MODULE(_netgraph_core, m) {
108105
),
109106
self_obj
110107
);
108+
return out;
111109
})
112110
.def("edge_src_view", [](py::object self_obj, const StrictMultiDiGraph& g){
113111
auto s = g.edge_src_view();
114-
return py::array(
112+
py::array out(
115113
py::buffer_info(
116114
const_cast<std::int32_t*>(s.data()),
117115
sizeof(std::int32_t),
@@ -122,10 +120,11 @@ PYBIND11_MODULE(_netgraph_core, m) {
122120
),
123121
self_obj
124122
);
123+
return out;
125124
})
126125
.def("edge_dst_view", [](py::object self_obj, const StrictMultiDiGraph& g){
127126
auto s = g.edge_dst_view();
128-
return py::array(
127+
py::array out(
129128
py::buffer_info(
130129
const_cast<std::int32_t*>(s.data()),
131130
sizeof(std::int32_t),
@@ -136,10 +135,11 @@ PYBIND11_MODULE(_netgraph_core, m) {
136135
),
137136
self_obj
138137
);
138+
return out;
139139
})
140140
.def("cost_view", [](py::object self_obj, const StrictMultiDiGraph& g){
141141
auto s = g.cost_view();
142-
return py::array(
142+
py::array out(
143143
py::buffer_info(
144144
const_cast<Cost*>(s.data()),
145145
sizeof(Cost),
@@ -150,6 +150,7 @@ PYBIND11_MODULE(_netgraph_core, m) {
150150
),
151151
self_obj
152152
);
153+
return out;
153154
})
154155
.def("row_offsets_view", [](const StrictMultiDiGraph& g){
155156
auto s = g.row_offsets_view();
@@ -446,24 +447,18 @@ PYBIND11_MODULE(_netgraph_core, m) {
446447
})
447448
.def("capacity_view", [](py::object self_obj, const FlowState& fs){
448449
auto s = fs.capacity_view();
449-
return py::array(
450-
py::buffer_info(
451-
const_cast<double*>(s.data()), sizeof(double), py::format_descriptor<double>::format(), 1, { s.size() }, { sizeof(double) }
452-
), self_obj);
450+
py::array out(py::buffer_info(const_cast<double*>(s.data()), sizeof(double), py::format_descriptor<double>::format(), 1, { s.size() }, { sizeof(double) }), self_obj);
451+
return out;
453452
})
454453
.def("residual_view", [](py::object self_obj, const FlowState& fs){
455454
auto s = fs.residual_view();
456-
return py::array(
457-
py::buffer_info(
458-
const_cast<double*>(s.data()), sizeof(double), py::format_descriptor<double>::format(), 1, { s.size() }, { sizeof(double) }
459-
), self_obj);
455+
py::array out(py::buffer_info(const_cast<double*>(s.data()), sizeof(double), py::format_descriptor<double>::format(), 1, { s.size() }, { sizeof(double) }), self_obj);
456+
return out;
460457
})
461458
.def("edge_flow_view", [](py::object self_obj, const FlowState& fs){
462459
auto s = fs.edge_flow_view();
463-
return py::array(
464-
py::buffer_info(
465-
const_cast<double*>(s.data()), sizeof(double), py::format_descriptor<double>::format(), 1, { s.size() }, { sizeof(double) }
466-
), self_obj);
460+
py::array out(py::buffer_info(const_cast<double*>(s.data()), sizeof(double), py::format_descriptor<double>::format(), 1, { s.size() }, { sizeof(double) }), self_obj);
461+
return out;
467462
})
468463
.def("place_on_dag", [](FlowState& fs, std::int32_t src, std::int32_t dst, const PredDAG& dag, double requested_flow, FlowPlacement placement, bool shortest_path){
469464
py::gil_scoped_release rel; auto placed = fs.place_on_dag(src, dst, dag, requested_flow, placement, shortest_path); py::gil_scoped_acquire acq; return placed;
@@ -630,8 +625,8 @@ PYBIND11_MODULE(_netgraph_core, m) {
630625
py::arg("diminishing_returns_epsilon_frac") = 1e-3)
631626
.def("flow_count", &FlowPolicy::flow_count)
632627
.def("placed_demand", &FlowPolicy::placed_demand)
633-
.def("place_demand", [](FlowPolicy& p, FlowGraph& fg, std::int32_t src, std::int32_t dst, std::int32_t flowClass, double volume, py::object target_per_flow, py::object min_flow){ std::optional<double> tpf; if (!target_per_flow.is_none()) tpf = py::cast<double>(target_per_flow); std::optional<double> mfl; if (!min_flow.is_none()) mfl = py::cast<double>(min_flow); py::gil_scoped_release rel; auto pr = p.place_demand(fg, src, dst, flowClass, volume, tpf, mfl); py::gil_scoped_acquire acq; return py::make_tuple(pr.first, pr.second); }, py::arg("flow_graph"), py::arg("src"), py::arg("dst"), py::arg("flowClass"), py::arg("volume"), py::arg("target_per_flow") = py::none(), py::arg("min_flow") = py::none())
634-
.def("rebalance_demand", [](FlowPolicy& p, FlowGraph& fg, std::int32_t src, std::int32_t dst, std::int32_t flowClass, double target){ py::gil_scoped_release rel; auto pr = p.rebalance_demand(fg, src, dst, flowClass, target); py::gil_scoped_acquire acq; return py::make_tuple(pr.first, pr.second); },
628+
.def("place_demand", [](FlowPolicy& p, FlowGraph& fg, std::int32_t src, std::int32_t dst, FlowClass flowClass, double volume, py::object target_per_flow, py::object min_flow){ std::optional<double> tpf; if (!target_per_flow.is_none()) tpf = py::cast<double>(target_per_flow); std::optional<double> mfl; if (!min_flow.is_none()) mfl = py::cast<double>(min_flow); py::gil_scoped_release rel; auto pr = p.place_demand(fg, src, dst, flowClass, volume, tpf, mfl); py::gil_scoped_acquire acq; return py::make_tuple(pr.first, pr.second); }, py::arg("flow_graph"), py::arg("src"), py::arg("dst"), py::arg("flowClass"), py::arg("volume"), py::arg("target_per_flow") = py::none(), py::arg("min_flow") = py::none())
629+
.def("rebalance_demand", [](FlowPolicy& p, FlowGraph& fg, std::int32_t src, std::int32_t dst, FlowClass flowClass, double target){ py::gil_scoped_release rel; auto pr = p.rebalance_demand(fg, src, dst, flowClass, target); py::gil_scoped_acquire acq; return py::make_tuple(pr.first, pr.second); },
635630
py::arg("flow_graph"), py::arg("src"), py::arg("dst"), py::arg("flowClass"), py::arg("target"))
636631
.def("remove_demand", [](FlowPolicy& p, FlowGraph& fg){ py::gil_scoped_release rel; p.remove_demand(fg); py::gil_scoped_acquire acq; })
637632
.def_property_readonly("flows", [](const FlowPolicy& p){ py::dict out; for (auto const& kv : p.flows()) { const auto& idx = kv.first; const auto& f = kv.second; out[py::make_tuple(idx.src, idx.dst, idx.flowClass, idx.flowId)] = py::make_tuple(f.src, f.dst, f.cost, f.placed_flow); } return out; });

include/netgraph/core/flow_graph.hpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Shared flow ledger layering over FlowState with per-flow deltas. */
1+
/* FlowGraph manages per-flow edge deltas over FlowState. */
22
#pragma once
33

44
#include <cstdint>
@@ -14,9 +14,8 @@
1414

1515
namespace netgraph::core {
1616

17-
// FlowGraph is a shared, authoritative flow ledger over a StrictMultiDiGraph.
18-
// It composes a FlowState for residual and aggregate edge_flow management, and
19-
// maintains per-flow edge deltas to support exact removal/reopt.
17+
// FlowGraph manages per-flow edge deltas over a StrictMultiDiGraph.
18+
// Composes FlowState for residual/aggregate edge flow management.
2019
class FlowGraph {
2120
public:
2221
explicit FlowGraph(const StrictMultiDiGraph& g);
@@ -30,8 +29,7 @@ class FlowGraph {
3029
// Access underlying graph (const)
3130
[[nodiscard]] const StrictMultiDiGraph& graph() const noexcept { return *g_; }
3231

33-
// Placement: applies placement and records per-edge deltas for this flow.
34-
// Returns placed amount.
32+
// Apply placement and record per-edge deltas for this flow. Returns placed amount.
3533
[[nodiscard]] Flow place(const FlowIndex& idx, NodeId src, NodeId dst,
3634
const PredDAG& dag, Flow amount,
3735
FlowPlacement placement, bool shortest_path = false);
@@ -40,23 +38,22 @@ class FlowGraph {
4038
void remove(const FlowIndex& idx);
4139

4240
// Remove all flows belonging to a given flowClass.
43-
void remove_by_class(std::int32_t flowClass);
41+
void remove_by_class(FlowClass flowClass);
4442

4543
// Reset all state to initial capacity and clear ledger.
4644
void reset() noexcept;
4745

4846
// Inspect: return a copy of the flow's edges and amounts.
4947
[[nodiscard]] std::vector<std::pair<EdgeId, Flow>> get_flow_edges(const FlowIndex& idx) const;
5048

51-
// Attempt to reconstruct a single path (LSP) for this flow from the ledger.
52-
// Returns empty vector if the flow does not correspond to a unique simple path
53-
// (e.g., when placed with multipath/proportional splitting).
49+
// Reconstruct single path for this flow from ledger.
50+
// Returns empty vector if flow uses multipath/proportional splitting.
5451
[[nodiscard]] std::vector<EdgeId> get_flow_path(const FlowIndex& idx) const;
5552

5653
private:
5754
const StrictMultiDiGraph* g_ {nullptr};
5855
FlowState fs_;
59-
// Per-flow ledger: only edges with non-zero assigned flow are stored.
56+
// Per-flow ledger: stores only edges with non-zero flow
6057
std::unordered_map<FlowIndex, std::vector<std::pair<EdgeId, Flow>>, FlowIndexHash> ledger_;
6158
};
6259

include/netgraph/core/flow_policy.hpp

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
/*
2-
FlowPolicy — policy engine managing flows for a single demand.
3-
See src/flow_policy.cpp for detailed behavior notes.
4-
*/
1+
/* FlowPolicy manages flows for a single demand. */
52
#pragma once
63

74
#include <cstdint>
@@ -22,7 +19,7 @@ namespace netgraph::core {
2219

2320
enum class PathAlg : std::int32_t { SPF = 1 };
2421

25-
// Execution context bundles algorithms and graph handle for clear dependency injection
22+
// Execution context with algorithms and graph handle
2623
struct ExecutionContext {
2724
Algorithms* algorithms;
2825
GraphHandle graph;
@@ -32,8 +29,7 @@ struct ExecutionContext {
3229
: algorithms(&algs), graph(gh) {}
3330
};
3431

35-
// Configuration for FlowPolicy behavior. Mirrors the long-form constructor
36-
// parameters in a grouped, maintainable struct.
32+
// Configuration struct for FlowPolicy behavior
3733
struct FlowPolicyConfig {
3834
PathAlg path_alg { PathAlg::SPF };
3935
FlowPlacement flow_placement { FlowPlacement::Proportional };
@@ -51,8 +47,7 @@ struct FlowPolicyConfig {
5147
double diminishing_returns_epsilon_frac { 1e-3 };
5248
};
5349

54-
// FlowPolicy orchestrates flow creation, placement, reopt, and removal for a
55-
// single demand (src,dst,flowClass) on a shared FlowGraph.
50+
// FlowPolicy manages flow creation, placement, reoptimization for a single demand
5651
class FlowPolicy {
5752
public:
5853
// New config-based constructor
@@ -106,31 +101,30 @@ class FlowPolicy {
106101
// Core operations
107102
[[nodiscard]] std::pair<double,double> place_demand(FlowGraph& fg,
108103
NodeId src, NodeId dst,
109-
std::int32_t flowClass,
104+
FlowClass flowClass,
110105
double volume,
111106
std::optional<double> target_per_flow = std::nullopt,
112107
std::optional<double> min_flow = std::nullopt);
113108

114109
[[nodiscard]] std::pair<double,double> rebalance_demand(FlowGraph& fg,
115110
NodeId src, NodeId dst,
116-
std::int32_t flowClass,
111+
FlowClass flowClass,
117112
double target_per_flow);
118113

119114
void remove_demand(FlowGraph& fg);
120115

121116
[[nodiscard]] const std::unordered_map<FlowIndex, FlowRecord, FlowIndexHash>& flows() const noexcept { return flows_; }
122117

123-
// Configure static paths to be used for flow creation (if endpoints match).
124-
// Each entry is (src, dst, dag, cost). If provided, max_flow_count must be
125-
// equal to the number of static paths (or will be set to that number).
118+
// Configure static paths for flow creation. Each entry is (src, dst, dag, cost).
119+
// max_flow_count must equal the number of static paths if set.
126120
void set_static_paths(std::vector<std::tuple<NodeId, NodeId, PredDAG, Cost>> paths);
127121

128122
private:
129123
// Helpers
130124
[[nodiscard]] std::optional<std::pair<PredDAG, Cost>> get_path_bundle(const FlowGraph& fg,
131125
NodeId src, NodeId dst,
132126
std::optional<double> min_flow);
133-
[[nodiscard]] FlowRecord* create_flow(FlowGraph& fg, NodeId src, NodeId dst, std::int32_t flowClass,
127+
[[nodiscard]] FlowRecord* create_flow(FlowGraph& fg, NodeId src, NodeId dst, FlowClass flowClass,
134128
std::optional<double> min_flow);
135129
[[nodiscard]] FlowRecord* reoptimize_flow(FlowGraph& fg, const FlowIndex& idx, double headroom);
136130

@@ -154,7 +148,7 @@ class FlowPolicy {
154148
// State
155149
std::unordered_map<FlowIndex, FlowRecord, FlowIndexHash> flows_;
156150
Cost best_path_cost_ { std::numeric_limits<Cost>::max() };
157-
std::int64_t next_flow_id_ { 0 };
151+
FlowId next_flow_id_ { 0 };
158152

159153
// Static paths (optional)
160154
std::vector<std::tuple<NodeId, NodeId, PredDAG, Cost>> static_paths_;

include/netgraph/core/types.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Core type aliases and small helper structs used across the library. */
1+
/* Core type aliases and helper structs. */
22
#pragma once
33

44
#include <cstdint>
@@ -14,12 +14,16 @@ using Cap = double;
1414
// Semantic alias for flow amounts (same unit as capacity)
1515
using Flow = double;
1616

17+
// Aliases for flow classification and identifiers for consistency
18+
using FlowClass = std::int32_t;
19+
using FlowId = std::int64_t;
20+
1721
// Identity of a flow: endpoints + class (priority bucket) + per-policy unique id
1822
struct FlowIndex {
1923
NodeId src;
2024
NodeId dst;
21-
std::int32_t flowClass; // small priority bucket
22-
std::int64_t flowId; // per-policy unique id
25+
FlowClass flowClass; // small priority bucket
26+
FlowId flowId; // per-policy unique id
2327
friend bool operator==(const FlowIndex& a, const FlowIndex& b) noexcept {
2428
return a.src==b.src && a.dst==b.dst && a.flowClass==b.flowClass && a.flowId==b.flowId;
2529
}
@@ -34,8 +38,8 @@ struct FlowIndexHash {
3438
};
3539
combine(std::hash<NodeId>{}(k.src));
3640
combine(std::hash<NodeId>{}(k.dst));
37-
combine(std::hash<std::int32_t>{}(k.flowClass));
38-
combine(std::hash<std::int64_t>{}(k.flowId));
41+
combine(std::hash<FlowClass>{}(k.flowClass));
42+
combine(std::hash<FlowId>{}(k.flowId));
3943
return h;
4044
}
4145
};

src/flow_graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void FlowGraph::remove(const FlowIndex& idx) {
5353
ledger_.erase(it);
5454
}
5555

56-
void FlowGraph::remove_by_class(std::int32_t flowClass) {
56+
void FlowGraph::remove_by_class(FlowClass flowClass) {
5757
std::vector<FlowIndex> to_rm;
5858
to_rm.reserve(ledger_.size());
5959
for (auto const& kv : ledger_) if (kv.first.flowClass == flowClass) to_rm.push_back(kv.first);

src/flow_policy.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,16 @@ std::optional<std::pair<PredDAG, Cost>> FlowPolicy::get_path_bundle(const FlowGr
5252
// Decide whether we need residual-aware SPF and whether to build an edge mask
5353
const bool require_residual = (sel.require_capacity || (flow_placement_ == FlowPlacement::EqualBalanced && min_flow.has_value()));
5454
const auto residual = fg.residual_view();
55-
std::vector<unsigned char> em; const bool* edge_mask_ptr = nullptr;
55+
std::vector<unsigned char> em; std::unique_ptr<bool[]> em_bool; const bool* edge_mask_ptr = nullptr; // Safe bool array for edge mask
5656
bool need_mask = false;
5757
if (require_residual) {
5858
// In shortest-path EqualBalanced mode, do not enforce per-edge minimum residual at SPF stage
5959
need_mask = (min_flow.has_value() && !(shortest_path_ && flow_placement_ == FlowPlacement::EqualBalanced));
6060
if (need_mask) {
61-
em.assign(residual.size(), 1u);
61+
em_bool.reset(new bool[residual.size()]);
6262
double thr = *min_flow;
63-
for (std::size_t i=0;i<residual.size();++i) em[i] = static_cast<unsigned char>(static_cast<double>(residual[i]) >= thr);
64-
edge_mask_ptr = reinterpret_cast<const bool*>(em.data());
63+
for (std::size_t i=0;i<residual.size();++i) em_bool[i] = static_cast<double>(residual[i]) >= thr;
64+
edge_mask_ptr = em_bool.get();
6565
}
6666
}
6767
SpfOptions opts;
@@ -99,7 +99,7 @@ std::optional<std::pair<PredDAG, Cost>> FlowPolicy::get_path_bundle(const FlowGr
9999

100100
/* Create a new flow using the current path bundle. Returns nullptr if no
101101
admissible path is available given constraints. */
102-
FlowRecord* FlowPolicy::create_flow(FlowGraph& fg, NodeId src, NodeId dst, std::int32_t flowClass,
102+
FlowRecord* FlowPolicy::create_flow(FlowGraph& fg, NodeId src, NodeId dst, FlowClass flowClass,
103103
std::optional<double> min_flow) {
104104
FlowIndex idx{src, dst, flowClass, next_flow_id_++};
105105
auto pb = get_path_bundle(fg, src, dst, min_flow);
@@ -141,7 +141,7 @@ FlowRecord* FlowPolicy::reoptimize_flow(FlowGraph& fg, const FlowIndex& idx, dou
141141
Returns (total_placed, leftover). */
142142
std::pair<double,double> FlowPolicy::place_demand(FlowGraph& fg,
143143
NodeId src, NodeId dst,
144-
std::int32_t flowClass,
144+
FlowClass flowClass,
145145
double volume,
146146
std::optional<double> target_per_flow,
147147
std::optional<double> min_flow) {
@@ -301,7 +301,7 @@ std::pair<double,double> FlowPolicy::place_demand(FlowGraph& fg,
301301
`target_per_flow`. Internally removes and re-places the same total volume. */
302302
std::pair<double,double> FlowPolicy::rebalance_demand(FlowGraph& fg,
303303
NodeId src, NodeId dst,
304-
std::int32_t flowClass,
304+
FlowClass flowClass,
305305
double target_per_flow) {
306306
double vol = placed_demand();
307307
remove_demand(fg);

0 commit comments

Comments
 (0)