Skip to content

Commit b806910

Browse files
committed
Update patch for MNIST main.py, as main.py has been updated
git hash of the latest main.py is d0b7e37
1 parent 59b8b18 commit b806910

1 file changed

Lines changed: 8 additions & 12 deletions

File tree

examples/MNIST/mnist.patch

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
--- mnist_main.py.orig 2025-05-09 10:51:06.814200110 -0500
2-
+++ mnist_main.py 2025-05-09 11:15:17.198167820 -0500
1+
--- main.py.d0b7e37 2026-02-18 14:50:27.722389236 -0600
2+
+++ main.py 2026-02-18 14:59:02.737245645 -0600
33
@@ -1,3 +1,8 @@
44
+#
55
+# Copyright (C) 2025, Northwestern University and Argonne National Laboratory
@@ -55,7 +55,7 @@
5555
args = parser.parse_args()
5656

5757
use_accel = not args.no_accel and torch.accelerator.is_available()
58-
@@ -103,12 +119,11 @@
58+
@@ -103,7 +119,7 @@
5959
else:
6060
device = torch.device("cpu")
6161

@@ -64,13 +64,7 @@
6464
test_kwargs = {'batch_size': args.test_batch_size}
6565
if use_accel:
6666
accel_kwargs = {'num_workers': 1,
67-
- 'pin_memory': True,
68-
- 'shuffle': True}
69-
+ 'pin_memory': True}
70-
train_kwargs.update(accel_kwargs)
71-
test_kwargs.update(accel_kwargs)
72-
73-
@@ -116,25 +131,53 @@
67+
@@ -117,25 +133,54 @@
7468
transforms.ToTensor(),
7569
transforms.Normalize((0.1307,), (0.3081,))
7670
])
@@ -115,19 +109,21 @@
115109
- torch.save(model.state_dict(), "mnist_cnn.pt")
116110
+ if rank == 0:
117111
+ torch.save(model.state_dict(), "mnist_cnn.pt")
118-
112+
+
119113
+ # close files
120114
+ train_file.close()
121115
+ test_file.close()
122116

117+
123118
if __name__ == '__main__':
119+
- main()
124120
+ ## initialize parallel environment
125121
+ comm, device = comm_file.init_parallel()
126122
+
127123
+ rank = comm.get_rank()
128124
+ nprocs = comm.get_size()
129125
+
130-
main()
126+
+ main()
131127
+
132128
+ comm.finalize()
133129
+

0 commit comments

Comments
 (0)