Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
5e039df
feat(attention): Added inter document masking for manual and flash at…
BlueCrescent Feb 23, 2026
a5354c4
feat(data loading): GPT2LLMCollateFn can now determine the sub sequen…
BlueCrescent Feb 23, 2026
eba9c5b
fix(attention): NaNs when using padding + inter document masking with…
BlueCrescent Feb 24, 2026
3d583c1
feat(attention): added sub_seq_lengths_key to GPT2LLMConfig and renam…
BlueCrescent Feb 24, 2026
cd00777
fix(attention): added missing sub_seq_lengths_key parameter to get_gp…
BlueCrescent Feb 24, 2026
fdd6465
fix(attention): computing sub sequence lengths on correct input
BlueCrescent Feb 24, 2026
e952bd0
docs(attention): better _get_unpad_data_for_concatenated_sequences() …
BlueCrescent Feb 24, 2026
115259c
test(attention): fixed collator tests and improved collator config va…
BlueCrescent Feb 25, 2026
8d71928
fix: directly use correct device + dtype for eos positions extensions
BlueCrescent Feb 26, 2026
91f67e4
chore: remove comment
BlueCrescent Feb 27, 2026
ab4c79a
refactor(data): improved naming in collator
BlueCrescent Feb 28, 2026
379420d
test(attention): turned global manual seed into fixture
BlueCrescent Feb 28, 2026
382a952
refactor(attention): improved error handling
BlueCrescent Feb 28, 2026
478628d
fix(attention): added not supported assertion for inter document mask…
BlueCrescent Feb 28, 2026
6d7e502
refactor(data): improved detection and reporting of sequences in batc…
BlueCrescent Feb 28, 2026
f173cff
refactor(attention): removed duplicate exception
BlueCrescent Feb 28, 2026
4a31747
fix(attention): bug introduced in improved error handling for unsuppo…
BlueCrescent Feb 28, 2026
b2ec756
chore: Merge remote-tracking branch 'origin/main' into inter_document…
BlueCrescent Mar 12, 2026
956b958
chore: Merge remote-tracking branch 'origin/main' into inter_document…
BlueCrescent Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 244 additions & 19 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from modalities.util import parse_enum_by_name

try:
from flash_attn import flash_attn_func
from flash_attn import flash_attn_func, flash_attn_varlen_func
except ModuleNotFoundError:
Comment thread
BlueCrescent marked this conversation as resolved.
flash_attn_func = None
flash_attn_varlen_func = None

# Logger configuration
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -501,6 +502,178 @@ def __init__(
self.q_norm = None
self.k_norm = None

def prepare_inter_document_masking(
self, in_batch_seq_lens: list[list[int]], max_seq_len: int
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]:
"""
Prepares the inter-document attention mask based on the input batch sequence lengths.
For manual attention, a 3D attention mask of shape (batch_size, total_seq_len, total_seq_len) is returned.
For flash attention, the cu_seqlens are computed and returned along with the indices
of valid tokens and the maximum sequence length in the batch.
For sdp attention, an exception is raised for now.

Args:
in_batch_seq_lens (list[list[int]]): A list of lists containing the sequence
lengths for each document in the batch.
max_seq_len (int): The maximum sequence length in the batch.

Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]: The inter-document masking information.
"""
device = self.c_proj.weight.device
if self.attention_impl == AttentionImplementation.MANUAL:
batch_size = len(in_batch_seq_lens)
attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device)
for i, doc_seq_lens in enumerate(in_batch_seq_lens):
doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0)
for j in range(len(doc_boundaries) - 1):
start_idx = doc_boundaries[j]
end_idx = doc_boundaries[j + 1]
attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True
return attn_mask
Comment thread
BlueCrescent marked this conversation as resolved.
if self.attention_impl == AttentionImplementation.DAO_FLASH:
concatenated_lengths = self._build_concatenated_lengths_tensor(
in_batch_seq_lens=in_batch_seq_lens,
max_seq_len=max_seq_len,
device=device,
)
return self._get_unpad_data_for_concatenated_sequences(concatenated_lengths)
if self.attention_impl == AttentionImplementation.PYTORCH_FLASH:
raise NotImplementedError(
"Inter-document masking is not supported for `pytorch_flash`. " "Use `manual` or `dao_flash`."
)
raise NotImplementedError(
f"Attention implementation {self.attention_impl} is not supported for inter-document masking."
)
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated

