55import numpy as np
66import torch
77import torch .nn as nn
8- from .. import attention
8+ from ..attention import optimized_attention
99from einops import rearrange , repeat
1010from .util import timestep_embedding
1111import comfy .ops
@@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
266266 qkv = qkv .reshape (qkv .shape [0 ], qkv .shape [1 ], 3 , - 1 , head_dim ).movedim (2 , 0 )
267267 return qkv [0 ], qkv [1 ], qkv [2 ]
268268
269- def optimized_attention (qkv , num_heads ):
270- return attention .optimized_attention (qkv [0 ], qkv [1 ], qkv [2 ], num_heads )
271269
272270class SelfAttention (nn .Module ):
273271 ATTENTION_MODES = ("xformers" , "torch" , "torch-hb" , "math" , "debug" )
@@ -326,9 +324,9 @@ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
326324 return x
327325
328326 def forward (self , x : torch .Tensor ) -> torch .Tensor :
329- qkv = self .pre_attention (x )
327+ q , k , v = self .pre_attention (x )
330328 x = optimized_attention (
331- qkv , num_heads = self .num_heads
329+ q , k , v , heads = self .num_heads
332330 )
333331 x = self .post_attention (x )
334332 return x
@@ -531,8 +529,8 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
531529 assert not self .pre_only
532530 qkv , intermediates = self .pre_attention (x , c )
533531 attn = optimized_attention (
534- qkv ,
535- num_heads = self .attn .num_heads ,
532+ qkv [ 0 ], qkv [ 1 ], qkv [ 2 ] ,
533+ heads = self .attn .num_heads ,
536534 )
537535 return self .post_attention (attn , * intermediates )
538536
@@ -557,8 +555,8 @@ def _block_mixing(context, x, context_block, x_block, c):
557555 qkv = tuple (o )
558556
559557 attn = optimized_attention (
560- qkv ,
561- num_heads = x_block .attn .num_heads ,
558+ qkv [ 0 ], qkv [ 1 ], qkv [ 2 ] ,
559+ heads = x_block .attn .num_heads ,
562560 )
563561 context_attn , x_attn = (
564562 attn [:, : context_qkv [0 ].shape [1 ]],
@@ -642,7 +640,7 @@ def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operation
642640 def forward (self , x ):
643641 qkv = self .qkv (x )
644642 q , k , v = split_qkv (qkv , self .dim_head )
645- x = optimized_attention (( q .reshape (q .shape [0 ], q .shape [1 ], - 1 ), k , v ), self .heads )
643+ x = optimized_attention (q .reshape (q .shape [0 ], q .shape [1 ], - 1 ), k , v , heads = self .heads )
646644 return self .proj (x )
647645
648646class ContextProcessorBlock (nn .Module ):
0 commit comments