11from __future__ import annotations
2+
23import logging
34from dataclasses import dataclass , field
45from typing import Dict , List , Set , Optional
@@ -57,7 +58,7 @@ class TreeStats:
5758 total_power : float = 0.0
5859
5960
60- @dataclass
61+ @dataclass ( eq = False )
6162class TreeNode :
6263 """
6364 Represents a node in the hierarchical tree.
@@ -79,6 +80,14 @@ class TreeNode:
7980 stats : TreeStats = field (default_factory = TreeStats )
8081 raw_nodes : List [Node ] = field (default_factory = list )
8182
83+ def __hash__ (self ) -> int :
84+ """
85+ Make the node hashable based on object identity.
86+ This preserves uniqueness in sets/dicts without
87+ forcing equality by fields.
88+ """
89+ return id (self )
90+
8291 def add_child (self , child_name : str ) -> TreeNode :
8392 """
8493 Ensure a child node named 'child_name' exists and return it.
@@ -135,6 +144,9 @@ def __init__(
135144 self ._node_map : Dict [str , TreeNode ] = {} # node_name -> deepest TreeNode
136145 self ._path_map : Dict [str , TreeNode ] = {} # path -> TreeNode
137146
147+ # Cache for storing each node's ancestor set:
148+ self ._ancestors_cache : Dict [TreeNode , Set [TreeNode ]] = {}
149+
138150 @classmethod
139151 def explore_network (
140152 cls ,
@@ -160,15 +172,15 @@ def explore_network(
160172 # 1) Build the hierarchical structure
161173 instance .root_node = instance ._build_hierarchy_tree ()
162174
163- # 2) Compute subtree sets
175+ # 2) Compute subtree sets (subtree_nodes)
164176 instance ._compute_subtree_sets (instance .root_node )
165177
166178 # 3) Build node and path maps
167179 instance ._build_node_map (instance .root_node )
168180 instance ._build_path_map (instance .root_node )
169181
170- # 4) Aggregate statistics (nodes, links , cost, power)
171- instance ._aggregate_stats ( instance . root_node )
182+ # 4) Aggregate statistics (node counts, link stats , cost, power)
183+ instance ._compute_statistics ( )
172184
173185 return instance
174186
@@ -270,96 +282,130 @@ def _roll_up_if_leaf(self, path: str) -> str:
270282 node = node .parent
271283 return self ._compute_full_path (node )
272284
273- def _aggregate_stats (self , node : TreeNode ) -> None :
285+ def _get_ancestors (self , node : TreeNode ) -> Set [TreeNode ]:
286+ """
287+ Return a cached set of this node's ancestors (including itself),
288+ up to the root.
274289 """
275- Summarize node count, link stats, and cost/power usage for this subtree,
276- then recurse to children.
290+ if node in self . _ancestors_cache :
291+ return self . _ancestors_cache [ node ]
277292
278- Args:
279- node (TreeNode): The current tree node to process.
293+ ancestors = set ()
294+ current = node
295+ while current is not None :
296+ ancestors .add (current )
297+ current = current .parent
298+ self ._ancestors_cache [node ] = ancestors
299+ return ancestors
300+
301+ def _compute_statistics (self ) -> None :
280302 """
281- # 1) Node count
282- node .stats .node_count = len (node .subtree_nodes )
303+ Computes all subtree statistics in a more efficient manner:
283304
284- # 2) Accumulate node-level cost/power
285- for node_name in node .subtree_nodes :
286- nd = self .network .nodes [node_name ]
305+ - node_count is set from each node's 'subtree_nodes' (already stored).
306+ - For each network node, cost/power is added to all ancestors in the
307+ hierarchy.
308+ - For each link, we figure out which subtrees see it as internal or
309+ external, and update stats accordingly.
310+ """
311+
312+ # 1) node_count: use subtree sets
313+ # (each node gets the size of subtree_nodes)
314+ # stats are zeroed initially in the constructor.
315+ def set_node_counts (node : TreeNode ) -> None :
316+ node .stats .node_count = len (node .subtree_nodes )
317+ for child in node .children .values ():
318+ set_node_counts (child )
319+
320+ set_node_counts (self .root_node )
321+
322+ # 2) Accumulate node cost/power into all ancestor stats
323+ for nd in self .network .nodes .values ():
287324 hw_component = nd .attrs .get ("hw_component" )
325+ comp = None
288326 if hw_component :
289327 comp = self .components_library .get (hw_component )
290- if comp :
291- node .stats .total_cost += comp .total_cost ()
292- node .stats .total_power += comp .total_power ()
293- else :
328+ if comp is None :
294329 logger .warning (
295330 "Node '%s' references unknown hw_component '%s'." ,
296- node_name ,
331+ nd . name ,
297332 hw_component ,
298333 )
299334
300- # Minor early-out: if this node is a leaf and has <= 1 raw node,
301- # no links are possible within or external to it.
302- # (If it has multiple raw_nodes, we do need the link checks.)
303- if node .is_leaf () and len (node .raw_nodes ) <= 1 :
304- return
305-
306- # 3) Evaluate link-level stats
335+ # Walk up from the deepest node
336+ node_for_name = self ._node_map [nd .name ]
337+ ancestors = self ._get_ancestors (node_for_name )
338+ if comp :
339+ cval = comp .total_cost ()
340+ pval = comp .total_power ()
341+ for an in ancestors :
342+ an .stats .total_cost += cval
343+ an .stats .total_power += pval
344+
345+ # 3) Single pass to accumulate link stats
346+ # For each link, determine for which subtrees it's internal vs external,
347+ # and update stats accordingly. Also add link hw cost/power if applicable.
307348 for link in self .network .links .values ():
308- src_in = link .source in node .subtree_nodes
309- dst_in = link .target in node .subtree_nodes
310-
311- if src_in and dst_in :
312- # Internal link
313- node .stats .internal_link_count += 1
314- node .stats .internal_link_capacity += link .capacity
315-
316- # If there's an hw_component on the link, add cost/power
317- hw_comp = link .attrs .get ("hw_component" )
318- if hw_comp :
319- comp = self .components_library .get (hw_comp )
320- if comp :
321- node .stats .total_cost += comp .total_cost ()
322- node .stats .total_power += comp .total_power ()
323- else :
324- logger .warning (
325- "Link '%s->%s' references unknown hw_component '%s'." ,
326- link .source ,
327- link .target ,
328- hw_comp ,
329- )
330- elif src_in ^ dst_in :
331- # External link
332- node .stats .external_link_count += 1
333- node .stats .external_link_capacity += link .capacity
334-
335- other_side = link .target if src_in else link .source
336- other_node = self ._node_map .get (other_side )
337- if other_node :
338- other_path = self ._compute_full_path (other_node )
339- bd = node .stats .external_link_details .setdefault (
340- other_path , ExternalLinkBreakdown ()
349+ src = link .source
350+ dst = link .target
351+
352+ # Check link's hw_component
353+ hw_comp = link .attrs .get ("hw_component" )
354+ link_comp = None
355+ if hw_comp :
356+ link_comp = self .components_library .get (hw_comp )
357+ if link_comp is None :
358+ logger .warning (
359+ "Link '%s->%s' references unknown hw_component '%s'." ,
360+ src ,
361+ dst ,
362+ hw_comp ,
341363 )
342- bd .link_count += 1
343- bd .link_capacity += link .capacity
344-
345- # Possibly add link optic/hw cost/power
346- hw_comp = link .attrs .get ("hw_component" )
347- if hw_comp :
348- comp = self .components_library .get (hw_comp )
349- if comp :
350- node .stats .total_cost += comp .total_cost ()
351- node .stats .total_power += comp .total_power ()
352- else :
353- logger .warning (
354- "Link '%s->%s' references unknown hw_component '%s'." ,
355- link .source ,
356- link .target ,
357- hw_comp ,
358- )
359-
360- # 4) Recurse to children
361- for child in node .children .values ():
362- self ._aggregate_stats (child )
364+
365+ src_node = self ._node_map [src ]
366+ dst_node = self ._node_map [dst ]
367+ A_src = self ._get_ancestors (src_node )
368+ A_dst = self ._get_ancestors (dst_node )
369+
370+ # Intersection => internal
371+ # XOR => external
372+ inter = A_src & A_dst
373+ xor = A_src ^ A_dst
374+
375+ # Capacity
376+ cap = link .capacity
377+
378+ # For cost/power from link, we add to any node
379+ # that sees it either internal or external.
380+ link_cost = link_comp .total_cost () if link_comp else 0.0
381+ link_power = link_comp .total_power () if link_comp else 0.0
382+
383+ # Internal link updates
384+ for an in inter :
385+ an .stats .internal_link_count += 1
386+ an .stats .internal_link_capacity += cap
387+ an .stats .total_cost += link_cost
388+ an .stats .total_power += link_power
389+
390+ # External link updates
391+ for an in xor :
392+ an .stats .external_link_count += 1
393+ an .stats .external_link_capacity += cap
394+ an .stats .total_cost += link_cost
395+ an .stats .total_power += link_power
396+
397+ # Update external_link_details
398+ if an in A_src :
399+ # 'an' sees the other side as 'dst'
400+ other_path = self ._compute_full_path (dst_node )
401+ else :
402+ # 'an' sees the other side as 'src'
403+ other_path = self ._compute_full_path (src_node )
404+ bd = an .stats .external_link_details .setdefault (
405+ other_path , ExternalLinkBreakdown ()
406+ )
407+ bd .link_count += 1
408+ bd .link_capacity += cap
363409
364410 def print_tree (
365411 self ,
0 commit comments