Skip to content

Commit 55c858b

Browse files
yfukaiJoOkuma
andauthored
Adding update_node signal and connected them to spatial filters (#266)
* added update_node signal and connected them to spatial filters * fixed failing benchmarks * bm timeout longer * Update src/tracksdata/graph/filters/_spatial_filter.py Co-authored-by: Jordão Bragantini <jordao.bragantini@gmail.com> * fixed unnecesary calls of dict() * lint * moved heavy computation to is_signal_on blocks * Apply suggestion from @JoOkuma * Update src/tracksdata/graph/_sql_graph.py Co-authored-by: Jordão Bragantini <jordao.bragantini@gmail.com> * Update src/tracksdata/graph/_sql_graph.py Co-authored-by: Jordão Bragantini <jordao.bragantini@gmail.com> * Update src/tracksdata/graph/_rustworkx_graph.py Co-authored-by: Jordão Bragantini <jordao.bragantini@gmail.com> --------- Co-authored-by: Jordão Bragantini <jordao.bragantini@gmail.com>
1 parent 4d47cd8 commit 55c858b

9 files changed

Lines changed: 261 additions & 57 deletions

File tree

.github/workflows/benchmarks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ jobs:
6060
--machine github-actions \
6161
--python 3.12 \
6262
--factor 1.5 \
63+
--attribute timeout=180 \
6364
--show-stderr || status=$?
6465
if [ "$status" -eq 2 ]; then
6566
echo "asv: benchmark run failed (exit 2). Failing CI." >&2

src/tracksdata/graph/_base_graph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ class BaseGraph(abc.ABC):
4747
Base class for a graph backend.
4848
"""
4949

50-
node_added = Signal(int)
51-
node_removed = Signal(int)
50+
node_added = Signal(int, object)
51+
node_removed = Signal(int, object)
52+
node_updated = Signal(int, object, object)
5253

5354
def __init__(self) -> None:
5455
self._cache = {}

src/tracksdata/graph/_graph_view.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable, Sequence
2-
from typing import Any, Literal, overload
2+
from typing import Any, Literal, cast, overload
33

44
import bidict
55
import polars as pl
@@ -400,8 +400,10 @@ def add_node(
400400
else:
401401
self._out_of_sync = True
402402

403-
self._root.node_added.emit_fast(parent_node_id)
404-
self.node_added.emit_fast(parent_node_id)
403+
if is_signal_on(self._root.node_added):
404+
self._root.node_added.emit(parent_node_id, attrs)
405+
if is_signal_on(self.node_added):
406+
self.node_added.emit(parent_node_id, attrs)
405407

406408
return parent_node_id
407409

@@ -417,12 +419,12 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None
417419
self._out_of_sync = True
418420

419421
if is_signal_on(self._root.node_added):
420-
for node_id in parent_node_ids:
421-
self._root.node_added.emit_fast(node_id)
422+
for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True):
423+
self._root.node_added.emit(node_id, node_attrs)
422424

423425
if is_signal_on(self.node_added):
424-
for node_id in parent_node_ids:
425-
self.node_added.emit_fast(node_id)
426+
for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True):
427+
self.node_added.emit(node_id, node_attrs)
426428

427429
return parent_node_ids
428430

@@ -446,9 +448,11 @@ def remove_node(self, node_id: int) -> None:
446448
if node_id not in self._external_to_local:
447449
raise ValueError(f"Node {node_id} does not exist in the graph.")
448450

451+
if is_signal_on(self.node_removed):
452+
old_attrs = self.nodes[node_id].to_dict()
453+
449454
# Remove from root graph first, because removing bounding box requires node attrs
450455
self._root.remove_node(node_id)
451-
self.node_removed.emit_fast(node_id)
452456

453457
if self.sync:
454458
# Get the local node ID and remove from local graph
@@ -474,6 +478,9 @@ def remove_node(self, node_id: int) -> None:
474478
else:
475479
self._out_of_sync = True
476480

481+
if is_signal_on(self.node_removed):
482+
self.node_removed.emit(node_id, old_attrs)
483+
477484
def add_edge(
478485
self,
479486
source_id: int,
@@ -652,6 +659,12 @@ def update_node_attrs(
652659
) -> None:
653660
if node_ids is None:
654661
node_ids = self.node_ids()
662+
else:
663+
node_ids = list(node_ids)
664+
665+
if is_signal_on(self.node_updated):
666+
old_attrs_by_id = self._root.filter(node_ids=node_ids).node_attrs()
667+
old_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in old_attrs_by_id.to_dicts()}
655668

656669
self._root.update_node_attrs(
657670
node_ids=node_ids,
@@ -660,13 +673,23 @@ def update_node_attrs(
660673
# because attributes are passed by reference, we need don't need if both are rustworkx graphs
661674
if not self._is_root_rx_graph:
662675
if self.sync:
663-
super().update_node_attrs(
664-
node_ids=self._map_to_local(node_ids),
665-
attrs=attrs,
666-
)
676+
with self.node_updated.blocked():
677+
super().update_node_attrs(
678+
node_ids=self._map_to_local(node_ids),
679+
attrs=attrs,
680+
)
667681
else:
668682
self._out_of_sync = True
669683

684+
if is_signal_on(self.node_updated):
685+
for node_id in node_ids:
686+
old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy
687+
self.node_updated.emit(
688+
node_id,
689+
old_attrs_by_id[node_id],
690+
self._root.nodes[node_id].to_dict(),
691+
)
692+
670693
def update_edge_attrs(
671694
self,
672695
*,

src/tracksdata/graph/_mapped_graph_mixin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from collections.abc import Sequence
9+
from numbers import Integral
910
from typing import Any, overload
1011

1112
import bidict
@@ -84,7 +85,7 @@ def _map_to_external(self, local_ids: int | Sequence[int] | None) -> int | list[
8485
"""
8586
if local_ids is None:
8687
return None
87-
if isinstance(local_ids, int):
88+
if isinstance(local_ids, Integral):
8889
return self._local_to_external[local_ids]
8990
return [self._local_to_external[lid] for lid in local_ids]
9091

@@ -113,7 +114,7 @@ def _map_to_local(self, external_ids: int | Sequence[int] | None) -> int | list[
113114
"""
114115
if external_ids is None:
115116
return None
116-
if isinstance(external_ids, int):
117+
if isinstance(external_ids, Integral):
117118
return self._external_to_local[external_ids]
118119
return [self._external_to_local[eid] for eid in external_ids]
119120

src/tracksdata/graph/_rustworkx_graph.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,8 @@ def add_node(
492492

493493
node_id = self.rx_graph.add_node(attrs)
494494
self._time_to_nodes.setdefault(attrs["t"], []).append(node_id)
495-
self.node_added.emit_fast(node_id)
495+
if is_signal_on(self.node_added):
496+
self.node_added.emit(node_id, attrs)
496497
return node_id
497498

498499
def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None = None) -> list[int]:
@@ -523,8 +524,8 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None
523524

524525
# checking if it has connections to reduce overhead
525526
if is_signal_on(self.node_added):
526-
for node_id in node_indices:
527-
self.node_added.emit_fast(node_id)
527+
for node_id, node_attrs in zip(node_indices, nodes, strict=True):
528+
self.node_added.emit(node_id, node_attrs)
528529

529530
return node_indices
530531

@@ -548,7 +549,9 @@ def remove_node(self, node_id: int) -> None:
548549
if node_id not in self.rx_graph.node_indices():
549550
raise ValueError(f"Node {node_id} does not exist in the graph.")
550551

551-
self.node_removed.emit_fast(node_id)
552+
old_attrs = None
553+
if is_signal_on(self.node_removed):
554+
old_attrs = dict(self.rx_graph[node_id])
552555

553556
# Get the time value before removing the node
554557
t = self.rx_graph[node_id]["t"]
@@ -566,6 +569,9 @@ def remove_node(self, node_id: int) -> None:
566569
if self._overlaps is not None:
567570
self._overlaps = [overlap for overlap in self._overlaps if node_id != overlap[0] and node_id != overlap[1]]
568571

572+
if is_signal_on(self.node_removed):
573+
self.node_removed.emit(node_id, old_attrs)
574+
569575
def add_edge(
570576
self,
571577
source_id: int,
@@ -1217,6 +1223,9 @@ def update_node_attrs(
12171223
if node_ids is None:
12181224
node_ids = self.node_ids()
12191225

1226+
if is_signal_on(self.node_updated):
1227+
old_attrs_by_id = {node_id: dict(self._graph[node_id]) for node_id in node_ids}
1228+
12201229
for key, value in attrs.items():
12211230
if key not in self.node_attr_keys():
12221231
raise ValueError(f"Node attribute key '{key}' not found in graph. Expected '{self.node_attr_keys()}'")
@@ -1231,6 +1240,10 @@ def update_node_attrs(
12311240
for node_id, v in zip(node_ids, value, strict=False):
12321241
self._graph[node_id][key] = v
12331242

1243+
if is_signal_on(self.node_updated):
1244+
for node_id in node_ids:
1245+
self.node_updated.emit(node_id, old_attrs_by_id[node_id], dict(self._graph[node_id]))
1246+
12341247
def update_edge_attrs(
12351248
self,
12361249
*,
@@ -1616,7 +1629,8 @@ def add_node(
16161629
self._next_external_id = max(self._next_external_id, index + 1)
16171630
# Add mapping using mixin
16181631
self._add_id_mapping(node_id, index)
1619-
self.node_added.emit_fast(index)
1632+
if is_signal_on(self.node_added):
1633+
self.node_added.emit(index, attrs)
16201634
return index
16211635

16221636
def bulk_add_nodes(
@@ -1662,8 +1676,8 @@ def bulk_add_nodes(
16621676
self._add_id_mappings(list(zip(graph_ids, indices, strict=True)))
16631677

16641678
if is_signal_on(self.node_added):
1665-
for index in indices:
1666-
self.node_added.emit_fast(index)
1679+
for index, node_attrs in zip(indices, nodes, strict=True):
1680+
self.node_added.emit(index, node_attrs)
16671681

16681682
return indices
16691683

@@ -1941,8 +1955,25 @@ def update_node_attrs(
19411955
node_ids : Sequence[int] | None
19421956
The node ids to update.
19431957
"""
1944-
node_ids = self._get_local_ids() if node_ids is None else self._map_to_local(node_ids)
1945-
super().update_node_attrs(attrs=attrs, node_ids=node_ids)
1958+
external_node_ids = self.node_ids() if node_ids is None else node_ids
1959+
local_node_ids = self._map_to_local(external_node_ids)
1960+
1961+
if is_signal_on(self.node_updated):
1962+
old_attrs_by_id = {
1963+
external_node_id: dict(self._graph[local_node_id])
1964+
for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True)
1965+
}
1966+
1967+
with self.node_updated.blocked():
1968+
super().update_node_attrs(attrs=attrs, node_ids=local_node_ids)
1969+
1970+
if is_signal_on(self.node_updated) and old_attrs_by_id is not None:
1971+
for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True):
1972+
self.node_updated.emit(
1973+
external_node_id,
1974+
old_attrs_by_id[external_node_id],
1975+
dict(self._graph[local_node_id]),
1976+
)
19461977

