Skip to content

Commit 391d2ea

Browse files
authored
Fp8kv support (#1220)
1 parent 8cc0095 commit 391d2ea

26 files changed

Lines changed: 2047 additions & 1867 deletions

docs/CN/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Lightllm 整合了众多的开源方案的优点,包括但不限于 FasterTran
4949
:caption: 部署教程
5050

5151
DeepSeek R1 部署 <tutorial/deepseek_deployment>
52+
FP8 KV 量化与校准 <tutorial/fp8_kv_quantization>
5253
多级缓存部署 <tutorial/multi_level_cache_deployment>
5354
多模态部署 <tutorial/multimodal>
5455
奖励模型部署 <tutorial/reward_model>

docs/CN/source/tutorial/api_server_args.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,10 @@ PD 分离模式参数
337337

338338
.. option:: --llm_kv_type
339339

340-
推理后端使用什么类型的数据存储kv cache, 可选值为 "None", "int8kv", "int4kv", "fp8kv"
340+
推理后端使用什么类型的数据存储kv cache, 可选值为 "None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"
341+
342+
- ``fp8kv_sph``: FP8 静态按 head 量化,对应 fa3 后端
343+
- ``fp8kv_spt``: FP8 静态按 tensor 量化,对应 flashinfer 后端
341344

342345
.. option:: --disable_cudagraph
343346

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
.. _tutorial/fp8_kv_quantization_cn:
2+
3+
FP8 KV 量化与校准指南
4+
======================
5+
6+
本章节介绍 LightLLM 中 FP8 KV 推理的使用方式,包括:
7+
8+
- 使用校准文件进行推理(``fp8kv_sph`` 或 ``fp8kv_spt``)
9+
- FP8 静态按 head 和按 tensor 的量化模式
10+
- 常见报错与排查建议
11+
12+
功能概览
13+
--------
14+
15+
LightLLM 的 FP8 KV 推理需要准备好的校准文件(``kv_cache_calib.json``),
16+
并通过 ``--kv_quant_calibration_config_path`` 加载。
17+
你可以直接使用 ``test/advanced_config/`` 目录下已有的校准文件,
18+
也可以使用 `LightCompress <https://github.com/ModelTC/LightCompress>`_ 工具导出,或使用自有兼容文件。
19+
20+
量化模式与后端对应
21+
------------------
22+
23+
LightLLM 支持两种 FP8 KV 量化模式:
24+
25+
- ``fp8kv_sph``: FP8 静态按 head 量化(Static Per-Head),每个 head 独立 scale,对应 ``fa3`` 后端
26+
- ``fp8kv_spt``: FP8 静态按 tensor 量化(Static Per-Tensor),K/V 各一个标量 scale,对应 ``flashinfer`` 后端
27+
28+
校准文件与量化模式强相关:
29+
30+
- ``fp8kv_sph`` 对应 ``per_head`` 校准文件
31+
- ``fp8kv_spt`` 对应 ``per_tensor`` 校准文件
32+
33+
不建议混用不同模式的校准文件。
34+
35+
使用校准文件启动 FP8 推理
36+
-------------------------
37+
38+
推理模式示例:
39+
40+
.. code-block:: console
41+
42+
$ python -m lightllm.server.api_server \
43+
--model_dir /path/to/model \
44+
--llm_kv_type fp8kv_sph \
45+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
46+
47+
.. code-block:: console
48+
49+
$ python -m lightllm.server.api_server \
50+
--model_dir /path/to/model \
51+
--llm_kv_type fp8kv_spt \
52+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
53+
54+
说明:
55+
56+
- ``fp8kv_sph`` 和 ``fp8kv_spt`` 模式必须提供 ``--kv_quant_calibration_config_path``。
57+
- attention backend 会根据量化模式自动选择,无需手动指定。
58+
59+
.. note::
60+
61+
使用 ``fp8kv_spt`` 模式(FP8 静态按 tensor 量化,使用 flashinfer 后端)时,
62+
必须安装 ``flashinfer-python==0.6.5``。默认安装的版本是 0.6.3,
63+
可能导致运行错误。请使用以下命令安装正确版本:
64+
65+
.. code-block:: console
66+
67+
$ pip install flashinfer-python==0.6.5
68+
69+
校准文件格式
70+
------------
71+
72+
``kv_cache_calib.json`` 主要字段包括:
73+
74+
- ``quant_type``: ``per_head`` 或 ``per_tensor``
75+
- ``num_layers``: 层数
76+
- ``num_head``: 总 head 数
77+
- ``scales_shape``: scale 张量形状
78+
- ``scales``: 实际 scale 数值
79+
- ``qmin`` / ``qmax``: FP8 范围参数
80+
81+
加载校准文件时,会校验模型架构、层数、head 数及量化类型是否匹配。
82+
83+
多卡说明
84+
--------
85+
86+
在多卡(TP)场景下,系统会根据当前 rank 自动切分本地需要的 head 对应 scale。
87+
你仍然只需要提供一份全量 ``kv_cache_calib.json``。
88+
89+
常见问题
90+
--------
91+
92+
1. 启动时报错需要 ``--kv_quant_calibration_config_path``
93+
94+
说明你使用了 ``--llm_kv_type fp8kv_sph`` 或 ``fp8kv_spt`` 但未传入校准文件路径。
95+
96+
2. 报错 ``quant_type not match``
97+
98+
通常是量化模式与校准文件类型不一致。例如拿 ``per_tensor`` 文件去跑 ``fp8kv_sph``。
99+
100+
3. 切换量化模式后效果异常
101+
102+
建议使用与目标量化模式匹配的校准文件,不要跨模式复用不兼容文件。

docs/EN/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Documentation List
4848
:caption: Deployment Tutorials
4949

5050
DeepSeek R1 Deployment <tutorial/deepseek_deployment>
51+
FP8 KV Quantization and Calibration <tutorial/fp8_kv_quantization>
5152
Multi-Level Cache Deployment <tutorial/multi_level_cache_deployment>
5253
Multimodal Deployment <tutorial/multimodal>
5354
Reward Model Deployment <tutorial/reward_model>

docs/EN/source/tutorial/api_server_args.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,16 @@ Performance Optimization Parameters
333333
* ``flashinfer``: Use FlashInfer backend
334334
* ``triton``: Use Triton backend
335335

336+
.. option:: --llm_kv_type
337+
338+
Set the KV cache data type for inference. Available options:
339+
340+
* ``None``: Use the dtype from model's config.json
341+
* ``int8kv``: INT8 KV quantization
342+
* ``int4kv``: INT4 KV quantization
343+
* ``fp8kv_sph``: FP8 static per-head quantization, uses fa3 backend
344+
* ``fp8kv_spt``: FP8 static per-tensor quantization, uses flashinfer backend
345+
336346
.. option:: --disable_cudagraph
337347

338348
Disable cudagraph in the decoding phase
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
.. _tutorial/fp8_kv_quantization_en:
2+
3+
FP8 KV Quantization and Calibration Guide
4+
=========================================
5+
6+
This chapter describes FP8 KV inference in LightLLM, including:
7+
8+
- Running inference with calibration data (``fp8kv_sph`` or ``fp8kv_spt``)
9+
- FP8 static per-head and per-tensor quantization modes
10+
- Common errors and troubleshooting
11+
12+
Overview
13+
--------
14+
15+
LightLLM FP8 KV inference requires a prepared calibration file (``kv_cache_calib.json``),
16+
which is loaded by ``--kv_quant_calibration_config_path``.
17+
You can use calibration files provided in ``test/advanced_config/``,
18+
export one with `LightCompress <https://github.com/ModelTC/LightCompress>`_, or use your own compatible file.
19+
20+
Quantization Modes and Backend Mapping
21+
------------------------------------------
22+
23+
LightLLM supports two FP8 KV quantization modes:
24+
25+
- ``fp8kv_sph``: FP8 Static Per-Head quantization, independent scale per head, uses ``fa3`` backend
26+
- ``fp8kv_spt``: FP8 Static Per-Tensor quantization, one scalar for K and one scalar for V, uses ``flashinfer`` backend
27+
28+
Calibration files are mode-dependent:
29+
30+
- ``fp8kv_sph`` corresponds to ``per_head`` calibration files
31+
- ``fp8kv_spt`` corresponds to ``per_tensor`` calibration files
32+
33+
Avoid mixing calibration files across different modes.
34+
35+
Start FP8 Inference with Calibration
36+
------------------------------------
37+
38+
Inference mode example:
39+
40+
.. code-block:: console
41+
42+
$ python -m lightllm.server.api_server \
43+
--model_dir /path/to/model \
44+
--llm_kv_type fp8kv_sph \
45+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
46+
47+
.. code-block:: console
48+
49+
$ python -m lightllm.server.api_server \
50+
--model_dir /path/to/model \
51+
--llm_kv_type fp8kv_spt \
52+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
53+
54+
Notes:
55+
56+
- ``fp8kv_sph`` and ``fp8kv_spt`` require ``--kv_quant_calibration_config_path``.
57+
- The attention backend will be automatically selected based on the quantization mode, no need to manually specify.
58+
59+
.. note::
60+
61+
When using ``fp8kv_spt`` mode (FP8 static per-tensor quantization with flashinfer backend),
62+
you must install ``flashinfer-python==0.6.5``. The default installed version is 0.6.3,
63+
which may cause runtime issues. Install the correct version with:
64+
65+
.. code-block:: console
66+
67+
$ pip install flashinfer-python==0.6.5
68+
69+
Calibration File Schema
70+
-----------------------
71+
72+
Key fields in ``kv_cache_calib.json``:
73+
74+
- ``quant_type``: ``per_head`` or ``per_tensor``
75+
- ``num_layers``: number of layers
76+
- ``num_head``: total number of heads
77+
- ``scales_shape``: shape of the scale tensor
78+
- ``scales``: actual scale values
79+
- ``qmin`` / ``qmax``: FP8 numeric range parameters
80+
81+
At load time, LightLLM validates architecture, layer count, head count, and quantization type.
82+
83+
Multi-GPU Note
84+
--------------
85+
86+
In multi-GPU (TP) setups, LightLLM slices the global scales to local rank heads automatically.
87+
You only need to provide one full ``kv_cache_calib.json`` file.
88+
89+
Common Issues
90+
-------------
91+
92+
1. Error says ``--kv_quant_calibration_config_path`` is required
93+
94+
You are using ``--llm_kv_type fp8kv_sph`` or ``fp8kv_spt`` without a calibration file path.
95+
96+
2. ``quant_type not match`` error
97+
98+
Usually caused by quantization mode/file mismatch (for example, using a ``per_tensor`` file with ``fp8kv_sph``).
99+
100+
3. Abnormal quality after mode switch
101+
102+
Use a calibration file that matches the target quantization mode instead of reusing an incompatible file.

lightllm/common/basemodel/attention/create_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
# "fa3": Fp8Fa3AttBackend,
3636
# "flashinfer": Fp8FlashInferAttBackend,
3737
},
38+
"fp8kv_sph": {
39+
"fa3": Fp8Fa3AttBackend,
40+
},
41+
"fp8kv_spt": {
42+
"flashinfer": Fp8FlashInferAttBackend,
43+
},
3844
}
3945

