Skip to content

Commit 8c45fc2

Browse files
author
qballand
committed
Improve standard visitor support
1 parent 3d50334 commit 8c45fc2

2 files changed

Lines changed: 16 additions & 1 deletion

File tree

freexgraph/freexgraph.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,20 @@ class FreExNode:
3939
parents: Set[str]
4040
"""Parents of the node to add """
4141

42+
extension_node: bool
43+
""" """
44+
4245
def __init__(
4346
self,
4447
uid: str = None,
4548
*,
4649
fork_id: Optional[str] = None,
4750
parents: Set[str] = None,
4851
graph_ref: nx.DiGraph = None,
52+
extension_node: bool = False,
4953
):
5054
self.parents = parents or set()
55+
self.extension_node = extension_node
5156
self._graph_ref = graph_ref
5257
self._id = uid
5358
self._fork_id = fork_id
@@ -63,6 +68,8 @@ def apply_accept_(self, visitor: AnyVisitor) -> bool:
6368
from freexgraph.standard_visitor import is_standard_visitor
6469

6570
if is_standard_visitor(visitor):
71+
if self.extension_node:
72+
self.accept(visitor)
6673
return visitor.visit_standard(self)
6774
return self.accept(visitor)
6875

@@ -126,7 +133,9 @@ class GraphNode(FreExNode):
126133
def __init__(
127134
self, uid: str = None, *, graph: "FreExGraph", parents: Set[str] = None
128135
):
129-
super().__init__(uid=uid, parents=parents, graph_ref=graph._graph)
136+
super().__init__(
137+
uid=uid, parents=parents, graph_ref=graph._graph, extension_node=True
138+
)
130139
self._graph_ex = graph
131140

132141
def accept(self, visitor: AnyVisitor) -> bool:

test/basic_freexgraph_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,12 @@ def test_graph_node(valid_basic_execution_graph, visitor_test):
207207
assert visitor_test.inner_graph_started[0].startswith(id_graph)
208208
assert visitor_test.inner_graph_started == visitor_test.inner_graph_ended
209209

210+
# test find in graph node from root
211+
finder = FindFirstVisitor(lambda n: n.id.startswith("id3"))
212+
finder.visit(execution_graph.root)
213+
assert finder.found()
214+
assert finder.result.id.startswith("id3")
215+
210216

211217
def test_add_nodes(node_list_complex_graph, visitor_test):
212218
execution_graph = FreExGraph()

0 commit comments

Comments
 (0)