Skip to content

Commit 69b4d8f

Browse files
committed
Remove duplicate indices calc
1 parent 4775805 commit 69b4d8f

2 files changed

Lines changed: 5 additions & 21 deletions

File tree

src/flash_head/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
"""FlashHead package version. Bump this to trigger a PyPI release."""
44

5-
__version__ = "0.1.5"
5+
__version__ = "0.1.6"

src/flash_head/flash_head.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _get_cluster_logits(
251251
hidden_states: torch.Tensor,
252252
top_clusters: torch.Tensor,
253253
use_identical_tiebreak: bool,
254-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
254+
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
255255
B, T, _ = hidden_states.shape
256256
if B != 1:
257257
raise NotImplementedError("FlashHead currently supports batch size = 1 only")
@@ -281,7 +281,7 @@ def _get_cluster_logits(
281281
bias=None,
282282
)
283283

284-
return logits, mapping
284+
return logits, mapping, indices
285285

286286
def get_next_token_standard(
287287
self,
@@ -322,12 +322,12 @@ def get_next_token(
322322
do_sample=do_sample,
323323
temperature=temperature,
324324
)
325-
cluster_logits, mapping = self._get_cluster_logits(
325+
cluster_logits, mapping, indices = self._get_cluster_logits(
326326
hidden_states, top_clusters, use_identical_tiebreak
327327
)
328328

329329
if do_sample:
330-
probs = (cluster_logits / temperature).softmax(dim=-1)
330+
probs = (cluster_logits[:, -1, :] / temperature).softmax(dim=-1)
331331
cluster_token_idx = torch.multinomial(probs, num_samples=1)
332332
else:
333333
cluster_token_idx = cluster_logits.argmax(
@@ -336,21 +336,5 @@ def get_next_token(
336336
if use_identical_tiebreak:
337337
cluster_token_idx = mapping[cluster_token_idx]
338338

339-
# Handle both 1D (from T>1 with .unique()) and 3D (from T==1) cases
340-
if top_clusters.dim() == 1:
341-
cluster_indices = top_clusters
342-
else:
343-
cluster_indices = top_clusters[0, 0]
344-
345-
maps = self.vocab_maps_tensor.index_select(0, cluster_indices)
346-
indices = maps.flatten().to(torch.int64)
347-
if self.special_token_ids_tensor.numel() > 0:
348-
special_ids = self.special_token_ids_tensor.to(
349-
device=indices.device
350-
)
351-
indices = torch.unique(torch.cat([indices, special_ids], dim=0))
352-
if use_identical_tiebreak:
353-
indices = indices.sort().values
354-
355339
vocab_index = indices[cluster_token_idx]
356340
return vocab_index[0]

0 commit comments

Comments
 (0)