Skip to content

Commit 9236fa2

Browse files
committed
🐛 Fix barrier in DeepSpeed
1 parent fed4f02 commit 9236fa2

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

mnist_ds.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def train(args, model, train_loader, epoch):
1919
loss = F.nll_loss(output, target)
2020
model.backward(loss)
2121
model.step()
22-
if dist.get_rank() == 0:
22+
if torch.distributed.get_rank() == 0:
2323
if batch_idx % args.log_interval == 0:
2424
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
2525
epoch, dist.get_world_size() * batch_idx * len(data), len(train_loader.dataset),
@@ -41,7 +41,7 @@ def test(model, device, test_loader):
4141

4242
test_loss /= len(test_loader.dataset)
4343

44-
if dist.get_rank() == 0:
44+
if torch.distributed.get_rank() == 0:
4545
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
4646
test_loss, correct, len(test_loader.dataset),
4747
100. * correct / len(test_loader.dataset)))
@@ -120,4 +120,4 @@ def main():
120120

121121

122122
if __name__ == '__main__':
123-
print(f'[{dist.get_rank()}] Total time elapsed: {main()} seconds')
123+
print(f'[{torch.distributed.get_rank()}] Total time elapsed: {main()} seconds')

0 commit comments

Comments
 (0)