@@ -164,20 +164,34 @@ def _nsa_decode_att(
164164 import flash_mla
165165
166166 nsa_dict = att_control .nsa_decode_dict
167+ layer_index = nsa_dict ["layer_index" ]
167168 topk_mem_indices = nsa_dict ["topk_mem_indices" ]
168169 softmax_scale = nsa_dict ["softmax_scale" ]
169170 kv_lora_rank = nsa_dict ["kv_lora_rank" ]
170171
172+ mem_manager = self .infer_state .mem_manager
173+ if hasattr (mem_manager , "get_decode_kv_cache_and_remap_indices" ):
174+ kv , topk_mem_indices = mem_manager .get_decode_kv_cache_and_remap_indices (
175+ layer_index = layer_index ,
176+ topk_mem_indices = topk_mem_indices ,
177+ )
178+
171179 if topk_mem_indices .ndim == 2 :
172180 topk_mem_indices = topk_mem_indices .unsqueeze (1 )
173181 assert topk_mem_indices .shape [1 ] == 1 , "FlashMLA sparse decode path currently expects seq_len_q == 1"
174182
175183 q_nope , q_rope = q
176184 q_all = torch .cat ([q_nope , q_rope ], dim = - 1 ).unsqueeze (1 ).contiguous ()
185+ if kv .shape [0 ] == 0 :
186+ return torch .zeros (
187+ (q_nope .shape [0 ], q_nope .shape [1 ], kv_lora_rank ),
188+ dtype = q_nope .dtype ,
189+ device = q_nope .device ,
190+ )
177191
178192 o_tensor , _ = flash_mla .flash_mla_with_kvcache (
179193 q = q_all ,
180- k_cache = kv .contiguous (),
194+ k_cache = kv if kv . is_contiguous () else kv .contiguous (),
181195 block_table = None ,
182196 cache_seqlens = None ,
183197 head_dim_v = kv_lora_rank ,
0 commit comments