19471978
def remove_node(self, node_id: int) -> None:
19481979
"""
@@ -1963,11 +1994,15 @@ def remove_node(self, node_id: int) -> None:
19631994

19641995
local_node_id = self._map_to_local(node_id)
19651996

1966-
self.node_removed.emit_fast(node_id)
1997+
if is_signal_on(self.node_removed):
1998+
old_attrs = dict(self._graph[local_node_id])
1999+
19672000
with self.node_removed.blocked():
19682001
super().remove_node(local_node_id)
19692002

19702003
self._remove_id_mapping(external_id=node_id)
2004+
if is_signal_on(self.node_removed):
2005+
self.node_removed.emit(node_id, old_attrs)
19712006

19722007
def filter(
19732008
self,

src/tracksdata/graph/_sql_graph.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,8 @@ def add_node(
718718
if index is None:
719719
self._max_id_per_time[time] = node_id
720720

721-
self.node_added.emit_fast(node_id)
721+
if is_signal_on(self.node_added):
722+
self.node_added.emit(node_id, attrs)
722723

723724
return node_id
724725

@@ -785,8 +786,9 @@ def bulk_add_nodes(
785786
self._chunked_sa_write(Session.bulk_insert_mappings, nodes, self.Node)
786787

787788
if is_signal_on(self.node_added):
788-
for node_id in node_ids:
789-
self.node_added.emit_fast(node_id)
789+
for node_id, node_attrs in zip(node_ids, nodes, strict=True):
790+
new_attrs = {key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID}
791+
self.node_added.emit(node_id, new_attrs)
790792

791793
return node_ids
792794

@@ -808,14 +810,15 @@ def remove_node(self, node_id: int) -> None:
808810
ValueError
809811
If the node_id does not exist in the graph.
810812
"""
811-
self.node_removed.emit_fast(node_id)
812-
813813
with Session(self._engine) as session:
814814
# Check if the node exists
815815
node = session.query(self.Node).filter(self.Node.node_id == node_id).first()
816816
if node is None:
817817
raise ValueError(f"Node {node_id} does not exist in the graph.")
818818

