@@ -251,7 +251,7 @@ def _get_cluster_logits(
251251 hidden_states : torch .Tensor ,
252252 top_clusters : torch .Tensor ,
253253 use_identical_tiebreak : bool ,
254- ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
254+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ], torch . Tensor ]:
255255 B , T , _ = hidden_states .shape
256256 if B != 1 :
257257 raise NotImplementedError ("FlashHead currently supports batch size = 1 only" )
@@ -281,7 +281,7 @@ def _get_cluster_logits(
281281 bias = None ,
282282 )
283283
284- return logits , mapping
284+ return logits , mapping , indices
285285
286286 def get_next_token_standard (
287287 self ,
@@ -322,12 +322,12 @@ def get_next_token(
322322 do_sample = do_sample ,
323323 temperature = temperature ,
324324 )
325- cluster_logits , mapping = self ._get_cluster_logits (
325+ cluster_logits , mapping , indices = self ._get_cluster_logits (
326326 hidden_states , top_clusters , use_identical_tiebreak
327327 )
328328
329329 if do_sample :
330- probs = (cluster_logits / temperature ).softmax (dim = - 1 )
330+ probs = (cluster_logits [:, - 1 , :] / temperature ).softmax (dim = - 1 )
331331 cluster_token_idx = torch .multinomial (probs , num_samples = 1 )
332332 else :
333333 cluster_token_idx = cluster_logits .argmax (
@@ -336,21 +336,5 @@ def get_next_token(
336336 if use_identical_tiebreak :
337337 cluster_token_idx = mapping [cluster_token_idx ]
338338
339- # Handle both 1D (from T>1 with .unique()) and 3D (from T==1) cases
340- if top_clusters .dim () == 1 :
341- cluster_indices = top_clusters
342- else :
343- cluster_indices = top_clusters [0 , 0 ]
344-
345- maps = self .vocab_maps_tensor .index_select (0 , cluster_indices )
346- indices = maps .flatten ().to (torch .int64 )
347- if self .special_token_ids_tensor .numel () > 0 :
348- special_ids = self .special_token_ids_tensor .to (
349- device = indices .device
350- )
351- indices = torch .unique (torch .cat ([indices , special_ids ], dim = 0 ))
352- if use_identical_tiebreak :
353- indices = indices .sort ().values
354-
355339 vocab_index = indices [cluster_token_idx ]
356340 return vocab_index [0 ]
0 commit comments