Skip to content

Commit c716785

Browse files
committed
extended risk groups
1 parent e44b6d9 commit c716785

6 files changed

Lines changed: 336 additions & 256 deletions

File tree

ngraph/failure_policy.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class FailurePolicy:
8888
A list of FailureRules to apply.
8989
attrs (Dict[str, Any]):
9090
Arbitrary metadata about this policy (e.g. "name", "description").
91-
If `fail_shared_risk_groups=True`, then shared-risk expansion is used.
91+
If `fail_shared_risk_groups=True`, shared-risk expansion is used.
9292
"""
9393

9494
rules: List[FailureRule] = field(default_factory=list)
@@ -110,8 +110,9 @@ def apply_failures(
110110
Returns:
111111
A list of failed entity IDs (union of all rule matches).
112112
"""
113+
# Merge all entities into a single dict
113114
all_entities = {**nodes, **links}
114-
failed_entities = set()
115+
failed_entities: set[str] = set()
115116

116117
# 1) Collect matched failures from each rule
117118
for rule in self.rules:
@@ -147,7 +148,7 @@ def _match_entities(
147148
return list(all_entities.keys())
148149

149150
if not conditions:
150-
# If zero conditions, we match nothing unless logic='any'.
151+
# If zero conditions, match nothing unless logic='any'.
151152
return []
152153

153154
matched = []
@@ -174,12 +175,13 @@ def _evaluate_conditions(
174175
Returns:
175176
True if conditions pass, else False.
176177
"""
177-
if logic not in ("and", "or"):
178+
if logic == "and":
179+
return all(_evaluate_condition(entity_attrs, c) for c in conditions)
180+
elif logic == "or":
181+
return any(_evaluate_condition(entity_attrs, c) for c in conditions)
182+
else:
178183
raise ValueError(f"Unsupported logic: {logic}")
179184

180-
results = [_evaluate_condition(entity_attrs, c) for c in conditions]
181-
return all(results) if logic == "and" else any(results)
182-
183185
@staticmethod
184186
def _select_entities(
185187
entity_ids: List[str],
@@ -191,7 +193,7 @@ def _select_entities(
191193
192194
Args:
193195
entity_ids: Matched entity IDs from _match_entities.
194-
all_entities: Full entity map (unused now, but available if needed).
196+
all_entities: Full entity map (unused currently, but available if needed).
195197
rule: The FailureRule specifying 'random', 'choice', or 'all'.
196198
197199
Returns:
@@ -204,44 +206,44 @@ def _select_entities(
204206
return [eid for eid in entity_ids if random() < rule.probability]
205207
elif rule.rule_type == "choice":
206208
count = min(rule.count, len(entity_ids))
207-
return sample(sorted(entity_ids), k=count)
209+
return sample(entity_ids, k=count)
208210
elif rule.rule_type == "all":
209211
return entity_ids
210212
else:
211213
raise ValueError(f"Unsupported rule_type: {rule.rule_type}")
212214

213215
def _expand_shared_risk_groups(
214-
self, failed_entities: set[str], all_entities: Dict[str, Dict[str, Any]]
216+
self,
217+
failed_entities: set[str],
218+
all_entities: Dict[str, Dict[str, Any]],
215219
) -> None:
216220
"""
217-
Expand the 'failed_entities' set so that if an entity has
218-
shared_risk_group=X, all other entities with the same group also fail.
219-
220-
This is done iteratively until no new failures are found.
221+
Expand 'failed_entities' so that if an entity is in a shared_risk_group,
222+
all other entities in that same group also fail. Continues until no new
223+
failures are added.
221224
222225
Args:
223-
failed_entities: Set of entity IDs already marked as failed.
224-
all_entities: Map of entity_id -> attributes (which may contain 'shared_risk_group').
226+
failed_entities: A set of entity IDs already marked as failed.
227+
all_entities: {entity_id -> attributes}, possibly including 'shared_risk_groups'.
225228
"""
226-
# Pre-compute SRG -> entity IDs mapping for efficiency
227-
srg_map = defaultdict(set)
229+
# Build a map of srg_value -> set of entity IDs
230+
srg_map: Dict[Any, set[str]] = defaultdict(set)
228231
for eid, attrs in all_entities.items():
229-
srg = attrs.get("shared_risk_group")
230-
if srg:
232+
srgs = attrs.get("shared_risk_groups", [])
233+
for srg in srgs:
231234
srg_map[srg].add(eid)
232235

236+
# BFS through the shared risk groups
233237
queue = deque(failed_entities)
234238
while queue:
235239
current = queue.popleft()
236-
current_srg = all_entities[current].get("shared_risk_group")
237-
if not current_srg:
238-
continue
239-
240-
# All entities in the same SRG should fail
241-
for other_eid in srg_map[current_srg]:
242-
if other_eid not in failed_entities:
243-
failed_entities.add(other_eid)
244-
queue.append(other_eid)
240+
current_srgs = all_entities[current].get("shared_risk_groups", [])
241+
for current_srg in current_srgs:
242+
# Fail every entity in this SRG
243+
for other_eid in srg_map[current_srg]:
244+
if other_eid not in failed_entities:
245+
failed_entities.add(other_eid)
246+
queue.append(other_eid)
245247

246248

247249
def _evaluate_condition(entity: Dict[str, Any], cond: FailureCondition) -> bool:

0 commit comments

Comments
 (0)