819+
if is_signal_on(self.node_removed):
820+
old_attrs = {key: getattr(node, key) for key in self.node_attr_keys()}
821+
819822
# Remove all edges where this node is source or target
820823
session.query(self.Edge).filter(
821824
sa.or_(self.Edge.source_id == node_id, self.Edge.target_id == node_id)
@@ -829,6 +832,8 @@ def remove_node(self, node_id: int) -> None:
829832
# Remove the node itself
830833
session.delete(node)
831834
session.commit()
835+
if is_signal_on(self.node_removed):
836+
self.node_removed.emit(node_id, old_attrs)
832837

833838
def add_edge(
834839
self,
@@ -1813,8 +1818,27 @@ def update_node_attrs(
18131818
if "t" in attrs:
18141819
raise ValueError("Node attribute 't' cannot be updated.")
18151820

1821+
updated_node_ids = self.node_ids() if node_ids is None else list(node_ids)
1822+
if len(updated_node_ids) == 0:
1823+
return
1824+
1825+
attr_keys = self.node_attr_keys()
1826+
if is_signal_on(self.node_updated):
1827+
old_df = self.filter(node_ids=updated_node_ids).node_attrs(
1828+
attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]
1829+
)
1830+
old_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in old_df.rows(named=True)}
1831+
18161832
self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs)
18171833

1834+
if is_signal_on(self.node_updated):
1835+
new_df = self.filter(node_ids=updated_node_ids).node_attrs(
1836+
attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]
1837+
)
1838+
new_attrs_by_id = {row[DEFAULT_ATTR_KEYS.NODE_ID]: row for row in new_df.rows(named=True)}
1839+
for node_id in updated_node_ids:
1840+
self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id])
1841+
18181842
def update_edge_attrs(
18191843
self,
18201844
*,

0 commit comments

Comments
 (0)