From d7630821e00abbdc75ac9c8ceab013feab3772a5 Mon Sep 17 00:00:00 2001 From: adlashab Date: Thu, 21 May 2026 02:27:19 -0700 Subject: [PATCH] [inductor] BatchLinearLHSFusion: also match torch._C._nn.linear Dynamo inlines torch.nn.functional.linear into torch._C._nn.linear on hot paths, so BatchLinearLHSFusion.match silently misses the inlined form. This one-line change extends CallFunctionVarArgs to also accept torch._C._nn.linear. The sibling fusion class PreGradBatchLinearFusion in the same file already accepts both call targets (line 626 in current main); this brings BatchLinearLHSFusion in line with that. The fusion is gated by pre_grad_fusion_options["batch_linear_lhs"] (unset by default), so this is invisible to users who haven't opted in. is_linear_node_can_be_fused still filters out shapes that can't actually be fused (different weight shapes per head, etc). Adds a regression test that builds a module out of nn.Linear modules (Dynamo inlines those to _C._nn.linear). Without this change the test expects 0 fires; with it, 1 fire. --- test/inductor/test_group_batch_fusion.py | 142 ++++++++++++++++++ .../_inductor/fx_passes/group_batch_fusion.py | 6 +- 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 670258df00197..0ddfd7aeeac93 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -458,6 +458,148 @@ def test_batch_linear_lhs_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={"batch_linear_lhs": {}}, + post_grad_fusion_options={}, + ) + def test_batch_linear_lhs_fusion_nn_linear_inlined(self): + # Same shape pattern as test_batch_linear_lhs_fusion, but the linears + # are produced through nn.Linear modules. Dynamo inlines those to + # torch._C._nn.linear, which the upstream matcher used to miss. + class M(torch.nn.Module): + def __init__(self, z, n, has_bias): + super().__init__() + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(z, z - i % 5, bias=has_bias) for i in range(n)] + ) + + def forward(self, x): + x = x + 1.2 + outs = [lin(x) for lin in self.linears] + return torch.sigmoid(torch.cat(outs, dim=1)) + + z, n = 10, 10 + for has_bias in [True, False]: + counters.clear() + module = M(z, n, has_bias).to(GPU_TYPE) + input = [torch.randn(20, z, device=GPU_TYPE)] + traced = torch.compile(module) + ref = module(*input) + res = traced(*input) + self.compare_pred(module, traced, input) + self.assertEqual(counters["inductor"]["batch_linear_lhs"], 1) + ref.sum().backward() + res.sum().backward() + self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) + self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) + counters.clear() + + @requires_gpu() + def test_as_strided_storage_offset_after_mm_fusion(self): + """ + Post-grad batch linear fusion rewrites parallel mm nodes into a + batched bmm followed by select views. The select outputs preserve the + original row stride, but they inherit a non-zero storage offset. + Downstream as_strided must inherit that offset instead of resetting to + the base storage offset. + """ + import copy + + from torch._dynamo.backends.common import aot_autograd + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor.decomposition import select_decomp_table + from torch._inductor.fx_passes.group_batch_fusion import ( + graph_search_options, + PostGradBatchLinearFusion, + ) + from torch._inductor.pattern_matcher import stable_topological_sort + + fused_counts = [] + + def fusing_compiler(gm, example_inputs): + opts = graph_search_options.copy() + opts["min_fuse_set_size"] = 3 + rule = PostGradBatchLinearFusion(graph_search_options=opts) + mm_nodes = [n for n in gm.graph.nodes if rule.match(n) is not None] + fused_counts.append(len(mm_nodes)) + if len(mm_nodes) >= 3: + rule.fuse(gm.graph, mm_nodes[:3]) + stable_topological_sort(gm.graph) + gm.graph.lint() + gm.recompile() + return compile_fx_inner(gm, example_inputs) + + fusing_backend = aot_autograd( + fw_compiler=fusing_compiler, + decompositions=select_decomp_table(), + ) + + class QKVModel(torch.nn.Module): + def __init__(self, hidden_size=64, num_heads=4): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.scale = self.head_dim**0.5 + + def forward(self, x): + # Reduced-size version of the issue author's QKV repro. + x = x.permute(1, 0, 2) + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + seq_len, batch_size, _ = q.shape + + q = q / self.scale + q = q.view(seq_len, batch_size, self.num_heads, self.head_dim) + q = q.permute(2, 1, 0, 3) + q = q.reshape(self.num_heads, batch_size * seq_len, self.head_dim) + q = q.as_strided( + (self.num_heads, batch_size * seq_len, self.head_dim), + (self.head_dim, self.num_heads * self.head_dim, 1), + ) + + k = k.view(seq_len, batch_size, self.num_heads, self.head_dim) + k = k.permute(2, 1, 0, 3) + k = k.reshape(self.num_heads, batch_size * seq_len, self.head_dim) + k = k.as_strided( + (self.num_heads, batch_size * seq_len, self.head_dim), + (self.head_dim, self.num_heads * self.head_dim, 1), + ) + k = k.permute(0, 2, 1) + + v = v.view(seq_len, batch_size, self.num_heads, self.head_dim) + v = v.permute(2, 1, 0, 3) + v = v.reshape(self.num_heads, batch_size * seq_len, self.head_dim) + v = v.as_strided( + (self.num_heads, batch_size * seq_len, self.head_dim), + (self.head_dim, self.num_heads * self.head_dim, 1), + ) + v = v.permute(0, 2, 1) + + return torch.bmm(q, k), k, v + + torch.manual_seed(42) + model = QKVModel().to(GPU_TYPE).eval() + x = torch.randn(1, 8, 64, device=GPU_TYPE) + + torch._dynamo.reset() + with torch.no_grad(): + ref = model(x) + + torch._dynamo.reset() + compiled = torch.compile(copy.deepcopy(model), backend=fusing_backend) + with torch.no_grad(): + res = compiled(x) + + self.assertEqual(fused_counts, [3]) + for ref_t, res_t in zip(ref, res): + self.assertEqual(ref_t, res_t, rtol=1e-3, atol=1e-3) + @requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={"batch_linear": {}}, diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 8782a5402538e..644f14fc8b965 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -493,9 +493,9 @@ class BatchLinearLHSFusion(BatchFusion): """ def match(self, node: torch.fx.Node) -> tuple[str, bool, Any] | None: - if CallFunctionVarArgs(torch.nn.functional.linear).match( - node - ) and is_linear_node_can_be_fused(node): + if CallFunctionVarArgs( + [torch.nn.functional.linear, torch._C._nn.linear] + ).match(node) and is_linear_node_can_be_fused(node): input = get_arg_value(node, 0, "input") bias = get_arg_value(node, 2, "bias") group_key = ("batch_linear_lhs", bias is None, input)