Skip to content

Commit bdf82d2

Browse files
author
钮圣虓
committed
fix
1 parent 8157948 commit bdf82d2

1 file changed

Lines changed: 7 additions & 12 deletions

File tree

lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ def _nsa_prefill_att(
7272
import flash_mla
7373

7474
nsa_dict = att_control.nsa_prefill_dict
75-
layer_index = nsa_dict["layer_index"]
76-
topk_mem_indices = nsa_dict["topk_mem_indices"]
7775
topk_indices = nsa_dict["topk_indices"]
78-
prefill_cache_kv = nsa_dict["prefill_cache_kv"]
7976
softmax_scale = nsa_dict["softmax_scale"]
8077
kv_lora_rank = nsa_dict["kv_lora_rank"]
78+
layer_index = nsa_dict["layer_index"]
79+
topk_mem_indices = nsa_dict["topk_mem_indices"]
80+
prefill_cache_kv = nsa_dict["prefill_cache_kv"]
8181

8282
if self.infer_state.prefix_total_token_num > 0:
8383
kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices(
@@ -92,17 +92,12 @@ def _nsa_prefill_att(
9292
if topk_indices.ndim == 2:
9393
topk_indices = topk_indices.unsqueeze(1)
9494

95-
topk_length = torch.sum(topk_indices != -1, dim=-1, dtype=torch.int32)
96-
if topk_length.ndim == 2 and topk_length.shape[1] == 1:
97-
topk_length = topk_length[:, 0].contiguous()
98-
9995
mla_out, _, _ = flash_mla.flash_mla_sparse_fwd(
100-
q=q.contiguous(),
101-
kv=kv.contiguous(),
102-
indices=topk_indices.contiguous(),
96+
q=q,
97+
kv=kv,
98+
indices=topk_indices,
10399
sm_scale=softmax_scale,
104100
d_v=kv_lora_rank,
105-
topk_length=topk_length,
106101
)
107102
return mla_out
108103

@@ -193,4 +188,4 @@ def _nsa_decode_att(
193188
is_fp8_kvcache=True,
194189
indices=topk_mem_indices.contiguous(),
195190
)
196-
return o_tensor[:, 0, :, :]
191+
return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d]

0 commit comments

Comments
 (0)