Skip to content

Commit 5a8a489

Browse files
authored
remove attention abstraction (Comfy-Org#5324)
1 parent 8ce2a10 commit 5a8a489

1 file changed

Lines changed: 8 additions & 10 deletions

File tree

  • comfy/ldm/modules/diffusionmodules

comfy/ldm/modules/diffusionmodules/mmdit.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import torch
77
import torch.nn as nn
8-
from .. import attention
8+
from ..attention import optimized_attention
99
from einops import rearrange, repeat
1010
from .util import timestep_embedding
1111
import comfy.ops
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
266266
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
267267
return qkv[0], qkv[1], qkv[2]
268268

269-
def optimized_attention(qkv, num_heads):
270-
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
271269

272270
class SelfAttention(nn.Module):
273271
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
@@ -326,9 +324,9 @@ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
326324
return x
327325

328326
def forward(self, x: torch.Tensor) -> torch.Tensor:
329-
qkv = self.pre_attention(x)
327+
q, k, v = self.pre_attention(x)
330328
x = optimized_attention(
331-
qkv, num_heads=self.num_heads
329+
q, k, v, heads=self.num_heads
332330
)
333331
x = self.post_attention(x)
334332
return x
@@ -531,8 +529,8 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
531529
assert not self.pre_only
532530
qkv, intermediates = self.pre_attention(x, c)
533531
attn = optimized_attention(
534-
qkv,
535-
num_heads=self.attn.num_heads,
532+
qkv[0], qkv[1], qkv[2],
533+
heads=self.attn.num_heads,
536534
)
537535
return self.post_attention(attn, *intermediates)
538536

@@ -557,8 +555,8 @@ def _block_mixing(context, x, context_block, x_block, c):
557555
qkv = tuple(o)
558556

559557
attn = optimized_attention(
560-
qkv,
561-
num_heads=x_block.attn.num_heads,
558+
qkv[0], qkv[1], qkv[2],
559+
heads=x_block.attn.num_heads,
562560
)
563561
context_attn, x_attn = (
564562
attn[:, : context_qkv[0].shape[1]],
@@ -642,7 +640,7 @@ def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operation
642640
def forward(self, x):
643641
qkv = self.qkv(x)
644642
q, k, v = split_qkv(qkv, self.dim_head)
645-
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
643+
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
646644
return self.proj(x)
647645

648646
class ContextProcessorBlock(nn.Module):

0 commit comments

Comments
 (0)