-
Notifications
You must be signed in to change notification settings - Fork 16
Added inter document masking for manual and flash attention. #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
5e039df
a5354c4
eba9c5b
3d583c1
cd00777
fdd6465
e952bd0
115259c
8d71928
91f67e4
ab4c79a
379420d
382a952
478628d
6d7e502
f173cff
4a31747
b2ec756
956b958
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,9 +20,10 @@ | |
| from modalities.util import parse_enum_by_name | ||
|
|
||
| try: | ||
| from flash_attn import flash_attn_func | ||
| from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
| except ModuleNotFoundError: | ||
| flash_attn_func = None | ||
| flash_attn_varlen_func = None | ||
|
|
||
| # Logger configuration | ||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -501,6 +502,178 @@ def __init__( | |
| self.q_norm = None | ||
| self.k_norm = None | ||
|
|
||
| def prepare_inter_document_masking( | ||
| self, in_batch_seq_lens: list[list[int]], max_seq_len: int | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]: | ||
| """ | ||
| Prepares the inter-document attention mask based on the input batch sequence lengths. | ||
| For manual attention, a 3D attention mask of shape (batch_size, total_seq_len, total_seq_len) is returned. | ||
| For flash attention, the cu_seqlens are computed and returned along with the indices | ||
| of valid tokens and the maximum sequence length in the batch. | ||
| For sdp attention, an exception is raised for now. | ||
|
|
||
| Args: | ||
| in_batch_seq_lens (list[list[int]]): A list of lists containing the sequence | ||
| lengths for each document in the batch. | ||
| max_seq_len (int): The maximum sequence length in the batch. | ||
|
|
||
| Returns: | ||
| torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]: The inter-document masking information. | ||
| """ | ||
| device = self.c_proj.weight.device | ||
| if self.attention_impl == AttentionImplementation.MANUAL: | ||
| batch_size = len(in_batch_seq_lens) | ||
| attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device) | ||
| for i, doc_seq_lens in enumerate(in_batch_seq_lens): | ||
| doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0) | ||
| for j in range(len(doc_boundaries) - 1): | ||
| start_idx = doc_boundaries[j] | ||
| end_idx = doc_boundaries[j + 1] | ||
| attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True | ||
| return attn_mask | ||
|
BlueCrescent marked this conversation as resolved.
|
||
| if self.attention_impl == AttentionImplementation.DAO_FLASH: | ||
| concatenated_lengths = self._build_concatenated_lengths_tensor( | ||
| in_batch_seq_lens=in_batch_seq_lens, | ||
| max_seq_len=max_seq_len, | ||
| device=device, | ||
| ) | ||
| return self._get_unpad_data_for_concatenated_sequences(concatenated_lengths) | ||
| if self.attention_impl == AttentionImplementation.PYTORCH_FLASH: | ||
| raise NotImplementedError( | ||
| "Inter-document masking is not supported for `pytorch_flash`. " "Use `manual` or `dao_flash`." | ||
| ) | ||
| raise NotImplementedError( | ||
| f"Attention implementation {self.attention_impl} is not supported for inter-document masking." | ||
| ) | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
|
|
||
| @staticmethod | ||
| def _build_concatenated_lengths_tensor( | ||
| in_batch_seq_lens: list[list[int]], max_seq_len: int, device: torch.device | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Build a tensor of concatenated subsequence lengths for each batch item. | ||
| Args: | ||
| in_batch_seq_lens: A list of per-batch lists, where each inner list contains | ||
| the lengths of subsequences for that batch item. | ||
| max_seq_len: The maximum allowed sequence length (number of subsequences and | ||
| total length constraints are validated against this value). | ||
| device: The torch device on which to allocate the output tensor. | ||
| Returns: | ||
| A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths | ||
| for each batch item, padded with zeros beyond the number of subsequences. | ||
| Raises: | ||
| ValueError: If a batch item has more subsequences than max_seq_len or if the | ||
| sum of its subsequence lengths exceeds max_seq_len. | ||
| """ | ||
| batch_size = len(in_batch_seq_lens) | ||
| concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device) | ||
| for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens): | ||
| if len(doc_seq_lens) > max_seq_len: | ||
| raise ValueError( | ||
| f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) " | ||
| f"for batch index {batch_idx}." | ||
| ) | ||
| if sum(doc_seq_lens) > max_seq_len: | ||
| raise ValueError( | ||
| f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) " | ||
| f"for batch index {batch_idx}." | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
| ) | ||
| if len(doc_seq_lens) > 0: | ||
| concatenated_lengths[batch_idx, : len(doc_seq_lens)] = torch.tensor( | ||
| doc_seq_lens, dtype=torch.int32, device=device | ||
| ) | ||
| return concatenated_lengths | ||
|
|
||
| @staticmethod | ||
| def _get_unpad_data_for_concatenated_sequences( | ||
| attention_mask_in_length: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, int]: | ||
| """ | ||
| Compute unpadded indices and cumulative sequence lengths for concatenated sequences. | ||
| Given a batch of per-subsequence lengths in `attention_mask_in_length`, this | ||
| builds a boolean mask over the maximum sequence length, extracts flattened | ||
| indices of valid (unpadded) tokens, and returns cumulative sequence lengths | ||
| (CU) along with the maximum subsequence length in the batch. | ||
| Args: | ||
| attention_mask_in_length (torch.Tensor): Tensor of shape (num_subsequences,) | ||
| containing the lengths of each subsequence in the concatenated batch. | ||
| Returns: | ||
| tuple[torch.Tensor, torch.Tensor, int]: | ||
| - indices: 1D tensor of flattened indices for all valid (unpadded) tokens. | ||
| - cu_seqlens: 1D int32 tensor of cumulative sequence lengths with a | ||
| leading zero (shape: num_subsequences + 1). | ||
| - max_seqlen_in_batch: Maximum subsequence length as an int. | ||
| Raises: | ||
| ValueError: If no subsequence lengths are provided (all zeros). | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
| """ | ||
|
|
||
| length = attention_mask_in_length.sum(dim=-1) | ||
| seqlen = attention_mask_in_length.size(-1) | ||
| attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand( | ||
| len(length), seqlen | ||
| ) < length.unsqueeze(1) | ||
| seqlens_in_batch = attention_mask_in_length[attention_mask_in_length > 0] | ||
| if seqlens_in_batch.numel() == 0: | ||
| raise ValueError("No subsequence lengths provided for inter-document masking.") | ||
| indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() | ||
| max_seqlen_in_batch = int(seqlens_in_batch.max().item()) | ||
| cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | ||
| return indices, cu_seqlens, max_seqlen_in_batch | ||
|
|
||
| @classmethod | ||
| def _execute_dao_flash_with_inter_document_masking( | ||
| cls, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| dropout: float, | ||
| attention_masking_information: tuple[torch.Tensor, torch.Tensor, int], | ||
| ) -> torch.Tensor: | ||
| if flash_attn_varlen_func is None: | ||
| raise NotImplementedError( | ||
| "ERROR! Dao Flash Attention varlen kernel is not available. " "Install flash-attn with varlen support." | ||
| ) | ||
|
|
||
| indices, cu_seqlens, max_seqlen = attention_masking_information | ||
|
|
||
| q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd) | ||
| k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd) | ||
| v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd) | ||
|
|
||
| batch_size, seq_len, n_head_q, head_dim = q.shape | ||
| n_head_kv = k.shape[2] | ||
|
|
||
| q_flat = q.reshape(batch_size * seq_len, n_head_q, head_dim) | ||
| k_flat = k.reshape(batch_size * seq_len, n_head_kv, head_dim) | ||
| v_flat = v.reshape(batch_size * seq_len, n_head_kv, head_dim) | ||
|
|
||
| q_unpad = q_flat.index_select(0, indices) | ||
| k_unpad = k_flat.index_select(0, indices) | ||
| v_unpad = v_flat.index_select(0, indices) | ||
|
|
||
| y_unpad = flash_attn_varlen_func( | ||
| q_unpad, | ||
| k_unpad, | ||
| v_unpad, | ||
| cu_seqlens_q=cu_seqlens, | ||
| cu_seqlens_k=cu_seqlens, | ||
| max_seqlen_q=max_seqlen, | ||
| max_seqlen_k=max_seqlen, | ||
| dropout_p=dropout, | ||
| causal=True, | ||
| softmax_scale=None, | ||
| window_size=(-1, -1), | ||
| ) | ||
|
|
||
| y = torch.zeros( | ||
| (batch_size * seq_len, n_head_q, head_dim), | ||
| dtype=y_unpad.dtype, | ||
| device=y_unpad.device, | ||
| ) | ||
| y.index_copy_(0, indices, y_unpad) | ||
| y = y.reshape(batch_size, seq_len, n_head_q, head_dim) | ||
| return y | ||
|
|
||
| def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Applies projections to the input tensor to get queries, keys, and values. | ||
|
|
@@ -600,6 +773,7 @@ def execute_attention( | |
| v: torch.Tensor, | ||
| dropout: float, | ||
| attention_impl: AttentionImplementation, | ||
| attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Executes attention mechanism based on the specified implementation. | ||
|
|
@@ -611,6 +785,8 @@ def execute_attention( | |
| v (torch.Tensor): The value tensor. | ||
| dropout (float): The dropout rate. | ||
| attention_impl (AttentionImplementation): The attention implementation to use. | ||
| attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None): | ||
| Optional tensor containing masking information for inter-document attention. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The output tensor. | ||
|
BlueCrescent marked this conversation as resolved.
|
||
|
|
@@ -624,7 +800,7 @@ def execute_attention( | |
| query=q, | ||
| key=k, | ||
| value=v, | ||
| attn_mask=None, | ||
| attn_mask=attention_masking_information, | ||
| dropout_p=dropout, | ||
| is_causal=True, | ||
| ) # (B, nh_q, T, hd) | ||
|
|
@@ -647,22 +823,37 @@ def execute_attention( | |
| if flash_attn_func is None: | ||
| raise NotImplementedError("ERROR! Dao Flash Attention is not installed.") | ||
| # the next three lines are only needed for flash-attn from Daio Lab | ||
| q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd) | ||
| k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd) | ||
| v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd) | ||
| y = flash_attn_func( | ||
| q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1) | ||
| ) # (B, T, nh_q, hd) | ||
| if attention_masking_information is None: | ||
| q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd) | ||
| k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd) | ||
| v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd) | ||
| y = flash_attn_func( | ||
| q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1) | ||
| ) # (B, T, nh_q, hd) | ||
| else: | ||
| y = cls._execute_dao_flash_with_inter_document_masking( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| dropout=dropout, | ||
| attention_masking_information=attention_masking_information, | ||
| ) | ||
|
BlueCrescent marked this conversation as resolved.
|
||
| else: | ||
| raise NotImplementedError(f"Attention implementation {attention_impl} not supported") | ||
| return y # (B, T, nh_q, hd) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| def forward( | ||
| self, | ||
| x: torch.Tensor, | ||
| attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Forward pass of the CausalSelfAttention module. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): Input tensor of shape (B, T, n_embd) | ||
| attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None): | ||
| Optional tensor containing masking information for inter-document attention. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection. | ||
|
|
@@ -675,7 +866,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.q_norm is not None and self.k_norm is not None: | ||
| q = self.q_norm(q) | ||
| k = self.k_norm(k) | ||
| y = CausalSelfAttention.execute_attention(q, k, v, self.dropout, self.attention_impl) # (B, T, nh_q, hd) | ||
| y = CausalSelfAttention.execute_attention( | ||
| q, k, v, self.dropout, self.attention_impl, attention_masking_information | ||
| ) # (B, T, nh_q, hd) | ||
| y = y.reshape(B, T, -1) # (B, T, n_embd), re-assemble all head outputs side by side | ||
| return self.resid_dropout(self.c_proj(y)) # (B, T, n_embd), output projection | ||
|
|
||
|
|
@@ -798,7 +991,7 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None: | |
| f"but got `n_embd = {n_embd}` and `ffn_hidden = {ffn_hidden}`." | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| def forward(self, x: torch.Tensor, attention_masking_information: torch.Tensor | None = None) -> torch.Tensor: | ||
| """ | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
| Forward pass of the GPT2Block. | ||
|
|
||
|
|
@@ -808,7 +1001,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| Returns: | ||
| torch.Tensor: Output tensor. | ||
| """ | ||
| x = x + self.attn(self.attention_norm(x)) | ||
| x = x + self.attn(self.attention_norm(x), attention_masking_information=attention_masking_information) | ||
| x = x + self.mlp(self.ffn_norm(x)) | ||
| return x | ||
|
|
||
|
|
@@ -839,6 +1032,7 @@ def __init__( | |
| use_weight_tying: bool, | ||
| seed: Optional[int] = None, | ||
| enforce_swiglu_hidden_dim_multiple_of: int = 256, | ||
| sub_seq_lengths_key: str | None = None, | ||
| ): | ||
| """ | ||
| Initializes the GPT2LLM object. | ||
|
|
@@ -867,6 +1061,8 @@ def __init__( | |
| enforce_swiglu_hidden_dim_multiple_of (int): Enforces | ||
| the hidden dimension in the SwiGLU layer to be a multiple of this value. | ||
| Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256. | ||
| sub_seq_lengths_key (str, optional): The key for sub sequence lengths to be | ||
| used for inter document masking. | ||
| """ | ||
| weight_decay_groups = { | ||
| "linear": [".attn", ".mlp", ".lm_head.weight"], | ||
|
|
@@ -876,6 +1072,7 @@ def __init__( | |
| super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) | ||
| self.sample_key = sample_key | ||
| self.prediction_key = prediction_key | ||
| self.sub_seq_lengths_key = sub_seq_lengths_key | ||
| self.sequence_length = sequence_length | ||
| self.n_embd = n_embd | ||
| self.n_layer = n_layer | ||
|
|
@@ -981,16 +1178,22 @@ def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, t | |
| dict[str, torch.Tensor] | torch.Tensor: Model output. | ||
| """ | ||
| if isinstance(inputs, dict): | ||
| return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} | ||
| return { | ||
| self.prediction_key: self.forward_impl( | ||
| inputs[self.sample_key], sub_seq_lengths=inputs.get(self.sub_seq_lengths_key) | ||
| ) | ||
| } | ||
|
BlueCrescent marked this conversation as resolved.
|
||
| else: | ||
| return self.forward_impl(inputs) | ||
|
|
||
| def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: | ||
| def forward_impl(self, inputs: torch.Tensor, sub_seq_lengths: list[list[int]] | None = None) -> torch.Tensor: | ||
| """ | ||
| Forward pass implementation of the GPT2LLM module. | ||
|
|
||
| Args: | ||
| inputs (torch.Tensor): A tensor containing input token ids. | ||
| sub_seq_lengths (list[list[int]], optional): The lengths of the subsequences of each sequence | ||
| in the batch. To be used for inter document masking. | ||
|
|
||
| Returns: | ||
| torch.Tensor: A tensor containing output logits. | ||
|
|
@@ -1013,8 +1216,16 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: | |
| # TODO: use drop out also without absolute position embedding? | ||
| h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h | ||
|
|
||
| # TODO: Handle this in case of pipeline parallelism. | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
| if sub_seq_lengths is not None: | ||
| attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking( | ||
| in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len | ||
| ) | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
| else: | ||
| attention_masking_information = None | ||
|
|
||
| for layer_idx in self.transformer.h: | ||
| h = self.transformer.h[layer_idx](h) | ||
| h = self.transformer.h[layer_idx](h, attention_masking_information=attention_masking_information) | ||
| h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h | ||
| h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h | ||
| return h | ||
|
|
@@ -1047,18 +1258,32 @@ def manual_scaled_dot_product_attention( | |
| attn_bias = torch.zeros( | ||
| L, S, dtype=query.dtype, device=query.device | ||
| ) # device added (not part of the original code) | ||
| if attn_mask is not None and attn_mask.dim() == 3: | ||
| attn_bias = attn_bias.unsqueeze(0).repeat(attn_mask.size(0), 1, 1) | ||
| if is_causal: | ||
| assert attn_mask is None | ||
| temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) # device added | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| if attn_mask is None: | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| elif attn_mask.dtype == torch.bool: | ||
| if attn_mask.dim() == 3: | ||
| combined_mask = temp_mask.unsqueeze(0) & attn_mask | ||
| else: | ||
| combined_mask = temp_mask & attn_mask | ||
| attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf")) | ||
| else: | ||
| if attn_mask.dim() == 3: | ||
| temp_mask = temp_mask.unsqueeze(0) | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias += attn_mask | ||
|
Comment on lines
+1304
to
+1315
|
||
| attn_bias.to(query.dtype) | ||
|
|
||
| if attn_mask is not None: | ||
| elif attn_mask is not None: | ||
|
BlueCrescent marked this conversation as resolved.
BlueCrescent marked this conversation as resolved.
|
||
| if attn_mask.dtype == torch.bool: | ||
| attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | ||
| else: | ||
| attn_bias += attn_mask | ||
| attn_weight = query @ key.transpose(-2, -1) * scale_factor | ||
| if attn_bias.dim() == 3: | ||
| attn_bias = attn_bias.unsqueeze(1) | ||
| attn_weight += attn_bias | ||
| attn_weight = torch.softmax(attn_weight, dim=-1) | ||
| attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.