@@ -72,12 +72,12 @@ def _nsa_prefill_att(
7272 import flash_mla
7373
7474 nsa_dict = att_control .nsa_prefill_dict
75- layer_index = nsa_dict ["layer_index" ]
76- topk_mem_indices = nsa_dict ["topk_mem_indices" ]
7775 topk_indices = nsa_dict ["topk_indices" ]
78- prefill_cache_kv = nsa_dict ["prefill_cache_kv" ]
7976 softmax_scale = nsa_dict ["softmax_scale" ]
8077 kv_lora_rank = nsa_dict ["kv_lora_rank" ]
78+ layer_index = nsa_dict ["layer_index" ]
79+ topk_mem_indices = nsa_dict ["topk_mem_indices" ]
80+ prefill_cache_kv = nsa_dict ["prefill_cache_kv" ]
8181
8282 if self .infer_state .prefix_total_token_num > 0 :
8383 kv , topk_indices = self .infer_state .mem_manager .get_prefill_kv_cache_and_remap_indices (
@@ -92,17 +92,12 @@ def _nsa_prefill_att(
9292 if topk_indices .ndim == 2 :
9393 topk_indices = topk_indices .unsqueeze (1 )
9494
95- topk_length = torch .sum (topk_indices != - 1 , dim = - 1 , dtype = torch .int32 )
96- if topk_length .ndim == 2 and topk_length .shape [1 ] == 1 :
97- topk_length = topk_length [:, 0 ].contiguous ()
98-
9995 mla_out , _ , _ = flash_mla .flash_mla_sparse_fwd (
100- q = q . contiguous () ,
101- kv = kv . contiguous () ,
102- indices = topk_indices . contiguous () ,
96+ q = q ,
97+ kv = kv ,
98+ indices = topk_indices ,
10399 sm_scale = softmax_scale ,
104100 d_v = kv_lora_rank ,
105- topk_length = topk_length ,
106101 )
107102 return mla_out
108103
@@ -193,4 +188,4 @@ def _nsa_decode_att(
193188 is_fp8_kvcache = True ,
194189 indices = topk_mem_indices .contiguous (),
195190 )
196- return o_tensor [:, 0 , :, :]
191+ return o_tensor [:, 0 , :, :] # [b, 1, h, d] -> [b, h, d]
0 commit comments