Skip to content

Commit 25723f2

Browse files
authored
performance improvements for spf and max-flow (#60)
1 parent 5a4c76a commit 25723f2

13 files changed

Lines changed: 2138 additions & 1139 deletions

ngraph/blueprints.py

Lines changed: 117 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from __future__ import annotations
2-
3-
import copy
41
from dataclasses import dataclass
52
from typing import Any, Dict, List
63

@@ -151,33 +148,36 @@ def _expand_group(
151148
if "use_blueprint" in group_def:
152149
# Expand blueprint subgroups
153150
blueprint_name: str = group_def["use_blueprint"]
154-
bp = ctx.blueprints.get(blueprint_name)
155-
if not bp:
156-
raise ValueError(
157-
f"Group '{group_name}' references unknown blueprint '{blueprint_name}'."
158-
)
159-
160-
param_overrides: Dict[str, Any] = group_def.get("parameters", {})
161-
coords = group_def.get("coords")
162-
163-
# For each subgroup in the blueprint, apply overrides and expand
164-
for bp_sub_name, bp_sub_def in bp.groups.items():
165-
merged_def = _apply_parameters(bp_sub_name, bp_sub_def, param_overrides)
166-
if coords is not None and "coords" not in merged_def:
167-
merged_def["coords"] = coords
168-
169-
_expand_group(
170-
ctx,
171-
parent_path=effective_path,
172-
group_name=bp_sub_name,
173-
group_def=merged_def,
174-
blueprint_expansion=True,
175-
)
176-
177-
# Expand blueprint adjacency
178-
for adj_def in bp.adjacency:
179-
_expand_blueprint_adjacency(ctx, adj_def, effective_path)
180-
151+
try:
152+
bp = ctx.blueprints.get(blueprint_name)
153+
if not bp:
154+
raise ValueError(
155+
f"Group '{group_name}' references unknown blueprint '{blueprint_name}'."
156+
)
157+
158+
param_overrides: Dict[str, Any] = group_def.get("parameters", {})
159+
coords = group_def.get("coords")
160+
161+
# For each subgroup in the blueprint, apply overrides and expand
162+
for bp_sub_name, bp_sub_def in bp.groups.items():
163+
merged_def = _apply_parameters(bp_sub_name, bp_sub_def, param_overrides)
164+
if coords is not None and "coords" not in merged_def:
165+
merged_def["coords"] = coords
166+
167+
_expand_group(
168+
ctx,
169+
parent_path=effective_path,
170+
group_name=bp_sub_name,
171+
group_def=merged_def,
172+
blueprint_expansion=True,
173+
)
174+
175+
# Expand blueprint adjacency
176+
for adj_def in bp.adjacency:
177+
_expand_blueprint_adjacency(ctx, adj_def, effective_path)
178+
179+
except Exception as e:
180+
raise ValueError(f"Error expanding blueprint '{blueprint_name}': {e}")
181181
else:
182182
# It's a direct node group
183183
node_count = group_def.get("node_count", 1)
@@ -217,11 +217,12 @@ def _expand_blueprint_adjacency(
217217
target_rel = adj_def["target"]
218218
pattern = adj_def.get("pattern", "mesh")
219219
link_params = adj_def.get("link_params", {})
220+
link_count = adj_def.get("link_count", 1)
220221

221222
src_path = _join_paths(parent_path, source_rel)
222223
tgt_path = _join_paths(parent_path, target_rel)
223224

224-
_expand_adjacency_pattern(ctx, src_path, tgt_path, pattern, link_params)
225+
_expand_adjacency_pattern(ctx, src_path, tgt_path, pattern, link_params, link_count)
225226

226227

227228
def _expand_adjacency(
@@ -239,13 +240,16 @@ def _expand_adjacency(
239240
source_path_raw = adj_def["source"]
240241
target_path_raw = adj_def["target"]
241242
pattern = adj_def.get("pattern", "mesh")
243+
link_count = adj_def.get("link_count", 1)
242244
link_params = adj_def.get("link_params", {})
243245

244246
# Convert to an absolute or relative path
245247
source_path = _join_paths("", source_path_raw)
246248
target_path = _join_paths("", target_path_raw)
247249

248-
_expand_adjacency_pattern(ctx, source_path, target_path, pattern, link_params)
250+
_expand_adjacency_pattern(
251+
ctx, source_path, target_path, pattern, link_params, link_count
252+
)
249253

250254

251255
def _expand_adjacency_pattern(
@@ -254,23 +258,25 @@ def _expand_adjacency_pattern(
254258
target_path: str,
255259
pattern: str,
256260
link_params: Dict[str, Any],
261+
link_count: int = 1,
257262
) -> None:
258263
"""
259264
Generates Link objects for the chosen adjacency pattern among matched nodes.
260265
261266
Supported Patterns:
262267
* "mesh": Connect every node from source side to every node on target side,
263-
skipping self-loops, and deduplicating reversed pairs.
264-
* "one_to_one": Pair each source node with exactly one target node (wrap-around).
265-
* "ring": (Example pattern) For demonstration, connect nodes in a ring among
266-
the union of source + target sets (ignores directionality).
268+
skipping self-loops and deduplicating reversed pairs.
269+
* "one_to_one": Pair each source node with exactly one target node (wrap-around),
270+
requiring that the larger set size is an integer multiple
271+
of the smaller set size.
267272
268273
Args:
269274
ctx (DSLExpansionContext): The context containing the target network.
270275
source_path (str): The path pattern identifying the source node group(s).
271276
target_path (str): The path pattern identifying the target node group(s).
272-
pattern (str): The type of adjacency pattern (e.g., "mesh", "one_to_one", "ring").
277+
pattern (str): The type of adjacency pattern (e.g., "mesh", "one_to_one").
273278
link_params (Dict[str, Any]): Additional link parameters (capacity, cost, attrs).
279+
link_count (int): Number of parallel links to create for each adjacency.
274280
"""
275281
source_node_groups = ctx.network.select_node_groups_by_path(source_path)
276282
target_node_groups = ctx.network.select_node_groups_by_path(target_path)
@@ -292,21 +298,21 @@ def _expand_adjacency_pattern(
292298
pair = tuple(sorted((sn.name, tn.name)))
293299
if pair not in dedup_pairs:
294300
dedup_pairs.add(pair)
295-
_create_link(ctx.network, sn.name, tn.name, link_params)
301+
_create_link(ctx.network, sn.name, tn.name, link_params, link_count)
296302

297303
elif pattern == "one_to_one":
298304
s_count = len(source_nodes)
299305
t_count = len(target_nodes)
300-
bigger, smaller = max(s_count, t_count), min(s_count, t_count)
306+
bigger_count = max(s_count, t_count)
307+
smaller_count = min(s_count, t_count)
301308

302-
# Basic check for wrap-around scenario
303-
if bigger % smaller != 0:
309+
if bigger_count % smaller_count != 0:
304310
raise ValueError(
305-
f"one_to_one pattern requires either equal node counts "
306-
f"or a valid wrap-around. Got {s_count} vs {t_count}."
311+
f"one_to_one pattern requires sizes with a multiple factor. "
312+
f"Got source={s_count}, target={t_count}."
307313
)
308314

309-
for i in range(bigger):
315+
for i in range(bigger_count):
310316
if s_count >= t_count:
311317
sn = source_nodes[i].name
312318
tn = target_nodes[i % t_count].name
@@ -320,36 +326,46 @@ def _expand_adjacency_pattern(
320326
pair = tuple(sorted((sn, tn)))
321327
if pair not in dedup_pairs:
322328
dedup_pairs.add(pair)
323-
_create_link(ctx.network, sn, tn, link_params)
329+
_create_link(ctx.network, sn, tn, link_params, link_count)
324330
else:
325331
raise ValueError(f"Unknown adjacency pattern: {pattern}")
326332

327333

328334
def _create_link(
329-
net: Network, source: str, target: str, link_params: Dict[str, Any]
335+
net: Network,
336+
source: str,
337+
target: str,
338+
link_params: Dict[str, Any],
339+
link_count: int = 1,
330340
) -> None:
331341
"""
332-
Creates and adds a Link to the network, applying capacity/cost/attrs from link_params.
342+
Creates and adds one or more Links to the network, applying capacity, cost,
343+
and attributes from link_params. Uses deep copies of the attributes to avoid
344+
accidental shared mutations.
333345
334346
Args:
335-
net (Network): The network to which the new link is added.
347+
net (Network): The network to which the new link(s) is/are added.
336348
source (str): Source node name for the link.
337349
target (str): Target node name for the link.
338350
link_params (Dict[str, Any]): A dict possibly containing 'capacity', 'cost',
339351
and 'attrs' keys.
352+
link_count (int): Number of parallel links to create between source and target.
340353
"""
341-
capacity = link_params.get("capacity", 1.0)
342-
cost = link_params.get("cost", 1.0)
343-
attrs = copy.deepcopy(link_params.get("attrs", {}))
354+
import copy
344355

345-
link = Link(
346-
source=source,
347-
target=target,
348-
capacity=capacity,
349-
cost=cost,
350-
attrs=attrs,
351-
)
352-
net.add_link(link)
356+
for _ in range(link_count):
357+
capacity = link_params.get("capacity", 1.0)
358+
cost = link_params.get("cost", 1.0)
359+
attrs = copy.deepcopy(link_params.get("attrs", {}))
360+
361+
link = Link(
362+
source=source,
363+
target=target,
364+
capacity=capacity,
365+
cost=cost,
366+
attrs=attrs,
367+
)
368+
net.add_link(link)
353369

354370

355371
def _process_direct_nodes(net: Network, network_data: Dict[str, Any]) -> None:
@@ -403,14 +419,8 @@ def _process_direct_links(net: Network, network_data: Dict[str, Any]) -> None:
403419
if source == target:
404420
raise ValueError(f"Link cannot have the same source and target: {source}")
405421
link_params = link_info.get("link_params", {})
406-
link = Link(
407-
source=source,
408-
target=target,
409-
capacity=link_params.get("capacity", 1.0),
410-
cost=link_params.get("cost", 1.0),
411-
attrs=link_params.get("attrs", {}),
412-
)
413-
net.add_link(link)
422+
link_count = link_info.get("link_count", 1)
423+
_create_link(net, source, target, link_params, link_count)
414424

415425

416426
def _process_link_overrides(network: Network, network_data: Dict[str, Any]) -> None:
@@ -475,7 +485,7 @@ def _update_links(
475485
any_direction: bool = True,
476486
) -> None:
477487
"""
478-
Update all Link objects between nodes matching 'source' and 'target' paths
488+
Updates all Link objects between nodes matching 'source' and 'target' paths
479489
with new parameters.
480490
481491
If any_direction=True, both (source->target) and (target->source) links
@@ -537,8 +547,11 @@ def _apply_parameters(
537547
Applies user-provided parameter overrides to a blueprint subgroup.
538548
539549
Example:
540-
If 'spine.node_count'=6 is in params_overrides,
541-
we set 'node_count'=6 for the 'spine' subgroup.
550+
If 'spine.node_count' = 6 is in params_overrides,
551+
it sets 'node_count'=6 for the 'spine' subgroup.
552+
553+
If 'spine.node_attrs.hw_type' = 'Dell',
554+
it sets subgroup_def['node_attrs']['hw_type'] = 'Dell'.
542555
543556
Args:
544557
subgroup_name (str): Name of the subgroup in the blueprint (e.g. 'spine').
@@ -547,23 +560,50 @@ def _apply_parameters(
547560
{'spine.node_count': 6, 'spine.node_attrs.hw_type': 'Dell'}.
548561
549562
Returns:
550-
Dict[str, Any]: A copy of subgroup_def with parameter overrides applied.
563+
Dict[str, Any]: A copy of subgroup_def with parameter overrides applied,
564+
including nested dictionary fields if specified by dotted paths (e.g. node_attrs.foo).
551565
"""
552-
out = dict(subgroup_def)
566+
import copy
567+
568+
out = copy.deepcopy(subgroup_def)
569+
553570
for key, val in params_overrides.items():
554571
parts = key.split(".")
555572
if parts[0] == subgroup_name and len(parts) > 1:
556-
field_name = ".".join(parts[1:])
557-
out[field_name] = val
573+
# We have a dotted path that might refer to nested dictionaries.
574+
subpath = parts[1:]
575+
_apply_nested_path(out, subpath, val)
576+
558577
return out
559578

560579

580+
def _apply_nested_path(
581+
node_def: Dict[str, Any], path_parts: List[str], value: Any
582+
) -> None:
583+
"""
584+
Recursively applies a path like ["node_attrs", "role"] to set node_def["node_attrs"]["role"] = value.
585+
Creates intermediate dicts as needed.
586+
"""
587+
if not path_parts:
588+
return
589+
key = path_parts[0]
590+
if len(path_parts) == 1:
591+
node_def[key] = value
592+
return
593+
594+
# Ensure that node_def[key] is a dict
595+
if key not in node_def or not isinstance(node_def[key], dict):
596+
node_def[key] = {}
597+
_apply_nested_path(node_def[key], path_parts[1:], value)
598+
599+
561600
def _join_paths(parent_path: str, rel_path: str) -> str:
562601
"""
563602
Joins two path segments according to NetGraph's DSL conventions:
564-
- If rel_path starts with '/', strip the leading slash and treat it
565-
as appended to parent_path if parent_path is not empty.
566-
- Otherwise, simply append rel_path to parent_path if parent_path is non-empty.
603+
604+
- If rel_path starts with '/', strip the leading slash and treat it as
605+
appended to parent_path if parent_path is not empty.
606+
- Otherwise, simply append rel_path to parent_path if parent_path is non-empty.
567607
568608
Args:
569609
parent_path (str): The existing path prefix.

0 commit comments

Comments
 (0)