Skip to content

Commit e55883f

Browse files
Merge pull request mala-project#499 from mala-project/develop
v1.2.1 - Minor Bugfixes
2 parents 04e4d2d + e4880d4 commit e55883f

8 files changed

Lines changed: 50 additions & 41 deletions

File tree

CITATION.cff

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22
cff-version: 1.2.0
33
message: "If you use this software, please cite it using these metadata."
44
authors:
5+
- affiliation: "Center for Advanced Systems Understanding (CASUS), Helmholtz-Zentrum Dresden-Rossendorf e.V. (HZDR)"
6+
family-names: Cangi
7+
given-names: Attila
8+
orcid: https://orcid.org/0000-0001-9162-262X
9+
- affiliation: "Sandia National Laboratories (SNL)"
10+
family-names: Rajamanickam
11+
given-names: Sivasankaran
12+
orcid: https://orcid.org/0000-0002-5854-409X
513
- affiliation: "Center for Advanced Systems Understanding (CASUS), Helmholtz-Zentrum Dresden-Rossendorf e.V. (HZDR)"
614
family-names: Brzoza
715
given-names: Bartosz
816
- affiliation: "Center for Advanced Systems Understanding (CASUS), Helmholtz-Zentrum Dresden-Rossendorf e.V. (HZDR)"
917
family-names: Callow
1018
given-names: Timothy J.
11-
- affiliation: "Center for Advanced Systems Understanding (CASUS), Helmholtz-Zentrum Dresden-Rossendorf e.V. (HZDR)"
12-
family-names: Cangi
13-
given-names: Attila
14-
orcid: https://orcid.org/0000-0001-9162-262X
1519
- affiliation: "Oak Ridge National Laboratory (ORNL)"
1620
family-names: Ellis
1721
given-names: J. Austin
@@ -54,10 +58,6 @@ authors:
5458
- affiliation: "Center for Advanced Systems Understanding (CASUS), Helmholtz-Zentrum Dresden-Rossendorf e.V. (HZDR)"
5559
family-names: Pöschel
5660
given-names: Franz
57-
- affiliation: "Sandia National Laboratories (SNL)"
58-
family-names: Rajamanickam
59-
given-names: Sivasankaran
60-
orcid: https://orcid.org/0000-0002-5854-409X
6161
- affiliation: "Nvidia Corporation"
6262
family-names: Romero
6363
given-names: Josh
62.3 KB
Loading
35.8 KB
Loading
47.1 KB
Loading
39.7 KB
Loading

mala/common/parameters.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import pickle
77
from time import sleep
88

9+
horovod_available = False
910
try:
1011
import horovod.torch as hvd
12+
horovod_available = True
1113
except ModuleNotFoundError:
1214
pass
1315
import numpy as np
@@ -732,7 +734,7 @@ def __init__(self):
732734
self.use_mixed_precision = False
733735
self.use_graphs = False
734736
self.training_report_frequency = 1000
735-
self.profiler_range = [1000, 2000]
737+
self.profiler_range = None #[1000, 2000]
736738

737739
def _update_horovod(self, new_horovod):
738740
super(ParametersRunning, self)._update_horovod(new_horovod)
@@ -1257,19 +1259,25 @@ def use_horovod(self):
12571259

12581260
@use_horovod.setter
12591261
def use_horovod(self, value):
1260-
if value:
1261-
hvd.init()
1262+
if value is False:
1263+
self._use_horovod = False
1264+
else:
1265+
if horovod_available:
1266+
hvd.init()
1267+
# Invalidate, will be updated in setter.
1268+
set_horovod_status(value)
1269+
self.device = None
1270+
self._use_horovod = value
1271+
self.network._update_horovod(self.use_horovod)
1272+
self.descriptors._update_horovod(self.use_horovod)
1273+
self.targets._update_horovod(self.use_horovod)
1274+
self.data._update_horovod(self.use_horovod)
1275+
self.running._update_horovod(self.use_horovod)
1276+
self.hyperparameters._update_horovod(self.use_horovod)
1277+
else:
1278+
parallel_warn("Horovod requested, but not installed found. "
1279+
"MALA will operate without horovod only.")
12621280

1263-
# Invalidate, will be updated in setter.
1264-
set_horovod_status(value)
1265-
self.device = None
1266-
self._use_horovod = value
1267-
self.network._update_horovod(self.use_horovod)
1268-
self.descriptors._update_horovod(self.use_horovod)
1269-
self.targets._update_horovod(self.use_horovod)
1270-
self.data._update_horovod(self.use_horovod)
1271-
self.running._update_horovod(self.use_horovod)
1272-
self.hyperparameters._update_horovod(self.use_horovod)
12731281

12741282
@property
12751283
def device(self):

mala/network/tester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __calculate_observable_error(self, snapshot_number, observable,
210210
target_calculator.read_from_array(predicted_target)
211211
predicted = target_calculator.band_energy
212212
return [actual, predicted,
213-
target_calculator.total_energy_dft_calculation]
213+
target_calculator.band_energy_dft_calculation]
214214

215215
elif observable == "number_of_electrons":
216216
target_calculator = self.data.target_calculator

mala/network/trainer.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -279,17 +279,18 @@ def train_network(self):
279279
self.data.training_data_sets[0].shuffle()
280280

281281
if self.parameters._configuration["gpu"]:
282-
torch.cuda.synchronize()
282+
torch.cuda.synchronize(self.parameters._configuration["device"])
283283
tsample = time.time()
284284
t0 = time.time()
285285
batchid = 0
286286
for loader in self.training_data_loaders:
287287
for (inputs, outputs) in loader:
288288

