Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 70 additions & 29 deletions lightx2v/common/ops/attn/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
from .utils.sparge_util import block_map_ordinal_lut_triton, get_block_map_meansim

try:
import flash_attn # noqa: F401
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn import flash_attn_func_v2
from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2
Comment on lines +8 to +9
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import statements for flash_attn_func_v2 and flash_attn_varlen_func_v2 appear to be missing the as keyword. Standard flash_attn package does not export these names directly; they should likely be aliased from the standard function names to maintain consistency with the v3 and v4 imports below.

Suggested change
from flash_attn import flash_attn_func_v2
from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2
from flash_attn import flash_attn_func as flash_attn_func_v2
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v2

except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
logger.info("flash_attn2 not found, please install flash_attn2 first")
flash_attn_func_v2 = None
flash_attn_varlen_func_v2 = None

try:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
logger.info("flash_attn3 not found, please install flash_attn3 first")
flash_attn_func_v3 = None
flash_attn_varlen_func_v3 = None

try:
Expand Down Expand Up @@ -49,18 +52,37 @@ def apply(
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
total_seqlen = bs * max_seqlen_q

if bs == 1:
if len(q.shape) == 3:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
x = flash_attn_func_v2(q, k, v).reshape(bs * max_seqlen_q, -1)
else:
if cu_seqlens_q.is_cpu:
cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True)
if cu_seqlens_kv.is_cpu:
cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True)
if max_seqlen_q.is_cpu:
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
if max_seqlen_kv.is_cpu:
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
Comment on lines +68 to +75
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The is_cpu check and .to() call on max_seqlen_q and max_seqlen_kv will cause an AttributeError if these arguments are passed as integers, which is the case in several models (e.g., Wan, Ulysses). Furthermore, Flash Attention kernels expect these values as host-side integers, so moving them to the GPU device is unnecessary and potentially incorrect.

Suggested change
if max_seqlen_q.is_cpu:
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
if max_seqlen_kv.is_cpu:
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])

x = flash_attn_varlen_func_v2(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(total_seqlen, -1)

return x


Expand All @@ -84,18 +106,37 @@ def apply(
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
total_seqlen = bs * max_seqlen_q

if bs == 1:
if len(q.shape) == 3:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
x = flash_attn_func_v3(q, k, v).reshape(bs * max_seqlen_q, -1)
else:
if cu_seqlens_q.is_cpu:
cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True)
if cu_seqlens_kv.is_cpu:
cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True)
if max_seqlen_q.is_cpu:
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
if max_seqlen_kv.is_cpu:
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
Comment on lines +122 to +129
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the issue in FlashAttn2Weight, the is_cpu check on max_seqlen_q and max_seqlen_kv will crash if they are integers. These parameters should remain as host-side integers for the Flash Attention kernel launch.

Suggested change
if max_seqlen_q.is_cpu:
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
if max_seqlen_kv.is_cpu:
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])

x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(total_seqlen, -1)

return x


