2323import sys
2424import tempfile
2525import traceback
26+ import warnings
2627from abc import abstractmethod
2728from contextlib import suppress
2829from dataclasses import dataclass
@@ -299,29 +300,65 @@ def _upload_fn(
299300 remove_queue .put ([local_filepath ])
300301
301302
302- def _map_items_to_workers_sequentially (num_workers : int , user_items : list [Any ]) -> list [list [Any ]]:
303+ def _map_items_to_workers_sequentially (
304+ num_workers : int , user_items : list [Any ], chunk_size : Optional [int ] = None
305+ ) -> list [list [Any ]]:
303306 """Map the items to the workers sequentially.
304307
308+ Args:
309+ num_workers: The number of workers to assign items to.
310+ user_items: The list of items to be distributed among workers.
311+ chunk_size: Optional `chunk size` that enforces deterministic,
312+ single-worker-style chunk boundaries. When set, each worker is
313+ assigned only full chunks of this size, and the final worker
314+ receives any remaining items (which may form a partial chunk).
315+
316+
305317 >>> workers_user_items = _map_items_to_workers_sequentially(2, list(range(5)))
306318 >>> assert workers_user_items == [[0, 1], [2, 3, 4]]
307319 """
320+ assert isinstance (chunk_size , (int , type (None ))), "chunk_size must be an integer or None"
321+
308322 num_nodes = _get_num_nodes ()
323+ node_rank = _get_node_rank ()
309324 world_size = num_nodes * num_workers
310- num_items_per_worker = len (user_items ) // world_size
311325
312- num_items_per_worker : list [ int ] = [ num_items_per_worker for _ in range ( world_size )]
313- reminder = len ( user_items ) % world_size
326+ if chunk_size is not None :
327+ assert chunk_size > 0 , "chunk_size must be a positive integer"
314328
315- for worker_idx in range (len (num_items_per_worker ) - 1 , - 1 , - 1 ):
316- if reminder == 0 :
317- break
318- num_items_per_worker [worker_idx ] += 1
319- reminder -= 1
329+ # Compute how many full chunks each worker can take
330+ full_chunks = len (user_items ) // chunk_size
331+ chunks_per_worker = full_chunks // world_size
332+
333+ if chunks_per_worker == 0 and node_rank == 0 :
334+ warnings .warn (
335+ f"chunk_size ({ chunk_size } ) is too large relative to dataset size ({ len (user_items )} ) "
336+ f"and world_size ({ world_size } ). This will result in idle workers. "
337+ f"Consider reducing chunk_size or using fewer workers."
338+ )
339+
340+ # Assign full chunks to all workers except the last
341+ num_items_per_worker = [chunks_per_worker * chunk_size for _ in range (world_size - 1 )]
342+
343+ # Last worker receives all remaining items (full chunks + optional tail)
344+ remaining = len (user_items ) - sum (num_items_per_worker )
345+ num_items_per_worker .append (remaining )
346+
347+ else :
348+ items_per_worker_count = len (user_items ) // world_size
349+
350+ num_items_per_worker : list [int ] = [items_per_worker_count for _ in range (world_size )]
351+ reminder = len (user_items ) % world_size
352+
353+ for worker_idx in range (len (num_items_per_worker ) - 1 , - 1 , - 1 ):
354+ if reminder == 0 :
355+ break
356+ num_items_per_worker [worker_idx ] += 1
357+ reminder -= 1
320358
321359 num_items_cumsum_per_worker = np .cumsum ([0 ] + num_items_per_worker )
322360
323361 out = []
324- node_rank = _get_node_rank ()
325362 worker_idx_start = node_rank * num_workers
326363 worker_idx_end = (node_rank + 1 ) * num_workers
327364
@@ -1080,6 +1117,7 @@ def __init__(
10801117 input_dir : Union [str , Dir ],
10811118 output_dir : Optional [Union [str , Dir ]] = None ,
10821119 num_workers : Optional [int ] = None ,
1120+ align_chunking : bool = False ,
10831121 num_downloaders : Optional [int ] = None ,
10841122 num_uploaders : Optional [int ] = None ,
10851123 delete_cached_files : bool = True ,
@@ -1102,6 +1140,8 @@ def __init__(
11021140 input_dir: The path to where the input data are stored.
11031141 output_dir: The path to where the output data are stored.
11041142 num_workers: The number of worker threads to use.
1143+ align_chunking: Ensures chunk boundaries match the single-worker layout by packing full chunks first
1144+ and placing all remaining items in the final worker.
11051145 num_downloaders: The number of file downloaders to use.
11061146 num_uploaders: The number of file uploaders to use.
11071147 delete_cached_files: Whether to delete the cached files.
@@ -1140,6 +1180,7 @@ def __init__(
11401180 self .output_dir = _resolve_dir (output_dir )
11411181
11421182 self .num_workers = num_workers or (1 if fast_dev_run else (os .cpu_count () or 1 ) * 4 )
1183+ self .align_chunking = align_chunking
11431184 self .num_downloaders = num_downloaders or 2
11441185 self .num_uploaders = num_uploaders or 1
11451186 self .delete_cached_files = delete_cached_files
@@ -1231,8 +1272,14 @@ def run(self, data_recipe: DataRecipe) -> None:
12311272 num_workers = self .num_workers , user_items = user_items , weights = item_sizes
12321273 )
12331274 else :
1275+ if self .align_chunking and data_recipe .chunk_size is None :
1276+ raise ValueError (
1277+ "`align_chunking` is set to True, but the `chunk_size` is not defined in the data recipe."
1278+ )
12341279 workers_user_items = _map_items_to_workers_sequentially (
1235- num_workers = self .num_workers , user_items = user_items
1280+ num_workers = self .num_workers ,
1281+ user_items = user_items ,
1282+ chunk_size = data_recipe .chunk_size if self .align_chunking else None ,
12361283 )
12371284 else :
12381285 assert isinstance (user_items , multiprocessing .queues .Queue )
0 commit comments