@staticmethod
def _build_concatenated_lengths_tensor(
in_batch_seq_lens: list[list[int]], max_seq_len: int, device: torch.device
) -> torch.Tensor:
"""
Build a tensor of concatenated subsequence lengths for each batch item.
Args:
in_batch_seq_lens: A list of per-batch lists, where each inner list contains
the lengths of subsequences for that batch item.
max_seq_len: The maximum allowed sequence length (number of subsequences and
total length constraints are validated against this value).
device: The torch device on which to allocate the output tensor.
Returns:
A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths
for each batch item, padded with zeros beyond the number of subsequences.
Raises:
ValueError: If a batch item has more subsequences than max_seq_len or if the
sum of its subsequence lengths exceeds max_seq_len.
"""
batch_size = len(in_batch_seq_lens)
concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device)
for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens):
if len(doc_seq_lens) > max_seq_len:
raise ValueError(
f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
f"for batch index {batch_idx}."
)
if sum(doc_seq_lens) > max_seq_len:
raise ValueError(
f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
f"for batch index {batch_idx}."
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
)
if len(doc_seq_lens) > 0:
concatenated_lengths[batch_idx, : len(doc_seq_lens)] = torch.tensor(
doc_seq_lens, dtype=torch.int32, device=device
)
return concatenated_lengths