Expand Down
2 changes: 0 additions & 2 deletions lightx2v/common/ops/attn/ring_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def apply(
slice_qkv_len,
cu_seqlens_qkv,
attention_module=None,
attention_type="flash_attn2",
seq_p_group=None,
use_fp8_comm=False,
use_fp4_comm=False,
Expand All @@ -76,7 +75,6 @@ def apply(
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
slice_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
use_fp8_comm: 是否使用 FP8 通信
use_fp4_comm: 是否使用 FP4 通信

Expand Down
10 changes: 0 additions & 10 deletions lightx2v/common/ops/attn/ulysses_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE

from .template import AttnWeightTemplate
from .utils.all2all import all2all_head2seq
Expand All @@ -31,7 +30,6 @@ def apply(
slice_qkv_len,
cu_seqlens_qkv,
attention_module=None,
attention_type="flash_attn2",
seq_p_group=None,
use_fp8_comm=False,
use_fp4_comm=False,
Expand All @@ -51,7 +49,6 @@ def apply(
v (torch.Tensor): 值张量,形状为 [shard_seqlen, kv_heads, hidden_dims]
slice_qkv_len (int): 图像或者文本查询、键和值的长度,根据 img_first 确定谁在前半部分
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
q_only_img (bool): 若为 True,q 只含图像 token,k/v 同时含图像和文本 token。
此时只对 k/v 做 img/txt 分割,q 整体参与图像侧 all-to-all。
支持 cross-attention 等 q 不含文本 token 的场景。
Expand Down Expand Up @@ -108,16 +105,12 @@ def apply(
cu_seqlens_kv[1] = txt_qkv_len + global_img_seqlen
if txt_mask_len:
cu_seqlens_kv = torch.cat((cu_seqlens_kv, torch.tensor([txt_mask_len + global_img_seqlen], dtype=torch.int32)))
if attention_type == "flash_attn2" or attention_type == "flash_attn3":
cu_seqlens_kv = cu_seqlens_kv.to(AI_DEVICE, non_blocking=True)
max_seqlen_kv = global_img_seqlen + txt_qkv_len

# q_only_img 时 q 只含图像 token,cu_seqlens_q 与 kv 侧不同
if q_only_img:
cu_seqlens_q = torch.zeros([2], dtype=torch.int32)
cu_seqlens_q[1] = global_img_seqlen
if attention_type == "flash_attn2" or attention_type == "flash_attn3":
cu_seqlens_q = cu_seqlens_q.to(AI_DEVICE, non_blocking=True)
max_seqlen_q = global_img_seqlen
else:
cu_seqlens_q = cu_seqlens_kv
Expand Down Expand Up @@ -599,7 +592,6 @@ def apply(
slice_qkv_len,
cu_seqlens_qkv,
attention_module=None,
attention_type="flash_attn2",
seq_p_group=None,
use_fp8_comm=False,
use_fp4_comm=False,
Expand All @@ -616,7 +608,6 @@ def apply(
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
slice_qkv_len (int): 图像或者文本查询、键和值的长度,根据img_first确定谁在前半部分
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"

返回:
torch.Tensor: 计算得到的注意力结果
Expand Down Expand Up @@ -748,7 +739,6 @@ def apply(
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度

# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs)

# 分割图像和文本的注意力结果
Expand Down
4 changes: 2 additions & 2 deletions lightx2v/models/networks/bagel/infer/transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def self_attn(
merged_value_states = packed_value_states
key_values_lens = query_lens

cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)).to(AI_DEVICE)
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)).to(AI_DEVICE)
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
Comment on lines +163 to +164
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Removing the .to(AI_DEVICE) call here will cause a device mismatch error. Unlike other models that use the attention wrappers in common/ops/attn/flash_attn.py (which handle device placement), this file calls the raw flash_attn_varlen_func directly. The raw function requires cu_seqlens to be on the same device as the input tensors.

Suggested change
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)).to(AI_DEVICE)
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)).to(AI_DEVICE)


packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
Expand Down
4 changes: 2 additions & 2 deletions lightx2v/models/networks/flux2/infer/transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def infer_double_stream_block(
query, key = self.apply_rope_func(query, key, image_rotary_emb)

total_len = query.shape[0]
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device=query.device)
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32)

model_cls = self.config.get("model_cls", "flux2_klein")

Expand Down Expand Up @@ -193,7 +193,7 @@ def infer_single_stream_block(
query, key = self.apply_rope_func(query, key, image_rotary_emb)

total_len = query.shape[0]
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device=query.device)
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32)

model_cls = self.config.get("model_cls", "flux2_klein")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
apply_rope_with_cos_sin_cache_inplace = None

from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE

from .module_io import HunyuanVideo15ImgBranchOutput, HunyuanVideo15TxtBranchOutput
from .triton_ops import fuse_scale_shift_kernel
Expand Down Expand Up @@ -226,7 +225,7 @@ def _infer_attn(self, weights, img_q, img_k, img_v, txt_q, txt_k, txt_v):
key = torch.cat([img_k, txt_k], dim=1)
value = torch.cat([img_v, txt_v], dim=1)
seqlen = query.shape[1]
cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to(AI_DEVICE, non_blocking=True)
cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu")

if self.config["seq_parallel"]:
attn_out = weights.self_attention_parallel.apply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def infer_double_stream_block(

# Calculate cu_seqlens for flash attention (batch_size=1)
total_len = query.shape[0]
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device=query.device)
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32)

# Use registered attention module
attn_output = block_weights.calculate.apply(
Expand Down Expand Up @@ -248,7 +248,7 @@ def infer_single_stream_block(

# Calculate cu_seqlens for flash attention (batch_size=1)
total_len = query.shape[0]
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device=query.device)
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32)

# Use registered attention module
attn_output = block_weights.calculate.apply(
Expand Down
19 changes: 7 additions & 12 deletions lightx2v/models/networks/ltx2/infer/transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,17 @@ def set_guidance_perturbation(
self._mm_skip_a2v = bool(skip_a2v)
self._mm_skip_v2a = bool(skip_v2a)

def _create_cu_seqlens(self, seq_len: int, device: torch.device) -> torch.Tensor:
def _create_cu_seqlens(self, seq_len: int) -> torch.Tensor:
"""
Create cumulative sequence lengths tensor for attention.

Args:
seq_len: Sequence length
device: Device to place the tensor on

Returns:
Cumulative sequence lengths tensor [0, seq_len]
"""
if self.config["attn_type"] in ["flash_attn2", "flash_attn3"]:
return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32).to(device, non_blocking=True)
else:
return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32)
return torch.tensor([0, seq_len]).cumsum(0, dtype=torch.int32)

