@@ -367,24 +367,17 @@ PYBIND11_MODULE(_netgraph_core, m, py::mod_gil_not_used()) {
367367 });
368368
369369 // FlowState bindings
370- py::class_<FlowState>(m, " FlowState" , py::dynamic_attr ())
371- .def (" __init__" , [](py::object self, py::object graph_obj){
372- const StrictMultiDiGraph& g = py::cast<const StrictMultiDiGraph&>(graph_obj);
373- FlowState* fs = self.cast <FlowState*>();
374- new (fs) FlowState (g);
375- self.attr (" _graph_ref" ) = graph_obj;
376- }, py::arg (" graph" ))
377- .def (" __init__" , [](py::object self, py::object graph_obj, py::array residual){
378- const StrictMultiDiGraph& g = py::cast<const StrictMultiDiGraph&>(graph_obj);
370+ py::class_<FlowState>(m, " FlowState" )
371+ .def (py::init ([](const StrictMultiDiGraph& g){ return FlowState (g); }),
372+ py::arg (" graph" ), py::keep_alive<1 , 2 >())
373+ .def (py::init ([](const StrictMultiDiGraph& g, py::array residual){
379374 if (!(py::isinstance<py::array_t <double >>(residual))) throw py::type_error (" residual must be a numpy float64 array" );
380375 if (!(residual.flags () & py::array::c_style)) throw py::type_error (" residual must be C-contiguous" );
381376 auto buf = residual.request ();
382377 if (buf.ndim != 1 || static_cast <std::size_t >(buf.shape [0 ]) != static_cast <std::size_t >(g.num_edges ())) throw py::type_error (" residual length must equal num_edges" );
383378 std::span<const double > rspan (static_cast <const double *>(buf.ptr ), static_cast <std::size_t >(buf.shape [0 ]));
384- FlowState* fs = self.cast <FlowState*>();
385- new (fs) FlowState (g, rspan);
386- self.attr (" _graph_ref" ) = graph_obj;
387- }, py::arg (" graph" ), py::arg (" residual" ))
379+ return FlowState (g, rspan);
380+ }), py::arg (" graph" ), py::arg (" residual" ), py::keep_alive<1 , 2 >())
388381 .def (" reset" , [](FlowState& fs){ fs.reset (); })
389382 .def (" reset" , [](FlowState& fs, py::array residual){
390383 if (!(py::isinstance<py::array_t <double >>(residual))) throw py::type_error (" residual must be a numpy float64 array" );
@@ -422,8 +415,7 @@ PYBIND11_MODULE(_netgraph_core, m, py::mod_gil_not_used()) {
422415 FlowPlacement placement, bool shortest_path, bool require_capacity,
423416 py::object node_mask, py::object edge_mask){
424417 FlowState& fs = py::cast<FlowState&>(self_obj);
425- // Get graph reference to validate mask lengths
426- const StrictMultiDiGraph& g = py::cast<const StrictMultiDiGraph&>(self_obj.attr (" _graph_ref" ));
418+ const StrictMultiDiGraph& g = fs.graph ();
427419 auto node_bs = to_bool_span_from_numpy (node_mask, static_cast <std::size_t >(g.num_nodes ()), " node_mask" );
428420 auto edge_bs = to_bool_span_from_numpy (edge_mask, static_cast <std::size_t >(g.num_edges ()), " edge_mask" );
429421 py::gil_scoped_release rel;
@@ -437,8 +429,7 @@ PYBIND11_MODULE(_netgraph_core, m, py::mod_gil_not_used()) {
437429 py::kw_only (), py::arg (" node_mask" ) = py::none (), py::arg (" edge_mask" ) = py::none ())
438430 .def (" compute_min_cut" , [](py::object self_obj, std::int32_t src, py::object node_mask, py::object edge_mask){
439431 const FlowState& fs = py::cast<const FlowState&>(self_obj);
440- // Get graph reference to validate mask lengths
441- const StrictMultiDiGraph& g = py::cast<const StrictMultiDiGraph&>(self_obj.attr (" _graph_ref" ));
432+ const StrictMultiDiGraph& g = fs.graph ();
442433 auto node_bs = to_bool_span_from_numpy (node_mask, static_cast <std::size_t >(g.num_nodes ()), " node_mask" );
443434 auto edge_bs = to_bool_span_from_numpy (edge_mask, static_cast <std::size_t >(g.num_edges ()), " edge_mask" );
444435 py::gil_scoped_release rel;
@@ -455,13 +446,9 @@ PYBIND11_MODULE(_netgraph_core, m, py::mod_gil_not_used()) {
455446 .def_readonly (" flowClass" , &FlowIndex::flowClass)
456447 .def_readonly (" flowId" , &FlowIndex::flowId);
457448
458- py::class_<FlowGraph>(m, " FlowGraph" , py::dynamic_attr ())
459- .def (" __init__" , [](py::object self, py::object graph_obj){
460- const StrictMultiDiGraph& g = py::cast<const StrictMultiDiGraph&>(graph_obj);
461- FlowGraph* fg = self.cast <FlowGraph*>();
462- new (fg) FlowGraph (g);
463- self.attr (" _graph_ref" ) = graph_obj;
464- }, py::arg (" graph" ))
449+ py::class_<FlowGraph>(m, " FlowGraph" )
450+ .def (py::init ([](const StrictMultiDiGraph& g){ return FlowGraph (g); }),
451+ py::arg (" graph" ), py::keep_alive<1 , 2 >())
465452 .def (" capacity_view" , [](py::object self_obj){
466453 const FlowGraph& fg = py::cast<const FlowGraph&>(self_obj);
467454 auto s = fg.capacity_view ();
@@ -572,32 +559,26 @@ PYBIND11_MODULE(_netgraph_core, m, py::mod_gil_not_used()) {
572559 .def_readwrite (" diminishing_returns_window" , &FlowPolicyConfig::diminishing_returns_window)
573560 .def_readwrite (" diminishing_returns_epsilon_frac" , &FlowPolicyConfig::diminishing_returns_epsilon_frac);
574561
575- py::class_<FlowPolicy>(m, " FlowPolicy" , py::dynamic_attr () )
576- .def (" __init__" , [](py::object self, py::object algs_obj, py::object graph_obj , FlowPolicyConfig cfg,
562+ py::class_<FlowPolicy>(m, " FlowPolicy" )
563+ .def (" __init__" , [](FlowPolicy& self, py::object algs_obj, const PyGraph& pg , FlowPolicyConfig cfg,
577564 py::object node_mask, py::object edge_mask){
578565 std::shared_ptr<Algorithms> algs = py::cast<std::shared_ptr<Algorithms>>(algs_obj);
579- const PyGraph& pg = py::cast<const PyGraph&>(graph_obj);
580566
581- // Convert masks to spans (FlowPolicy will copy the data)
582567 auto node_bs = to_bool_span_from_numpy (node_mask, static_cast <std::size_t >(pg.num_nodes ), " node_mask" );
583568 auto edge_bs = to_bool_span_from_numpy (edge_mask, static_cast <std::size_t >(pg.num_edges ), " edge_mask" );
584569
585- // Update config with mask spans (will be copied by FlowPolicy constructor)
586570 cfg.node_mask = node_bs.view ;
587571 cfg.edge_mask = edge_bs.view ;
588572
589573 ExecutionContext ctx (algs, pg.handle );
590- FlowPolicy* fp = self.cast <FlowPolicy*>();
591- new (fp) FlowPolicy (ctx, cfg);
592- self.attr (" _algorithms_ref" ) = algs_obj;
593- self.attr (" _graph_ref" ) = graph_obj;
574+ new (&self) FlowPolicy (ctx, cfg);
594575 },
595576 py::arg (" algorithms" ), py::arg (" graph" ), py::arg (" config" ),
596577 py::kw_only (),
597578 py::arg (" node_mask" ) = py::none (),
598579 py::arg (" edge_mask" ) = py::none (),
599- py::keep_alive<1 , 2 >() // self keeps algorithms alive
600- )
580+ py::keep_alive<1 , 2 >(), // self keeps algorithms alive
581+ py::keep_alive< 1 , 3 >()) // self keeps graph alive
601582 .def (" flow_count" , &FlowPolicy::flow_count)
602583 .def (" placed_demand" , &FlowPolicy::placed_demand)
603584 .def (" place_demand" , [](FlowPolicy& p, FlowGraph& fg, std::int32_t src, std::int32_t dst, FlowClass flowClass, double volume, py::object target_per_flow, py::object min_flow){ std::optional<double > tpf; if (!target_per_flow.is_none ()) tpf = py::cast<double >(target_per_flow); std::optional<double > mfl; if (!min_flow.is_none ()) mfl = py::cast<double >(min_flow); py::gil_scoped_release rel; auto pr = p.place_demand (fg, src, dst, flowClass, volume, tpf, mfl); py::gil_scoped_acquire acq; return py::make_tuple (pr.first , pr.second ); }, py::arg (" flow_graph" ), py::arg (" src" ), py::arg (" dst" ), py::arg (" flowClass" ), py::arg (" volume" ), py::arg (" target_per_flow" ) = py::none (), py::arg (" min_flow" ) = py::none ())
0 commit comments