@@ -123,7 +123,9 @@ class GraphNode(FreExNode):
123123
124124 _graph_ex : "FreExGraph"
125125
126- def __init__ (self , uid : str , * , graph : "FreExGraph" , parents : Set [str ] = None ):
126+ def __init__ (
127+ self , uid : str = None , * , graph : "FreExGraph" , parents : Set [str ] = None
128+ ):
127129 super ().__init__ (uid = uid , parents = parents , graph_ref = graph ._graph )
128130 self ._graph_ex = graph
129131
@@ -278,6 +280,46 @@ def get_node(self, node_id: str) -> Optional[AnyFreExNode]:
278280 return None
279281 return self ._graph .nodes [node_id ]["content" ]
280282
283+ def sub_graph (
284+ self , from_node_id : str , to_nodes_id : Optional [List [str ]] = None
285+ ) -> "FreExGraph" :
286+ """Utility method to retrieve a subgraph from a given node until the end of the graph or until one of the
287+ provided node is encountered.
288+
289+ :param from_node_id: node from which the sub graph start
290+ :param to_nodes_id: nodes on which the sub graph stop, if none encountered, subgraph go until the leaf nodes
291+ :return: a sub graph delimited by the provided nodes id
292+ """
293+ from_node : FreExNode = self .get_node (from_node_id )
294+ assert (
295+ from_node is not None
296+ ), f"Error sub graph from node { from_node_id } , node has to be in the execution graph"
297+
298+ nodes_in_subgraph : List [FreExNode ] = []
299+ nodes_in_subgraph_id : List [str ] = []
300+
301+ def add_node_in_subgraph (current_node : FreExNode ):
302+ current_node .parents = {
303+ p for p in current_node .parents if p in nodes_in_subgraph_id
304+ }
305+ nodes_in_subgraph .append (current_node )
306+ if to_nodes_id is not None and current_node in to_nodes_id :
307+ return
308+ all_suc = list (self ._graph .successors (current_node .id ))
309+ for successor in all_suc :
310+ n = self .get_node (successor )
311+ assert (
312+ n is not None
313+ ), f"Error sub graph to node { n .id } , node has to be in the execution graph"
314+ add_node_in_subgraph (n )
315+ nodes_in_subgraph_id .append (n .id )
316+
317+ add_node_in_subgraph (from_node )
318+
319+ sub_graph = FreExGraph ()
320+ sub_graph .add_nodes (nodes_in_subgraph )
321+ return sub_graph
322+
281323 def fork_from_node (
282324 self , forked_node : FreExNode , * , join_id : Optional [str ] = None
283325 ) -> None :
0 commit comments