Skip to content

Commit 00ec10c

Browse files
committed
🐛 Add barrier
1 parent e3d1da4 commit 00ec10c

4 files changed

Lines changed: 54 additions & 7 deletions

File tree

mnist_ddp.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,19 @@ def main():
138138
transforms.ToTensor(),
139139
transforms.Normalize((0.1307,), (0.3081,))
140140
])
141+
142+
if args.distributed:
143+
if torch.distributed.get_rank() != 0:
144+
# might be downloading mnist data, let rank 0 download first
145+
torch.distributed.barrier()
146+
141147
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
148+
149+
if args.distributed:
150+
if torch.distributed.get_rank() == 0:
151+
# mnist data is downloaded, indicate other ranks can proceed
152+
torch.distributed.barrier()
153+
142154
val_dataset = datasets.MNIST('./data', train=False, transform=transform)
143155
if args.distributed:
144156
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
@@ -171,6 +183,9 @@ def main():
171183
# Only run validation on GPU 0 process, for simplicity, so we do not run validation on multi gpu.
172184
if dist.get_rank() == 0:
173185
test(model_without_ddp, device, test_loader)
186+
torch.distributed.barrier()
187+
else:
188+
torch.distributed.barrier()
174189
else:
175190
test(model, device, test_loader)
176191
scheduler.step()
@@ -183,7 +198,10 @@ def main():
183198
else:
184199
torch.save(model.state_dict(), f"mnist_cnn_.pt")
185200

186-
return dist.get_rank(), total_time
201+
if args.distributed:
202+
return dist.get_rank(), total_time
203+
else:
204+
return 0, total_time
187205

188206

189207
if __name__ == '__main__':

mnist_ds.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,17 @@ def main():
8585
transforms.ToTensor(),
8686
transforms.Normalize((0.1307,), (0.3081,))
8787
])
88-
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
88+
89+
if torch.distributed.get_rank() != 0:
90+
# might be downloading mnist data, let rank 0 download first
91+
torch.distributed.barrier()
92+
93+
dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
94+
95+
if torch.distributed.get_rank() == 0:
96+
# mnist data is downloaded, indicate other ranks can proceed
97+
torch.distributed.barrier()
98+
8999
dataset2 = datasets.MNIST('../data', train=False, transform=transform)
90100
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
91101

mnist_hf.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
2424
accelerator.backward(loss)
2525
optimizer.step()
2626
if batch_idx % args.log_interval == 0:
27-
if local_rank == 0:
27+
if accelerator.is_main_process:
2828
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
2929
epoch, AcceleratorState().num_processes * batch_idx * len(data), len(train_loader.dataset),
3030
100. * batch_idx / len(train_loader), loss.item()))
@@ -46,7 +46,7 @@ def test(model, device, test_loader):
4646

4747
test_loss /= len(test_loader.dataset)
4848

49-
if local_rank == 0:
49+
if accelerator.is_main_process:
5050
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
5151
test_loss, correct, len(test_loader.dataset),
5252
100. * correct / len(test_loader.dataset)))
@@ -100,8 +100,18 @@ def main():
100100
transforms.ToTensor(),
101101
transforms.Normalize((0.1307,), (0.3081,))
102102
])
103-
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
104-
dataset2 = datasets.MNIST('../data', train=False, transform=transform)
103+
104+
if not accelerator.is_main_process:
105+
# might be downloading mnist data, let rank 0 download first
106+
accelerator.wait_for_everyone()
107+
108+
dataset1 = datasets.MNIST('./data', train=True, download=True, transform=transform)
109+
110+
if accelerator.is_main_process:
111+
# mnist data is downloaded, indicate other ranks can proceed
112+
accelerator.wait_for_everyone()
113+
114+
dataset2 = datasets.MNIST('./data', train=False, transform=transform)
105115
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
106116
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
107117

mnist_hvd.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,18 @@ def main():
122122
transforms.ToTensor(),
123123
transforms.Normalize((0.1307,), (0.3081,))
124124
])
125+
126+
if hvd.rank() != 0:
127+
# might be downloading mnist data, let rank 0 download first
128+
hvd.barrier()
129+
125130
# train_dataset = datasets.MNIST('data-%d' % hvd.rank(), train=True, download=True, transform=transform)
126131
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
127-
132+
133+
if hvd.rank() == 0:
134+
# mnist data is downloaded, indicate other ranks can proceed
135+
hvd.barrier()
136+
128137
# Horovod: use DistributedSampler to partition the training data.
129138
train_sampler = dist.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
130139
train_loader = torch.utils.data.DataLoader(

0 commit comments

Comments
 (0)