Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions test/inductor/test_group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,148 @@ def test_batch_linear_lhs_fusion(self):
self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
counters.clear()

@requires_gpu()
@torch._inductor.config.patch(
pre_grad_fusion_options={"batch_linear_lhs": {}},
post_grad_fusion_options={},
)
def test_batch_linear_lhs_fusion_nn_linear_inlined(self):
# Same shape pattern as test_batch_linear_lhs_fusion, but the linears
# are produced through nn.Linear modules. Dynamo inlines those to
# torch._C._nn.linear, which the upstream matcher used to miss.
class M(torch.nn.Module):
def __init__(self, z, n, has_bias):
super().__init__()
self.linears = torch.nn.ModuleList(
[torch.nn.Linear(z, z - i % 5, bias=has_bias) for i in range(n)]
)

def forward(self, x):
x = x + 1.2
outs = [lin(x) for lin in self.linears]
return torch.sigmoid(torch.cat(outs, dim=1))

z, n = 10, 10
for has_bias in [True, False]:
counters.clear()
module = M(z, n, has_bias).to(GPU_TYPE)
input = [torch.randn(20, z, device=GPU_TYPE)]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
self.compare_pred(module, traced, input)
self.assertEqual(counters["inductor"]["batch_linear_lhs"], 1)
ref.sum().backward()
res.sum().backward()
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8)
counters.clear()

@requires_gpu()
def test_as_strided_storage_offset_after_mm_fusion(self):
"""
Post-grad batch linear fusion rewrites parallel mm nodes into a
batched bmm followed by select views. The select outputs preserve the
original row stride, but they inherit a non-zero storage offset.
Downstream as_strided must inherit that offset instead of resetting to
the base storage offset.
"""
import copy

from torch._dynamo.backends.common import aot_autograd
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.decomposition import select_decomp_table
from torch._inductor.fx_passes.group_batch_fusion import (
graph_search_options,
PostGradBatchLinearFusion,
)
from torch._inductor.pattern_matcher import stable_topological_sort

fused_counts = []

def fusing_compiler(gm, example_inputs):
opts = graph_search_options.copy()
opts["min_fuse_set_size"] = 3
rule = PostGradBatchLinearFusion(graph_search_options=opts)
mm_nodes = [n for n in gm.graph.nodes if rule.match(n) is not None]
fused_counts.append(len(mm_nodes))
if len(mm_nodes) >= 3:
rule.fuse(gm.graph, mm_nodes[:3])
stable_topological_sort(gm.graph)
gm.graph.lint()
gm.recompile()
return compile_fx_inner(gm, example_inputs)

fusing_backend = aot_autograd(
fw_compiler=fusing_compiler,
decompositions=select_decomp_table(),
)

class QKVModel(torch.nn.Module):
def __init__(self, hidden_size=64, num_heads=4):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
self.scale = self.head_dim**0.5

def forward(self, x):
# Reduced-size version of the issue author's QKV repro.
x = x.permute(1, 0, 2)
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
seq_len, batch_size, _ = q.shape

q = q / self.scale
q = q.view(seq_len, batch_size, self.num_heads, self.head_dim)
q = q.permute(2, 1, 0, 3)
q = q.reshape(self.num_heads, batch_size * seq_len, self.head_dim)
q = q.as_strided(
(self.num_heads, batch_size * seq_len, self.head_dim),
(self.head_dim, self.num_heads * self.head_dim, 1),
)

k = k.view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.permute(2, 1, 0, 3)
k = k.reshape(self.num_heads, batch_size * seq_len, self.head_dim)
k = k.as_strided(
(self.num_heads, batch_size * seq_len, self.head_dim),
(self.head_dim, self.num_heads * self.head_dim, 1),
)
k = k.permute(0, 2, 1)

v = v.view(seq_len, batch_size, self.num_heads, self.head_dim)
v = v.permute(2, 1, 0, 3)
v = v.reshape(self.num_heads, batch_size * seq_len, self.head_dim)
v = v.as_strided(
(self.num_heads, batch_size * seq_len, self.head_dim),
(self.head_dim, self.num_heads * self.head_dim, 1),
)
v = v.permute(0, 2, 1)

return torch.bmm(q, k), k, v

torch.manual_seed(42)
model = QKVModel().to(GPU_TYPE).eval()
x = torch.randn(1, 8, 64, device=GPU_TYPE)

torch._dynamo.reset()
with torch.no_grad():
ref = model(x)

torch._dynamo.reset()
compiled = torch.compile(copy.deepcopy(model), backend=fusing_backend)
with torch.no_grad():
res = compiled(x)

self.assertEqual(fused_counts, [3])
for ref_t, res_t in zip(ref, res):
self.assertEqual(ref_t, res_t, rtol=1e-3, atol=1e-3)

@requires_gpu()
@torch._inductor.config.patch(
pre_grad_fusion_options={"batch_linear": {}},
Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,9 @@ class BatchLinearLHSFusion(BatchFusion):
"""

def match(self, node: torch.fx.Node) -> tuple[str, bool, Any] | None:
if CallFunctionVarArgs(torch.nn.functional.linear).match(
node
) and is_linear_node_can_be_fused(node):
if CallFunctionVarArgs(
[torch.nn.functional.linear, torch._C._nn.linear]
).match(node) and is_linear_node_can_be_fused(node):
input = get_arg_value(node, 0, "input")
bias = get_arg_value(node, 2, "bias")
group_key = ("batch_linear_lhs", bias is None, input)
Expand Down