Skip to content

Commit bbeaf8c

Browse files
deependujhaCopilot
andauthored
feat: add align_chunking option to preserve deterministic chunk boundaries across workers (#768)
* claymore * update * update * update * Update src/litdata/processing/data_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/litdata/processing/data_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update * add tests * update * update * multi-node support for align-chunking * update * update * address review --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 4195db0 commit bbeaf8c

5 files changed

Lines changed: 207 additions & 11 deletions

File tree

src/litdata/processing/data_processor.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sys
2424
import tempfile
2525
import traceback
26+
import warnings
2627
from abc import abstractmethod
2728
from contextlib import suppress
2829
from dataclasses import dataclass
@@ -299,29 +300,65 @@ def _upload_fn(
299300
remove_queue.put([local_filepath])
300301

301302

302-
def _map_items_to_workers_sequentially(num_workers: int, user_items: list[Any]) -> list[list[Any]]:
303+
def _map_items_to_workers_sequentially(
304+
num_workers: int, user_items: list[Any], chunk_size: Optional[int] = None
305+
) -> list[list[Any]]:
303306
"""Map the items to the workers sequentially.
304307
308+
Args:
309+
num_workers: The number of workers to assign items to.
310+
user_items: The list of items to be distributed among workers.
311+
chunk_size: Optional `chunk size` that enforces deterministic,
312+
single-worker-style chunk boundaries. When set, each worker is
313+
assigned only full chunks of this size, and the final worker
314+
receives any remaining items (which may form a partial chunk).
315+
316+
305317
>>> workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
306318
>>> assert workers_user_items == [[0, 1], [2, 3, 4]]
307319
"""
320+
assert isinstance(chunk_size, (int, type(None))), "chunk_size must be an integer or None"
321+
308322
num_nodes = _get_num_nodes()
323+
node_rank = _get_node_rank()
309324
world_size = num_nodes * num_workers
310-
num_items_per_worker = len(user_items) // world_size
311325

312-
num_items_per_worker: list[int] = [num_items_per_worker for _ in range(world_size)]
313-
reminder = len(user_items) % world_size
326+
if chunk_size is not None:
327+
assert chunk_size > 0, "chunk_size must be a positive integer"
314328

315-
for worker_idx in range(len(num_items_per_worker) - 1, -1, -1):
316-
if reminder == 0:
317-
break
318-
num_items_per_worker[worker_idx] += 1
319-
reminder -= 1
329+
# Compute how many full chunks each worker can take
330+
full_chunks = len(user_items) // chunk_size
331+
chunks_per_worker = full_chunks // world_size
332+
333+
if chunks_per_worker == 0 and node_rank == 0:
334+
warnings.warn(
335+
f"chunk_size ({chunk_size}) is too large relative to dataset size ({len(user_items)}) "
336+
f"and world_size ({world_size}). This will result in idle workers. "
337+
f"Consider reducing chunk_size or using fewer workers."
338+
)
339+
340+
# Assign full chunks to all workers except the last
341+
num_items_per_worker = [chunks_per_worker * chunk_size for _ in range(world_size - 1)]
342+
343+
# Last worker receives all remaining items (full chunks + optional tail)
344+
remaining = len(user_items) - sum(num_items_per_worker)
345+
num_items_per_worker.append(remaining)
346+
347+
else:
348+
items_per_worker_count = len(user_items) // world_size
349+
350+
num_items_per_worker: list[int] = [items_per_worker_count for _ in range(world_size)]
351+
reminder = len(user_items) % world_size
352+
353+
for worker_idx in range(len(num_items_per_worker) - 1, -1, -1):
354+
if reminder == 0:
355+
break
356+
num_items_per_worker[worker_idx] += 1
357+
reminder -= 1
320358

321359
num_items_cumsum_per_worker = np.cumsum([0] + num_items_per_worker)
322360

323361
out = []
324-
node_rank = _get_node_rank()
325362
worker_idx_start = node_rank * num_workers
326363
worker_idx_end = (node_rank + 1) * num_workers
327364

@@ -1080,6 +1117,7 @@ def __init__(
10801117
input_dir: Union[str, Dir],
10811118
output_dir: Optional[Union[str, Dir]] = None,
10821119
num_workers: Optional[int] = None,
1120+
align_chunking: bool = False,
10831121
num_downloaders: Optional[int] = None,
10841122
num_uploaders: Optional[int] = None,
10851123
delete_cached_files: bool = True,
@@ -1102,6 +1140,8 @@ def __init__(
11021140
input_dir: The path to where the input data are stored.
11031141
output_dir: The path to where the output data are stored.
11041142
num_workers: The number of worker threads to use.
1143+
align_chunking: Ensures chunk boundaries match the single-worker layout by packing full chunks first
1144+
and placing all remaining items in the final worker.
11051145
num_downloaders: The number of file downloaders to use.
11061146
num_uploaders: The number of file uploaders to use.
11071147
delete_cached_files: Whether to delete the cached files.
@@ -1140,6 +1180,7 @@ def __init__(
11401180
self.output_dir = _resolve_dir(output_dir)
11411181

11421182
self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4)
1183+
self.align_chunking = align_chunking
11431184
self.num_downloaders = num_downloaders or 2
11441185
self.num_uploaders = num_uploaders or 1
11451186
self.delete_cached_files = delete_cached_files
@@ -1231,8 +1272,14 @@ def run(self, data_recipe: DataRecipe) -> None:
12311272
num_workers=self.num_workers, user_items=user_items, weights=item_sizes
12321273
)
12331274
else:
1275+
if self.align_chunking and data_recipe.chunk_size is None:
1276+
raise ValueError(
1277+
"`align_chunking` is set to True, but the `chunk_size` is not defined in the data recipe."
1278+
)
12341279
workers_user_items = _map_items_to_workers_sequentially(
1235-
num_workers=self.num_workers, user_items=user_items
1280+
num_workers=self.num_workers,
1281+
user_items=user_items,
1282+
chunk_size=data_recipe.chunk_size if self.align_chunking else None,
12361283
)
12371284
else:
12381285
assert isinstance(user_items, multiprocessing.queues.Queue)

src/litdata/processing/functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def optimize(
393393
weights: Optional[list[int]] = None,
394394
chunk_size: Optional[int] = None,
395395
chunk_bytes: Optional[Union[int, str]] = None,
396+
align_chunking: bool = False,
396397
compression: Optional[str] = None,
397398
encryption: Optional[Encryption] = None,
398399
num_workers: Optional[int] = None,
@@ -428,6 +429,10 @@ def optimize(
428429
weights: Provide an associated weight to each input. This is used to balance work among workers.
429430
chunk_size: The maximum number of elements to hold within a chunk.
430431
chunk_bytes: The maximum number of bytes to hold within a chunk.
432+
align_chunking: Ensures chunk boundaries match the single-worker layout by packing full chunks first
433+
and placing all remaining items in the final worker. Each worker will receive chunks of this size,
434+
except possibly the last worker which may receive a smaller chunk. Note: this will result in uneven
435+
workload distribution among workers, and last worker may receive more data than others.
431436
compression: The compression algorithm to use over the chunks.
432437
encryption: The encryption algorithm to use over the chunks.
433438
num_workers: The number of workers to use during processing
@@ -489,6 +494,9 @@ def optimize(
489494
if chunk_size is None and chunk_bytes is None:
490495
raise ValueError("Either `chunk_size` or `chunk_bytes` needs to be defined.")
491496

497+
if align_chunking and chunk_size is None:
498+
raise ValueError("When `align_chunking` is set to True, `chunk_size` needs to be defined.")
499+
492500
if not _IS_IN_STUDIO and (machine is not None or num_nodes is not None):
493501
raise ValueError(
494502
"Only https://lightning.ai/ supports multiple nodes or selecting a machine.Create an account to try it out."
@@ -555,6 +563,7 @@ def optimize(
555563
input_dir=resolved_dir,
556564
output_dir=_output_dir,
557565
num_workers=num_workers or _get_default_num_workers(),
566+
align_chunking=align_chunking,
558567
fast_dev_run=fast_dev_run,
559568
num_downloaders=num_downloaders,
560569
num_uploaders=num_uploaders,

tests/processing/test_data_processor.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,36 @@ def test_map_items_to_workers_sequentially(monkeypatch):
373373
assert workers_user_items == [[24, 25], [26, 27], [28, 29], [30, 31]]
374374

375375

376+
def test_map_items_to_workers_sequentially_align_chunking(monkeypatch):
377+
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)), chunk_size=2)
378+
assert workers_user_items == [list(range(5))]
379+
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)), chunk_size=2)
380+
assert workers_user_items == [[0, 1], [2, 3, 4]]
381+
workers_user_items = _map_items_to_workers_sequentially(2, list(range(6)), chunk_size=2)
382+
assert workers_user_items == [[0, 1], [2, 3, 4, 5]]
383+
384+
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
385+
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
386+
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)), chunk_size=2)
387+
assert workers_user_items == [[0, 1]]
388+
389+
# 2 nodes, 2 workers per node, chunk_size=2.
390+
# Total items = 5 => only the final worker should receive them,
391+
# because no worker except the last can form even one full chunk. (5/ (2*2*2) = 0.625 ~ 0)
392+
with pytest.warns(UserWarning, match="Consider reducing chunk_size or using fewer workers"):
393+
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)), chunk_size=2)
394+
assert workers_user_items == [[], []]
395+
396+
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
397+
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
398+
workers_user_items = _map_items_to_workers_sequentially(1, list(range(5)), chunk_size=2)
399+
assert workers_user_items == [[2, 3, 4]]
400+
401+
# On node 1 (rank 1), last worker should receive all items.
402+
workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)), chunk_size=2)
403+
assert workers_user_items == [[], [0, 1, 2, 3, 4]]
404+
405+
376406
def test_fake_queue():
377407
q = FakeQueue()
378408
index = [1, 2]
@@ -400,6 +430,16 @@ def prepare_item(self, item):
400430
return item
401431