@staticmethod
def _get_unpad_data_for_concatenated_sequences(
attention_mask_in_length: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Compute unpadded indices and cumulative sequence lengths for concatenated sequences.
Given a batch of per-subsequence lengths in `attention_mask_in_length`, this
builds a boolean mask over the maximum sequence length, extracts flattened
indices of valid (unpadded) tokens, and returns cumulative sequence lengths
(CU) along with the maximum subsequence length in the batch.
Args:
attention_mask_in_length (torch.Tensor): Tensor of shape (num_subsequences,)
containing the lengths of each subsequence in the concatenated batch.
Returns:
tuple[torch.Tensor, torch.Tensor, int]:
- indices: 1D tensor of flattened indices for all valid (unpadded) tokens.
- cu_seqlens: 1D int32 tensor of cumulative sequence lengths with a
leading zero (shape: num_subsequences + 1).
- max_seqlen_in_batch: Maximum subsequence length as an int.
Raises:
ValueError: If no subsequence lengths are provided (all zeros).
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
"""

length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(
len(length), seqlen
) < length.unsqueeze(1)
seqlens_in_batch = attention_mask_in_length[attention_mask_in_length > 0]
if seqlens_in_batch.numel() == 0:
raise ValueError("No subsequence lengths provided for inter-document masking.")
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch

@classmethod
def _execute_dao_flash_with_inter_document_masking(
cls,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout: float,
attention_masking_information: tuple[torch.Tensor, torch.Tensor, int],
) -> torch.Tensor:
if flash_attn_varlen_func is None:
raise NotImplementedError(
"ERROR! Dao Flash Attention varlen kernel is not available. " "Install flash-attn with varlen support."
)

indices, cu_seqlens, max_seqlen = attention_masking_information

q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)

batch_size, seq_len, n_head_q, head_dim = q.shape
n_head_kv = k.shape[2]

q_flat = q.reshape(batch_size * seq_len, n_head_q, head_dim)
k_flat = k.reshape(batch_size * seq_len, n_head_kv, head_dim)
v_flat = v.reshape(batch_size * seq_len, n_head_kv, head_dim)

q_unpad = q_flat.index_select(0, indices)
k_unpad = k_flat.index_select(0, indices)
v_unpad = v_flat.index_select(0, indices)

y_unpad = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=dropout,
causal=True,
softmax_scale=None,
window_size=(-1, -1),
)

y = torch.zeros(
(batch_size * seq_len, n_head_q, head_dim),
dtype=y_unpad.dtype,
device=y_unpad.device,
)
y.index_copy_(0, indices, y_unpad)
y = y.reshape(batch_size, seq_len, n_head_q, head_dim)
return y

def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies projections to the input tensor to get queries, keys, and values.
Expand Down Expand Up @@ -600,6 +773,7 @@ def execute_attention(
v: torch.Tensor,
dropout: float,
attention_impl: AttentionImplementation,
attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None,
) -> torch.Tensor:
"""
Executes attention mechanism based on the specified implementation.
Expand All @@ -611,6 +785,8 @@ def execute_attention(
v (torch.Tensor): The value tensor.
dropout (float): The dropout rate.
attention_impl (AttentionImplementation): The attention implementation to use.
attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
Optional tensor containing masking information for inter-document attention.

Returns:
torch.Tensor: The output tensor.
Comment thread
BlueCrescent marked this conversation as resolved.
Expand All @@ -624,7 +800,7 @@ def execute_attention(
query=q,
key=k,
value=v,
attn_mask=None,
attn_mask=attention_masking_information,
dropout_p=dropout,
is_causal=True,
) # (B, nh_q, T, hd)
Expand All @@ -647,22 +823,37 @@ def execute_attention(
if flash_attn_func is None:
raise NotImplementedError("ERROR! Dao Flash Attention is not installed.")
# the next three lines are only needed for flash-attn from Daio Lab
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
y = flash_attn_func(
q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1)
) # (B, T, nh_q, hd)
if attention_masking_information is None:
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
y = flash_attn_func(
q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1)
) # (B, T, nh_q, hd)
else:
y = cls._execute_dao_flash_with_inter_document_masking(
q=q,
k=k,
v=v,
dropout=dropout,
attention_masking_information=attention_masking_information,
)
Comment thread
BlueCrescent marked this conversation as resolved.
else:
raise NotImplementedError(f"Attention implementation {attention_impl} not supported")
return y # (B, T, nh_q, hd)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None,
) -> torch.Tensor:
"""
Forward pass of the CausalSelfAttention module.

Args:
x (torch.Tensor): Input tensor of shape (B, T, n_embd)
attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
Optional tensor containing masking information for inter-document attention.

Returns:
torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection.
Expand All @@ -675,7 +866,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
y = CausalSelfAttention.execute_attention(q, k, v, self.dropout, self.attention_impl) # (B, T, nh_q, hd)
y = CausalSelfAttention.execute_attention(
q, k, v, self.dropout, self.attention_impl, attention_masking_information
) # (B, T, nh_q, hd)
y = y.reshape(B, T, -1) # (B, T, n_embd), re-assemble all head outputs side by side
return self.resid_dropout(self.c_proj(y)) # (B, T, n_embd), output projection

Expand Down Expand Up @@ -798,7 +991,7 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None:
f"but got `n_embd = {n_embd}` and `ffn_hidden = {ffn_hidden}`."
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, attention_masking_information: torch.Tensor | None = None) -> torch.Tensor:
"""
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
Forward pass of the GPT2Block.

