1+ #########################################################################
2+ # Reference: https://blog.csdn.net/u010900574/article/details/122780585 #
3+ #########################################################################
4+
5+ from __future__ import print_function
6+ import os
7+ import time
8+ import argparse
9+ import torch
10+ import torch .nn as nn
11+ import torch .nn .functional as F
12+ import torch .optim as optim
13+ from torchvision import datasets , transforms
14+ from torch .optim .lr_scheduler import StepLR
15+ import torch .distributed as dist
16+
17+ def init_distributed_mode (args ):
18+ """
19+ initilize DDP
20+ """
21+ os .environ ['OMP_NUM_THREADS' ] = "1"
22+ if "RANK" in os .environ and "WORLD_SIZE" in os .environ :
23+ args .rank = int (os .environ ["RANK" ])
24+ args .world_size = int (os .environ ["WORLD_SIZE" ])
25+ args .gpu = int (os .environ ["LOCAL_RANK" ])
26+ elif "SLURM_PROCID" in os .environ :
27+ args .rank = int (os .environ ["SLURM_PROCID" ])
28+ args .gpu = args .rank % torch .cuda .device_count ()
29+ elif hasattr (args , "rank" ):
30+ pass
31+ else :
32+ print ("Not using distributed mode" )
33+ args .distributed = False
34+ return
35+
36+ args .distributed = True
37+
38+ torch .cuda .set_device (args .gpu )
39+ args .dist_backend = "nccl"
40+ print (f"| distributed init (rank { args .rank } ): { args .dist_url } , local rank:{ args .gpu } , world size:{ args .world_size } " , flush = True )
41+ dist .init_process_group (
42+ backend = args .dist_backend , init_method = args .dist_url , world_size = args .world_size , rank = args .rank
43+ )
44+
45+ class Net (nn .Module ):
46+ def __init__ (self ):
47+ super (Net , self ).__init__ ()
48+ self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 )
49+ self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 )
50+ self .dropout1 = nn .Dropout (0.25 )
51+ self .dropout2 = nn .Dropout (0.5 )
52+ self .fc1 = nn .Linear (9216 , 128 )
53+ self .fc2 = nn .Linear (128 , 10 )
54+
55+ def forward (self , x ):
56+ x = self .conv1 (x )
57+ x = F .relu (x )
58+ x = self .conv2 (x )
59+ x = F .relu (x )
60+ x = F .max_pool2d (x , 2 )
61+ x = self .dropout1 (x )
62+ x = torch .flatten (x , 1 )
63+ x = self .fc1 (x )
64+ x = F .relu (x )
65+ x = self .dropout2 (x )
66+ x = self .fc2 (x )
67+ output = F .log_softmax (x , dim = 1 )
68+ return output
69+
70+
71+ def train (args , model , device , train_loader , optimizer , epoch ):
72+ model .train ()
73+ for batch_idx , (data , target ) in enumerate (train_loader ):
74+ data , target = data .to (device ), target .to (device )
75+ optimizer .zero_grad ()
76+ output = model (data )
77+ loss = F .nll_loss (output , target )
78+ loss .backward ()
79+ optimizer .step ()
80+ if args .distributed :
81+ if dist .get_rank () == 0 :
82+ if batch_idx % args .log_interval == 0 :
83+ print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
84+ epoch , dist .get_world_size () * batch_idx * len (data ), len (train_loader .dataset ),
85+ 100. * batch_idx / len (train_loader ), loss .item ()))
86+ else :
87+ if batch_idx % args .log_interval == 0 :
88+ print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
89+ epoch , batch_idx * len (data ), len (train_loader .dataset ),
90+ 100. * batch_idx / len (train_loader ), loss .item ()))
91+ if args .dry_run :
92+ break
93+
94+
95+ def test (model , device , test_loader ):
96+ model .eval ()
97+ test_loss = 0
98+ correct = 0
99+ with torch .no_grad ():
100+ for data , target in test_loader :
101+ data , target = data .to (device ), target .to (device )
102+ output = model (data )
103+ test_loss += F .nll_loss (output , target , reduction = 'sum' ).item () # sum up batch loss
104+ pred = output .argmax (dim = 1 , keepdim = True ) # get the index of the max log-probability
105+ correct += pred .eq (target .view_as (pred )).sum ().item ()
106+
107+ test_loss /= len (test_loader .dataset )
108+
109+ print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
110+ test_loss , correct , len (test_loader .dataset ),
111+ 100. * correct / len (test_loader .dataset )))
112+
113+
114+ def main ():
115+ # Training settings
116+ parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
117+ parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' ,
118+ help = 'input batch size for training (default: 64)' )
119+ parser .add_argument ('--test-batch-size' , type = int , default = 1000 , metavar = 'N' ,
120+ help = 'input batch size for testing (default: 1000)' )
121+ parser .add_argument ('--epochs' , type = int , default = 14 , metavar = 'N' ,
122+ help = 'number of epochs to train (default: 14)' )
123+ parser .add_argument ('--lr' , type = float , default = 1.0 , metavar = 'LR' ,
124+ help = 'learning rate (default: 1.0)' )
125+ parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
126+ help = 'Learning rate step gamma (default: 0.7)' )
127+ parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
128+ help = 'disables CUDA training' )
129+ parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
130+ help = 'quickly check a single pass' )
131+ parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
132+ help = 'random seed (default: 1)' )
133+ parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
134+ help = 'how many batches to wait before logging training status' )
135+ parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
136+ help = 'For Saving the current Model' )
137+
138+ parser .add_argument ('--local_rank' , type = int , help = 'local rank, will passed by ddp' )
139+ parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
140+ parser .add_argument ("--dist-url" , default = "env://" , type = str , help = "url used to set up distributed training" )
141+ args = parser .parse_args ()
142+ use_cuda = not args .no_cuda and torch .cuda .is_available ()
143+
144+ init_distributed_mode (args )
145+
146+ torch .manual_seed (args .seed )
147+
148+ device = torch .device ("cuda" if use_cuda else "cpu" )
149+
150+ train_kwargs = {'batch_size' : args .batch_size }
151+ test_kwargs = {'batch_size' : args .test_batch_size }
152+ if use_cuda :
153+ cuda_kwargs = {'num_workers' : 1 ,
154+ 'pin_memory' : True ,
155+ }
156+ train_kwargs .update (cuda_kwargs )
157+ test_kwargs .update (cuda_kwargs )
158+
159+ transform = transforms .Compose ([
160+ transforms .ToTensor (),
161+ transforms .Normalize ((0.1307 ,), (0.3081 ,))
162+ ])
163+ train_dataset = datasets .MNIST ('./data' , train = True , download = True , transform = transform )
164+ val_dataset = datasets .MNIST ('./data' , train = False , transform = transform )
165+ if args .distributed :
166+ train_sampler = torch .utils .data .distributed .DistributedSampler (train_dataset , shuffle = True )
167+ else :
168+ train_sampler = torch .utils .data .RandomSampler (train_dataset )
169+ test_sampler = torch .utils .data .SequentialSampler (val_dataset )
170+
171+ train_loader = torch .utils .data .DataLoader (train_dataset , sampler = train_sampler , ** train_kwargs )
172+ test_loader = torch .utils .data .DataLoader (val_dataset , sampler = test_sampler , ** test_kwargs )
173+
174+ model = Net ().to (device )
175+ model_without_ddp = model
176+ if args .distributed :
177+ model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
178+ model_without_ddp = model .module
179+
180+ optimizer = optim .Adadelta (model .parameters (), lr = args .lr )
181+
182+ scheduler = StepLR (optimizer , step_size = 1 , gamma = args .gamma )
183+ for epoch in range (1 , args .epochs + 1 ):
184+ if args .distributed :
185+ train_sampler .set_epoch (epoch )
186+ train (args , model , device , train_loader , optimizer , epoch )
187+ if args .distributed :
188+ # Only run validation on GPU 0 process, for simplity, so we do not run validation on multi gpu.
189+ if dist .get_rank () == 0 :
190+ test (model_without_ddp , device , test_loader )
191+ else :
192+ test (model , device , test_loader )
193+ scheduler .step ()
194+
195+ if args .save_model :
196+ if args .distributed :
197+ if dist .get_rank () == 0 :
198+ # only save model on GPU0 process.
199+ torch .save (model .state_dict (), f"mnist_cnn.pt" )
200+ else :
201+ torch .save (model .state_dict (), f"mnist_cnn_.pt" )
202+
203+
204+ if __name__ == '__main__' :
205+ start = time .time ()
206+ main ()
207+ print (f'Total time elapsed: { time .time () - start } seconds' )
0 commit comments