Skip to content

Commit 972e569

Browse files
author
钮圣虓
committed
feat: fp8 dsa support
1 parent 4fc0835 commit 972e569

15 files changed

Lines changed: 654 additions & 20 deletions

docker/Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
44
ARG PYTHON_VERSION=3.10
55
ARG MAMBA_VERSION=24.7.1-0
66
ARG VLLM_VERSION=0.16.0
7+
ARG FLASH_MLA_REF=47c35a7
78
ARG TARGETPLATFORM
89
ARG ENABLE_DEEPEP=1
910
ARG ENABLE_NIXL=1
@@ -45,6 +46,11 @@ COPY ./requirements.txt /lightllm/requirements.txt
4546
RUN pip install -U pip
4647
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
4748
RUN pip install --no-cache-dir vllm==${VLLM_VERSION}
49+
RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
50+
cd /root/FlashMLA && \
51+
git checkout ${FLASH_MLA_REF} && \
52+
git submodule update --init --recursive && \
53+
FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir .
4854

4955
RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*
5056

lightllm/common/basemodel/attention/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
# NSA backend
1414
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
15+
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend
1516

1617
from .create_utils import (
1718
get_prefill_att_backend_class,

lightllm/common/basemodel/attention/create_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .flashinfer.fp import FlashInferAttBackend
1616
from .flashinfer.mla import MlaFlashInferAttBackend
1717
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
18+
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend
1819

1920
logger = init_logger(__name__)
2021

@@ -56,6 +57,9 @@
5657
"flashmla_sparse": NsaFlashMlaSparseAttBackend,
5758
# Future backends: "fa3", "tilelang", "aiter"
5859
},
60+
"fp8kv_dsa": {
61+
"flashmla_sparse": NsaFlashMlaFp8SparseAttBackend,
62+
},
5963
}
6064

6165

lightllm/common/basemodel/attention/nsa/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,17 @@
55
NsaFlashMlaSparsePrefillAttState,
66
NsaFlashMlaSparseDecodeAttState,
77
)
8+
from .fp8_flashmla_sparse import (
9+
NsaFlashMlaFp8SparseAttBackend,
10+
NsaFlashMlaFp8SparsePrefillAttState,
11+
NsaFlashMlaFp8SparseDecodeAttState,
12+
)
813

914
__all__ = [
1015
"NsaFlashMlaSparseAttBackend",
1116
"NsaFlashMlaSparsePrefillAttState",
1217
"NsaFlashMlaSparseDecodeAttState",
18+
"NsaFlashMlaFp8SparseAttBackend",
19+
"NsaFlashMlaFp8SparsePrefillAttState",
20+
"NsaFlashMlaFp8SparseDecodeAttState",
1321
]

