2828from concurrent .futures import ThreadPoolExecutor
2929from typing import TYPE_CHECKING , Any , Dict , Optional , Protocol , Set
3030
31- from ngraph .dsl .selectors import flatten_link_attrs , flatten_node_attrs
31+ from ngraph .dsl .selectors import (
32+ flatten_link_attrs ,
33+ flatten_node_attrs ,
34+ flatten_risk_group_attrs ,
35+ )
3236from ngraph .logging import get_logger
3337from ngraph .model .failure .policy_set import FailurePolicySet
3438from ngraph .types .base import FlowPlacement
@@ -72,23 +76,21 @@ def _create_cache_key(
7276 Returns:
7377 Tuple suitable for use as a cache key.
7478 """
75- # Basic components that are always hashable
7679 base_key = (
7780 tuple (sorted (excluded_nodes )),
7881 tuple (sorted (excluded_links )),
7982 analysis_name ,
8083 )
8184
82- # Normalize analysis_kwargs for hashing
8385 hashable_kwargs = []
8486 for key , value in sorted (analysis_kwargs .items ()):
85- # Try to hash the (key, value) pair directly
8687 if _is_hashable ((key , value )):
8788 hashable_kwargs .append ((key , value ))
8889 else :
8990 # Use object id for non-hashable values. Avoid str() which triggers
90- # deep __repr__ traversals on large objects (e.g., graphs with thousands
91- # of edges). id() works here because these objects persist across calls.
91+ # deep __repr__ traversals on large objects (e.g., graphs with
92+ # thousands of edges). id() is safe here because these objects
93+ # persist across calls within one FailureManager lifetime.
9294 hashable_kwargs .append ((key , f"{ type (value ).__name__ } _{ id (value )} " ))
9395
9496 return base_key + (tuple (hashable_kwargs ),)
@@ -244,6 +246,11 @@ def __init__(
244246 self .policy_name = policy_name
245247 self ._merged_node_attrs : dict [str , dict [str , Any ]] | None = None
246248 self ._merged_link_attrs : dict [str , dict [str , Any ]] | None = None
249+ self ._merged_rg_attrs : dict [str , dict [str , Any ]] | None = None
250+ self ._prepared_policy_matches : dict [int , dict [int , tuple [str , ...]]] = {}
251+ self ._risk_group_exclusions : (
252+ dict [str , tuple [frozenset [str ], frozenset [str ]]] | None
253+ ) = None
247254
248255 def get_failure_policy (self ) -> "FailurePolicy | None" :
249256 """Get failure policy for analysis.
@@ -264,6 +271,27 @@ def get_failure_policy(self) -> "FailurePolicy | None":
264271 else :
265272 return None
266273
274+ def _ensure_flattened_maps (self ) -> None :
275+ """Build flattened attribute views for all entity types (once).
276+
277+ Merges top-level model fields (name, disabled, etc.) with .attrs
278+ so that condition matching in apply_failures works uniformly.
279+ All three maps are built together to prevent partial initialization.
280+ """
281+ if self ._merged_node_attrs is not None :
282+ return
283+ self ._merged_node_attrs = {
284+ name : flatten_node_attrs (node ) for name , node in self .network .nodes .items ()
285+ }
286+ self ._merged_link_attrs = {
287+ lid : flatten_link_attrs (link , lid )
288+ for lid , link in self .network .links .items ()
289+ }
290+ self ._merged_rg_attrs = {
291+ name : flatten_risk_group_attrs (rg )
292+ for name , rg in self .network .risk_groups .items ()
293+ }
294+
267295 def compute_exclusions (
268296 self ,
269297 policy : "FailurePolicy | None" = None ,
@@ -293,29 +321,27 @@ def compute_exclusions(
293321 if policy is None :
294322 return excluded_nodes , excluded_links
295323
296- # Build merged views of nodes and links including top-level fields required by
297- # policy matching and risk-group expansion. Results are cached for reuse.
298- if self ._merged_node_attrs is None :
299- self ._merged_node_attrs = {
300- node_name : flatten_node_attrs (node )
301- for node_name , node in self .network .nodes .items ()
302- }
303- if self ._merged_link_attrs is None :
304- self ._merged_link_attrs = {
305- link_id : flatten_link_attrs (link , link_id )
306- for link_id , link in self .network .links .items ()
307- }
308-
324+ self ._ensure_flattened_maps ()
325+ assert (
326+ self ._merged_node_attrs is not None
327+ ) # guaranteed by _ensure_flattened_maps
328+ assert self ._merged_link_attrs is not None
329+ assert self ._merged_rg_attrs is not None
309330 node_map = self ._merged_node_attrs
310331 link_map = self ._merged_link_attrs
332+ rg_map = self ._merged_rg_attrs
333+ prepared_matches = self ._get_prepared_policy_matches (
334+ policy , node_map , link_map , rg_map
335+ )
311336
312337 # Apply failure policy with optional deterministic seed override
313338 failed_ids = policy .apply_failures (
314339 node_map ,
315340 link_map ,
316- self . network . risk_groups ,
341+ rg_map ,
317342 seed = seed_offset ,
318343 failure_trace = failure_trace ,
344+ prepared_matches = prepared_matches ,
319345 )
320346
321347 # Separate entity types for exclusion sets
@@ -325,23 +351,99 @@ def compute_exclusions(
325351 elif f_id in self .network .links :
326352 excluded_links .add (f_id )
327353 elif f_id in self .network .risk_groups :
328- # Recursively expand risk groups
329- risk_group = self .network .risk_groups [f_id ]
330- to_check = [risk_group ]
331- while to_check :
332- grp = to_check .pop ()
333- # Add all nodes/links in this risk group
334- for node_name , node in self .network .nodes .items ():
335- if grp .name in node .risk_groups :
336- excluded_nodes .add (node_name )
337- for link_id , link in self .network .links .items ():
338- if grp .name in link .risk_groups :
339- excluded_links .add (link_id )
340- # Check children recursively
341- to_check .extend (grp .children )
354+ risk_group_nodes , risk_group_links = self ._get_risk_group_exclusions (
355+ f_id
356+ )
357+ excluded_nodes .update (risk_group_nodes )
358+ excluded_links .update (risk_group_links )
342359
343360 return excluded_nodes , excluded_links
344361
362+ def _get_risk_group_exclusions (
363+ self ,
364+ risk_group_name : str ,
365+ ) -> tuple [frozenset [str ], frozenset [str ]]:
366+ """Return transitive node/link exclusions for a failed risk group."""
367+ if self ._risk_group_exclusions is None :
368+ self ._risk_group_exclusions = self ._build_risk_group_exclusions ()
369+ return self ._risk_group_exclusions .get (
370+ risk_group_name ,
371+ (frozenset (), frozenset ()),
372+ )
373+
374+ def _build_risk_group_exclusions (
375+ self ,
376+ ) -> dict [str , tuple [frozenset [str ], frozenset [str ]]]:
377+ """Build transitive member index for each risk group once per manager."""
378+ direct_nodes : dict [str , set [str ]] = {
379+ name : set () for name in self .network .risk_groups
380+ }
381+ direct_links : dict [str , set [str ]] = {
382+ name : set () for name in self .network .risk_groups
383+ }
384+
385+ for node_name , node in self .network .nodes .items ():
386+ for risk_group_name in node .risk_groups :
387+ direct_nodes .setdefault (risk_group_name , set ()).add (node_name )
388+
389+ for link_id , link in self .network .links .items ():
390+ for risk_group_name in link .risk_groups :
391+ direct_links .setdefault (risk_group_name , set ()).add (link_id )
392+
393+ expanded : dict [str , tuple [frozenset [str ], frozenset [str ]]] = {}
394+
395+ def expand_group (name : str ) -> tuple [frozenset [str ], frozenset [str ]]:
396+ cached = expanded .get (name )
397+ if cached is not None :
398+ return cached
399+ if name in visiting :
400+ return (
401+ frozenset (direct_nodes .get (name , ())),
402+ frozenset (direct_links .get (name , ())),
403+ )
404+
405+ visiting .add (name )
406+ nodes = set (direct_nodes .get (name , ()))
407+ links = set (direct_links .get (name , ()))
408+ risk_group = self .network .risk_groups .get (name )
409+ if risk_group is not None :
410+ for child in risk_group .children :
411+ child_nodes , child_links = expand_group (child .name )
412+ nodes .update (child_nodes )
413+ links .update (child_links )
414+
415+ result = (frozenset (nodes ), frozenset (links ))
416+ visiting .remove (name )
417+ expanded [name ] = result
418+ return result
419+
420+ visiting : set [str ] = set ()
421+ for risk_group_name in self .network .risk_groups :
422+ expand_group (risk_group_name )
423+
424+ return expanded
425+
426+ def _get_prepared_policy_matches (
427+ self ,
428+ policy : "FailurePolicy" ,
429+ node_map : dict [str , dict [str , Any ]],
430+ link_map : dict [str , dict [str , Any ]],
431+ rg_map : dict [str , dict [str , Any ]],
432+ ) -> dict [int , tuple [str , ...]]:
433+ """Prepare stable ordered candidate pools for a policy once per manager."""
434+ policy_key = id (policy )
435+ cached = self ._prepared_policy_matches .get (policy_key )
436+ if cached is not None :
437+ return cached
438+
439+ prepared = policy .prepare_matches (
440+ node_map ,
441+ link_map ,
442+ rg_map ,
443+ )
444+ self ._prepared_policy_matches [policy_key ] = prepared
445+ return prepared
446+
345447 def run_monte_carlo_analysis (
346448 self ,
347449 analysis_func : AnalysisFunction ,
@@ -432,10 +534,6 @@ def run_monte_carlo_analysis(
432534 f"parallelism={ parallelism } , policy={ self .policy_name } "
433535 )
434536
435- # Pre-compute worker arguments for all iterations
436- logger .debug ("Pre-computing failure exclusions for all iterations" )
437- pre_compute_start = time .time ()
438-
439537 # Baseline is always run first (no failures, separate from failure iterations)
440538 baseline_arg = (
441539 self .network ,
@@ -448,16 +546,16 @@ def run_monte_carlo_analysis(
448546 func_name ,
449547 )
450548
451- # Build failure iteration arguments (indexed 0..iterations-1)
549+ logger .debug ("Pre-computing failure exclusions for all iterations" )
550+ pre_compute_start = time .time ()
551+
452552 worker_args : list [tuple ] = []
453553 key_to_first_arg : dict [tuple , tuple ] = {}
454554 key_to_count : dict [tuple , int ] = {}
455555 key_to_trace : dict [tuple , dict [str , Any ]] = {}
456556
457557 for i in range (iterations ):
458558 seed_offset = seed + i if seed is not None else None
459-
460- # Pre-compute exclusions for this failure iteration
461559 trace = {} if store_failure_patterns else None
462560 excluded_nodes , excluded_links = self .compute_exclusions (
463561 policy , seed_offset , failure_trace = trace
@@ -475,14 +573,12 @@ def run_monte_carlo_analysis(
475573 )
476574 worker_args .append (arg )
477575
478- # Build deduplication key (excludes iteration index)
479576 dedup_key = _create_cache_key (
480577 excluded_nodes , excluded_links , func_name , analysis_kwargs
481578 )
482579 if dedup_key not in key_to_first_arg :
483580 key_to_first_arg [dedup_key ] = arg
484581 key_to_count [dedup_key ] = 1
485- # Store trace for first occurrence
486582 if trace is not None :
487583 key_to_trace [dedup_key ] = trace
488584 else :
@@ -493,7 +589,6 @@ def run_monte_carlo_analysis(
493589 f"Pre-computed { len (worker_args )} failure exclusion sets in { pre_compute_time :.2f} s"
494590 )
495591
496- # Prepare unique tasks (deduplicated by failure pattern + analysis params)
497592 unique_worker_args : list [tuple ] = list (key_to_first_arg .values ())
498593 num_unique_tasks : int = len (unique_worker_args )
499594 if iterations > 0 :
@@ -503,17 +598,14 @@ def run_monte_carlo_analysis(
503598
504599 start_time = time .time ()
505600
506- # Always run baseline first (separate from failure iterations)
507601 baseline_result_raw = self ._run_serial ([baseline_arg ])
508602 baseline_result = baseline_result_raw [0 ] if baseline_result_raw else None
509603
510- # Enrich baseline result with failure metadata
511604 if baseline_result is not None and hasattr (baseline_result , "failure_id" ):
512605 baseline_result .failure_id = ""
513606 baseline_result .failure_state = {"excluded_nodes" : [], "excluded_links" : []}
514- baseline_result .failure_trace = None # No policy applied for baseline
607+ baseline_result .failure_trace = None
515608
516- # Execute failure iterations (deduplicated)
517609 if iterations > 0 :
518610 use_parallel = parallelism > 1 and num_unique_tasks > 1
519611 if use_parallel :
@@ -523,7 +615,6 @@ def run_monte_carlo_analysis(
523615 else :
524616 unique_result_values = self ._run_serial (unique_worker_args )
525617
526- # Map unique task results back to their dedup keys
527618 key_to_result : dict [tuple , Any ] = {}
528619 for (dedup_key , _arg ), value in zip (
529620 key_to_first_arg .items (), unique_result_values , strict = True
0 commit comments