Skip to content

Commit 8157948

Browse files
author
钮圣虓
committed
refine4
1 parent d233475 commit 8157948

3 files changed

Lines changed: 595 additions & 50 deletions

File tree

lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -91,58 +91,17 @@ def get_prefill_kv_cache_and_remap_indices(
9191
prefill_mem_index: torch.Tensor,
9292
prefill_cache_kv: torch.Tensor,
9393
):
94-
squeeze_h_kv = topk_indices.ndim == 2
95-
if squeeze_h_kv:
96-
topk_indices = topk_indices.unsqueeze(1)
97-
98-
valid_mask = topk_indices != -1
99-
valid_indices = topk_indices[valid_mask]
100-
101-
if valid_indices.numel() == 0:
102-
empty_kv = torch.empty(
103-
(0, 1, self.kv_head_dim),
104-
dtype=self.prefill_dtype,
105-
device=topk_indices.device,
106-
)
107-
remapped = topk_indices.clone()
108-
if squeeze_h_kv:
109-
remapped = remapped.squeeze(1)
110-
return empty_kv, remapped
111-
112-
unique_mem_index, inverse = torch.unique(valid_indices, sorted=False, return_inverse=True)
113-
unique_mem_index_i64 = unique_mem_index.to(torch.int64)
114-
prefill_mem_index_i64 = prefill_mem_index.to(torch.int64)
115-
116-
sorted_prefill_mem_index, sorted_prefill_pos = torch.sort(prefill_mem_index_i64)
117-
prefill_insert_pos = torch.searchsorted(sorted_prefill_mem_index, unique_mem_index_i64)
118-
prefill_hit_mask = prefill_insert_pos < sorted_prefill_mem_index.numel()
119-
if torch.any(prefill_hit_mask):
120-
hit_insert_pos = prefill_insert_pos[prefill_hit_mask]
121-
prefill_hit_mask[prefill_hit_mask] = (
122-
sorted_prefill_mem_index.index_select(0, hit_insert_pos) == unique_mem_index_i64[prefill_hit_mask]
123-
)
124-
125-
compact_kv = torch.empty(
126-
(unique_mem_index.shape[0], prefill_cache_kv.shape[1], self.kv_head_dim),
127-
dtype=self.prefill_dtype,
128-
device=prefill_cache_kv.device,
94+
from lightllm.models.deepseek3_2.triton_kernel.prefill_compact_kv_flashmla_fp8 import (
95+
get_prefill_kv_cache_and_remap_indices_triton,
12996
)
13097

131-
if torch.any(prefill_hit_mask):
132-
prefill_rows = sorted_prefill_pos[prefill_insert_pos[prefill_hit_mask]]
133-
compact_kv[prefill_hit_mask] = prefill_cache_kv.index_select(0, prefill_rows)
134-
135-
if torch.any(~prefill_hit_mask):
136-
prefix_mem_index = unique_mem_index_i64[~prefill_hit_mask]
137-
packed_kv = self.kv_buffer[layer_index].index_select(0, prefix_mem_index)
138-
compact_kv[~prefill_hit_mask] = self._dequantize_packed_kv(packed_kv)
139-
140-
remapped = torch.full_like(topk_indices, -1)
141-
remapped[valid_mask] = inverse.to(remapped.dtype)
142-
143-
if squeeze_h_kv:
144-
remapped = remapped.squeeze(1)
145-
return compact_kv, remapped
98+
return get_prefill_kv_cache_and_remap_indices_triton(
99+
packed_kv=self.kv_buffer[layer_index],
100+
topk_mem_indices=topk_indices,
101+
prefill_mem_index=prefill_mem_index,
102+
prefill_cache_kv=prefill_cache_kv,
103+
prefill_dtype=self.prefill_dtype,
104+
)
146105

147106
def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor:
148107
return self.indexer_k_buffer[layer_index]
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def _build_prefill_row_table_kernel(
8+
prefill_mem_index_ptr,
9+
row_table_ptr,
10+
prefill_token_num,
11+
):
12+
pid = tl.program_id(0)
13+
if pid < prefill_token_num:
14+
mem_index = tl.load(prefill_mem_index_ptr + pid)
15+
tl.store(row_table_ptr + mem_index, pid)
16+
17+
18+
@triton.jit
19+
def _fill_compact_kv_kernel(
20+
packed_nope_ptr,
21+
packed_scale_ptr,
22+
packed_rope_ptr,
23+
unique_mem_index_ptr,
24+
prefill_row_table_ptr,
25+
prefill_kv_ptr,
26+
compact_kv_ptr,
27+
packed_nope_stride_s,
28+
packed_nope_stride_d,
29+
packed_scale_stride_s,
30+
packed_scale_stride_d,
31+
packed_rope_stride_s,
32+
packed_rope_stride_d,
33+
prefill_kv_stride_s,
34+
prefill_kv_stride_d,
35+
compact_kv_stride_s,
36+
compact_kv_stride_d,
37+
unique_num,
38+
KV_NOPE_DIM: tl.constexpr,
39+
KV_ROPE_DIM: tl.constexpr,
40+
GROUP_SIZE: tl.constexpr,
41+
BLOCK_D: tl.constexpr,
42+
):
43+
pid_s = tl.program_id(0)
44+
pid_block = tl.program_id(1)
45+
46+
if pid_s >= unique_num:
47+
return
48+
49+
mem_index = tl.load(unique_mem_index_ptr + pid_s)
50+
prefill_row = tl.load(prefill_row_table_ptr + mem_index)
51+
offs_d = tl.arange(0, BLOCK_D)
52+
53+
if prefill_row != -1:
54+
if pid_block < (KV_NOPE_DIM // GROUP_SIZE):
55+
mask = offs_d < GROUP_SIZE
56+
value = tl.load(
57+
prefill_kv_ptr
58+
+ prefill_row * prefill_kv_stride_s
59+
+ (pid_block * GROUP_SIZE + offs_d) * prefill_kv_stride_d,
60+
mask=mask,
61+
).to(tl.float32)
62+
tl.store(
63+
compact_kv_ptr + pid_s * compact_kv_stride_s + (pid_block * GROUP_SIZE + offs_d) * compact_kv_stride_d,
64+
value,
65+
mask=mask,
66+
)
67+
else:
68+
mask = offs_d < KV_ROPE_DIM
69+
value = tl.load(
70+
prefill_kv_ptr + prefill_row * prefill_kv_stride_s + (KV_NOPE_DIM + offs_d) * prefill_kv_stride_d,
71+
mask=mask,
72+
).to(tl.float32)
73+
tl.store(
74+
compact_kv_ptr + pid_s * compact_kv_stride_s + (KV_NOPE_DIM + offs_d) * compact_kv_stride_d,
75+
value,
76+
mask=mask,
77+
)
78+
else:
79+
if pid_block < (KV_NOPE_DIM // GROUP_SIZE):
80+
mask = offs_d < GROUP_SIZE
81+
src_fp8 = tl.load(
82+
packed_nope_ptr
83+
+ mem_index * packed_nope_stride_s
84+
+ (pid_block * GROUP_SIZE + offs_d) * packed_nope_stride_d,
85+
mask=mask,
86+
)
87+
scale = tl.load(packed_scale_ptr + mem_index * packed_scale_stride_s + pid_block * packed_scale_stride_d)
88+
value = src_fp8.to(tl.float32) * scale
89+
tl.store(
90+
compact_kv_ptr + pid_s * compact_kv_stride_s + (pid_block * GROUP_SIZE + offs_d) * compact_kv_stride_d,
91+
value,
92+
mask=mask,
93+
)
94+
else:
95+
mask = offs_d < KV_ROPE_DIM
96+
value = tl.load(
97+
packed_rope_ptr + mem_index * packed_rope_stride_s + offs_d * packed_rope_stride_d,
98+
mask=mask,
99+
).to(tl.float32)
100+
tl.store(
101+
compact_kv_ptr + pid_s * compact_kv_stride_s + (KV_NOPE_DIM + offs_d) * compact_kv_stride_d,
102+
value,
103+
mask=mask,
104+
)
105+
106+
107+
@torch.no_grad()
108+
def get_prefill_kv_cache_and_remap_indices_triton(
109+
packed_kv: torch.Tensor,
110+
topk_mem_indices: torch.Tensor,
111+
prefill_mem_index: torch.Tensor,
112+
prefill_cache_kv: torch.Tensor,
113+
prefill_dtype: torch.dtype,
114+
):
115+
squeeze_h_kv = topk_mem_indices.ndim == 2
116+
if squeeze_h_kv:
117+
topk_mem_indices = topk_mem_indices.unsqueeze(1)
118+
119+
original_shape = topk_mem_indices.shape
120+
flat_topk = topk_mem_indices.reshape(-1).contiguous().to(torch.int32)
121+
122+
if flat_topk.numel() == 0:
123+
empty_kv = torch.empty((0, 1, 576), dtype=prefill_dtype, device=packed_kv.device)
124+
remapped = topk_mem_indices.clone()
125+
if squeeze_h_kv:
126+
remapped = remapped.squeeze(1)
127+
return empty_kv, remapped
128+
129+
valid_mask = flat_topk != -1
130+
valid_topk = flat_topk[valid_mask]
131+
if valid_topk.numel() == 0:
132+
empty_kv = torch.empty((0, 1, 576), dtype=prefill_dtype, device=packed_kv.device)
133+
remapped = torch.full(original_shape, -1, dtype=torch.int32, device=packed_kv.device)
134+
if squeeze_h_kv:
135+
remapped = remapped.squeeze(1)
136+
return empty_kv, remapped
137+
138+
table_size = packed_kv.shape[0]
139+
140+
prefill_row_table = torch.full((table_size,), -1, dtype=torch.int32, device=packed_kv.device)
141+
_build_prefill_row_table_kernel[(prefill_mem_index.numel(),)](
142+
prefill_mem_index_ptr=prefill_mem_index.to(torch.int32).contiguous(),
143+
row_table_ptr=prefill_row_table,
144+
prefill_token_num=prefill_mem_index.numel(),
145+
num_warps=4,
146+
)
147+
148+
unique_mem_index, inverse = torch.unique(valid_topk, sorted=False, return_inverse=True)
149+
unique_mem_index = unique_mem_index.to(torch.int32)
150+
unique_count = unique_mem_index.numel()
151+
remapped_flat = torch.full_like(flat_topk, -1)
152+
remapped_flat[valid_mask] = inverse.to(torch.int32)
153+
154+
compact_kv = torch.empty((unique_count, 1, 576), dtype=prefill_dtype, device=packed_kv.device)
155+
packed_nope = packed_kv[:, :, :512].view(torch.float8_e4m3fn).view(-1, 512)
156+
packed_scale = packed_kv[:, :, 512:528].view(torch.float32).view(-1, 4)
157+
packed_rope = packed_kv[:, :, 528:].view(torch.bfloat16).view(-1, 64)
158+
prefill_kv_2d = prefill_cache_kv.view(-1, 576)
159+
compact_kv_2d = compact_kv.view(-1, 576)
160+
161+
_fill_compact_kv_kernel[(unique_count, 5)](
162+
packed_nope_ptr=packed_nope,
163+
packed_scale_ptr=packed_scale,
164+
packed_rope_ptr=packed_rope,
165+
unique_mem_index_ptr=unique_mem_index,
166+
prefill_row_table_ptr=prefill_row_table,
167+
prefill_kv_ptr=prefill_kv_2d,
168+
compact_kv_ptr=compact_kv_2d,
169+
packed_nope_stride_s=packed_nope.stride(0),
170+
packed_nope_stride_d=packed_nope.stride(1),
171+
packed_scale_stride_s=packed_scale.stride(0),
172+
packed_scale_stride_d=packed_scale.stride(1),
173+
packed_rope_stride_s=packed_rope.stride(0),
174+
packed_rope_stride_d=packed_rope.stride(1),
175+
prefill_kv_stride_s=prefill_kv_2d.stride(0),
176+
prefill_kv_stride_d=prefill_kv_2d.stride(1),
177+
compact_kv_stride_s=compact_kv_2d.stride(0),
178+
compact_kv_stride_d=compact_kv_2d.stride(1),
179+
unique_num=unique_count,
180+
KV_NOPE_DIM=512,
181+
KV_ROPE_DIM=64,
182+
GROUP_SIZE=128,
183+
BLOCK_D=128,
184+
num_warps=4,
185+
)
186+
187+
remapped = remapped_flat.view(original_shape)
188+
if squeeze_h_kv:
189+
remapped = remapped.squeeze(1)
190+
return compact_kv, remapped

0 commit comments

Comments
 (0)