lightllm/common/basemodel/attention/nsa/flashmla_sparse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _nsa_decode_att(
165165
from sgl_kernel.flash_attn import flash_attn_with_kvcache
166166

167167
nsa_dict = att_control.nsa_decode_dict
168-
topk_indices = nsa_dict["topk_indices"]
168+
topk_mem_indices = nsa_dict["topk_mem_indices"]
169169
softmax_scale = nsa_dict["softmax_scale"]
170170
kv_lora_rank = nsa_dict["kv_lora_rank"]
171171
qk_rope_head_dim = nsa_dict["qk_rope_head_dim"]
@@ -181,7 +181,7 @@ def _nsa_decode_att(
181181
k_cache=k_rope,
182182
v_cache=kv_nope,
183183
qv=q_nope,
184-
page_table=topk_indices,
184+
page_table=topk_mem_indices,
185185
cache_seqlens=self.nsa_cache_seqlens,
186186
cu_seqlens_q=self.infer_state.b1_cu_q_seq_len,
187187
cu_seqlens_k_new=self.nsa_cu_seqlens_k_new,
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import dataclasses
2+
import torch
3+
from typing import TYPE_CHECKING, Tuple
4+
5+
from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState
6+
from lightllm.utils.dist_utils import get_current_device_id
7+
8+
if TYPE_CHECKING:
9+
from lightllm.common.basemodel.infer_struct import InferStateInfo
10+
11+
12+
class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend):
13+
def __init__(self, model):
14+
super().__init__(model=model)
15+
device = get_current_device_id()
16+
self.ragged_mem_buffers = [
17+
torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device)
18+
for _ in range(2)
19+
]
20+
21+
def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState":
22+
return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state)
23+
24+
def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState":
25+
return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state)
26+
27+
28+
@dataclasses.dataclass
29+
class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState):
30+
ks: torch.Tensor = None
31+
ke: torch.Tensor = None
32+
lengths: torch.Tensor = None
33+
ragged_mem_index: torch.Tensor = None
34+
35+
def init_state(self):
36+
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
37+
self.ragged_mem_index = torch.empty(
38+
self.infer_state.total_token_num,
39+
dtype=torch.int32,
40+
device=get_current_device_id(),
41+
)
42+
from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke
43+
44+
self.ks, self.ke, self.lengths = gen_nsa_ks_ke(
45+
b_seq_len=self.infer_state.b_seq_len,
46+
b_q_seq_len=self.infer_state.b_q_seq_len,
47+
b_req_idx=self.infer_state.b_req_idx,
48+
req_to_token_index=self.infer_state.req_manager.req_to_token_indexs,
49+
q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num,
50+
ragged_mem_index=self.ragged_mem_index,
51+
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
52+
)
53+
return
54+
55+
def prefill_att(
56+
self,
57+
q: torch.Tensor,
58+
k: torch.Tensor,
59+
v: torch.Tensor,
60+
att_control: AttControl = AttControl(),
61+
alloc_func=torch.empty,
62+
) -> torch.Tensor:
63+
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
64+
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"
65+
return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control)
66+
67+
def _nsa_prefill_att(
68+
self,
69+
q: torch.Tensor,
70+
packed_kv: torch.Tensor,
71+
att_control: AttControl,
72+
) -> torch.Tensor:
73+
import flash_mla
74+
75+
nsa_dict = att_control.nsa_prefill_dict
76+
topk_indices = nsa_dict["topk_indices"]
77+
softmax_scale = nsa_dict["softmax_scale"]
78+
kv_lora_rank = nsa_dict["kv_lora_rank"]
79+
topk_mem_indices = nsa_dict["topk_mem_indices"]
80+
prefill_cache_kv = nsa_dict["prefill_cache_kv"]
81+
82+
if self.infer_state.prefix_total_token_num > 0:
83+
kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices(
84+
packed_kv=packed_kv,
85+
topk_indices=topk_mem_indices,
86+
prefill_mem_index=self.infer_state.mem_index,
87+
prefill_cache_kv=prefill_cache_kv,
88+
)
89+
else:
90+
kv = prefill_cache_kv
91+
92+
if topk_indices.ndim == 2:
93+
topk_indices = topk_indices.unsqueeze(1)
94+
95+
mla_out, _, _ = flash_mla.flash_mla_sparse_fwd(
96+
q=q,
97+
kv=kv,
98+
indices=topk_indices,
99+
sm_scale=softmax_scale,
100+
d_v=kv_lora_rank,
101+
)
102+
return mla_out
103+
104+
105+
@dataclasses.dataclass
106+
class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState):
107+
ks: torch.Tensor = None
108+
ke: torch.Tensor = None
109+
lengths: torch.Tensor = None
110+
ragged_mem_index: torch.Tensor = None
111+
flashmla_sched_meta: object = None
112+
113+
def init_state(self):
114+
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
115+
model = self.backend.model
116+
use_cuda_graph = (
117+
self.infer_state.batch_size <= model.graph_max_batch_size
118+
and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch
119+
)
120+
121+
if use_cuda_graph:
122+
self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index]
123+
else:
124+
self.ragged_mem_index = torch.empty(
125+
self.infer_state.total_token_num,
126+
dtype=torch.int32,
127+
device=get_current_device_id(),
128+
)
129+
130+
from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke
131+
132+
self.ks, self.ke, self.lengths = gen_nsa_ks_ke(
133+
b_seq_len=self.infer_state.b_seq_len,
134+
b_q_seq_len=self.infer_state.b_q_seq_len,
135+
b_req_idx=self.infer_state.b_req_idx,
136+
req_to_token_index=self.infer_state.req_manager.req_to_token_indexs,
137+
q_token_num=self.infer_state.b_seq_len.shape[0],
138+
ragged_mem_index=self.ragged_mem_index,
139+
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
140+
)
141+
import flash_mla
142+
143+
self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata()
144+
return
145+
146+
def decode_att(
147+
self,
148+
q: Tuple[torch.Tensor, torch.Tensor],
149+
k: torch.Tensor,
150+
v: torch.Tensor,
151+
att_control: AttControl = AttControl(),
152+
alloc_func=torch.empty,
153+
) -> torch.Tensor:
154+
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
155+
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"
156+
return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control)
157+
158+
def _nsa_decode_att(
159+
self,
160+
q: Tuple[torch.Tensor, torch.Tensor],
161+
packed_kv: torch.Tensor,
162+
att_control: AttControl,
163+
) -> torch.Tensor:
164+
import flash_mla
165+
166+
nsa_dict = att_control.nsa_decode_dict
167+
topk_mem_indices = nsa_dict["topk_mem_indices"]
168+
softmax_scale = nsa_dict["softmax_scale"]
169+
kv_lora_rank = nsa_dict["kv_lora_rank"]
170+
171+
if topk_mem_indices.ndim == 2:
172+
topk_mem_indices = topk_mem_indices.unsqueeze(1)
173+
assert topk_mem_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1"
174+
175+
q_nope, q_rope = q
176+
q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous()
177+
kv = torch.as_strided(
178+
packed_kv,
179+
size=(packed_kv.shape[0], 1, 1, packed_kv.shape[-1]),
180+
stride=(packed_kv.stride(0), packed_kv.shape[-1], packed_kv.shape[-1], packed_kv.stride(-1)),
181+
)
182+
183+
o_tensor, _ = flash_mla.flash_mla_with_kvcache(
184+
q=q_all,
185+
k_cache=kv,
186+
block_table=None,
187+
cache_seqlens=None,
188+
head_dim_v=kv_lora_rank,
189+
tile_scheduler_metadata=self.flashmla_sched_meta,
190+
num_splits=None,
191+
softmax_scale=softmax_scale,
192+
causal=False,
193+
is_fp8_kvcache=True,
194+
indices=topk_mem_indices,
195+
)
196+
return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d]