Expand All @@ -808,7 +1001,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Output tensor.
"""
x = x + self.attn(self.attention_norm(x))
x = x + self.attn(self.attention_norm(x), attention_masking_information=attention_masking_information)
x = x + self.mlp(self.ffn_norm(x))
return x

Expand Down Expand Up @@ -839,6 +1032,7 @@ def __init__(
use_weight_tying: bool,
seed: Optional[int] = None,
enforce_swiglu_hidden_dim_multiple_of: int = 256,
sub_seq_lengths_key: str | None = None,
):
"""
Initializes the GPT2LLM object.
Expand Down Expand Up @@ -867,6 +1061,8 @@ def __init__(
enforce_swiglu_hidden_dim_multiple_of (int): Enforces
the hidden dimension in the SwiGLU layer to be a multiple of this value.
Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256.
sub_seq_lengths_key (str, optional): The key for sub sequence lengths to be
used for inter document masking.
"""
weight_decay_groups = {
"linear": [".attn", ".mlp", ".lm_head.weight"],
Expand All @@ -876,6 +1072,7 @@ def __init__(
super().__init__(weight_decay_groups=weight_decay_groups, seed=seed)
self.sample_key = sample_key
self.prediction_key = prediction_key
self.sub_seq_lengths_key = sub_seq_lengths_key
self.sequence_length = sequence_length
self.n_embd = n_embd
self.n_layer = n_layer
Expand Down Expand Up @@ -981,16 +1178,22 @@ def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, t
dict[str, torch.Tensor] | torch.Tensor: Model output.
"""
if isinstance(inputs, dict):
return {self.prediction_key: self.forward_impl(inputs[self.sample_key])}
return {
self.prediction_key: self.forward_impl(
inputs[self.sample_key], sub_seq_lengths=inputs.get(self.sub_seq_lengths_key)
)
}
Comment thread
BlueCrescent marked this conversation as resolved.
else:
return self.forward_impl(inputs)

def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
def forward_impl(self, inputs: torch.Tensor, sub_seq_lengths: list[list[int]] | None = None) -> torch.Tensor:
"""
Forward pass implementation of the GPT2LLM module.

Args:
inputs (torch.Tensor): A tensor containing input token ids.
sub_seq_lengths (list[list[int]], optional): The lengths of the subsequences of each sequence
in the batch. To be used for inter document masking.

Returns:
torch.Tensor: A tensor containing output logits.
Expand All @@ -1013,8 +1216,16 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
# TODO: use drop out also without absolute position embedding?
h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h

# TODO: Handle this in case of pipeline parallelism.
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
if sub_seq_lengths is not None:
attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking(
in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len
)
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
else:
attention_masking_information = None

for layer_idx in self.transformer.h:
h = self.transformer.h[layer_idx](h)
h = self.transformer.h[layer_idx](h, attention_masking_information=attention_masking_information)
h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h
h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h
return h
Expand Down Expand Up @@ -1047,18 +1258,32 @@ def manual_scaled_dot_product_attention(
attn_bias = torch.zeros(
L, S, dtype=query.dtype, device=query.device
) # device added (not part of the original code)
if attn_mask is not None and attn_mask.dim() == 3:
attn_bias = attn_bias.unsqueeze(0).repeat(attn_mask.size(0), 1, 1)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) # device added
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
if attn_mask is None:
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
elif attn_mask.dtype == torch.bool:
if attn_mask.dim() == 3:
combined_mask = temp_mask.unsqueeze(0) & attn_mask
else:
combined_mask = temp_mask & attn_mask
attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf"))
else:
if attn_mask.dim() == 3:
temp_mask = temp_mask.unsqueeze(0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias += attn_mask
Comment on lines +1304 to +1315
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fully_masked variable is computed to identify rows that have no valid attention positions after combining causal and inter-document masks. However, this is only used when attn_mask.dtype is torch.bool within the is_causal branch. If attn_mask is a float mask, fully_masked will remain None, which means the special handling for fully masked rows won't apply. This could lead to NaN values in attention weights after softmax on fully masked rows when using float masks. Consider computing fully_masked for float masks as well, or document this limitation.

Copilot uses AI. Check for mistakes.
attn_bias.to(query.dtype)

if attn_mask is not None:
elif attn_mask is not None:
Comment thread
BlueCrescent marked this conversation as resolved.
Comment thread
BlueCrescent marked this conversation as resolved.
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
if attn_bias.dim() == 3:
attn_bias = attn_bias.unsqueeze(1)
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
Expand Down
Loading