Skip to content

Commit 382a952

Browse files
committed
refactor(attention): improved error handling
1 parent 379420d commit 382a952

1 file changed

Lines changed: 37 additions & 20 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -523,14 +523,7 @@ def prepare_inter_document_masking(
523523
"""
524524
device = self.c_proj.weight.device
525525
if self.attention_impl == AttentionImplementation.MANUAL:
526-
batch_size = len(in_batch_seq_lens)
527-
attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device)
528-
for i, doc_seq_lens in enumerate(in_batch_seq_lens):
529-
doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0)
530-
for j in range(len(doc_boundaries) - 1):
531-
start_idx = doc_boundaries[j]
532-
end_idx = doc_boundaries[j + 1]
533-
attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True
526+
attn_mask = self._build_3d_attention_mask(in_batch_seq_lens, max_seq_len, device)
534527
return attn_mask
535528
if self.attention_impl == AttentionImplementation.DAO_FLASH:
536529
concatenated_lengths = self._build_concatenated_lengths_tensor(
@@ -547,6 +540,21 @@ def prepare_inter_document_masking(
547540
f"Attention implementation {self.attention_impl} is not supported for inter-document masking."
548541
)
549542

543+
@staticmethod
544+
def _build_3d_attention_mask(
545+
in_batch_seq_lens: list[list[int]], max_seq_len: int, device: torch.device
546+
) -> torch.Tensor:
547+
batch_size = len(in_batch_seq_lens)
548+
attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device)
549+
for i, doc_seq_lens in enumerate(in_batch_seq_lens):
550+
CausalSelfAttention._validate_subsequence_lengths(max_seq_len, i, doc_seq_lens)
551+
doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0)
552+
for j in range(len(doc_boundaries) - 1):
553+
start_idx = doc_boundaries[j]
554+
end_idx = doc_boundaries[j + 1]
555+
attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True
556+
return attn_mask
557+
550558
@staticmethod
551559
def _build_concatenated_lengths_tensor(
552560
in_batch_seq_lens: list[list[int]], max_seq_len: int, device: torch.device
@@ -569,22 +577,26 @@ def _build_concatenated_lengths_tensor(
569577
batch_size = len(in_batch_seq_lens)
570578
concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device)
571579
for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens):
572-
if len(doc_seq_lens) > max_seq_len:
573-
raise ValueError(
574-
f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
575-
f"for batch index {batch_idx}."
576-
)
577-
if sum(doc_seq_lens) > max_seq_len:
578-
raise ValueError(
579-
f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
580-
f"for batch index {batch_idx}."
581-
)
580+
CausalSelfAttention._validate_subsequence_lengths(max_seq_len, batch_idx, doc_seq_lens)
582581
if len(doc_seq_lens) > 0:
583582
concatenated_lengths[batch_idx, : len(doc_seq_lens)] = torch.tensor(
584583
doc_seq_lens, dtype=torch.int32, device=device
585584
)
586585
return concatenated_lengths
587586

587+
@staticmethod
588+
def _validate_subsequence_lengths(max_seq_len: int, batch_idx: int, doc_seq_lens: list[int]):
589+
if len(doc_seq_lens) > max_seq_len:
590+
raise ValueError(
591+
f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
592+
f"for batch index {batch_idx}. (Detected while building inter document masking.)"
593+
)
594+
if sum(doc_seq_lens) > max_seq_len:
595+
raise ValueError(
596+
f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
597+
f"for batch index {batch_idx}. (Detected while building inter document masking.)"
598+
)
599+
588600
@staticmethod
589601
def _get_unpad_data_for_concatenated_sequences(
590602
attention_mask_in_length: torch.Tensor,
@@ -1232,13 +1244,18 @@ def forward_impl(self, inputs: torch.Tensor, sub_seq_lengths: list[list[int]] |
12321244
# TODO: use drop out also without absolute position embedding?
12331245
h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h
12341246

1235-
# TODO: Handle this in case of pipeline parallelism.
12361247
if sub_seq_lengths is not None:
12371248
attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking(
12381249
in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len
12391250
)
12401251
else:
1241-
attention_masking_information = None
1252+
# TODO: Handle this in case of pipeline parallelism.
1253+
raise NotImplementedError(
1254+
"In the current document part, not attention layer exists from "
1255+
"which to build inter document masking. Most likely, pipeline "
1256+
"parallelism is being used for which inter document masking is "
1257+
"currently not supported."
1258+
)
12421259

12431260
for layer_idx in self.transformer.h:
12441261
h = self.transformer.h[layer_idx](h, attention_masking_information=attention_masking_information)

0 commit comments

Comments
 (0)