lightllm/common/kv_cache_mem_manager/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
44
from .deepseek2_mem_manager import Deepseek2MemoryManager
55
from .deepseek3_2mem_manager import Deepseek3_2MemoryManager
6+
from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager
67
from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager
78
from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager
89

@@ -13,6 +14,7 @@
1314
"PPLINT8KVMemoryManager",
1415
"Deepseek2MemoryManager",
1516
"Deepseek3_2MemoryManager",
17+
"FP8PerTokenGroupQuantDeepseek3_2MemoryManager",
1618
"FP8StaticPerHeadQuantMemManager",
1719
"FP8StaticPerTensorQuantMemManager",
1820
]

lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv:
3434
def get_att_input_params(self, layer_index: int) -> Any:
3535
kv = self.kv_buffer[layer_index][:, :, : (self.head_dim - (144 // 2))]
3636
return kv
37+
38+
def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor:
39+
return self.kv_buffer[layer_index].view(dtype=torch.uint8)[:, :, -132:]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
from typing import Any
3+
4+
from .deepseek2_mem_manager import Deepseek2MemoryManager
5+
6+
7+
class FP8PerTokenGroupQuantDeepseek3_2MemoryManager(Deepseek2MemoryManager):
8+
kv_nope_dim = 512
9+
kv_rope_dim = 64
10+
# 576 = 512 + 64
11+
kv_head_dim = kv_nope_dim + kv_rope_dim
12+
13+
quant_group_size = 128
14+
# 4 = 512 / 128
15+
quant_group_num = kv_nope_dim // quant_group_size
16+
# 4 * 4 = quant_group_num * fp32
17+
# 64 * 2 = kv_rope_dim * bfloat16
18+
# 656 bytes = 512 + (4 * 4) + (64 * 2)
19+
flashmla_bytes_per_token = kv_nope_dim + quant_group_num * 4 + kv_rope_dim * 2
20+
21+
indexer_head_dim = 128
22+
# 128 + 4 = indexer_head_dim + fp32
23+
# 132 bytes = 128 + 4
24+
indexer_bytes_per_token = indexer_head_dim + 4
25+
26+
# 16-byte 对齐,满足FlashMLA的对齐要求
27+
alignment = 16
28+
total_bytes_per_token = (
29+
(flashmla_bytes_per_token + indexer_bytes_per_token + alignment - 1) // alignment * alignment
30+
)
31+
32+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
33+
assert head_num == 1, "DeepSeek-V3.2 DSA FP8 path expects MQA-style head_num == 1"
34+
self.prefill_dtype = dtype
35+
super().__init__(size, torch.uint8, head_num, self.total_bytes_per_token, layer_num, always_copy, mem_fraction)
36+
37+
def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
38+
from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import (
39+
destindex_copy_kv_flashmla_fp8,
40+
)
41+
42+
rope_dim = 64
43+
kv_lora_rank = kv.shape[2] - rope_dim
44+
assert kv_lora_rank == 512, f"Expected kv_lora_rank=512, got {kv_lora_rank}"
45+
46+
o_nope = self.kv_buffer[layer_index][:, :, :512].view(torch.float8_e4m3fn)
47+
o_scale = self.kv_buffer[layer_index][:, :, 512:528].view(torch.float32)
48+
o_rope = self.kv_buffer[layer_index][:, :, 528 : self.flashmla_bytes_per_token].view(torch.bfloat16)
49+
destindex_copy_kv_flashmla_fp8(
50+
kv[:, :, :kv_lora_rank],
51+
kv[:, :, kv_lora_rank:],
52+
mem_index,
53+
o_nope,
54+
o_scale,
55+
o_rope,
56+
)
57+
58+
def get_att_input_params(self, layer_index: int) -> Any:
59+
return self.kv_buffer[layer_index][:, :, : self.flashmla_bytes_per_token]
60+
61+
def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor:
62+
begin = self.flashmla_bytes_per_token
63+
end = begin + self.indexer_bytes_per_token
64+
return self.kv_buffer[layer_index][:, :, begin:end]
65+
66+
def get_prefill_kv_cache_and_remap_indices(
67+
self,
68+
packed_kv: torch.Tensor,
69+
topk_indices: torch.Tensor,
70+
prefill_mem_index: torch.Tensor,
71+
prefill_cache_kv: torch.Tensor,
72+
):
73+
from lightllm.models.deepseek3_2.triton_kernel.prefill_compact_kv_flashmla_fp8 import (
74+
get_prefill_kv_cache_and_remap_indices_triton,
75+
)
76+
77+
return get_prefill_kv_cache_and_remap_indices_triton(
78+
packed_kv=packed_kv,
79+
topk_mem_indices=topk_indices,
80+
prefill_mem_index=prefill_mem_index,
81+
prefill_cache_kv=prefill_cache_kv,
82+
prefill_dtype=self.prefill_dtype,
83+
)

0 commit comments

Comments
 (0)