@@ -523,14 +523,7 @@ def prepare_inter_document_masking(
523523 """
524524 device = self .c_proj .weight .device
525525 if self .attention_impl == AttentionImplementation .MANUAL :
526- batch_size = len (in_batch_seq_lens )
527- attn_mask = torch .zeros ((batch_size , max_seq_len , max_seq_len ), dtype = torch .bool , device = device )
528- for i , doc_seq_lens in enumerate (in_batch_seq_lens ):
529- doc_boundaries = torch .cumsum (torch .tensor ([0 ] + doc_seq_lens , device = device ), dim = 0 )
530- for j in range (len (doc_boundaries ) - 1 ):
531- start_idx = doc_boundaries [j ]
532- end_idx = doc_boundaries [j + 1 ]
533- attn_mask [i , start_idx :end_idx , start_idx :end_idx ] = True
526+ attn_mask = self ._build_3d_attention_mask (in_batch_seq_lens , max_seq_len , device )
534527 return attn_mask
535528 if self .attention_impl == AttentionImplementation .DAO_FLASH :
536529 concatenated_lengths = self ._build_concatenated_lengths_tensor (
@@ -547,6 +540,21 @@ def prepare_inter_document_masking(
547540 f"Attention implementation { self .attention_impl } is not supported for inter-document masking."
548541 )
549542
543+ @staticmethod
544+ def _build_3d_attention_mask (
545+ in_batch_seq_lens : list [list [int ]], max_seq_len : int , device : torch .device
546+ ) -> torch .Tensor :
547+ batch_size = len (in_batch_seq_lens )
548+ attn_mask = torch .zeros ((batch_size , max_seq_len , max_seq_len ), dtype = torch .bool , device = device )
549+ for i , doc_seq_lens in enumerate (in_batch_seq_lens ):
550+ CausalSelfAttention ._validate_subsequence_lengths (max_seq_len , i , doc_seq_lens )
551+ doc_boundaries = torch .cumsum (torch .tensor ([0 ] + doc_seq_lens , device = device ), dim = 0 )
552+ for j in range (len (doc_boundaries ) - 1 ):
553+ start_idx = doc_boundaries [j ]
554+ end_idx = doc_boundaries [j + 1 ]
555+ attn_mask [i , start_idx :end_idx , start_idx :end_idx ] = True
556+ return attn_mask
557+
550558 @staticmethod
551559 def _build_concatenated_lengths_tensor (
552560 in_batch_seq_lens : list [list [int ]], max_seq_len : int , device : torch .device
@@ -569,22 +577,26 @@ def _build_concatenated_lengths_tensor(
569577 batch_size = len (in_batch_seq_lens )
570578 concatenated_lengths = torch .zeros ((batch_size , max_seq_len ), dtype = torch .int32 , device = device )
571579 for batch_idx , doc_seq_lens in enumerate (in_batch_seq_lens ):
572- if len (doc_seq_lens ) > max_seq_len :
573- raise ValueError (
574- f"Number of subsequences ({ len (doc_seq_lens )} ) exceeds max_seq_len ({ max_seq_len } ) "
575- f"for batch index { batch_idx } ."
576- )
577- if sum (doc_seq_lens ) > max_seq_len :
578- raise ValueError (
579- f"Sum of subsequence lengths ({ sum (doc_seq_lens )} ) exceeds max_seq_len ({ max_seq_len } ) "
580- f"for batch index { batch_idx } ."
581- )
580+ CausalSelfAttention ._validate_subsequence_lengths (max_seq_len , batch_idx , doc_seq_lens )
582581 if len (doc_seq_lens ) > 0 :
583582 concatenated_lengths [batch_idx , : len (doc_seq_lens )] = torch .tensor (
584583 doc_seq_lens , dtype = torch .int32 , device = device
585584 )
586585 return concatenated_lengths
587586
587+ @staticmethod
588+ def _validate_subsequence_lengths (max_seq_len : int , batch_idx : int , doc_seq_lens : list [int ]):
589+ if len (doc_seq_lens ) > max_seq_len :
590+ raise ValueError (
591+ f"Number of subsequences ({ len (doc_seq_lens )} ) exceeds max_seq_len ({ max_seq_len } ) "
592+ f"for batch index { batch_idx } . (Detected while building inter document masking.)"
593+ )
594+ if sum (doc_seq_lens ) > max_seq_len :
595+ raise ValueError (
596+ f"Sum of subsequence lengths ({ sum (doc_seq_lens )} ) exceeds max_seq_len ({ max_seq_len } ) "
597+ f"for batch index { batch_idx } . (Detected while building inter document masking.)"
598+ )
599+
588600 @staticmethod
589601 def _get_unpad_data_for_concatenated_sequences (
590602 attention_mask_in_length : torch .Tensor ,
@@ -1232,13 +1244,18 @@ def forward_impl(self, inputs: torch.Tensor, sub_seq_lengths: list[list[int]] |
12321244 # TODO: use drop out also without absolute position embedding?
12331245 h = self .transformer .drop (h ) if hasattr (self .transformer , "drop" ) else h
12341246
1235- # TODO: Handle this in case of pipeline parallelism.
12361247 if sub_seq_lengths is not None :
12371248 attention_masking_information = self .transformer .h ["0" ].attn .prepare_inter_document_masking (
12381249 in_batch_seq_lens = sub_seq_lengths , max_seq_len = seq_len
12391250 )
12401251 else :
1241- attention_masking_information = None
1252+ # TODO: Handle this in case of pipeline parallelism.
1253+ raise NotImplementedError (
1254+ "In the current document part, not attention layer exists from "
1255+ "which to build inter document masking. Most likely, pipeline "
1256+ "parallelism is being used for which inter document masking is "
1257+ "currently not supported."
1258+ )
12421259
12431260 for layer_idx in self .transformer .h :
12441261 h = self .transformer .h [layer_idx ](h , attention_masking_information = attention_masking_information )
0 commit comments