Skip to content

Commit d684489

Browse files
committed
add: split_by_node filter for wds loader
1 parent 600ef67 commit d684489

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

seqchromloader/loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class _SeqChromDatasetByWds(IterableDataset):
4949
def __init__(self, wds, transforms:dict=None):
5050
self.transforms = transforms
5151

52-
self.wds = wds
52+
self.rank = rank
53+
self.world_size = world_size
5354

5455
def initialize(self):
5556
# this function will be called by worker_init_function in DataLoader
@@ -59,6 +60,7 @@ def __iter__(self):
5960
worker_info = torch.utils.data.get_worker_info()
6061
pipeline = [
6162
wds.SimpleShardList(self.wds),
63+
split_by_node(self.rank, self.world_size),
6264
wds.split_by_worker,
6365
wds.tarfile_to_samples(),
6466
wds.decode(),

0 commit comments

Comments
 (0)