402432

433+
class DummyDataChunkRecipe(DataChunkRecipe):
434+
is_generator = False
435+
436+
def prepare_structure(self, input_dir: str) -> list[Any]:
437+
return []
438+
439+
def prepare_item(self, item):
440+
return item
441+
442+
403443
@pytest.mark.parametrize("delete_cached_files", [True])
404444
@pytest.mark.parametrize("fast_dev_run", [10])
405445
@pytest.mark.skipif(condition=not _PIL_AVAILABLE or sys.platform == "win32", reason="Requires: ['pil']")
@@ -477,6 +517,17 @@ def test_data_processsor(fast_dev_run, delete_cached_files, tmpdir, monkeypatch)
477517
assert len(files) == expected
478518

479519

520+
def test_data_processor_align_chunking_requires_chunk_size(tmpdir):
521+
output_dir = str(tmpdir / "output_dir")
522+
data_processor = DataProcessor(input_dir=Dir(), output_dir=output_dir, num_workers=1, align_chunking=True)
523+
with pytest.raises(ValueError, match="`chunk_size` is not defined in the data recipe"):
524+
data_processor.run(
525+
DummyDataChunkRecipe(
526+
chunk_bytes="10MB" # chunk_size is not defined here to trigger the error
527+
)
528+
)
529+
530+
480531
class TestDataProcessor(DataProcessor):
481532
def _broadcast_object(self, obj: Any) -> Any:
482533
return obj

