@@ -678,196 +678,17 @@ def _validate_network(self, data_set_fractions, metrics):
678678 # If only the LDOS is in the validation metrics (as is the
679679 # case for, e.g., distributed network trainings), we can
680680 # use a faster (or at least better parallelizing) code
681+
681682 if (
682683 len (self .parameters .validation_metrics ) == 1
683684 and self .parameters .validation_metrics [0 ] == "ldos"
684685 ):
685- validation_loss_sum = torch .zeros (
686- 1 , device = self .parameters ._configuration ["device" ]
687- )
688- with torch .no_grad ():
689- if self .parameters ._configuration ["gpu" ]:
690- report_freq = self .parameters .training_log_interval
691- torch .cuda .synchronize (
692- self .parameters ._configuration ["device" ]
693- )
694- tsample = time .time ()
695- batchid = 0
696- for loader in data_loaders :
697- for x , y in loader :
698- x = x .to (
699- self .parameters ._configuration [
700- "device"
701- ],
702- non_blocking = True ,
703- )
704- y = y .to (
705- self .parameters ._configuration [
706- "device"
707- ],
708- non_blocking = True ,
709- )
710686
711- if (
712- self .parameters .use_graphs
713- and self ._validation_graph is None
714- ):
715- printout (
716- "Capturing CUDA graph for validation." ,
717- min_verbosity = 2 ,
718- )
719- s = torch .cuda .Stream (
720- self .parameters ._configuration [
721- "device"
722- ]
723- )
724- s .wait_stream (
725- torch .cuda .current_stream (
726- self .parameters ._configuration [
727- "device"
728- ]
729- )
730- )
731- # Warmup for graphs
732- with torch .cuda .stream (s ):
733- for _ in range (20 ):
734- with torch .cuda .amp .autocast (
735- enabled = self .parameters .use_mixed_precision
736- ):
737- prediction = self .network (
738- x
739- )
740- if (
741- self .parameters_full .use_ddp
742- ):
743- loss = self .network .module .calculate_loss (
744- prediction , y
745- )
746- else :
747- loss = self .network .calculate_loss (
748- prediction , y
749- )
750- torch .cuda .current_stream (
751- self .parameters ._configuration [
752- "device"
753- ]
754- ).wait_stream (s )
755-
756- # Create static entry point tensors to graph
757- self .static_input_validation = (
758- torch .empty_like (x )
759- )
760- self .static_target_validation = (
761- torch .empty_like (y )
762- )
763-
764- # Capture graph
765- self ._validation_graph = (
766- torch .cuda .CUDAGraph ()
767- )
768- with torch .cuda .graph (
769- self ._validation_graph
770- ):
771- with torch .cuda .amp .autocast (
772- enabled = self .parameters .use_mixed_precision
773- ):
774- self .static_prediction_validation = self .network (
775- self .static_input_validation
776- )
777- if (
778- self .parameters_full .use_ddp
779- ):
780- self .static_loss_validation = self .network .module .calculate_loss (
781- self .static_prediction_validation ,
782- self .static_target_validation ,
783- )
784- else :
785- self .static_loss_validation = self .network .calculate_loss (
786- self .static_prediction_validation ,
787- self .static_target_validation ,
788- )
789-
790- if self ._validation_graph :
791- self .static_input_validation .copy_ (x )
792- self .static_target_validation .copy_ (y )
793- self ._validation_graph .replay ()
794- validation_loss_sum += (
795- self .static_loss_validation
796- )
797- else :
798- with torch .cuda .amp .autocast (
799- enabled = self .parameters .use_mixed_precision
800- ):
801- prediction = self .network (x )
802- if self .parameters_full .use_ddp :
803- loss = self .network .module .calculate_loss (
804- prediction , y
805- )
806- else :
807- loss = self .network .calculate_loss (
808- prediction , y
809- )
810- validation_loss_sum += loss
811- if (
812- batchid != 0
813- and (batchid + 1 ) % report_freq == 0
814- ):
815- torch .cuda .synchronize (
816- self .parameters ._configuration [
817- "device"
818- ]
819- )
820- sample_time = time .time () - tsample
821- avg_sample_time = (
822- sample_time / report_freq
823- )
824- avg_sample_tput = (
825- report_freq
826- * x .shape [0 ]
827- / sample_time
828- )
829- printout (
830- f"batch { batchid + 1 } , " # /{total_samples}, "
831- f"validation avg time: { avg_sample_time } "
832- f"validation avg throughput: { avg_sample_tput } " ,
833- min_verbosity = 2 ,
834- )
835- tsample = time .time ()
836- batchid += 1
837- torch .cuda .synchronize (
838- self .parameters ._configuration ["device" ]
839- )
840- else :
841- batchid = 0
842- for loader in data_loaders :
843- for x , y in loader :
844- x = x .to (
845- self .parameters ._configuration [
846- "device"
847- ]
848- )
849- y = y .to (
850- self .parameters ._configuration [
851- "device"
852- ]
853- )
854- prediction = self .network (x )
855- if self .parameters_full .use_ddp :
856- validation_loss_sum += (
857- self .network .module .calculate_loss (
858- prediction , y
859- ).item ()
860- )
861- else :
862- validation_loss_sum += (
863- self .network .calculate_loss (
864- prediction , y
865- ).item ()
866- )
867- batchid += 1
868-
869- validation_loss = validation_loss_sum .item () / batchid
870- errors [data_set_type ]["ldos" ] = validation_loss
687+ errors [data_set_type ]["ldos" ] = (
688+ self .__calculate_validation_error_ldos_only (
689+ data_loaders
690+ )
691+ )
871692
872693 else :
873694 with torch .no_grad ():
@@ -912,6 +733,150 @@ def _validate_network(self, data_set_fractions, metrics):
912733 )
913734 return errors
914735
736+ def __calculate_validation_error_ldos_only (self , data_loaders ):
737+ validation_loss_sum = torch .zeros (
738+ 1 , device = self .parameters ._configuration ["device" ]
739+ )
740+ with torch .no_grad ():
741+ if self .parameters ._configuration ["gpu" ]:
742+ report_freq = self .parameters .training_log_interval
743+ torch .cuda .synchronize (
744+ self .parameters ._configuration ["device" ]
745+ )
746+ tsample = time .time ()
747+ batchid = 0
748+ for loader in data_loaders :
749+ for x , y in loader :
750+ x = x .to (
751+ self .parameters ._configuration ["device" ],
752+ non_blocking = True ,
753+ )
754+ y = y .to (
755+ self .parameters ._configuration ["device" ],
756+ non_blocking = True ,
757+ )
758+
759+ if (
760+ self .parameters .use_graphs
761+ and self ._validation_graph is None
762+ ):
763+ printout (
764+ "Capturing CUDA graph for validation." ,
765+ min_verbosity = 2 ,
766+ )
767+ s = torch .cuda .Stream (
768+ self .parameters ._configuration ["device" ]
769+ )
770+ s .wait_stream (
771+ torch .cuda .current_stream (
772+ self .parameters ._configuration ["device" ]
773+ )
774+ )
775+ # Warmup for graphs
776+ with torch .cuda .stream (s ):
777+ for _ in range (20 ):
778+ with torch .cuda .amp .autocast (
779+ enabled = self .parameters .use_mixed_precision
780+ ):
781+ prediction = self .network (x )
782+ if self .parameters_full .use_ddp :
783+ loss = self .network .module .calculate_loss (
784+ prediction , y
785+ )
786+ else :
787+ loss = self .network .calculate_loss (
788+ prediction , y
789+ )
790+ torch .cuda .current_stream (
791+ self .parameters ._configuration ["device" ]
792+ ).wait_stream (s )
793+
794+ # Create static entry point tensors to graph
795+ self .static_input_validation = torch .empty_like (x )
796+ self .static_target_validation = torch .empty_like (y )
797+
798+ # Capture graph
799+ self ._validation_graph = torch .cuda .CUDAGraph ()
800+ with torch .cuda .graph (self ._validation_graph ):
801+ with torch .cuda .amp .autocast (
802+ enabled = self .parameters .use_mixed_precision
803+ ):
804+ self .static_prediction_validation = (
805+ self .network (
806+ self .static_input_validation
807+ )
808+ )
809+ if self .parameters_full .use_ddp :
810+ self .static_loss_validation = self .network .module .calculate_loss (
811+ self .static_prediction_validation ,
812+ self .static_target_validation ,
813+ )
814+ else :
815+ self .static_loss_validation = self .network .calculate_loss (
816+ self .static_prediction_validation ,
817+ self .static_target_validation ,
818+ )
819+
820+ if self ._validation_graph :
821+ self .static_input_validation .copy_ (x )
822+ self .static_target_validation .copy_ (y )
823+ self ._validation_graph .replay ()
824+ validation_loss_sum += self .static_loss_validation
825+ else :
826+ with torch .cuda .amp .autocast (
827+ enabled = self .parameters .use_mixed_precision
828+ ):
829+ prediction = self .network (x )
830+ if self .parameters_full .use_ddp :
831+ loss = self .network .module .calculate_loss (
832+ prediction , y
833+ )
834+ else :
835+ loss = self .network .calculate_loss (
836+ prediction , y
837+ )
838+ validation_loss_sum += loss
839+ if batchid != 0 and (batchid + 1 ) % report_freq == 0 :
840+ torch .cuda .synchronize (
841+ self .parameters ._configuration ["device" ]
842+ )
843+ sample_time = time .time () - tsample
844+ avg_sample_time = sample_time / report_freq
845+ avg_sample_tput = (
846+ report_freq * x .shape [0 ] / sample_time
847+ )
848+ printout (
849+ f"batch { batchid + 1 } , " # /{total_samples}, "
850+ f"validation avg time: { avg_sample_time } "
851+ f"validation avg throughput: { avg_sample_tput } " ,
852+ min_verbosity = 2 ,
853+ )
854+ tsample = time .time ()
855+ batchid += 1
856+ torch .cuda .synchronize (
857+ self .parameters ._configuration ["device" ]
858+ )
859+ else :
860+ batchid = 0
861+ for loader in data_loaders :
862+ for x , y in loader :
863+ x = x .to (self .parameters ._configuration ["device" ])
864+ y = y .to (self .parameters ._configuration ["device" ])
865+ prediction = self .network (x )
866+ if self .parameters_full .use_ddp :
867+ validation_loss_sum += (
868+ self .network .module .calculate_loss (
869+ prediction , y
870+ ).item ()
871+ )
872+ else :
873+ validation_loss_sum += self .network .calculate_loss (
874+ prediction , y
875+ ).item ()
876+ batchid += 1
877+
878+ return validation_loss_sum .item () / batchid
879+
915880 def __prepare_to_train (self , optimizer_dict ):
916881 """Prepare everything for training."""
917882 # Configure keyword arguments for DataSampler.
0 commit comments