Skip to content

Commit 4a31747

Browse files
committed
fix(attention): bug introduced in improved error handling for unsupported pp + inter doc masking
1 parent f173cff commit 4a31747

1 file changed

Lines changed: 15 additions & 11 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,8 @@ def prepare_inter_document_masking(
533533
)
534534
return self._get_unpad_data_for_concatenated_sequences(concatenated_lengths)
535535
raise NotImplementedError(
536-
f"Attention implementation {self.attention_impl} is not supported for inter-document masking. Use `manual` or `dao_flash`."
536+
f"Attention implementation {self.attention_impl} is not supported for "
537+
"inter-document masking. Use `manual` or `dao_flash`."
537538
)
538539

539540
@staticmethod
@@ -1244,17 +1245,20 @@ def forward_impl(self, inputs: torch.Tensor, sub_seq_lengths: list[list[int]] |
12441245
h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h
12451246

12461247
if sub_seq_lengths is not None:
1247-
attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking(
1248-
in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len
1249-
)
1248+
if hasattr(self.transformer, "h") and "0" in self.transformer.h:
1249+
attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking(
1250+
in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len
1251+
)
1252+
else:
1253+
# TODO: Handle this in case of pipeline parallelism.
1254+
raise NotImplementedError(
1255+
"In the current document part, not attention layer exists from "
1256+
"which to build inter document masking. Most likely, pipeline "
1257+
"parallelism is being used for which inter document masking is "
1258+
"currently not supported."
1259+
)
12501260
else:
1251-
# TODO: Handle this in case of pipeline parallelism.
1252-
raise NotImplementedError(
1253-
"In the current document part, not attention layer exists from "
1254-
"which to build inter document masking. Most likely, pipeline "
1255-
"parallelism is being used for which inter document masking is "
1256-
"currently not supported."
1257-
)
1261+
attention_masking_information = None
12581262

12591263
for layer_idx in self.transformer.h:
12601264
h = self.transformer.h[layer_idx](h, attention_masking_information=attention_masking_information)

0 commit comments

Comments
 (0)