Skip to content

Commit 6d7e502

Browse files
committed
refactor(data): improved detection and reporting of sequences in batch consisting completely of padding tokens
1 parent 478628d commit 6d7e502

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/modalities/models/gpt2/collator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def _compute_sub_sequence_lengths_for_each_sequence(self, sample_tensor: torch.T
5959
for batch_seq in sample_tensor:
6060
eos_positions = (batch_seq == self.eos_token_id).nonzero(as_tuple=True)[0]
6161
if len(eos_positions) == 0:
62-
assert (
63-
self.padding_token_id is None or batch_seq[0] != self.padding_token_id
64-
), "Sequence starts with padding token"
62+
assert self.padding_token_id is None or (
63+
batch_seq[0] != self.padding_token_id and torch.all(batch_seq != self.padding_token_id)
64+
), "Whole batch sequence consists of padding tokens."
6565
sub_seq_lengths_in_batch.append([len(batch_seq)])
6666
else:
6767
lens_in_seq = self._compute_subsequence_length_in_sequence(batch_seq, eos_positions)

0 commit comments

Comments
 (0)