|
| 1 | +import dataclasses |
| 2 | +import torch |
| 3 | +from typing import TYPE_CHECKING, Tuple |
| 4 | + |
| 5 | +from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState |
| 6 | +from lightllm.utils.dist_utils import get_current_device_id |
| 7 | + |
| 8 | +if TYPE_CHECKING: |
| 9 | + from lightllm.common.basemodel.infer_struct import InferStateInfo |
| 10 | + |
| 11 | + |
| 12 | +class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): |
| 13 | + def __init__(self, model): |
| 14 | + super().__init__(model=model) |
| 15 | + device = get_current_device_id() |
| 16 | + self.ragged_mem_buffers = [ |
| 17 | + torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) |
| 18 | + for _ in range(2) |
| 19 | + ] |
| 20 | + |
| 21 | + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": |
| 22 | + return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) |
| 23 | + |
| 24 | + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": |
| 25 | + return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) |
| 26 | + |
| 27 | + |
| 28 | +@dataclasses.dataclass |
| 29 | +class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): |
| 30 | + ks: torch.Tensor = None |
| 31 | + ke: torch.Tensor = None |
| 32 | + lengths: torch.Tensor = None |
| 33 | + ragged_mem_index: torch.Tensor = None |
| 34 | + |
| 35 | + def init_state(self): |
| 36 | + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend |
| 37 | + self.ragged_mem_index = torch.empty( |
| 38 | + self.infer_state.total_token_num, |
| 39 | + dtype=torch.int32, |
| 40 | + device=get_current_device_id(), |
| 41 | + ) |
| 42 | + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke |
| 43 | + |
| 44 | + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( |
| 45 | + b_seq_len=self.infer_state.b_seq_len, |
| 46 | + b_q_seq_len=self.infer_state.b_q_seq_len, |
| 47 | + b_req_idx=self.infer_state.b_req_idx, |
| 48 | + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, |
| 49 | + q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, |
| 50 | + ragged_mem_index=self.ragged_mem_index, |
| 51 | + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, |
| 52 | + ) |
| 53 | + return |
| 54 | + |
| 55 | + def prefill_att( |
| 56 | + self, |
| 57 | + q: torch.Tensor, |
| 58 | + k: torch.Tensor, |
| 59 | + v: torch.Tensor, |
| 60 | + att_control: AttControl = AttControl(), |
| 61 | + alloc_func=torch.empty, |
| 62 | + ) -> torch.Tensor: |
| 63 | + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" |
| 64 | + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" |
| 65 | + return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control) |
| 66 | + |
| 67 | + def _nsa_prefill_att( |
| 68 | + self, |
| 69 | + q: torch.Tensor, |
| 70 | + packed_kv: torch.Tensor, |
| 71 | + att_control: AttControl, |
| 72 | + ) -> torch.Tensor: |
| 73 | + import flash_mla |
| 74 | + |
| 75 | + nsa_dict = att_control.nsa_prefill_dict |
| 76 | + topk_indices = nsa_dict["topk_indices"] |
| 77 | + softmax_scale = nsa_dict["softmax_scale"] |
| 78 | + kv_lora_rank = nsa_dict["kv_lora_rank"] |
| 79 | + topk_mem_indices = nsa_dict["topk_mem_indices"] |
| 80 | + prefill_cache_kv = nsa_dict["prefill_cache_kv"] |
| 81 | + |
| 82 | + if self.infer_state.prefix_total_token_num > 0: |
| 83 | + kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices( |
| 84 | + packed_kv=packed_kv, |
| 85 | + topk_indices=topk_mem_indices, |
| 86 | + prefill_mem_index=self.infer_state.mem_index, |
| 87 | + prefill_cache_kv=prefill_cache_kv, |
| 88 | + ) |
| 89 | + else: |
| 90 | + kv = prefill_cache_kv |
| 91 | + |
| 92 | + if topk_indices.ndim == 2: |
| 93 | + topk_indices = topk_indices.unsqueeze(1) |
| 94 | + |
| 95 | + mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( |
| 96 | + q=q, |
| 97 | + kv=kv, |
| 98 | + indices=topk_indices, |
| 99 | + sm_scale=softmax_scale, |
| 100 | + d_v=kv_lora_rank, |
| 101 | + ) |
| 102 | + return mla_out |
| 103 | + |
| 104 | + |
| 105 | +@dataclasses.dataclass |
| 106 | +class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): |
| 107 | + ks: torch.Tensor = None |
| 108 | + ke: torch.Tensor = None |
| 109 | + lengths: torch.Tensor = None |
| 110 | + ragged_mem_index: torch.Tensor = None |
| 111 | + flashmla_sched_meta: object = None |
| 112 | + |
| 113 | + def init_state(self): |
| 114 | + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend |
| 115 | + model = self.backend.model |
| 116 | + use_cuda_graph = ( |
| 117 | + self.infer_state.batch_size <= model.graph_max_batch_size |
| 118 | + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch |
| 119 | + ) |
| 120 | + |
| 121 | + if use_cuda_graph: |
| 122 | + self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] |
| 123 | + else: |
| 124 | + self.ragged_mem_index = torch.empty( |
| 125 | + self.infer_state.total_token_num, |
| 126 | + dtype=torch.int32, |
| 127 | + device=get_current_device_id(), |
| 128 | + ) |
| 129 | + |
| 130 | + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke |
| 131 | + |
| 132 | + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( |
| 133 | + b_seq_len=self.infer_state.b_seq_len, |
| 134 | + b_q_seq_len=self.infer_state.b_q_seq_len, |
| 135 | + b_req_idx=self.infer_state.b_req_idx, |
| 136 | + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, |
| 137 | + q_token_num=self.infer_state.b_seq_len.shape[0], |
| 138 | + ragged_mem_index=self.ragged_mem_index, |
| 139 | + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, |
| 140 | + ) |
| 141 | + import flash_mla |
| 142 | + |
| 143 | + self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() |
| 144 | + return |
| 145 | + |
| 146 | + def decode_att( |
| 147 | + self, |
| 148 | + q: Tuple[torch.Tensor, torch.Tensor], |
| 149 | + k: torch.Tensor, |
| 150 | + v: torch.Tensor, |
| 151 | + att_control: AttControl = AttControl(), |
| 152 | + alloc_func=torch.empty, |
| 153 | + ) -> torch.Tensor: |
| 154 | + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" |
| 155 | + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" |
| 156 | + return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control) |
| 157 | + |
| 158 | + def _nsa_decode_att( |
| 159 | + self, |
| 160 | + q: Tuple[torch.Tensor, torch.Tensor], |
| 161 | + packed_kv: torch.Tensor, |
| 162 | + att_control: AttControl, |
| 163 | + ) -> torch.Tensor: |
| 164 | + import flash_mla |
| 165 | + |
| 166 | + nsa_dict = att_control.nsa_decode_dict |
| 167 | + topk_mem_indices = nsa_dict["topk_mem_indices"] |
| 168 | + softmax_scale = nsa_dict["softmax_scale"] |
| 169 | + kv_lora_rank = nsa_dict["kv_lora_rank"] |
| 170 | + |
| 171 | + if topk_mem_indices.ndim == 2: |
| 172 | + topk_mem_indices = topk_mem_indices.unsqueeze(1) |
| 173 | + assert topk_mem_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" |
| 174 | + |
| 175 | + q_nope, q_rope = q |
| 176 | + q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() |
| 177 | + kv = torch.as_strided( |
| 178 | + packed_kv, |
| 179 | + size=(packed_kv.shape[0], 1, 1, packed_kv.shape[-1]), |
| 180 | + stride=(packed_kv.stride(0), packed_kv.shape[-1], packed_kv.shape[-1], packed_kv.stride(-1)), |
| 181 | + ) |
| 182 | + |
| 183 | + o_tensor, _ = flash_mla.flash_mla_with_kvcache( |
| 184 | + q=q_all, |
| 185 | + k_cache=kv, |
| 186 | + block_table=None, |
| 187 | + cache_seqlens=None, |
| 188 | + head_dim_v=kv_lora_rank, |
| 189 | + tile_scheduler_metadata=self.flashmla_sched_meta, |
| 190 | + num_splits=None, |
| 191 | + softmax_scale=softmax_scale, |
| 192 | + causal=False, |
| 193 | + is_fp8_kvcache=True, |
| 194 | + indices=topk_mem_indices, |
| 195 | + ) |
| 196 | + return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d] |
0 commit comments