diff --git a/lightx2v/common/ops/attn/flash_attn.py b/lightx2v/common/ops/attn/flash_attn.py index d4b10f623..1c189bc9b 100755 --- a/lightx2v/common/ops/attn/flash_attn.py +++ b/lightx2v/common/ops/attn/flash_attn.py @@ -5,16 +5,19 @@ from .utils.sparge_util import block_map_ordinal_lut_triton, get_block_map_meansim try: - import flash_attn # noqa: F401 - from flash_attn.flash_attn_interface import flash_attn_varlen_func + from flash_attn import flash_attn_func_v2 + from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2 except ImportError: - logger.info("flash_attn_varlen_func not found, please install flash_attn2 first") - flash_attn_varlen_func = None + logger.info("flash_attn2 not found, please install flash_attn2 first") + flash_attn_func_v2 = None + flash_attn_varlen_func_v2 = None try: + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 except ImportError: - logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") + logger.info("flash_attn3 not found, please install flash_attn3 first") + flash_attn_func_v3 = None flash_attn_varlen_func_v3 = None try: @@ -49,18 +52,37 @@ def apply( bs = 1 elif len(q.shape) == 4: bs = q.shape[0] - q = q.reshape(-1, q.shape[-2], q.shape[-1]) - k = k.reshape(-1, k.shape[-2], k.shape[-1]) - v = v.reshape(-1, v.shape[-2], v.shape[-1]) - x = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - ).reshape(bs * max_seqlen_q, -1) + total_seqlen = bs * max_seqlen_q + + if bs == 1: + if len(q.shape) == 3: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + x = flash_attn_func_v2(q, k, v).reshape(bs * max_seqlen_q, -1) + else: + if cu_seqlens_q.is_cpu: + cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True) + if cu_seqlens_kv.is_cpu: + cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True) + if max_seqlen_q.is_cpu: + max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) + if max_seqlen_kv.is_cpu: + max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) + if len(q.shape) == 4: + q = q.reshape(-1, q.shape[-2], q.shape[-1]) + k = k.reshape(-1, k.shape[-2], k.shape[-1]) + v = v.reshape(-1, v.shape[-2], v.shape[-1]) + x = flash_attn_varlen_func_v2( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ).reshape(total_seqlen, -1) + return x @@ -84,18 +106,37 @@ def apply( bs = 1 elif len(q.shape) == 4: bs = q.shape[0] - q = q.reshape(-1, q.shape[-2], q.shape[-1]) - k = k.reshape(-1, k.shape[-2], k.shape[-1]) - v = v.reshape(-1, v.shape[-2], v.shape[-1]) - x = flash_attn_varlen_func_v3( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - ).reshape(bs * max_seqlen_q, -1) + total_seqlen = bs * max_seqlen_q + + if bs == 1: + if len(q.shape) == 3: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + x = flash_attn_func_v3(q, k, v).reshape(bs * max_seqlen_q, -1) + else: + if cu_seqlens_q.is_cpu: + cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True) + if cu_seqlens_kv.is_cpu: + cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True) + if max_seqlen_q.is_cpu: + max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) + if max_seqlen_kv.is_cpu: + max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) + if len(q.shape) == 4: + q = q.reshape(-1, q.shape[-2], q.shape[-1]) + k = k.reshape(-1, k.shape[-2], k.shape[-1]) + v = v.reshape(-1, v.shape[-2], v.shape[-1]) + x = flash_attn_varlen_func_v3( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ).reshape(total_seqlen, -1) + return x diff --git a/lightx2v/common/ops/attn/ring_attn.py b/lightx2v/common/ops/attn/ring_attn.py index 512ecdf63..a0b8d5288 100644 --- a/lightx2v/common/ops/attn/ring_attn.py +++ b/lightx2v/common/ops/attn/ring_attn.py @@ -59,7 +59,6 @@ def apply( slice_qkv_len, cu_seqlens_qkv, attention_module=None, - attention_type="flash_attn2", seq_p_group=None, use_fp8_comm=False, use_fp4_comm=False, @@ -76,7 +75,6 @@ def apply( v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims] slice_qkv_len (int): 图像查询、键和值的长度 cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息 - attention_type (str): 注意力类型,默认为 "flash_attn2" use_fp8_comm: 是否使用 FP8 通信 use_fp4_comm: 是否使用 FP4 通信 diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py index 47e39538f..45552ae63 100755 --- a/lightx2v/common/ops/attn/ulysses_attn.py +++ b/lightx2v/common/ops/attn/ulysses_attn.py @@ -4,7 +4,6 @@ from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER -from lightx2v_platform.base.global_var import AI_DEVICE from .template import AttnWeightTemplate from .utils.all2all import all2all_head2seq @@ -31,7 +30,6 @@ def apply( slice_qkv_len, cu_seqlens_qkv, attention_module=None, - attention_type="flash_attn2", seq_p_group=None, use_fp8_comm=False, use_fp4_comm=False, @@ -51,7 +49,6 @@ def apply( v (torch.Tensor): 值张量,形状为 [shard_seqlen, kv_heads, hidden_dims] slice_qkv_len (int): 图像或者文本查询、键和值的长度,根据 img_first 确定谁在前半部分 cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息 - attention_type (str): 注意力类型,默认为 "flash_attn2" q_only_img (bool): 若为 True,q 只含图像 token,k/v 同时含图像和文本 token。 此时只对 k/v 做 img/txt 分割,q 整体参与图像侧 all-to-all。 支持 cross-attention 等 q 不含文本 token 的场景。 @@ -108,16 +105,12 @@ def apply( cu_seqlens_kv[1] = txt_qkv_len + global_img_seqlen if txt_mask_len: cu_seqlens_kv = torch.cat((cu_seqlens_kv, torch.tensor([txt_mask_len + global_img_seqlen], dtype=torch.int32))) - if attention_type == "flash_attn2" or attention_type == "flash_attn3": - cu_seqlens_kv = cu_seqlens_kv.to(AI_DEVICE, non_blocking=True) max_seqlen_kv = global_img_seqlen + txt_qkv_len # q_only_img 时 q 只含图像 token,cu_seqlens_q 与 kv 侧不同 if q_only_img: cu_seqlens_q = torch.zeros([2], dtype=torch.int32) cu_seqlens_q[1] = global_img_seqlen - if attention_type == "flash_attn2" or attention_type == "flash_attn3": - cu_seqlens_q = cu_seqlens_q.to(AI_DEVICE, non_blocking=True) max_seqlen_q = global_img_seqlen else: cu_seqlens_q = cu_seqlens_kv @@ -599,7 +592,6 @@ def apply( slice_qkv_len, cu_seqlens_qkv, attention_module=None, - attention_type="flash_attn2", seq_p_group=None, use_fp8_comm=False, use_fp4_comm=False, @@ -616,7 +608,6 @@ def apply( v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims] slice_qkv_len (int): 图像或者文本查询、键和值的长度,根据img_first确定谁在前半部分 cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息 - attention_type (str): 注意力类型,默认为 "flash_attn2" 返回: torch.Tensor: 计算得到的注意力结果 @@ -748,7 +739,6 @@ def apply( max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度 # 调用注意力函数计算注意力结果 - # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv) attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs) # 分割图像和文本的注意力结果 diff --git a/lightx2v/models/networks/bagel/infer/transformer_infer.py b/lightx2v/models/networks/bagel/infer/transformer_infer.py index dfd954d1e..9151d2112 100644 --- a/lightx2v/models/networks/bagel/infer/transformer_infer.py +++ b/lightx2v/models/networks/bagel/infer/transformer_infer.py @@ -160,8 +160,8 @@ def self_attn( merged_value_states = packed_value_states key_values_lens = query_lens - cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)).to(AI_DEVICE) - cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)).to(AI_DEVICE) + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) packed_attn_output = flash_attn_varlen_func( q=packed_query_states, diff --git a/lightx2v/models/networks/flux2/infer/transformer_infer.py b/lightx2v/models/networks/flux2/infer/transformer_infer.py index 3e425dd60..b6205a041 100644 --- a/lightx2v/models/networks/flux2/infer/transformer_infer.py +++ b/lightx2v/models/networks/flux2/infer/transformer_infer.py @@ -95,7 +95,7 @@ def infer_double_stream_block( query, key = self.apply_rope_func(query, key, image_rotary_emb) total_len = query.shape[0] - cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device=query.device) + cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32) model_cls = self.config.get("model_cls", "flux2_klein") @@ -193,7 +193,7 @@ def infer_single_stream_block( query, key = self.apply_rope_func(query, key, image_rotary_emb) total_len = query.shape[0] - cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device=query.device) + cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32) model_cls = self.config.get("model_cls", "flux2_klein") diff --git a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py index 848729221..87341b4f6 100755 --- a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py +++ b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py @@ -10,7 +10,6 @@ apply_rope_with_cos_sin_cache_inplace = None from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer -from lightx2v_platform.base.global_var import AI_DEVICE from .module_io import HunyuanVideo15ImgBranchOutput, HunyuanVideo15TxtBranchOutput from .triton_ops import fuse_scale_shift_kernel @@ -226,7 +225,7 @@ def _infer_attn(self, weights, img_q, img_k, img_v, txt_q, txt_k, txt_v): key = torch.cat([img_k, txt_k], dim=1) value = torch.cat([img_v, txt_v], dim=1) seqlen = query.shape[1] - cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu") if self.config["seq_parallel"]: attn_out = weights.self_attention_parallel.apply( diff --git a/lightx2v/models/networks/longcat_image/infer/transformer_infer.py b/lightx2v/models/networks/longcat_image/infer/transformer_infer.py index 70192f2d4..7c8229db6 100755 --- a/lightx2v/models/networks/longcat_image/infer/transformer_infer.py +++ b/lightx2v/models/networks/longcat_image/infer/transformer_infer.py @@ -138,7 +138,7 @@ def infer_double_stream_block( # Calculate cu_seqlens for flash attention (batch_size=1) total_len = query.shape[0] - cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device=query.device) + cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32) # Use registered attention module attn_output = block_weights.calculate.apply( @@ -248,7 +248,7 @@ def infer_single_stream_block( # Calculate cu_seqlens for flash attention (batch_size=1) total_len = query.shape[0] - cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device=query.device) + cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32) # Use registered attention module attn_output = block_weights.calculate.apply( diff --git a/lightx2v/models/networks/ltx2/infer/transformer_infer.py b/lightx2v/models/networks/ltx2/infer/transformer_infer.py index 87db25eaa..b69dc9d4e 100644 --- a/lightx2v/models/networks/ltx2/infer/transformer_infer.py +++ b/lightx2v/models/networks/ltx2/infer/transformer_infer.py @@ -111,21 +111,17 @@ def set_guidance_perturbation( self._mm_skip_a2v = bool(skip_a2v) self._mm_skip_v2a = bool(skip_v2a) - def _create_cu_seqlens(self, seq_len: int, device: torch.device) -> torch.Tensor: + def _create_cu_seqlens(self, seq_len: int) -> torch.Tensor: """ Create cumulative sequence lengths tensor for attention. Args: seq_len: Sequence length - device: Device to place the tensor on Returns: Cumulative sequence lengths tensor [0, seq_len] """ - if self.config["attn_type"] in ["flash_attn2", "flash_attn3"]: - return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32).to(device, non_blocking=True) - else: - return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32) + return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32) def _gather_cross_attn_context(self, context: torch.Tensor, k_pe=None): """ @@ -282,7 +278,7 @@ def _infer_attn( if is_self_attn and not is_audio and self.config.get("seq_parallel", False) and not use_tp: # Cache cu_seqlens_qkv for self-attention (q, k, v have same length) if self.v_attn_cu_seqlens_qkv is None: - self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0]) out = attn_phase.attn_func_parallel.apply( q=q, @@ -291,7 +287,6 @@ def _infer_attn( slice_qkv_len=seq_len, cu_seqlens_qkv=self.v_attn_cu_seqlens_qkv, attention_module=attn_phase.attn_func, - attention_type=self.config["attn_type"], seq_p_group=self.seq_p_group, use_fp8_comm=self.seq_p_fp8_comm, use_fp4_comm=self.seq_p_fp4_comm, @@ -303,17 +298,17 @@ def _infer_attn( # Cache cu_seqlens_qkv for self-attention only if is_self_attn: if not is_audio and self.v_attn_cu_seqlens_qkv is None: - self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0]) elif is_audio and self.a_attn_cu_seqlens_qkv is None: - self.a_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device) + self.a_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0]) cu_seqlens_q = self.v_attn_cu_seqlens_qkv if not is_audio else self.a_attn_cu_seqlens_qkv cu_seqlens_kv = cu_seqlens_q # For self-attn, q and k have same length else: # For cross-attention, always create cu_seqlens dynamically # because k length varies by context type (text, audio, video) - cu_seqlens_q = self._create_cu_seqlens(q.shape[0], q.device) - cu_seqlens_kv = self._create_cu_seqlens(k.shape[0], k.device) + cu_seqlens_q = self._create_cu_seqlens(q.shape[0]) + cu_seqlens_kv = self._create_cu_seqlens(k.shape[0]) out = attn_phase.attn_func.apply( q=q, diff --git a/lightx2v/models/networks/neopp/infer/transformer_infer.py b/lightx2v/models/networks/neopp/infer/transformer_infer.py index 60b7fe6ba..c3bcf889b 100755 --- a/lightx2v/models/networks/neopp/infer/transformer_infer.py +++ b/lightx2v/models/networks/neopp/infer/transformer_infer.py @@ -36,7 +36,6 @@ def __init__(self, config): self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") else: self.seq_p_group = None - self.cross_attn_type = self.config["attn_type"] self.kv_cache = KVCacheManager() @torch.no_grad() @@ -57,8 +56,8 @@ def infer(self, weights, pre_infer_out, inputs): if self._seqlen_cache.get(_cache_key, {}).get("seqlens") != (seq_len_q, seq_len_k): self._seqlen_cache[_cache_key] = { "seqlens": (seq_len_q, seq_len_k), - "cu_q": torch.tensor([0, seq_len_q], dtype=torch.int32, device=hidden_states.device), - "cu_k": torch.tensor([0, seq_len_k], dtype=torch.int32, device=hidden_states.device), + "cu_q": torch.tensor([0, seq_len_q], dtype=torch.int32), + "cu_k": torch.tensor([0, seq_len_k], dtype=torch.int32), } self._cu_seqlens_q = self._seqlen_cache[_cache_key]["cu_q"] self._cu_seqlens_k = self._seqlen_cache[_cache_key]["cu_k"] @@ -173,7 +172,6 @@ def _compute_attn(self, attn_w, query_states, key_states, value_states): slice_qkv_len=self._kvcache_len, cu_seqlens_qkv=self._cu_seqlens_k, attention_module=attn_w.cross_attn, - attention_type=self.cross_attn_type, seq_p_group=self.seq_p_group, img_first=False, q_only_img=True, diff --git a/lightx2v/models/networks/qwen_image/infer/transformer_infer.py b/lightx2v/models/networks/qwen_image/infer/transformer_infer.py index 25a39feac..67a6e5bf4 100755 --- a/lightx2v/models/networks/qwen_image/infer/transformer_infer.py +++ b/lightx2v/models/networks/qwen_image/infer/transformer_infer.py @@ -11,7 +11,7 @@ def calculate_q_k_len(q, k_lens): - q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) + q_lens = torch.tensor([q.size(0)], dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) return cu_seqlens_q, cu_seqlens_k @@ -205,7 +205,7 @@ def infer_cross_attn( joint_value = torch.cat([txt_value, img_value], dim=0) img_qkv_len = joint_query.shape[0] - cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu").to(joint_query.device, non_blocking=True) + cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu") if self.config["seq_parallel"]: joint_hidden_states = cross_attn_phase.calculate_parallel.apply( diff --git a/lightx2v/models/networks/wan/infer/causvid/transformer_infer.py b/lightx2v/models/networks/wan/infer/causvid/transformer_infer.py index 63f2c8034..236f8075e 100755 --- a/lightx2v/models/networks/wan/infer/causvid/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/causvid/transformer_infer.py @@ -116,7 +116,7 @@ def infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs self.kv_cache[block_idx]["k"][kv_start:kv_end] = k self.kv_cache[block_idx]["v"][kv_start:kv_end] = v - cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q=q, k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device)) + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q=q, k_lens=torch.tensor([kv_end], dtype=torch.int32)) if not self.parallel_attention: attn_out = weights.self_attn_1.apply( @@ -157,7 +157,7 @@ def infer_cross_attn(self, weights, x, context, block_idx): k = self.crossattn_cache[block_idx]["k"] v = self.crossattn_cache[block_idx]["v"] - cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) + cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=torch.tensor([k.size(0)], dtype=torch.int32)) attn_out = weights.cross_attn_1.apply( q=q, @@ -175,7 +175,7 @@ def infer_cross_attn(self, weights, x, context, block_idx): cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( q, - k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), + k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32), ) img_attn_out = weights.cross_attn_2.apply( diff --git a/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py b/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py index dfed29f1c..d8224d4c4 100644 --- a/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py @@ -78,15 +78,9 @@ def infer_cross_attn(self, phase, x, context, y_out, gate_msa, block, conditiona v = phase.cross_attn_v.apply(context).view(-1, n, d) if self.cross_attn_cu_seqlens_q is None: - if self.cross_attn_1_type == "flash_attn2" or self.cross_attn_1_type == "flash_attn3": - self.cross_attn_cu_seqlens_q = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True) - else: - self.cross_attn_cu_seqlens_q = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32) + self.cross_attn_cu_seqlens_q = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32) if self.cross_attn_cu_seqlens_kv is None: - if self.cross_attn_1_type == "flash_attn2" or self.cross_attn_1_type == "flash_attn3": - self.cross_attn_cu_seqlens_kv = torch.tensor([0, k.shape[0]]).cumsum(0, dtype=torch.int32).to(k.device, non_blocking=True) - else: - self.cross_attn_cu_seqlens_kv = torch.tensor([0, k.shape[0]]).cumsum(0, dtype=torch.int32) + self.cross_attn_cu_seqlens_kv = torch.tensor([0, k.shape[0]]).cumsum(0, dtype=torch.int32) attn_out = phase.cross_attn_1.apply( q=q, k=k, @@ -102,10 +96,7 @@ def infer_cross_attn(self, phase, x, context, y_out, gate_msa, block, conditiona v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d) if self.cross_attn_cu_seqlens_kv_img is None: - if self.cross_attn_2_type == "flash_attn2" or self.cross_attn_2_type == "flash_attn3": - self.cross_attn_cu_seqlens_kv_img = torch.tensor([0, k_img.shape[0]]).cumsum(0, dtype=torch.int32).to(k_img.device, non_blocking=True) - else: - self.cross_attn_cu_seqlens_kv_img = torch.tensor([0, k_img.shape[0]]).cumsum(0, dtype=torch.int32) + self.cross_attn_cu_seqlens_kv_img = torch.tensor([0, k_img.shape[0]]).cumsum(0, dtype=torch.int32) img_attn_out = phase.cross_attn_2.apply( q=q, diff --git a/lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py b/lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py index 5110de169..1b02512ef 100644 --- a/lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py @@ -89,7 +89,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): x = weights.patch_embedding.apply(x.unsqueeze(0)) grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] x = x.flatten(2).transpose(1, 2).contiguous() - seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) + seq_lens = torch.tensor(x.size(1), dtype=torch.int32).unsqueeze(0) embed_tmp = sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x) embed = self.time_embedding(weights, embed_tmp) diff --git a/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py b/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py index 2d33679d2..38221b6b0 100755 --- a/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py @@ -175,8 +175,8 @@ def _sp_kvcache_attn(self, q, k_cache, v_cache, phase): full_k = self._a2a_seq_to_heads(k_cache, world_size, shard_heads, self.seq_p_group) full_v = self._a2a_seq_to_heads(v_cache, world_size, shard_heads, self.seq_p_group) - q_lens = torch.tensor([full_q.size(0)], dtype=torch.int32, device=full_q.device) - k_lens = torch.tensor([full_k.size(0)], dtype=torch.int32, device=full_k.device) + q_lens = torch.tensor([full_q.size(0)], dtype=torch.int32) + k_lens = torch.tensor([full_k.size(0)], dtype=torch.int32) cu_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) @@ -407,7 +407,7 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa, bloc cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( q, - k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device), + k_lens=torch.tensor([k.size(0)], dtype=torch.int32), ) attn_out = phase.cross_attn_1.apply( q=q, @@ -425,7 +425,7 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa, bloc cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( q, - k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), + k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32), ) img_attn_out = phase.cross_attn_2.apply( q=q, diff --git a/lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py b/lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py index d828476cd..4b4cfb1a4 100644 --- a/lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py @@ -75,7 +75,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w)) x = x.flatten(2).transpose(1, 2) # B FHW C' - seq_lens = torch.tensor([u.size(0) for u in x], dtype=torch.long, device=torch.device("cuda")) + seq_lens = torch.tensor([u.size(0) for u in x], dtype=torch.long) assert seq_lens[0] <= 15 * 1 * 880 embed_tmp = sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x) # torch.Size([3, 256]) diff --git a/lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py b/lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py index 915b5167c..c2f0a02a9 100755 --- a/lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py @@ -264,7 +264,7 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa): cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( q, - k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device), + k_lens=torch.tensor([k.size(0)], dtype=torch.int32), ) attn_out = phase.cross_attn_1.apply( diff --git a/lightx2v/models/networks/wan/infer/pre_infer.py b/lightx2v/models/networks/wan/infer/pre_infer.py index 2a811aa6f..8d9e85d1a 100755 --- a/lightx2v/models/networks/wan/infer/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/pre_infer.py @@ -134,7 +134,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] x = x.flatten(2).transpose(1, 2).contiguous() - # seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) + # seq_lens = torch.tensor(x.size(1), dtype=torch.int32).unsqueeze(0) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) if self.enable_dynamic_cfg: diff --git a/lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py b/lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py index 45a015095..c45ad2bc7 100755 --- a/lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py @@ -82,7 +82,7 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): x = weights.patch_embedding.apply(x.unsqueeze(0)) grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] x = x.flatten(2).transpose(1, 2).contiguous() - seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) + seq_lens = torch.tensor(x.size(1), dtype=torch.int32).unsqueeze(0) embed_tmp = sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x) embed = self.time_embedding(weights, embed_tmp) diff --git a/lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py b/lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py index aa455a0dc..8445923c3 100755 --- a/lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py @@ -58,7 +58,7 @@ def __init__(self, config): self._initialize_crossattn_cache(self.device, self.dtype) def _calculate_q_k_len(self, q, k_lens): - q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) + q_lens = torch.tensor([q.size(0)], dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) return cu_seqlens_q, cu_seqlens_k @@ -282,7 +282,7 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa): cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( q, - k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device), + k_lens=torch.tensor([k.size(0)], dtype=torch.int32), ) attn_out = phase.cross_attn_1.apply( q=q, @@ -300,7 +300,7 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa): cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( q, - k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), + k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32), ) img_attn_out = phase.cross_attn_2.apply( q=q, diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py index ee0af8b14..8daba4808 100755 --- a/lightx2v/models/networks/wan/infer/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/transformer_infer.py @@ -21,10 +21,6 @@ class WanTransformerInfer(BaseTransformerInfer): def __init__(self, config): self.config = config self.task = config["task"] - self.attention_type = config.get("attention_type", "flash_attn2") - self.self_attn_1_type = config.get("self_attn_1_type", "flash_attn2") - self.cross_attn_1_type = config.get("cross_attn_1_type", "flash_attn2") - self.cross_attn_2_type = config.get("cross_attn_2_type", "flash_attn2") self.blocks_num = config["num_layers"] self.phases_num = 3 self.has_post_adapter = False @@ -202,10 +198,7 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): q, k = self.apply_rope_func(q, k, cos_sin) img_qkv_len = q.shape[0] if self.self_attn_cu_seqlens_qkv is None: - if self.self_attn_1_type in ["flash_attn2", "flash_attn3"]: - self.self_attn_cu_seqlens_qkv = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True) - else: - self.self_attn_cu_seqlens_qkv = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32) + self.self_attn_cu_seqlens_qkv = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32) if self.clean_cuda_cache: del norm1_out, shift_msa, scale_msa @@ -224,7 +217,6 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): slice_qkv_len=img_qkv_len, cu_seqlens_qkv=self.self_attn_cu_seqlens_qkv, attention_module=phase.self_attn_1, - attention_type=self.self_attn_1_type, seq_p_group=self.seq_p_group, use_fp8_comm=self.seq_p_fp8_comm, use_fp4_comm=self.seq_p_fp4_comm, @@ -276,15 +268,9 @@ def infer_cross_attn(self, phase, x, context, y_out, gate_msa): v = phase.cross_attn_v.apply(context).view(-1, n, d) if self.cross_attn_cu_seqlens_q is None: - if self.cross_attn_1_type == "flash_attn2" or self.cross_attn_1_type == "flash_attn3": - self.cross_attn_cu_seqlens_q = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True) - else: - self.cross_attn_cu_seqlens_q = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32) + self.cross_attn_cu_seqlens_q = torch.tensor([0, q.shape[0]]).cumsum(0, dtype=torch.int32) if self.cross_attn_cu_seqlens_kv is None: - if self.cross_attn_1_type == "flash_attn2" or self.cross_attn_1_type == "flash_attn3": - self.cross_attn_cu_seqlens_kv = torch.tensor([0, k.shape[0]]).cumsum(0, dtype=torch.int32).to(k.device, non_blocking=True) - else: - self.cross_attn_cu_seqlens_kv = torch.tensor([0, k.shape[0]]).cumsum(0, dtype=torch.int32) + self.cross_attn_cu_seqlens_kv = torch.tensor([0, k.shape[0]]).cumsum(0, dtype=torch.int32) attn_out = phase.cross_attn_1.apply( q=q, k=k, @@ -300,10 +286,7 @@ def infer_cross_attn(self, phase, x, context, y_out, gate_msa): v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d) if self.cross_attn_cu_seqlens_kv_img is None: - if self.cross_attn_2_type == "flash_attn2" or self.cross_attn_2_type == "flash_attn3": - self.cross_attn_cu_seqlens_kv_img = torch.tensor([0, k_img.shape[0]]).cumsum(0, dtype=torch.int32).to(k_img.device, non_blocking=True) - else: - self.cross_attn_cu_seqlens_kv_img = torch.tensor([0, k_img.shape[0]]).cumsum(0, dtype=torch.int32) + self.cross_attn_cu_seqlens_kv_img = torch.tensor([0, k_img.shape[0]]).cumsum(0, dtype=torch.int32) img_attn_out = phase.cross_attn_2.apply( q=q, diff --git a/lightx2v/models/networks/worldplay/infer/ar_transformer_infer.py b/lightx2v/models/networks/worldplay/infer/ar_transformer_infer.py index 9871de932..92da97745 100644 --- a/lightx2v/models/networks/worldplay/infer/ar_transformer_infer.py +++ b/lightx2v/models/networks/worldplay/infer/ar_transformer_infer.py @@ -327,7 +327,7 @@ def _infer_causal_attn(self, weights, img_q, img_k, img_v, txt_q, txt_k, txt_v, # Use bidirectional attention for full-video inference # (causal=False, same as distill model) seqlen = query.shape[1] - cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu") attn_out = weights.self_attention.apply( q=query, k=key, @@ -541,7 +541,7 @@ def infer_txt_with_offload(self, weights, infer_module_out, cache_txt=True): block_weights = self.offload_manager.cuda_buffers[0] txt_q, txt_k, txt_v, txt_branch_out = self._infer_txt_branch_before_attn(block_weights, infer_module_out) txt_seqlen = txt_q.shape[1] - cu_seqlens_qkv = torch.tensor([0, txt_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + cu_seqlens_qkv = torch.tensor([0, txt_seqlen], dtype=torch.int32, device="cpu") txt_attn = block_weights.self_attention.apply( q=txt_q, k=txt_k, @@ -603,8 +603,8 @@ def infer_vision_with_offload(self, weights, infer_module_out, cache_vision=Fals value_full = value_full.transpose(1, 2) img_seqlen = query.shape[1] kv_seqlen = key_full.shape[1] - cu_seqlens_q = torch.tensor([0, img_seqlen, 2 * img_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) - cu_seqlens_kv = torch.tensor([0, kv_seqlen, 2 * kv_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + cu_seqlens_q = torch.tensor([0, img_seqlen, 2 * img_seqlen], dtype=torch.int32, device="cpu") + cu_seqlens_kv = torch.tensor([0, kv_seqlen, 2 * kv_seqlen], dtype=torch.int32, device="cpu") attn_out = block_weights.self_attention.apply( q=query, k=key_full, @@ -644,8 +644,8 @@ def infer_vision_with_offload(self, weights, infer_module_out, cache_vision=Fals value = value.transpose(1, 2) img_seqlen = img_q.shape[1] kv_seqlen = key.shape[1] - cu_seqlens_q = torch.tensor([0, img_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) - cu_seqlens_kv = torch.tensor([0, kv_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + cu_seqlens_q = torch.tensor([0, img_seqlen], dtype=torch.int32, device="cpu") + cu_seqlens_kv = torch.tensor([0, kv_seqlen], dtype=torch.int32, device="cpu") img_attn = block_weights.self_attention.apply( q=img_q, k=key, @@ -702,7 +702,7 @@ def infer_txt(self, weights, infer_module_out, cache_txt=True): # Text self-attention (is_causal=False) txt_seqlen = txt_q.shape[1] - cu_seqlens_qkv = torch.tensor([0, txt_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + cu_seqlens_qkv = torch.tensor([0, txt_seqlen], dtype=torch.int32, device="cpu") txt_attn = block_weights.self_attention.apply( q=txt_q, k=txt_k, @@ -817,8 +817,8 @@ def infer_vision(self, weights, infer_module_out, cache_vision=False): kv_seqlen = key_full.shape[1] # cu_seqlens for 2 sequences (normal and prope) - cu_seqlens_q = torch.tensor([0, img_seqlen, 2 * img_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) - cu_seqlens_kv = torch.tensor([0, kv_seqlen, 2 * kv_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + cu_seqlens_q = torch.tensor([0, img_seqlen, 2 * img_seqlen], dtype=torch.int32, device="cpu") + cu_seqlens_kv = torch.tensor([0, kv_seqlen, 2 * kv_seqlen], dtype=torch.int32, device="cpu") # Single attention call for both streams attn_out = block_weights.self_attention.apply( @@ -878,8 +878,8 @@ def infer_vision(self, weights, infer_module_out, cache_vision=False): # Attention computation img_seqlen = img_q.shape[1] kv_seqlen = key.shape[1] - cu_seqlens_q = torch.tensor([0, img_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) - cu_seqlens_kv = torch.tensor([0, kv_seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True) + cu_seqlens_q = torch.tensor([0, img_seqlen], dtype=torch.int32, device="cpu") + cu_seqlens_kv = torch.tensor([0, kv_seqlen], dtype=torch.int32, device="cpu") img_attn = block_weights.self_attention.apply( q=img_q, diff --git a/lightx2v/models/networks/z_image/infer/transformer_infer.py b/lightx2v/models/networks/z_image/infer/transformer_infer.py index 711019c2e..41145f3f7 100755 --- a/lightx2v/models/networks/z_image/infer/transformer_infer.py +++ b/lightx2v/models/networks/z_image/infer/transformer_infer.py @@ -65,7 +65,7 @@ def infer_attn(self, attn_phase, hidden_states, freqs_cis, scale_msa=None): query, key = self.apply_rope_func(query, key, freqs_cis) total_seq_len = query.shape[0] - cu_seqlens = torch.tensor([0, total_seq_len], dtype=torch.int32, device="cpu").to(query.device, non_blocking=True) + cu_seqlens = torch.tensor([0, total_seq_len], dtype=torch.int32, device="cpu") if self.config["seq_parallel"]: hidden_states_out = attn_phase.calculate_parallel.apply(