File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1- --- mnist_main .py.orig 2025-05-09 10:51:06.814200110 -0500
2- +++ mnist_main .py 2025-05-09 11:15:17.198167820 -0500
1+ --- main .py.d0b7e37 2026-02-18 14:50:27.722389236 -0600
2+ +++ main .py 2026-02-18 14:59:02.737245645 -0600
33@@ -1,3 +1,8 @@
44+ #
55+ # Copyright (C) 2025, Northwestern University and Argonne National Laboratory
5555 args = parser.parse_args()
5656
5757 use_accel = not args.no_accel and torch.accelerator.is_available()
58- @@ -103,12 +119,11 @@
58+ @@ -103,7 +119,7 @@
5959 else:
6060 device = torch.device("cpu")
6161
6464 test_kwargs = {'batch_size': args.test_batch_size}
6565 if use_accel:
6666 accel_kwargs = {'num_workers': 1,
67- - 'pin_memory': True,
68- - 'shuffle': True}
69- + 'pin_memory': True}
70- train_kwargs.update(accel_kwargs)
71- test_kwargs.update(accel_kwargs)
72-
73- @@ -116,25 +131,53 @@
67+ @@ -117,25 +133,54 @@
7468 transforms.ToTensor(),
7569 transforms.Normalize((0.1307,), (0.3081,))
7670 ])
115109- torch.save(model.state_dict(), "mnist_cnn.pt")
116110+ if rank == 0:
117111+ torch.save(model.state_dict(), "mnist_cnn.pt")
118-
112+ +
119113+ # close files
120114+ train_file.close()
121115+ test_file.close()
122116
117+
123118 if __name__ == '__main__':
119+ - main()
124120+ ## initialize parallel environment
125121+ comm, device = comm_file.init_parallel()
126122+
127123+ rank = comm.get_rank()
128124+ nprocs = comm.get_size()
129125+
130- main()
126+ + main()
131127+
132128+ comm.finalize()
133129+
You can’t perform that action at this time.
0 commit comments