def _gather_cross_attn_context(self, context: torch.Tensor, k_pe=None):
"""
Expand Down Expand Up @@ -282,7 +278,7 @@ def _infer_attn(
if is_self_attn and not is_audio and self.config.get("seq_parallel", False) and not use_tp:
# Cache cu_seqlens_qkv for self-attention (q, k, v have same length)
if self.v_attn_cu_seqlens_qkv is None:
self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device)
self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0])

out = attn_phase.attn_func_parallel.apply(
q=q,
Expand All @@ -291,7 +287,6 @@ def _infer_attn(
slice_qkv_len=seq_len,
cu_seqlens_qkv=self.v_attn_cu_seqlens_qkv,
attention_module=attn_phase.attn_func,
attention_type=self.config["attn_type"],
seq_p_group=self.seq_p_group,
use_fp8_comm=self.seq_p_fp8_comm,
use_fp4_comm=self.seq_p_fp4_comm,
Expand All @@ -303,17 +298,17 @@ def _infer_attn(
# Cache cu_seqlens_qkv for self-attention only
if is_self_attn:
if not is_audio and self.v_attn_cu_seqlens_qkv is None:
self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device)
self.v_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0])
elif is_audio and self.a_attn_cu_seqlens_qkv is None:
self.a_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0], q.device)
self.a_attn_cu_seqlens_qkv = self._create_cu_seqlens(q.shape[0])

cu_seqlens_q = self.v_attn_cu_seqlens_qkv if not is_audio else self.a_attn_cu_seqlens_qkv
cu_seqlens_kv = cu_seqlens_q # For self-attn, q and k have same length
else:
# For cross-attention, always create cu_seqlens dynamically
# because k length varies by context type (text, audio, video)
cu_seqlens_q = self._create_cu_seqlens(q.shape[0], q.device)
cu_seqlens_kv = self._create_cu_seqlens(k.shape[0], k.device)
cu_seqlens_q = self._create_cu_seqlens(q.shape[0])
cu_seqlens_kv = self._create_cu_seqlens(k.shape[0])

out = attn_phase.attn_func.apply(
q=q,
Expand Down
6 changes: 2 additions & 4 deletions lightx2v/models/networks/neopp/infer/transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(self, config):
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.cross_attn_type = self.config["attn_type"]
self.kv_cache = KVCacheManager()

@torch.no_grad()
Expand All @@ -57,8 +56,8 @@ def infer(self, weights, pre_infer_out, inputs):
if self._seqlen_cache.get(_cache_key, {}).get("seqlens") != (seq_len_q, seq_len_k):
self._seqlen_cache[_cache_key] = {
"seqlens": (seq_len_q, seq_len_k),
"cu_q": torch.tensor([0, seq_len_q], dtype=torch.int32, device=hidden_states.device),
"cu_k": torch.tensor([0, seq_len_k], dtype=torch.int32, device=hidden_states.device),
"cu_q": torch.tensor([0, seq_len_q], dtype=torch.int32),
"cu_k": torch.tensor([0, seq_len_k], dtype=torch.int32),
}
self._cu_seqlens_q = self._seqlen_cache[_cache_key]["cu_q"]
self._cu_seqlens_k = self._seqlen_cache[_cache_key]["cu_k"]
Expand Down Expand Up @@ -173,7 +172,6 @@ def _compute_attn(self, attn_w, query_states, key_states, value_states):
slice_qkv_len=self._kvcache_len,
cu_seqlens_qkv=self._cu_seqlens_k,
attention_module=attn_w.cross_attn,
attention_type=self.cross_attn_type,
seq_p_group=self.seq_p_group,
img_first=False,
q_only_img=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def calculate_q_k_len(q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
q_lens = torch.tensor([q.size(0)], dtype=torch.int32)
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k
Expand Down Expand Up @@ -205,7 +205,7 @@ def infer_cross_attn(
joint_value = torch.cat([txt_value, img_value], dim=0)

img_qkv_len = joint_query.shape[0]
cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu").to(joint_query.device, non_blocking=True)
cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu")

if self.config["seq_parallel"]:
joint_hidden_states = cross_attn_phase.calculate_parallel.apply(
Expand Down
Loading
Loading