Skip to content

Commit b2ec756

Browse files
committed
chore: Merge remote-tracking branch 'origin/main' into inter_document_masking_for_attention
2 parents 4a31747 + 8f84b2d commit b2ec756

12 files changed

Lines changed: 498 additions & 61 deletions

File tree

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ It is recommended to install Modalities via uv or install PyTorch, psutil and Ni
4444
# Get uv (tested with uv version 0.9.13)
4545
curl -LsSf https://astral.sh/uv/install.sh | sh
4646

47-
uv sync
47+
uv sync --extra [cpu|cu126|cu128|cu130] # Get CUDA version via nvidia-smi
4848
source .venv/bin/activate
4949

5050
# For developers: use [tests,linting] and install pre-commit hooks
51-
uv sync --extra tests --extra linting
51+
uv sync --extra [cpu|cu126|cu128|cu130] --extra tests --extra linting
5252
pre-commit install --install-hooks
5353
```
5454

@@ -60,7 +60,8 @@ conda create -n modalities python=3.13
6060
conda activate modalities
6161

6262
# Install PyTorch, psutil, Ninja and Flash Attention
63-
pip install "torch<2.11.0"
63+
# For PyTorch, select the correct index URL for your CUDA/CPU setup from https://pytorch.org/get-started/locally/ e.g.:
64+
pip install "torch>=2.10,<2.11.0" torchvision --index-url https://download.pytorch.org/whl/cu130
6465
pip install psutil ninja # Ninja lowers compilation time of flash attention significantly
6566
pip install flash-attn==2.8.3 --no-build-isolation
6667
```

