77import time
88import argparse
99import torch
10- import torch .nn as nn
1110import torch .nn .functional as F
1211import torch .optim as optim
1312from torchvision import datasets , transforms
1413from torch .optim .lr_scheduler import StepLR
1514import torch .distributed as dist
15+ from net import Net
16+
1617
1718def init_distributed_mode (args ):
1819 """
19- initilize DDP
20+ Initialize DDP
2021 """
2122 os .environ ['OMP_NUM_THREADS' ] = "1"
2223 if "RANK" in os .environ and "WORLD_SIZE" in os .environ :
@@ -37,36 +38,13 @@ def init_distributed_mode(args):
3738
3839 torch .cuda .set_device (args .gpu )
3940 args .dist_backend = "nccl"
40- print (f"| distributed init (rank { args .rank } ): { args .dist_url } , local rank:{ args .gpu } , world size:{ args .world_size } " , flush = True )
41+ print (
42+ f"| distributed init (rank { args .rank } ): { args .dist_url } , local rank:{ args .gpu } , world size:{ args .world_size } " ,
43+ flush = True )
4144 dist .init_process_group (
4245 backend = args .dist_backend , init_method = args .dist_url , world_size = args .world_size , rank = args .rank
4346 )
4447
45- class Net (nn .Module ):
46- def __init__ (self ):
47- super (Net , self ).__init__ ()
48- self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 )
49- self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 )
50- self .dropout1 = nn .Dropout (0.25 )
51- self .dropout2 = nn .Dropout (0.5 )
52- self .fc1 = nn .Linear (9216 , 128 )
53- self .fc2 = nn .Linear (128 , 10 )
54-
55- def forward (self , x ):
56- x = self .conv1 (x )
57- x = F .relu (x )
58- x = self .conv2 (x )
59- x = F .relu (x )
60- x = F .max_pool2d (x , 2 )
61- x = self .dropout1 (x )
62- x = torch .flatten (x , 1 )
63- x = self .fc1 (x )
64- x = F .relu (x )
65- x = self .dropout2 (x )
66- x = self .fc2 (x )
67- output = F .log_softmax (x , dim = 1 )
68- return output
69-
7048
7149def train (args , model , device , train_loader , optimizer , epoch ):
7250 model .train ()
@@ -81,13 +59,13 @@ def train(args, model, device, train_loader, optimizer, epoch):
8159 if dist .get_rank () == 0 :
8260 if batch_idx % args .log_interval == 0 :
8361 print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
84- epoch , dist .get_world_size () * batch_idx * len (data ), len (train_loader .dataset ),
85- 100. * batch_idx / len (train_loader ), loss .item ()))
62+ epoch , dist .get_world_size () * batch_idx * len (data ), len (train_loader .dataset ),
63+ 100. * batch_idx / len (train_loader ), loss .item ()))
8664 else :
8765 if batch_idx % args .log_interval == 0 :
8866 print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
89- epoch , batch_idx * len (data ), len (train_loader .dataset ),
90- 100. * batch_idx / len (train_loader ), loss .item ()))
67+ epoch , batch_idx * len (data ), len (train_loader .dataset ),
68+ 100. * batch_idx / len (train_loader ), loss .item ()))
9169 if args .dry_run :
9270 break
9371
@@ -156,20 +134,20 @@ def main():
156134 train_kwargs .update (cuda_kwargs )
157135 test_kwargs .update (cuda_kwargs )
158136
159- transform = transforms .Compose ([
137+ transform = transforms .Compose ([
160138 transforms .ToTensor (),
161139 transforms .Normalize ((0.1307 ,), (0.3081 ,))
162- ])
140+ ])
163141 train_dataset = datasets .MNIST ('./data' , train = True , download = True , transform = transform )
164- val_dataset = datasets .MNIST ('./data' , train = False , transform = transform )
142+ val_dataset = datasets .MNIST ('./data' , train = False , transform = transform )
165143 if args .distributed :
166144 train_sampler = torch .utils .data .distributed .DistributedSampler (train_dataset , shuffle = True )
167145 else :
168146 train_sampler = torch .utils .data .RandomSampler (train_dataset )
169147 test_sampler = torch .utils .data .SequentialSampler (val_dataset )
170148
171- train_loader = torch .utils .data .DataLoader (train_dataset , sampler = train_sampler , ** train_kwargs )
172- test_loader = torch .utils .data .DataLoader (val_dataset , sampler = test_sampler , ** test_kwargs )
149+ train_loader = torch .utils .data .DataLoader (train_dataset , sampler = train_sampler , ** train_kwargs )
150+ test_loader = torch .utils .data .DataLoader (val_dataset , sampler = test_sampler , ** test_kwargs )
173151
174152 model = Net ().to (device )
175153 model_without_ddp = model
@@ -180,12 +158,17 @@ def main():
180158 optimizer = optim .Adadelta (model .parameters (), lr = args .lr )
181159
182160 scheduler = StepLR (optimizer , step_size = 1 , gamma = args .gamma )
161+
162+ total_time = 0.
163+
183164 for epoch in range (1 , args .epochs + 1 ):
184165 if args .distributed :
185166 train_sampler .set_epoch (epoch )
167+ start = time .time ()
186168 train (args , model , device , train_loader , optimizer , epoch )
169+ total_time += time .time () - start
187170 if args .distributed :
188- # Only run validation on GPU 0 process, for simplity , so we do not run validation on multi gpu.
171+ # Only run validation on GPU 0 process, for simplicity , so we do not run validation on multi gpu.
189172 if dist .get_rank () == 0 :
190173 test (model_without_ddp , device , test_loader )
191174 else :
@@ -200,8 +183,9 @@ def main():
200183 else :
201184 torch .save (model .state_dict (), f"mnist_cnn_.pt" )
202185
186+ return dist .get_rank (), total_time
187+
203188
204189if __name__ == '__main__' :
205- start = time .time ()
206- main ()
207- print (f'Total time elapsed: { time .time () - start } seconds' )
190+ rk , tt = main ()
191+ print (f'[{ rk } ] Total time elapsed: { tt } seconds' )
0 commit comments