289-
if batchid == self.parameters.profiler_range[0]:
290-
torch.cuda.profiler.start()
291-
if batchid == self.parameters.profiler_range[1]:
292-
torch.cuda.profiler.stop()
289+
if self.parameters.profiler_range is not None:
290+
if batchid == self.parameters.profiler_range[0]:
291+
torch.cuda.profiler.start()
292+
if batchid == self.parameters.profiler_range[1]:
293+
torch.cuda.profiler.stop()
293294

294295
torch.cuda.nvtx.range_push(f"step {batchid}")
295296

@@ -309,7 +310,7 @@ def train_network(self):
309310
training_loss_sum += loss
310311

311312
if batchid != 0 and (batchid + 1) % self.parameters.training_report_frequency == 0:
312-
torch.cuda.synchronize()
313+
torch.cuda.synchronize(self.parameters._configuration["device"])
313314
sample_time = time.time() - tsample
314315
avg_sample_time = sample_time / self.parameters.training_report_frequency
315316
avg_sample_tput = self.parameters.training_report_frequency * inputs.shape[0] / sample_time
@@ -319,14 +320,14 @@ def train_network(self):
319320
min_verbosity=2)
320321
tsample = time.time()
321322
batchid += 1
322-
torch.cuda.synchronize()
323+
torch.cuda.synchronize(self.parameters._configuration["device"])
323324
t1 = time.time()
324325
printout(f"training time: {t1 - t0}", min_verbosity=2)
325326

326327
training_loss = training_loss_sum.item() / batchid
327328

328329
# Calculate the validation loss. and output it.
329-
torch.cuda.synchronize()
330+
torch.cuda.synchronize(self.parameters._configuration["device"])
330331
else:
331332
batchid = 0
332333
for loader in self.training_data_loaders:
@@ -375,14 +376,14 @@ def train_network(self):
375376
self.tensor_board.close()
376377

377378
if self.parameters._configuration["gpu"]:
378-
torch.cuda.synchronize()
379+
torch.cuda.synchronize(self.parameters._configuration["device"])
379380

380381
# Mix the DataSets up (this function only does something
381382
# in the lazy loading case).
382383
if self.parameters.use_shuffling_for_samplers:
383384
self.data.mix_datasets()
384385
if self.parameters._configuration["gpu"]:
385-
torch.cuda.synchronize()
386+
torch.cuda.synchronize(self.parameters._configuration["device"])
386387

387388
# If a scheduler is used, update it.
388389
if self.scheduler is not None:
@@ -636,8 +637,8 @@ def __process_mini_batch(self, network, input_data, target_data):
636637
if self.parameters._configuration["gpu"]:
637638
if self.parameters.use_graphs and self.train_graph is None:
638639
printout("Capturing CUDA graph for training.", min_verbosity=2)
639-
s = torch.cuda.Stream()
640-
s.wait_stream(torch.cuda.current_stream())
640+
s = torch.cuda.Stream(self.parameters._configuration["device"])
641+
s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
641642
# Warmup for graphs
642643
with torch.cuda.stream(s):
643644
for _ in range(20):
@@ -651,7 +652,7 @@ def __process_mini_batch(self, network, input_data, target_data):
651652
self.gradscaler.scale(loss).backward()
652653
else:
653654
loss.backward()
654-
torch.cuda.current_stream().wait_stream(s)
655+
torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)
655656

656657
# Create static entry point tensors to graph
657658
self.static_input_data = torch.empty_like(input_data)
@@ -742,7 +743,7 @@ def __validate_network(self, network, data_set_type, validation_type):
742743
with torch.no_grad():
743744
if self.parameters._configuration["gpu"]:
744745
report_freq = self.parameters.training_report_frequency
745-
torch.cuda.synchronize()
746+
torch.cuda.synchronize(self.parameters._configuration["device"])
746747
tsample = time.time()
747748
batchid = 0
748749
for loader in data_loaders:
@@ -754,15 +755,15 @@ def __validate_network(self, network, data_set_type, validation_type):
754755

755756
if self.parameters.use_graphs and self.validation_graph is None:
756757
printout("Capturing CUDA graph for validation.", min_verbosity=2)
757-
s = torch.cuda.Stream()
758-
s.wait_stream(torch.cuda.current_stream())
758+
s = torch.cuda.Stream(self.parameters._configuration["device"])
759+
s.wait_stream(torch.cuda.current_stream(self.parameters._configuration["device"]))
759760
# Warmup for graphs
760761
with torch.cuda.stream(s):
761762
for _ in range(20):
762763
with torch.cuda.amp.autocast(enabled=self.parameters.use_mixed_precision):
763764
prediction = network(x)
764765
loss = network.calculate_loss(prediction, y)
765-
torch.cuda.current_stream().wait_stream(s)
766+
torch.cuda.current_stream(self.parameters._configuration["device"]).wait_stream(s)
766767

767768
# Create static entry point tensors to graph
768769
self.static_input_validation = torch.empty_like(x)
@@ -786,7 +787,7 @@ def __validate_network(self, network, data_set_type, validation_type):
786787
loss = network.calculate_loss(prediction, y)
787788
validation_loss_sum += loss
788789
if batchid != 0 and (batchid + 1) % report_freq == 0:
789-
torch.cuda.synchronize()
790+
torch.cuda.synchronize(self.parameters._configuration["device"])
790791
sample_time = time.time() - tsample
791792
avg_sample_time = sample_time / report_freq
792793
avg_sample_tput = report_freq * x.shape[0] / sample_time
@@ -796,7 +797,7 @@ def __validate_network(self, network, data_set_type, validation_type):
796797
min_verbosity=2)
797798
tsample = time.time()
798799
batchid += 1
799-
torch.cuda.synchronize()
800+
torch.cuda.synchronize(self.parameters._configuration["device"])
800801
else:
801802
batchid = 0
802803
for loader in data_loaders:

0 commit comments

Comments
 (0)