tests/processing/test_functions.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import glob
22
import io
3+
import math
34
import os
45
import random
56
import shutil
@@ -123,6 +124,61 @@ def random_image(index):
123124
return {"image": fake_img, "class": index}
124125

125126

127+
@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
128+
def test_optimize_align_chunking_requires_chunk_size(tmp_path):
129+
output_dir = tmp_path / "output_requires_chunk_size"
130+
131+
with pytest.raises(ValueError, match="`chunk_size` needs to be defined"):
132+
optimize(
133+
fn=compress,
134+
inputs=list(range(7 * 64)),
135+
chunk_bytes="1MB",
136+
output_dir=str(output_dir),
137+
num_workers=1,
138+
align_chunking=True,
139+
)
140+
141+
142+
@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
143+
@pytest.mark.parametrize("num_workers", [1, 2])
144+
@pytest.mark.parametrize("chunk_size", [16, 32, 64])
145+
def test_optimize_align_chunking_creates_expected_chunks(tmp_path, chunk_size, num_workers):
146+
output_dir = tmp_path / f"output_workers_{num_workers}"
147+
148+
inputs = list(range(7 * 64))
149+
150+
optimize(
151+
fn=compress,
152+
inputs=inputs,
153+
chunk_size=chunk_size,
154+
output_dir=str(output_dir),
155+
num_workers=num_workers,
156+
align_chunking=True,
157+
)
158+
159+
assert output_dir.exists()
160+
161+
actual_files = set(os.listdir(output_dir))
162+
163+
total_items = len(inputs)
164+
items_per_worker = total_items / num_workers
165+
chunks_per_worker = items_per_worker / chunk_size
166+
167+
# each worker should create `math.floor(chunks_per_worker)` chunks,
168+
# except the last worker which will create the chunk with remaining items `math.ceil(chunks_per_worker)`
169+
expected_chunks_by_worker = {
170+
worker_id: (math.floor(chunks_per_worker) if worker_id < num_workers - 1 else math.ceil(chunks_per_worker))
171+
for worker_id in range(num_workers)
172+
}
173+
174+
expected_chunk_files = {
175+
f"chunk-{worker_id}-{i}.bin" for worker_id, indices in expected_chunks_by_worker.items() for i in range(indices)
176+
}
177+
expected_files = expected_chunk_files | {"index.json"}
178+
179+
assert actual_files == expected_files
180+
181+
126182
@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
127183
def test_optimize_append_overwrite(tmpdir):
128184
output_dir = str(tmpdir / "output_dir")

