@@ -24,7 +24,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
2424 accelerator .backward (loss )
2525 optimizer .step ()
2626 if batch_idx % args .log_interval == 0 :
27- if local_rank == 0 :
27+ if accelerator . is_main_process :
2828 print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
2929 epoch , AcceleratorState ().num_processes * batch_idx * len (data ), len (train_loader .dataset ),
3030 100. * batch_idx / len (train_loader ), loss .item ()))
@@ -46,7 +46,7 @@ def test(model, device, test_loader):
4646
4747 test_loss /= len (test_loader .dataset )
4848
49- if local_rank == 0 :
49+ if accelerator . is_main_process :
5050 print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
5151 test_loss , correct , len (test_loader .dataset ),
5252 100. * correct / len (test_loader .dataset )))
@@ -100,8 +100,18 @@ def main():
100100 transforms .ToTensor (),
101101 transforms .Normalize ((0.1307 ,), (0.3081 ,))
102102 ])
103- dataset1 = datasets .MNIST ('../data' , train = True , download = True , transform = transform )
104- dataset2 = datasets .MNIST ('../data' , train = False , transform = transform )
103+
104+ if not accelerator .is_main_process :
105+ # might be downloading mnist data, let rank 0 download first
106+ accelerator .wait_for_everyone ()
107+
108+ dataset1 = datasets .MNIST ('./data' , train = True , download = True , transform = transform )
109+
110+ if accelerator .is_main_process :
111+ # mnist data is downloaded, indicate other ranks can proceed
112+ accelerator .wait_for_everyone ()
113+
114+ dataset2 = datasets .MNIST ('./data' , train = False , transform = transform )
105115 train_loader = torch .utils .data .DataLoader (dataset1 , ** train_kwargs )
106116 test_loader = torch .utils .data .DataLoader (dataset2 , ** test_kwargs )
107117
0 commit comments