pyproject.toml

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ description = "Modalities, a PyTorch-native framework for distributed and reprod
66
readme = "README.md"
77
dependencies = [
88
"numpy",
9-
"torch<2.11.0",
109
"ninja",
1110
"packaging",
1211
"tqdm",
@@ -25,25 +24,86 @@ dependencies = [
2524
"matplotlib",
2625
"wandb",
2726
"einops>=0.7.0",
28-
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'",
2927
"debugpy", # For VSCode debugging support
3028
]
3129

3230
[project.urls]
3331
Homepage = "https://github.com/Modalities/modalities"
3432
Issues = "https://github.com/Modalities/modalities/issues"
3533

36-
[project.optional-dependencies]
37-
linting = ["pre-commit"]
38-
tests = ["pytest", "pytest-cov", "debugpy"]
39-
4034
[project.scripts]
4135
modalities = "modalities.__main__:main"
4236

4337
[build-system]
4438
requires = ["setuptools >= 61.0.0"]
4539
build-backend = "setuptools.build_meta"
4640

41+
[project.optional-dependencies]
42+
linting = ["pre-commit"]
43+
tests = ["pytest", "pytest-cov", "debugpy"]
44+
45+
cpu = ["torch>=2.10,<2.11.0", "torchvision"]
46+
cu126 = [
47+
"torch>=2.10,<2.11.0",
48+
"torchvision",
49+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
50+
]
51+
cu128 = [
52+
"torch>=2.10,<2.11.0",
53+
"torchvision",
54+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
55+
]
56+
cu130 = [
57+
"torch>=2.10,<2.11.0",
58+
"torchvision",
59+
"flash-attn==2.8.3; platform_system != 'Darwin' and platform_machine != 'aarch64'"
60+
]
61+
62+
[tool.uv]
63+
conflicts = [
64+
[
65+
{ extra = "cpu" },
66+
{ extra = "cu126" },
67+
{ extra = "cu128" },
68+
{ extra = "cu130" },
69+
],
70+
]
71+
72+
[tool.uv.sources]
73+
torch = [
74+
{ index = "pytorch-cpu", extra = "cpu" },
75+
{ index = "pytorch-cu126", extra = "cu126" },
76+
{ index = "pytorch-cu128", extra = "cu128" },
77+
{ index = "pytorch-cu130", extra = "cu130" },
78+
]
79+
torchvision = [
80+
{ index = "pytorch-cpu", extra = "cpu" },
81+
{ index = "pytorch-cu126", extra = "cu126" },
82+
{ index = "pytorch-cu128", extra = "cu128" },
83+
{ index = "pytorch-cu130", extra = "cu130" },
84+
]
85+
86+
[[tool.uv.index]]
87+
name = "pytorch-cpu"
88+
url = "https://download.pytorch.org/whl/cpu"
89+
explicit = true
90+
91+
[[tool.uv.index]]
92+
name = "pytorch-cu126"
93+
url = "https://download.pytorch.org/whl/cu126"
94+
explicit = true
95+
96+
[[tool.uv.index]]
97+
name = "pytorch-cu128"
98+
url = "https://download.pytorch.org/whl/cu128"
99+
explicit = true
100+
101+
[[tool.uv.index]]
102+
name = "pytorch-cu130"
103+
url = "https://download.pytorch.org/whl/cu130"
104+
explicit = true
105+
106+
47107
[tool.uv.extra-build-dependencies]
48108
flash-attn = [
49109
{ requirement = "torch", match-runtime = true },

src/modalities/config/instantiation_models.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23
from pathlib import Path
34
from typing import Annotated, Any, Optional
@@ -27,6 +28,8 @@
2728
from modalities.util import warn_rank_0
2829
from modalities.utils.profilers.profilers import SteppableNoProfiler
2930

31+
logger = logging.getLogger(__name__)
32+
3033

3134
class CudaEnvSettings(BaseModel):
3235
local_rank: Annotated[int, Field(strict=True, ge=0)]
@@ -46,6 +49,7 @@ class ConsistencyEnforcement(BaseModel):
4649
enforce_last_step_logged: bool = True
4750
enforce_last_step_evaluated: bool = True
4851
enforce_last_step_checkpointed: bool = True
52+
enforce_enough_tokens_in_dataset: bool = True
4953

5054

5155
class Intervals(BaseModel):
@@ -192,15 +196,14 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
192196

193197
@model_validator(mode="after")
194198
def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel":
195-
if (
196-
len(self.train_dataset) * self.settings.step_profile.sequence_length
197-
< self.settings.training_target.num_target_tokens
198-
):
199-
raise ValueError(
200-
"Not enough tokens in the dataset. "
201-
f"Actual: {len(self.train_dataset) * self.settings.step_profile.sequence_length}, "
202-
f"Expected: >={self.settings.training_target.num_target_tokens}"
203-
)
199+
dataset_tokens = len(self.train_dataset) * self.settings.step_profile.sequence_length
200+
expected_tokens = self.settings.training_target.num_target_tokens
201+
if dataset_tokens < expected_tokens:
202+
msg = f"Not enough tokens in dataset. Actual: {dataset_tokens}, Expected: >={expected_tokens}"
203+
if self.settings.consistency_enforcement.enforce_enough_tokens_in_dataset:
204+
raise ValueError(msg)
205+
else:
206+
logger.warning(msg)
204207
return self
205208

206209

src/modalities/dataloader/preprocessing/tokenization/tokenized_file_writer.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
22
import os
33
import pickle
4-
from itertools import repeat
54
from pathlib import Path
65
from typing import BinaryIO
76

@@ -82,30 +81,56 @@ def _write_index_segment(file_descriptor: BinaryIO, index_list: list[tuple[int,
8281
def _write_data_segment(
8382
file_descriptor: BinaryIO, token_data: list[np.ndarray], token_size_in_bytes: int, write_batch_size: int
8483
) -> list[tuple[int, int]]:
85-
def encoded_token_to_bytes(encoded_token: int, token_size_in_bytes: int) -> bytes:
86-
# Converts an token_ids to its byte representation.
87-
try:
88-
token_bytes = encoded_token.to_bytes(token_size_in_bytes, byteorder="little", signed=False)
89-
except OverflowError as e:
90-
raise ValueError(f"Token {encoded_token} cannot be represented by {token_size_in_bytes} bytes.") from e
91-
return token_bytes
92-
93-
samples = []
94-
index_list = []
84+
# Fast path: vectorized cast + tobytes (no per-token Python work).
85+
# Preserves little-endian unsigned representation and overflow checks.
86+
87+
if token_size_in_bytes == 1:
88+
dtype = np.dtype("u1")
89+
elif token_size_in_bytes == 2:
90+
dtype = np.dtype("<u2") # force little-endian
91+
elif token_size_in_bytes == 4:
92+
dtype = np.dtype("<u4") # force little-endian
93+
else:
94+
raise ValueError("Currently only support token byte sizes of 1, 2, and 4.")
95+
96+
max_allowed = 2 ** (8 * token_size_in_bytes) - 1
97+
98+
samples: list[bytes] = []
99+
index_list: list[tuple[int, int]] = []
95100
curr_offset = 0
101+
pending = 0
102+
96103
for sample_tokens in token_data:
97-
# convert token_ids to byte representation
98-
sample_token_byte_string = b"".join(
99-
map(encoded_token_to_bytes, sample_tokens.tolist(), repeat(token_size_in_bytes))
100-
)
104+
arr = np.asarray(sample_tokens)
105+
106+
# ---- Overflow / range check (preserves original semantics) ----
107+
if arr.size:
108+
min_val = int(arr.min())
109+
max_val = int(arr.max())
110+
if min_val < 0 or max_val > max_allowed:
111+
raise ValueError(
112+
f"Token values out of range for {token_size_in_bytes} bytes: "
113+
f"min={min_val}, max={max_val}, allowed=[0, {max_allowed}]"
114+
)
115+
# ----------------------------------------------------------------
116+
117+
# Cast to correct unsigned little-endian dtype
118+
arr = np.asarray(arr, dtype=dtype, order="C")
119+
sample_token_byte_string = arr.tobytes(order="C")
120+
101121
samples.append(sample_token_byte_string)
102122
index_list.append((curr_offset, len(sample_token_byte_string)))
103123
curr_offset += len(sample_token_byte_string)
104-
if len(samples) % write_batch_size == 0:
124+
125+
pending += 1
126+
if pending >= write_batch_size:
105127
file_descriptor.write(b"".join(samples))
106-
samples = []
128+
samples.clear()
129+
pending = 0
130+
107131
if len(samples) > 0:
108132
file_descriptor.write(b"".join(samples))
133+
109134
return index_list
110135

111136
@staticmethod

0 commit comments

Comments
 (0)