Skip to content

Commit a5c9435

Browse files
committed
improving performance of explorer
1 parent 338b08f commit a5c9435

1 file changed

Lines changed: 126 additions & 80 deletions

File tree

ngraph/explorer.py

Lines changed: 126 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
23
import logging
34
from dataclasses import dataclass, field
45
from typing import Dict, List, Set, Optional
@@ -57,7 +58,7 @@ class TreeStats:
5758
total_power: float = 0.0
5859

5960

60-
@dataclass
61+
@dataclass(eq=False)
6162
class TreeNode:
6263
"""
6364
Represents a node in the hierarchical tree.
@@ -79,6 +80,14 @@ class TreeNode:
7980
stats: TreeStats = field(default_factory=TreeStats)
8081
raw_nodes: List[Node] = field(default_factory=list)
8182

83+
def __hash__(self) -> int:
84+
"""
85+
Make the node hashable based on object identity.
86+
This preserves uniqueness in sets/dicts without
87+
forcing equality by fields.
88+
"""
89+
return id(self)
90+
8291
def add_child(self, child_name: str) -> TreeNode:
8392
"""
8493
Ensure a child node named 'child_name' exists and return it.
@@ -135,6 +144,9 @@ def __init__(
135144
self._node_map: Dict[str, TreeNode] = {} # node_name -> deepest TreeNode
136145
self._path_map: Dict[str, TreeNode] = {} # path -> TreeNode
137146

147+
# Cache for storing each node's ancestor set:
148+
self._ancestors_cache: Dict[TreeNode, Set[TreeNode]] = {}
149+
138150
@classmethod
139151
def explore_network(
140152
cls,
@@ -160,15 +172,15 @@ def explore_network(
160172
# 1) Build the hierarchical structure
161173
instance.root_node = instance._build_hierarchy_tree()
162174

163-
# 2) Compute subtree sets
175+
# 2) Compute subtree sets (subtree_nodes)
164176
instance._compute_subtree_sets(instance.root_node)
165177

166178
# 3) Build node and path maps
167179
instance._build_node_map(instance.root_node)
168180
instance._build_path_map(instance.root_node)
169181

170-
# 4) Aggregate statistics (nodes, links, cost, power)
171-
instance._aggregate_stats(instance.root_node)
182+
# 4) Aggregate statistics (node counts, link stats, cost, power)
183+
instance._compute_statistics()
172184

173185
return instance
174186

@@ -270,96 +282,130 @@ def _roll_up_if_leaf(self, path: str) -> str:
270282
node = node.parent
271283
return self._compute_full_path(node)
272284

273-
def _aggregate_stats(self, node: TreeNode) -> None:
285+
def _get_ancestors(self, node: TreeNode) -> Set[TreeNode]:
286+
"""
287+
Return a cached set of this node's ancestors (including itself),
288+
up to the root.
274289
"""
275-
Summarize node count, link stats, and cost/power usage for this subtree,
276-
then recurse to children.
290+
if node in self._ancestors_cache:
291+
return self._ancestors_cache[node]
277292

278-
Args:
279-
node (TreeNode): The current tree node to process.
293+
ancestors = set()
294+
current = node
295+
while current is not None:
296+
ancestors.add(current)
297+
current = current.parent
298+
self._ancestors_cache[node] = ancestors
299+
return ancestors
300+
301+
def _compute_statistics(self) -> None:
280302
"""
281-
# 1) Node count
282-
node.stats.node_count = len(node.subtree_nodes)
303+
Computes all subtree statistics in a more efficient manner:
283304
284-
# 2) Accumulate node-level cost/power
285-
for node_name in node.subtree_nodes:
286-
nd = self.network.nodes[node_name]
305+
- node_count is set from each node's 'subtree_nodes' (already stored).
306+
- For each network node, cost/power is added to all ancestors in the
307+
hierarchy.
308+
- For each link, we figure out which subtrees see it as internal or
309+
external, and update stats accordingly.
310+
"""
311+
312+
# 1) node_count: use subtree sets
313+
# (each node gets the size of subtree_nodes)
314+
# stats are zeroed initially in the constructor.
315+
def set_node_counts(node: TreeNode) -> None:
316+
node.stats.node_count = len(node.subtree_nodes)
317+
for child in node.children.values():
318+
set_node_counts(child)
319+
320+
set_node_counts(self.root_node)
321+
322+
# 2) Accumulate node cost/power into all ancestor stats
323+
for nd in self.network.nodes.values():
287324
hw_component = nd.attrs.get("hw_component")
325+
comp = None
288326
if hw_component:
289327
comp = self.components_library.get(hw_component)
290-
if comp:
291-
node.stats.total_cost += comp.total_cost()
292-
node.stats.total_power += comp.total_power()
293-
else:
328+
if comp is None:
294329
logger.warning(
295330
"Node '%s' references unknown hw_component '%s'.",
296-
node_name,
331+
nd.name,
297332
hw_component,
298333
)
299334

300-
# Minor early-out: if this node is a leaf and has <= 1 raw node,
301-
# no links are possible within or external to it.
302-
# (If it has multiple raw_nodes, we do need the link checks.)
303-
if node.is_leaf() and len(node.raw_nodes) <= 1:
304-
return
305-
306-
# 3) Evaluate link-level stats
335+
# Walk up from the deepest node
336+
node_for_name = self._node_map[nd.name]
337+
ancestors = self._get_ancestors(node_for_name)
338+
if comp:
339+
cval = comp.total_cost()
340+
pval = comp.total_power()
341+
for an in ancestors:
342+
an.stats.total_cost += cval
343+
an.stats.total_power += pval
344+
345+
# 3) Single pass to accumulate link stats
346+
# For each link, determine for which subtrees it's internal vs external,
347+
# and update stats accordingly. Also add link hw cost/power if applicable.
307348
for link in self.network.links.values():
308-
src_in = link.source in node.subtree_nodes
309-
dst_in = link.target in node.subtree_nodes
310-
311-
if src_in and dst_in:
312-
# Internal link
313-
node.stats.internal_link_count += 1
314-
node.stats.internal_link_capacity += link.capacity
315-
316-
# If there's an hw_component on the link, add cost/power
317-
hw_comp = link.attrs.get("hw_component")
318-
if hw_comp:
319-
comp = self.components_library.get(hw_comp)
320-
if comp:
321-
node.stats.total_cost += comp.total_cost()
322-
node.stats.total_power += comp.total_power()
323-
else:
324-
logger.warning(
325-
"Link '%s->%s' references unknown hw_component '%s'.",
326-
link.source,
327-
link.target,
328-
hw_comp,
329-
)
330-
elif src_in ^ dst_in:
331-
# External link
332-
node.stats.external_link_count += 1
333-
node.stats.external_link_capacity += link.capacity
334-
335-
other_side = link.target if src_in else link.source
336-
other_node = self._node_map.get(other_side)
337-
if other_node:
338-
other_path = self._compute_full_path(other_node)
339-
bd = node.stats.external_link_details.setdefault(
340-
other_path, ExternalLinkBreakdown()
349+
src = link.source
350+
dst = link.target
351+
352+
# Check link's hw_component
353+
hw_comp = link.attrs.get("hw_component")
354+
link_comp = None
355+
if hw_comp:
356+
link_comp = self.components_library.get(hw_comp)
357+
if link_comp is None:
358+
logger.warning(
359+
"Link '%s->%s' references unknown hw_component '%s'.",
360+
src,
361+
dst,
362+
hw_comp,
341363
)
342-
bd.link_count += 1
343-
bd.link_capacity += link.capacity
344-
345-
# Possibly add link optic/hw cost/power
346-
hw_comp = link.attrs.get("hw_component")
347-
if hw_comp:
348-
comp = self.components_library.get(hw_comp)
349-
if comp:
350-
node.stats.total_cost += comp.total_cost()
351-
node.stats.total_power += comp.total_power()
352-
else:
353-
logger.warning(
354-
"Link '%s->%s' references unknown hw_component '%s'.",
355-
link.source,
356-
link.target,
357-
hw_comp,
358-
)
359-
360-
# 4) Recurse to children
361-
for child in node.children.values():
362-
self._aggregate_stats(child)
364+
365+
src_node = self._node_map[src]
366+
dst_node = self._node_map[dst]
367+
A_src = self._get_ancestors(src_node)
368+
A_dst = self._get_ancestors(dst_node)
369+
370+
# Intersection => internal
371+
# XOR => external
372+
inter = A_src & A_dst
373+
xor = A_src ^ A_dst
374+
375+
# Capacity
376+
cap = link.capacity
377+
378+
# For cost/power from link, we add to any node
379+
# that sees it either internal or external.
380+
link_cost = link_comp.total_cost() if link_comp else 0.0
381+
link_power = link_comp.total_power() if link_comp else 0.0
382+
383+
# Internal link updates
384+
for an in inter:
385+
an.stats.internal_link_count += 1
386+
an.stats.internal_link_capacity += cap
387+
an.stats.total_cost += link_cost
388+
an.stats.total_power += link_power
389+
390+
# External link updates
391+
for an in xor:
392+
an.stats.external_link_count += 1
393+
an.stats.external_link_capacity += cap
394+
an.stats.total_cost += link_cost
395+
an.stats.total_power += link_power
396+
397+
# Update external_link_details
398+
if an in A_src:
399+
# 'an' sees the other side as 'dst'
400+
other_path = self._compute_full_path(dst_node)
401+
else:
402+
# 'an' sees the other side as 'src'
403+
other_path = self._compute_full_path(src_node)
404+
bd = an.stats.external_link_details.setdefault(
405+
other_path, ExternalLinkBreakdown()
406+
)
407+
bd.link_count += 1
408+
bd.link_capacity += cap
363409

364410
def print_tree(
365411
self,

0 commit comments

Comments
 (0)