@@ -675,47 +675,207 @@ def _validate_network(self, data_set_fractions, metrics):
675675 )
676676 loader_id += 1
677677 else :
678- with torch .no_grad ():
679- for snapshot_number in trange (
680- offset_snapshots ,
681- number_of_snapshots + offset_snapshots ,
682- desc = "Validation" ,
683- disable = self .parameters_full .verbosity < 2 ,
684- ):
685- # Get optimal batch size and number of batches per snapshotss
686- grid_size = (
687- self .data .parameters .snapshot_directories_list [
688- snapshot_number
689- ].grid_size
690- )
678+ # If only the LDOS is in the validation metrics (as is the
679+ # case for, e.g., distributed network trainings), we can
680+ # use a faster (or at least better parallelizing) code
691681
692- optimal_batch_size = self ._correct_batch_size (
693- grid_size , self .parameters .mini_batch_size
694- )
695- number_of_batches_per_snapshot = int (
696- grid_size / optimal_batch_size
682+ if (
683+ len (self .parameters .validation_metrics ) == 1
684+ and self .parameters .validation_metrics [0 ] == "ldos"
685+ ):
686+
687+ errors [data_set_type ]["ldos" ] = (
688+ self .__calculate_validation_error_ldos_only (
689+ data_loaders
697690 )
691+ )
698692
699- actual_outputs , predicted_outputs = (
700- self ._forward_entire_snapshot (
693+ else :
694+ with torch .no_grad ():
695+ for snapshot_number in trange (
696+ offset_snapshots ,
697+ number_of_snapshots + offset_snapshots ,
698+ desc = "Validation" ,
699+ disable = self .parameters_full .verbosity < 2 ,
700+ ):
701+ # Get optimal batch size and number of batches per snapshotss
702+ grid_size = (
703+ self .data .parameters .snapshot_directories_list [
704+ snapshot_number
705+ ].grid_size
706+ )
707+
708+ optimal_batch_size = self ._correct_batch_size (
709+ grid_size , self .parameters .mini_batch_size
710+ )
711+ number_of_batches_per_snapshot = int (
712+ grid_size / optimal_batch_size
713+ )
714+
715+ actual_outputs , predicted_outputs = (
716+ self ._forward_entire_snapshot (
717+ snapshot_number ,
718+ data_sets [0 ],
719+ data_set_type [0 :2 ],
720+ number_of_batches_per_snapshot ,
721+ optimal_batch_size ,
722+ )
723+ )
724+ calculated_errors = self ._calculate_errors (
725+ actual_outputs ,
726+ predicted_outputs ,
727+ metrics ,
701728 snapshot_number ,
702- data_sets [0 ],
703- data_set_type [0 :2 ],
704- number_of_batches_per_snapshot ,
705- optimal_batch_size ,
706729 )
730+ for metric in metrics :
731+ errors [data_set_type ][metric ].append (
732+ calculated_errors [metric ]
733+ )
734+ return errors
735+
736+ def __calculate_validation_error_ldos_only (self , data_loaders ):
737+ validation_loss_sum = torch .zeros (
738+ 1 , device = self .parameters ._configuration ["device" ]
739+ )
740+ with torch .no_grad ():
741+ if self .parameters ._configuration ["gpu" ]:
742+ report_freq = self .parameters .training_log_interval
743+ torch .cuda .synchronize (
744+ self .parameters ._configuration ["device" ]
745+ )
746+ tsample = time .time ()
747+ batchid = 0
748+ for loader in data_loaders :
749+ for x , y in loader :
750+ x = x .to (
751+ self .parameters ._configuration ["device" ],
752+ non_blocking = True ,
707753 )
708- calculated_errors = self ._calculate_errors (
709- actual_outputs ,
710- predicted_outputs ,
711- metrics ,
712- snapshot_number ,
754+ y = y .to (
755+ self .parameters ._configuration ["device" ],
756+ non_blocking = True ,
713757 )
714- for metric in metrics :
715- errors [data_set_type ][metric ].append (
716- calculated_errors [metric ]
758+
759+ if (
760+ self .parameters .use_graphs
761+ and self ._validation_graph is None
762+ ):
763+ printout (
764+ "Capturing CUDA graph for validation." ,
765+ min_verbosity = 2 ,
717766 )
718- return errors
767+ s = torch .cuda .Stream (
768+ self .parameters ._configuration ["device" ]
769+ )
770+ s .wait_stream (
771+ torch .cuda .current_stream (
772+ self .parameters ._configuration ["device" ]
773+ )
774+ )
775+ # Warmup for graphs
776+ with torch .cuda .stream (s ):
777+ for _ in range (20 ):
778+ with torch .cuda .amp .autocast (
779+ enabled = self .parameters .use_mixed_precision
780+ ):
781+ prediction = self .network (x )
782+ if self .parameters_full .use_ddp :
783+ loss = self .network .module .calculate_loss (
784+ prediction , y
785+ )
786+ else :
787+ loss = self .network .calculate_loss (
788+ prediction , y
789+ )
790+ torch .cuda .current_stream (
791+ self .parameters ._configuration ["device" ]
792+ ).wait_stream (s )
793+
794+ # Create static entry point tensors to graph
795+ self .static_input_validation = torch .empty_like (x )
796+ self .static_target_validation = torch .empty_like (y )
797+
798+ # Capture graph
799+ self ._validation_graph = torch .cuda .CUDAGraph ()
800+ with torch .cuda .graph (self ._validation_graph ):
801+ with torch .cuda .amp .autocast (
802+ enabled = self .parameters .use_mixed_precision
803+ ):
804+ self .static_prediction_validation = (
805+ self .network (
806+ self .static_input_validation
807+ )
808+ )
809+ if self .parameters_full .use_ddp :
810+ self .static_loss_validation = self .network .module .calculate_loss (
811+ self .static_prediction_validation ,
812+ self .static_target_validation ,
813+ )
814+ else :
815+ self .static_loss_validation = self .network .calculate_loss (
816+ self .static_prediction_validation ,
817+ self .static_target_validation ,
818+ )
819+
820+ if self ._validation_graph :
821+ self .static_input_validation .copy_ (x )
822+ self .static_target_validation .copy_ (y )
823+ self ._validation_graph .replay ()
824+ validation_loss_sum += self .static_loss_validation
825+ else :
826+ with torch .cuda .amp .autocast (
827+ enabled = self .parameters .use_mixed_precision
828+ ):
829+ prediction = self .network (x )
830+ if self .parameters_full .use_ddp :
831+ loss = self .network .module .calculate_loss (
832+ prediction , y
833+ )
834+ else :
835+ loss = self .network .calculate_loss (
836+ prediction , y
837+ )
838+ validation_loss_sum += loss
839+ if batchid != 0 and (batchid + 1 ) % report_freq == 0 :
840+ torch .cuda .synchronize (
841+ self .parameters ._configuration ["device" ]
842+ )
843+ sample_time = time .time () - tsample
844+ avg_sample_time = sample_time / report_freq
845+ avg_sample_tput = (
846+ report_freq * x .shape [0 ] / sample_time
847+ )
848+ printout (
849+ f"batch { batchid + 1 } , " # /{total_samples}, "
850+ f"validation avg time: { avg_sample_time } "
851+ f"validation avg throughput: { avg_sample_tput } " ,
852+ min_verbosity = 2 ,
853+ )
854+ tsample = time .time ()
855+ batchid += 1
856+ torch .cuda .synchronize (
857+ self .parameters ._configuration ["device" ]
858+ )
859+ else :
860+ batchid = 0
861+ for loader in data_loaders :
862+ for x , y in loader :
863+ x = x .to (self .parameters ._configuration ["device" ])
864+ y = y .to (self .parameters ._configuration ["device" ])
865+ prediction = self .network (x )
866+ if self .parameters_full .use_ddp :
867+ validation_loss_sum += (
868+ self .network .module .calculate_loss (
869+ prediction , y
870+ ).item ()
871+ )
872+ else :
873+ validation_loss_sum += self .network .calculate_loss (
874+ prediction , y
875+ ).item ()
876+ batchid += 1
877+
878+ return validation_loss_sum .item () / batchid
719879
720880 def __prepare_to_train (self , optimizer_dict ):
721881 """Prepare everything for training."""
0 commit comments