4046
mla_data_type_to_backend = {

lightllm/common/basemodel/attention/fa3/fp8.py

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,9 @@ def init_state(self):
4545
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
4646
)
4747
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
48-
self.k_descale = (
49-
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
50-
if offline_scales is not None
51-
else torch.ones(
52-
(mem_manager.layer_num, batch_size, head_num),
53-
dtype=torch.float32,
54-
device=device,
55-
)
56-
)
57-
self.v_descale = (
58-
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
59-
if offline_scales is not None
60-
else torch.ones(
61-
(mem_manager.layer_num, batch_size, head_num),
62-
dtype=torch.float32,
63-
device=device,
64-
)
65-
)
48+
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
49+
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
50+
6651

6752
def prefill_att(
6853
self,
@@ -89,19 +74,21 @@ def _fp8_prefill_att(
8974
) -> torch.Tensor:
9075
self.backend: Fp8Fa3AttBackend = self.backend # for typing
9176

77+
q_head_num = q.shape[1]
78+
q_head_dim = q.shape[2]
79+
k_head_num = k.shape[1]
9280
q, q_scale = q_per_head_fp8_quant(
93-
q,
81+
q.reshape(q.shape[0], k_head_num, -1),
9482
self.infer_state.b_seq_len,
9583
self.cu_seqlens_q,
96-
self.mid_token_batch_ids,
84+
token_batch_ids=self.mid_token_batch_ids,
9785
)
98-
k_head_num = k.shape[1]
9986
k_head_dim = k.shape[2]
10087
cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
10188
cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
10289
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)
10390
o = flash_attn_with_kvcache(
104-
q=q,
91+
q=q.reshape(-1, q_head_num, q_head_dim),
10592
k_cache=cache_k,
10693
v_cache=cache_v,
10794
page_table=self.page_table,
@@ -141,24 +128,9 @@ def init_state(self):
141128
head_num = mem_manager.head_num
142129

143130
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
144-
self.k_descale = (
145-
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
146-
if offline_scales is not None
147-
else torch.ones(
148-
(mem_manager.layer_num, batch_size, head_num),
149-
dtype=torch.float32,
150-
device=device,
151-
)
152-
)
153-
self.v_descale = (
154-
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
155-
if offline_scales is not None
156-
else torch.ones(
157-
(mem_manager.layer_num, batch_size, head_num),
158-
dtype=torch.float32,
159-
device=device,
160-
)
161-
)
131+
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
132+
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
133+
162134
return
163135

164136
def copy_for_decode_cuda_graph(self, new_state: "Fp8Fa3DecodeAttState"):
@@ -200,9 +172,11 @@ def _fp8_decode_att(
200172
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)
201173

202174
q_head_num = q.shape[1]
203-
q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
175+
if scaled_fp8_quant is None:
176+
raise ImportError("scaled_fp8_quant is unavailable. Please install vllm to enable FP8 decode attention.")
177+
q, q_scale = scaled_fp8_quant(q.reshape(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
204178
o = flash_attn_with_kvcache(
205-
q=q.view(-1, q_head_num, k_head_dim),
179+
q=q.reshape(-1, q_head_num, k_head_dim),
206180
k_cache=cache_k,
207181
v_cache=cache_v,
208182
page_table=self.page_table,

0 commit comments

Comments
 (0)