-
Notifications
You must be signed in to change notification settings - Fork 201
fix fa varlen seqlen #1031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix fa varlen seqlen #1031
Changes from all commits
c47e001
3947f3b
d2a907c
c3bb834
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,16 +5,19 @@ | |||||||||||||||||||||||||
| from .utils.sparge_util import block_map_ordinal_lut_triton, get_block_map_meansim | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| import flash_attn # noqa: F401 | ||||||||||||||||||||||||||
| from flash_attn.flash_attn_interface import flash_attn_varlen_func | ||||||||||||||||||||||||||
| from flash_attn import flash_attn_func_v2 | ||||||||||||||||||||||||||
| from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2 | ||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||
| logger.info("flash_attn_varlen_func not found, please install flash_attn2 first") | ||||||||||||||||||||||||||
| flash_attn_varlen_func = None | ||||||||||||||||||||||||||
| logger.info("flash_attn2 not found, please install flash_attn2 first") | ||||||||||||||||||||||||||
| flash_attn_func_v2 = None | ||||||||||||||||||||||||||
| flash_attn_varlen_func_v2 = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| from flash_attn_interface import flash_attn_func as flash_attn_func_v3 | ||||||||||||||||||||||||||
| from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 | ||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||
| logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") | ||||||||||||||||||||||||||
| logger.info("flash_attn3 not found, please install flash_attn3 first") | ||||||||||||||||||||||||||
| flash_attn_func_v3 = None | ||||||||||||||||||||||||||
| flash_attn_varlen_func_v3 = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
|
|
@@ -49,18 +52,37 @@ def apply( | |||||||||||||||||||||||||
| bs = 1 | ||||||||||||||||||||||||||
| elif len(q.shape) == 4: | ||||||||||||||||||||||||||
| bs = q.shape[0] | ||||||||||||||||||||||||||
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | ||||||||||||||||||||||||||
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | ||||||||||||||||||||||||||
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) | ||||||||||||||||||||||||||
| x = flash_attn_varlen_func( | ||||||||||||||||||||||||||
| q, | ||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||
| v, | ||||||||||||||||||||||||||
| cu_seqlens_q, | ||||||||||||||||||||||||||
| cu_seqlens_kv, | ||||||||||||||||||||||||||
| max_seqlen_q, | ||||||||||||||||||||||||||
| max_seqlen_kv, | ||||||||||||||||||||||||||
| ).reshape(bs * max_seqlen_q, -1) | ||||||||||||||||||||||||||
| total_seqlen = bs * max_seqlen_q | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if bs == 1: | ||||||||||||||||||||||||||
| if len(q.shape) == 3: | ||||||||||||||||||||||||||
| q = q.unsqueeze(0) | ||||||||||||||||||||||||||
| k = k.unsqueeze(0) | ||||||||||||||||||||||||||
| v = v.unsqueeze(0) | ||||||||||||||||||||||||||
| x = flash_attn_func_v2(q, k, v).reshape(bs * max_seqlen_q, -1) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| if cu_seqlens_q.is_cpu: | ||||||||||||||||||||||||||
| cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True) | ||||||||||||||||||||||||||
| if cu_seqlens_kv.is_cpu: | ||||||||||||||||||||||||||
| cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True) | ||||||||||||||||||||||||||
| if max_seqlen_q.is_cpu: | ||||||||||||||||||||||||||
| max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) | ||||||||||||||||||||||||||
| if max_seqlen_kv.is_cpu: | ||||||||||||||||||||||||||
| max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) | ||||||||||||||||||||||||||
| if len(q.shape) == 4: | ||||||||||||||||||||||||||
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | ||||||||||||||||||||||||||
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | ||||||||||||||||||||||||||
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) | ||||||||||||||||||||||||||
|
Comment on lines
+68
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||
| x = flash_attn_varlen_func_v2( | ||||||||||||||||||||||||||
| q, | ||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||
| v, | ||||||||||||||||||||||||||
| cu_seqlens_q, | ||||||||||||||||||||||||||
| cu_seqlens_kv, | ||||||||||||||||||||||||||
| max_seqlen_q, | ||||||||||||||||||||||||||
| max_seqlen_kv, | ||||||||||||||||||||||||||
| ).reshape(total_seqlen, -1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -84,18 +106,37 @@ def apply( | |||||||||||||||||||||||||
| bs = 1 | ||||||||||||||||||||||||||
| elif len(q.shape) == 4: | ||||||||||||||||||||||||||
| bs = q.shape[0] | ||||||||||||||||||||||||||
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | ||||||||||||||||||||||||||
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | ||||||||||||||||||||||||||
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) | ||||||||||||||||||||||||||
| x = flash_attn_varlen_func_v3( | ||||||||||||||||||||||||||
| q, | ||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||
| v, | ||||||||||||||||||||||||||
| cu_seqlens_q, | ||||||||||||||||||||||||||
| cu_seqlens_kv, | ||||||||||||||||||||||||||
| max_seqlen_q, | ||||||||||||||||||||||||||
| max_seqlen_kv, | ||||||||||||||||||||||||||
| ).reshape(bs * max_seqlen_q, -1) | ||||||||||||||||||||||||||
| total_seqlen = bs * max_seqlen_q | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if bs == 1: | ||||||||||||||||||||||||||
| if len(q.shape) == 3: | ||||||||||||||||||||||||||
| q = q.unsqueeze(0) | ||||||||||||||||||||||||||
| k = k.unsqueeze(0) | ||||||||||||||||||||||||||
| v = v.unsqueeze(0) | ||||||||||||||||||||||||||
| x = flash_attn_func_v3(q, k, v).reshape(bs * max_seqlen_q, -1) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| if cu_seqlens_q.is_cpu: | ||||||||||||||||||||||||||
| cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True) | ||||||||||||||||||||||||||
| if cu_seqlens_kv.is_cpu: | ||||||||||||||||||||||||||
| cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True) | ||||||||||||||||||||||||||
| if max_seqlen_q.is_cpu: | ||||||||||||||||||||||||||
| max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) | ||||||||||||||||||||||||||
| if max_seqlen_kv.is_cpu: | ||||||||||||||||||||||||||
| max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) | ||||||||||||||||||||||||||
| if len(q.shape) == 4: | ||||||||||||||||||||||||||
| q = q.reshape(-1, q.shape[-2], q.shape[-1]) | ||||||||||||||||||||||||||
| k = k.reshape(-1, k.shape[-2], k.shape[-1]) | ||||||||||||||||||||||||||
| v = v.reshape(-1, v.shape[-2], v.shape[-1]) | ||||||||||||||||||||||||||
|
Comment on lines
+122
to
+129
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the issue in
Suggested change
|
||||||||||||||||||||||||||
| x = flash_attn_varlen_func_v3( | ||||||||||||||||||||||||||
| q, | ||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||
| v, | ||||||||||||||||||||||||||
| cu_seqlens_q, | ||||||||||||||||||||||||||
| cu_seqlens_kv, | ||||||||||||||||||||||||||
| max_seqlen_q, | ||||||||||||||||||||||||||
| max_seqlen_kv, | ||||||||||||||||||||||||||
| ).reshape(total_seqlen, -1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -160,8 +160,8 @@ def self_attn( | |||||||||
| merged_value_states = packed_value_states | ||||||||||
| key_values_lens = query_lens | ||||||||||
|
|
||||||||||
| cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)).to(AI_DEVICE) | ||||||||||
| cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)).to(AI_DEVICE) | ||||||||||
| cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) | ||||||||||
| cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) | ||||||||||
|
Comment on lines
+163
to
+164
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing the
Suggested change
|
||||||||||
|
|
||||||||||
| packed_attn_output = flash_attn_varlen_func( | ||||||||||
| q=packed_query_states, | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import statements for
flash_attn_func_v2andflash_attn_varlen_func_v2appear to be missing theaskeyword. Standardflash_attnpackage does not export these names directly; they should likely be aliased from the standard function names to maintain consistency with thev3andv4imports below.