@@ -88,7 +88,7 @@ class FailurePolicy:
8888 A list of FailureRules to apply.
8989 attrs (Dict[str, Any]):
9090 Arbitrary metadata about this policy (e.g. "name", "description").
91- If `fail_shared_risk_groups=True`, then shared-risk expansion is used.
91+ If `fail_shared_risk_groups=True`, shared-risk expansion is used.
9292 """
9393
9494 rules : List [FailureRule ] = field (default_factory = list )
@@ -110,8 +110,9 @@ def apply_failures(
110110 Returns:
111111 A list of failed entity IDs (union of all rule matches).
112112 """
113+ # Merge all entities into a single dict
113114 all_entities = {** nodes , ** links }
114- failed_entities = set ()
115+ failed_entities : set [ str ] = set ()
115116
116117 # 1) Collect matched failures from each rule
117118 for rule in self .rules :
@@ -147,7 +148,7 @@ def _match_entities(
147148 return list (all_entities .keys ())
148149
149150 if not conditions :
150- # If zero conditions, we match nothing unless logic='any'.
151+ # If zero conditions, match nothing unless logic='any'.
151152 return []
152153
153154 matched = []
@@ -174,12 +175,13 @@ def _evaluate_conditions(
174175 Returns:
175176 True if conditions pass, else False.
176177 """
177- if logic not in ("and" , "or" ):
178+ if logic == "and" :
179+ return all (_evaluate_condition (entity_attrs , c ) for c in conditions )
180+ elif logic == "or" :
181+ return any (_evaluate_condition (entity_attrs , c ) for c in conditions )
182+ else :
178183 raise ValueError (f"Unsupported logic: { logic } " )
179184
180- results = [_evaluate_condition (entity_attrs , c ) for c in conditions ]
181- return all (results ) if logic == "and" else any (results )
182-
183185 @staticmethod
184186 def _select_entities (
185187 entity_ids : List [str ],
@@ -191,7 +193,7 @@ def _select_entities(
191193
192194 Args:
193195 entity_ids: Matched entity IDs from _match_entities.
194- all_entities: Full entity map (unused now , but available if needed).
196+ all_entities: Full entity map (unused currently , but available if needed).
195197 rule: The FailureRule specifying 'random', 'choice', or 'all'.
196198
197199 Returns:
@@ -204,44 +206,44 @@ def _select_entities(
204206 return [eid for eid in entity_ids if random () < rule .probability ]
205207 elif rule .rule_type == "choice" :
206208 count = min (rule .count , len (entity_ids ))
207- return sample (sorted ( entity_ids ) , k = count )
209+ return sample (entity_ids , k = count )
208210 elif rule .rule_type == "all" :
209211 return entity_ids
210212 else :
211213 raise ValueError (f"Unsupported rule_type: { rule .rule_type } " )
212214
213215 def _expand_shared_risk_groups (
214- self , failed_entities : set [str ], all_entities : Dict [str , Dict [str , Any ]]
216+ self ,
217+ failed_entities : set [str ],
218+ all_entities : Dict [str , Dict [str , Any ]],
215219 ) -> None :
216220 """
217- Expand the 'failed_entities' set so that if an entity has
218- shared_risk_group=X, all other entities with the same group also fail.
219-
220- This is done iteratively until no new failures are found.
221+ Expand 'failed_entities' so that if an entity is in a shared_risk_group,
222+ all other entities in that same group also fail. Continues until no new
223+ failures are added.
221224
222225 Args:
223- failed_entities: Set of entity IDs already marked as failed.
224- all_entities: Map of entity_id -> attributes (which may contain 'shared_risk_group') .
226+ failed_entities: A set of entity IDs already marked as failed.
227+ all_entities: { entity_id -> attributes}, possibly including 'shared_risk_groups' .
225228 """
226- # Pre-compute SRG -> entity IDs mapping for efficiency
227- srg_map = defaultdict (set )
229+ # Build a map of srg_value -> set of entity IDs
230+ srg_map : Dict [ Any , set [ str ]] = defaultdict (set )
228231 for eid , attrs in all_entities .items ():
229- srg = attrs .get ("shared_risk_group" )
230- if srg :
232+ srgs = attrs .get ("shared_risk_groups" , [] )
233+ for srg in srgs :
231234 srg_map [srg ].add (eid )
232235
236+ # BFS through the shared risk groups
233237 queue = deque (failed_entities )
234238 while queue :
235239 current = queue .popleft ()
236- current_srg = all_entities [current ].get ("shared_risk_group" )
237- if not current_srg :
238- continue
239-
240- # All entities in the same SRG should fail
241- for other_eid in srg_map [current_srg ]:
242- if other_eid not in failed_entities :
243- failed_entities .add (other_eid )
244- queue .append (other_eid )
240+ current_srgs = all_entities [current ].get ("shared_risk_groups" , [])
241+ for current_srg in current_srgs :
242+ # Fail every entity in this SRG
243+ for other_eid in srg_map [current_srg ]:
244+ if other_eid not in failed_entities :
245+ failed_entities .add (other_eid )
246+ queue .append (other_eid )
245247
246248
247249def _evaluate_condition (entity : Dict [str , Any ], cond : FailureCondition ) -> bool :
0 commit comments