Skip to content

Commit 4f80b77

Browse files
authored
Cache graph_signature property lookups in partitioning hot loops (#18352)
inputs_to_parameters, inputs_to_buffers, etc. are @Property methods that rebuild a dict from input_specs on every access. In partitioning loops that check each graph node, this causes O(nodes × params) dict constructions. For a 40-layer MoE model with ~26K nodes and ~1,671 parameters, this results in ~260M iterations of pure overhead that showed up as multiple minutes during to_edge_transform_and_lower. Cache the dicts before the loops in: - backend/utils.py: tag_constant_data (2 loops × 3 dict lookups each) - backend/utils.py: tag_mutated_buffer (1 loop × 2 dict lookups) - lowered_backend_module.py: __call__ (2 passes × 2 dict lookups) - lowered_backend_module.py: arrange_graph_placeholders (1 loop × 2) No behavioral change — the graph signature is not mutated between caching and use in any of these functions.
1 parent dff7d78 commit 4f80b77

2 files changed

Lines changed: 34 additions & 25 deletions

File tree

exir/backend/utils.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from executorch.exir.lowered_backend_module import create_submodule_from_nodes
2626
from tabulate import tabulate
27-
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2827
from torch.fx.experimental.symbolic_shapes import has_free_symbols
2928
from torch.fx.node import Node
3029
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -351,15 +350,22 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
351350
subgraph. Throw error when const/param/buffers is used across different partitions. That is the
352351
underlying data will be owned by multiple delegates.
353352
"""
353+
# Cache signature lookups to avoid rebuilding dicts on every access.
354+
sig = edge_program.graph_signature
355+
params_map = sig.inputs_to_parameters
356+
buffers_map = sig.inputs_to_buffers
357+
constants_map = sig.inputs_to_lifted_tensor_constants
358+
buffers_to_mutate = sig.buffers_to_mutate
359+
354360
mutated_buffer = set()
355361
for node in edge_program.graph.nodes:
356362
if node.op == "placeholder" and (
357-
is_param(edge_program, node)
358-
or is_buffer(edge_program, node)
359-
or is_lifted_tensor_constant(edge_program, node)
363+
node.name in params_map
364+
or node.name in buffers_map
365+
or node.name in constants_map
360366
):
361367
for node_user in node.users:
362-
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
368+
if node_user.name in buffers_to_mutate:
363369
logging.info(
364370
"The buffer node is a mutated buffer node, which is not constant."
365371
)
@@ -368,9 +374,9 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
368374
for node in edge_program.graph.nodes:
369375
# go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition
370376
if node.op == "placeholder" and (
371-
is_param(edge_program, node)
372-
or is_buffer(edge_program, node)
373-
or is_lifted_tensor_constant(edge_program, node)
377+
node.name in params_map
378+
or node.name in buffers_map
379+
or node.name in constants_map
374380
):
375381
if node not in mutated_buffer:
376382
user_tags = set()
@@ -397,12 +403,17 @@ def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
397403
subgraph. Throw error when buffers is used across different partitions. That is the
398404
underlying data will be owned by multiple delegates.
399405
"""
406+
# Cache signature lookups to avoid rebuilding dicts on every access.
407+
sig = edge_program.graph_signature
408+
buffers_map = sig.inputs_to_buffers
409+
buffers_to_mutate = sig.buffers_to_mutate
410+
400411
for node in edge_program.graph.nodes:
401412
# Determine whether this node is a mutated buffer
402413
is_mutated_buffer_node = False
403-
if node.op == "placeholder" and is_buffer(edge_program, node):
414+
if node.op == "placeholder" and node.name in buffers_map:
404415
for node_user in node.users:
405-
if node_user.name in edge_program.graph_signature.buffers_to_mutate:
416+
if node_user.name in buffers_to_mutate:
406417
is_mutated_buffer_node = True
407418
break
408419
# This node is mutated buffer, tag it

exir/lowered_backend_module.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,19 @@ def program(
223223

224224
lowered_exported_program = copy.deepcopy(self._original_exported_program)
225225

226+
# Cache these properties to avoid rebuilding the dict on each access.
227+
sig = lowered_exported_program.graph_signature
228+
params_map = sig.inputs_to_parameters
229+
buffers_map = sig.inputs_to_buffers
230+
226231
# The real input nodes are the ones not buffer or parameter
227232
all_input_nodes = [
228233
node
229234
for node in lowered_exported_program.graph.nodes
230235
if (
231236
node.op == "placeholder"
232-
and node.name
233-
not in lowered_exported_program.graph_signature.inputs_to_buffers
234-
and node.name
235-
not in lowered_exported_program.graph_signature.inputs_to_parameters
237+
and node.name not in buffers_map
238+
and node.name not in params_map
236239
)
237240
]
238241

@@ -250,9 +253,7 @@ def program(
250253
# Find placeholders that are parameters or buffers, remove them from the main graph
251254
for node in lowered_exported_program.graph.nodes:
252255
if node.op == "placeholder" and (
253-
node.name in lowered_exported_program.graph_signature.inputs_to_buffers
254-
or node.name
255-
in lowered_exported_program.graph_signature.inputs_to_parameters
256+
node.name in buffers_map or node.name in params_map
256257
):
257258
lowered_exported_program.graph.erase_node(node)
258259

@@ -402,22 +403,19 @@ def arrange_graph_placeholders(
402403
graph_sign = owning_program.graph_signature
403404

404405
# Add all placeholders into the graph first:
406+
# Cache these properties — each call rebuilds the dict from input_specs.
407+
params_map = graph_sign.inputs_to_parameters
408+
buffers_map = graph_sign.inputs_to_buffers
405409
param_nodes = []
406410
buffer_nodes = []
407411
input_nodes = []
408412
for node in gm.graph.nodes:
409413
if node.op != "placeholder":
410414
continue
411415

412-
if (
413-
node.name in graph_sign.inputs_to_parameters
414-
and node.meta.get("delegation_tag", None) == tag
415-
):
416+
if node.name in params_map and node.meta.get("delegation_tag", None) == tag:
416417
param_nodes.append(node)
417-
elif (
418-
node.name in graph_sign.inputs_to_buffers
419-
and node.meta.get("delegation_tag", None) == tag
420-
):
418+
elif node.name in buffers_map and node.meta.get("delegation_tag", None) == tag:
421419
buffer_nodes.append(node)
422420
else:
423421
input_nodes.append(node)

0 commit comments

Comments
 (0)