tests/streaming/test_dataloader.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import tensor
77

88
from litdata.constants import _VIZ_TRACKER_AVAILABLE
9+
from litdata.processing.functions import optimize
910
from litdata.streaming import (
1011
Cache,
1112
CombinedStreamingDataset,
@@ -496,3 +497,35 @@ def test_dataloader_dataset_transform_inheritance(tmpdir, shuffle):
496497
# Verify that the transform is applied correctly
497498
for i, item in enumerate(complete_data):
498499
assert item == i * 2, f"Expected {i * 2}, got {item}"
500+
501+
502+
def getter(index: int):
503+
return index
504+
505+
506+
@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
507+
@pytest.mark.parametrize("num_workers", [1, 2])
508+
def test_dataloader_with_align_chunking(tmp_path, num_workers):
509+
output_dir = tmp_path / f"output_workers_{num_workers}"
510+
511+
optimize(
512+
fn=getter,
513+
inputs=list(range(7 * 64)),
514+
chunk_size=64,
515+
output_dir=str(output_dir),
516+
num_workers=num_workers,
517+
align_chunking=True,
518+
)
519+
520+
# Ensure batches contain elements from the same chunk when using align_chunking
521+
dataset = StreamingDataset(str(output_dir), shuffle=True)
522+
523+
# make sure batch_size of dataloader is equal to chunk_size used during optimize
524+
dataloader = StreamingDataLoader(dataset, batch_size=64, num_workers=num_workers, shuffle=True)
525+
526+
for i, batch in enumerate(dataloader):
527+
min_element_in_batch = torch.min(batch).item()
528+
max_element_in_batch = torch.max(batch).item()
529+
assert max_element_in_batch - min_element_in_batch < 64, (
530+
f"Batch {i} contains elements from multiple chunks: min {min_element_in_batch}, max {max_element_in_batch}"
531+
)

0 commit comments

Comments
 (0)