Skip to content

Commit e662607

Browse files
author
钮圣虓
committed
refine5
1 parent bdf82d2 commit e662607

2 files changed

Lines changed: 76 additions & 278 deletions

File tree

lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,20 +164,34 @@ def _nsa_decode_att(
164164
import flash_mla
165165

166166
nsa_dict = att_control.nsa_decode_dict
167+
layer_index = nsa_dict["layer_index"]
167168
topk_mem_indices = nsa_dict["topk_mem_indices"]
168169
softmax_scale = nsa_dict["softmax_scale"]
169170
kv_lora_rank = nsa_dict["kv_lora_rank"]
170171

172+
mem_manager = self.infer_state.mem_manager
173+
if hasattr(mem_manager, "get_decode_kv_cache_and_remap_indices"):
174+
kv, topk_mem_indices = mem_manager.get_decode_kv_cache_and_remap_indices(
175+
layer_index=layer_index,
176+
topk_mem_indices=topk_mem_indices,
177+
)
178+
171179
if topk_mem_indices.ndim == 2:
172180
topk_mem_indices = topk_mem_indices.unsqueeze(1)
173181
assert topk_mem_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1"
174182

175183
q_nope, q_rope = q
176184
q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous()
185+
if kv.shape[0] == 0:
186+
return torch.zeros(
187+
(q_nope.shape[0], q_nope.shape[1], kv_lora_rank),
188+
dtype=q_nope.dtype,
189+
device=q_nope.device,
190+
)
177191

178192
o_tensor, _ = flash_mla.flash_mla_with_kvcache(
179193
q=q_all,
180-
k_cache=kv.contiguous(),
194+
k_cache=kv if kv.is_contiguous() else kv.contiguous(),
181195
block_table=None,
182196
cache_seqlens=None,
183197
head_dim_v=kv_lora_rank,

0 commit comments

Comments
 (0)