@@ -279,17 +279,18 @@ def train_network(self):
279279 self .data .training_data_sets [0 ].shuffle ()
280280
281281 if self .parameters ._configuration ["gpu" ]:
282- torch .cuda .synchronize ()
282+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
283283 tsample = time .time ()
284284 t0 = time .time ()
285285 batchid = 0
286286 for loader in self .training_data_loaders :
287287 for (inputs , outputs ) in loader :
288288
289- if batchid == self .parameters .profiler_range [0 ]:
290- torch .cuda .profiler .start ()
291- if batchid == self .parameters .profiler_range [1 ]:
292- torch .cuda .profiler .stop ()
289+ if self .parameters .profiler_range is not None :
290+ if batchid == self .parameters .profiler_range [0 ]:
291+ torch .cuda .profiler .start ()
292+ if batchid == self .parameters .profiler_range [1 ]:
293+ torch .cuda .profiler .stop ()
293294
294295 torch .cuda .nvtx .range_push (f"step { batchid } " )
295296
@@ -309,7 +310,7 @@ def train_network(self):
309310 training_loss_sum += loss
310311
311312 if batchid != 0 and (batchid + 1 ) % self .parameters .training_report_frequency == 0 :
312- torch .cuda .synchronize ()
313+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
313314 sample_time = time .time () - tsample
314315 avg_sample_time = sample_time / self .parameters .training_report_frequency
315316 avg_sample_tput = self .parameters .training_report_frequency * inputs .shape [0 ] / sample_time
@@ -319,14 +320,14 @@ def train_network(self):
319320 min_verbosity = 2 )
320321 tsample = time .time ()
321322 batchid += 1
322- torch .cuda .synchronize ()
323+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
323324 t1 = time .time ()
324325 printout (f"training time: { t1 - t0 } " , min_verbosity = 2 )
325326
326327 training_loss = training_loss_sum .item () / batchid
327328
328329 # Calculate the validation loss. and output it.
329- torch .cuda .synchronize ()
330+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
330331 else :
331332 batchid = 0
332333 for loader in self .training_data_loaders :
@@ -375,14 +376,14 @@ def train_network(self):
375376 self .tensor_board .close ()
376377
377378 if self .parameters ._configuration ["gpu" ]:
378- torch .cuda .synchronize ()
379+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
379380
380381 # Mix the DataSets up (this function only does something
381382 # in the lazy loading case).
382383 if self .parameters .use_shuffling_for_samplers :
383384 self .data .mix_datasets ()
384385 if self .parameters ._configuration ["gpu" ]:
385- torch .cuda .synchronize ()
386+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
386387
387388 # If a scheduler is used, update it.
388389 if self .scheduler is not None :
@@ -636,8 +637,8 @@ def __process_mini_batch(self, network, input_data, target_data):
636637 if self .parameters ._configuration ["gpu" ]:
637638 if self .parameters .use_graphs and self .train_graph is None :
638639 printout ("Capturing CUDA graph for training." , min_verbosity = 2 )
639- s = torch .cuda .Stream ()
640- s .wait_stream (torch .cuda .current_stream ())
640+ s = torch .cuda .Stream (self . parameters . _configuration [ "device" ] )
641+ s .wait_stream (torch .cuda .current_stream (self . parameters . _configuration [ "device" ] ))
641642 # Warmup for graphs
642643 with torch .cuda .stream (s ):
643644 for _ in range (20 ):
@@ -651,7 +652,7 @@ def __process_mini_batch(self, network, input_data, target_data):
651652 self .gradscaler .scale (loss ).backward ()
652653 else :
653654 loss .backward ()
654- torch .cuda .current_stream ().wait_stream (s )
655+ torch .cuda .current_stream (self . parameters . _configuration [ "device" ] ).wait_stream (s )
655656
656657 # Create static entry point tensors to graph
657658 self .static_input_data = torch .empty_like (input_data )
@@ -742,7 +743,7 @@ def __validate_network(self, network, data_set_type, validation_type):
742743 with torch .no_grad ():
743744 if self .parameters ._configuration ["gpu" ]:
744745 report_freq = self .parameters .training_report_frequency
745- torch .cuda .synchronize ()
746+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
746747 tsample = time .time ()
747748 batchid = 0
748749 for loader in data_loaders :
@@ -754,15 +755,15 @@ def __validate_network(self, network, data_set_type, validation_type):
754755
755756 if self .parameters .use_graphs and self .validation_graph is None :
756757 printout ("Capturing CUDA graph for validation." , min_verbosity = 2 )
757- s = torch .cuda .Stream ()
758- s .wait_stream (torch .cuda .current_stream ())
758+ s = torch .cuda .Stream (self . parameters . _configuration [ "device" ] )
759+ s .wait_stream (torch .cuda .current_stream (self . parameters . _configuration [ "device" ] ))
759760 # Warmup for graphs
760761 with torch .cuda .stream (s ):
761762 for _ in range (20 ):
762763 with torch .cuda .amp .autocast (enabled = self .parameters .use_mixed_precision ):
763764 prediction = network (x )
764765 loss = network .calculate_loss (prediction , y )
765- torch .cuda .current_stream ().wait_stream (s )
766+ torch .cuda .current_stream (self . parameters . _configuration [ "device" ] ).wait_stream (s )
766767
767768 # Create static entry point tensors to graph
768769 self .static_input_validation = torch .empty_like (x )
@@ -786,7 +787,7 @@ def __validate_network(self, network, data_set_type, validation_type):
786787 loss = network .calculate_loss (prediction , y )
787788 validation_loss_sum += loss
788789 if batchid != 0 and (batchid + 1 ) % report_freq == 0 :
789- torch .cuda .synchronize ()
790+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
790791 sample_time = time .time () - tsample
791792 avg_sample_time = sample_time / report_freq
792793 avg_sample_tput = report_freq * x .shape [0 ] / sample_time
@@ -796,7 +797,7 @@ def __validate_network(self, network, data_set_type, validation_type):
796797 min_verbosity = 2 )
797798 tsample = time .time ()
798799 batchid += 1
799- torch .cuda .synchronize ()
800+ torch .cuda .synchronize (self . parameters . _configuration [ "device" ] )
800801 else :
801802 batchid = 0
802803 for loader in data_loaders :
0 commit comments