We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 600ef67 commit d684489Copy full SHA for d684489
1 file changed
seqchromloader/loader.py
@@ -49,7 +49,8 @@ class _SeqChromDatasetByWds(IterableDataset):
49
def __init__(self, wds, transforms:dict=None):
50
self.transforms = transforms
51
52
- self.wds = wds
+ self.rank = rank
53
+ self.world_size = world_size
54
55
def initialize(self):
56
# this function will be called by worker_init_function in DataLoader
@@ -59,6 +60,7 @@ def __iter__(self):
59
60
worker_info = torch.utils.data.get_worker_info()
61
pipeline = [
62
wds.SimpleShardList(self.wds),
63
+ split_by_node(self.rank, self.world_size),
64
wds.split_by_worker,
65
wds.tarfile_to_samples(),
66
wds.decode(),
0 commit comments