@@ -533,7 +533,8 @@ def prepare_inter_document_masking(
533533 )
534534 return self ._get_unpad_data_for_concatenated_sequences (concatenated_lengths )
535535 raise NotImplementedError (
536- f"Attention implementation { self .attention_impl } is not supported for inter-document masking. Use `manual` or `dao_flash`."
536+ f"Attention implementation { self .attention_impl } is not supported for "
537+ "inter-document masking. Use `manual` or `dao_flash`."
537538 )
538539
539540 @staticmethod
@@ -1244,17 +1245,20 @@ def forward_impl(self, inputs: torch.Tensor, sub_seq_lengths: list[list[int]] |
12441245 h = self .transformer .drop (h ) if hasattr (self .transformer , "drop" ) else h
12451246
12461247 if sub_seq_lengths is not None :
1247- attention_masking_information = self .transformer .h ["0" ].attn .prepare_inter_document_masking (
1248- in_batch_seq_lens = sub_seq_lengths , max_seq_len = seq_len
1249- )
1248+ if hasattr (self .transformer , "h" ) and "0" in self .transformer .h :
1249+ attention_masking_information = self .transformer .h ["0" ].attn .prepare_inter_document_masking (
1250+ in_batch_seq_lens = sub_seq_lengths , max_seq_len = seq_len
1251+ )
1252+ else :
1253+ # TODO: Handle this in case of pipeline parallelism.
1254+ raise NotImplementedError (
1255+ "In the current document part, not attention layer exists from "
1256+ "which to build inter document masking. Most likely, pipeline "
1257+ "parallelism is being used for which inter document masking is "
1258+ "currently not supported."
1259+ )
12501260 else :
1251- # TODO: Handle this in case of pipeline parallelism.
1252- raise NotImplementedError (
1253- "In the current document part, not attention layer exists from "
1254- "which to build inter document masking. Most likely, pipeline "
1255- "parallelism is being used for which inter document masking is "
1256- "currently not supported."
1257- )
1261+ attention_masking_information = None
12581262
12591263 for layer_idx in self .transformer .h :
12601264 h = self .transformer .h [layer_idx ](h , attention_masking_information = attention_masking_information )
0 commit comments