@@ -45,24 +45,9 @@ def init_state(self):
4545 torch .arange (batch_size , device = device ), self .infer_state .b_q_seq_len
4646 )
4747 # 为了减少推理计算量,在推理外部初始化k_descale和v_descale
48- self .k_descale = (
49- offline_scales [:, :head_num ].view (- 1 , 1 , head_num ).expand (offline_scales .shape [0 ], batch_size , head_num )
50- if offline_scales is not None
51- else torch .ones (
52- (mem_manager .layer_num , batch_size , head_num ),
53- dtype = torch .float32 ,
54- device = device ,
55- )
56- )
57- self .v_descale = (
58- offline_scales [:, head_num :].view (- 1 , 1 , head_num ).expand (offline_scales .shape [0 ], batch_size , head_num )
59- if offline_scales is not None
60- else torch .ones (
61- (mem_manager .layer_num , batch_size , head_num ),
62- dtype = torch .float32 ,
63- device = device ,
64- )
65- )
48+ self .k_descale = offline_scales [:, :head_num ].view (- 1 , 1 , head_num ).expand (offline_scales .shape [0 ], batch_size , head_num )
49+ self .v_descale = offline_scales [:, head_num :].view (- 1 , 1 , head_num ).expand (offline_scales .shape [0 ], batch_size , head_num )
50+
6651
6752 def prefill_att (
6853 self ,
@@ -89,19 +74,21 @@ def _fp8_prefill_att(
8974 ) -> torch .Tensor :
9075 self .backend : Fp8Fa3AttBackend = self .backend # for typing
9176
77+ q_head_num = q .shape [1 ]
78+ q_head_dim = q .shape [2 ]
79+ k_head_num = k .shape [1 ]
9280 q , q_scale = q_per_head_fp8_quant (
93- q ,
81+ q . reshape ( q . shape [ 0 ], k_head_num , - 1 ) ,
9482 self .infer_state .b_seq_len ,
9583 self .cu_seqlens_q ,
96- self .mid_token_batch_ids ,
84+ token_batch_ids = self .mid_token_batch_ids ,
9785 )
98- k_head_num = k .shape [1 ]
9986 k_head_dim = k .shape [2 ]
10087 cache_k = k .view (- 1 , 1 , k_head_num , k_head_dim ).view (torch .float8_e4m3fn )
10188 cache_v = v .view (- 1 , 1 , k_head_num , k_head_dim ).view (torch .float8_e4m3fn )
10289 layer_index = self .backend ._find_layer_index (k = cache_k , v = cache_v , att_state = self )
10390 o = flash_attn_with_kvcache (
104- q = q ,
91+ q = q . reshape ( - 1 , q_head_num , q_head_dim ) ,
10592 k_cache = cache_k ,
10693 v_cache = cache_v ,
10794 page_table = self .page_table ,
@@ -141,24 +128,9 @@ def init_state(self):
141128 head_num = mem_manager .head_num
142129
143130 # 为了减少推理计算量,在推理外部初始化k_descale和v_descale
144- self .k_descale = (
145- offline_scales [:, :head_num ].view (- 1 , 1 , head_num ).expand (offline_scales .shape [0 ], batch_size , head_num )
146- if offline_scales is not None
147- else torch .ones (
148- (mem_manager .layer_num , batch_size , head_num ),
149- dtype = torch .float32 ,
150- device = device ,
151- )
152- )
153- self .v_descale = (
154- offline_scales [:, head_num :].view (- 1 , 1 , head_num ).expand (offline_scales .shape [0 ], batch_size , head_num )
155- if offline_scales is not None
156- else torch .ones (
157- (mem_manager .layer_num , batch_size , head_num ),
158- dtype = torch .float32 ,
159- device = device ,
160- )
161- )
131+ self .k_descale = offline_scales [:, :head_num ].view (- 1 , 1 , head_num ).expand (offline_scales .shape [0 ], batch_size , head_num )
132+ self .v_descale = offline_scales [:, head_num :].view (- 1 , 1 , head_num ).expand (offline_scales .shape [0 ], batch_size , head_num )
133+
162134 return
163135
164136 def copy_for_decode_cuda_graph (self , new_state : "Fp8Fa3DecodeAttState" ):
@@ -200,9 +172,11 @@ def _fp8_decode_att(
200172 layer_index = self .backend ._find_layer_index (k = cache_k , v = cache_v , att_state = self )
201173
202174 q_head_num = q .shape [1 ]
203- q , q_scale = scaled_fp8_quant (q .view (q .shape [0 ] * k_head_num , - 1 ), use_per_token_if_dynamic = True )
175+ if scaled_fp8_quant is None :
176+ raise ImportError ("scaled_fp8_quant is unavailable. Please install vllm to enable FP8 decode attention." )
177+ q , q_scale = scaled_fp8_quant (q .reshape (q .shape [0 ] * k_head_num , - 1 ), use_per_token_if_dynamic = True )
204178 o = flash_attn_with_kvcache (
205- q = q .view (- 1 , q_head_num , k_head_dim ),
179+ q = q .reshape (- 1 , q_head_num , k_head_dim ),
206180 k_cache = cache_k ,
207181 v_cache = cache_v ,
208182 page_table = self .page_table ,
0 commit comments