@@ -675,46 +675,241 @@ def _validate_network(self, data_set_fractions, metrics):
675675 )
676676 loader_id += 1
677677 else :
678- with torch .no_grad ():
679- for snapshot_number in trange (
680- offset_snapshots ,
681- number_of_snapshots + offset_snapshots ,
682- desc = "Validation" ,
683- disable = self .parameters_full .verbosity < 2 ,
684- ):
685- # Get optimal batch size and number of batches per snapshotss
686- grid_size = (
687- self .data .parameters .snapshot_directories_list [
688- snapshot_number
689- ].grid_size
690- )
678+ # If only the LDOS is in the validation metrics (as is the
679+ # case for, e.g., distributed network trainings), we can
680+ # use a faster (or at least better parallelizing) code
681+ if (
682+ len (self .parameters .validation_metrics ) == 1
683+ and self .parameters .validation_metrics [0 ] == "ldos"
684+ ):
685+ validation_loss_sum = torch .zeros (
686+ 1 , device = self .parameters ._configuration ["device" ]
687+ )
688+ with torch .no_grad ():
689+ if self .parameters ._configuration ["gpu" ]:
690+ report_freq = self .parameters .training_log_interval
691+ torch .cuda .synchronize (
692+ self .parameters ._configuration ["device" ]
693+ )
694+ tsample = time .time ()
695+ batchid = 0
696+ for loader in data_loaders :
697+ for x , y in loader :
698+ x = x .to (
699+ self .parameters ._configuration [
700+ "device"
701+ ],
702+ non_blocking = True ,
703+ )
704+ y = y .to (
705+ self .parameters ._configuration [
706+ "device"
707+ ],
708+ non_blocking = True ,
709+ )
710+
711+ if (
712+ self .parameters .use_graphs
713+ and self .validation_graph is None
714+ ):
715+ printout (
716+ "Capturing CUDA graph for validation." ,
717+ min_verbosity = 2 ,
718+ )
719+ s = torch .cuda .Stream (
720+ self .parameters ._configuration [
721+ "device"
722+ ]
723+ )
724+ s .wait_stream (
725+ torch .cuda .current_stream (
726+ self .parameters ._configuration [
727+ "device"
728+ ]
729+ )
730+ )
731+ # Warmup for graphs
732+ with torch .cuda .stream (s ):
733+ for _ in range (20 ):
734+ with torch .cuda .amp .autocast (
735+ enabled = self .parameters .use_mixed_precision
736+ ):
737+ prediction = self .network (
738+ x
739+ )
740+ if (
741+ self .parameters_full .use_ddp
742+ ):
743+ loss = self .network .module .calculate_loss (
744+ prediction , y
745+ )
746+ else :
747+ loss = self .network .calculate_loss (
748+ prediction , y
749+ )
750+ torch .cuda .current_stream (
751+ self .parameters ._configuration [
752+ "device"
753+ ]
754+ ).wait_stream (s )
755+
756+ # Create static entry point tensors to graph
757+ self .static_input_validation = (
758+ torch .empty_like (x )
759+ )
760+ self .static_target_validation = (
761+ torch .empty_like (y )
762+ )
763+
764+ # Capture graph
765+ self .validation_graph = (
766+ torch .cuda .CUDAGraph ()
767+ )
768+ with torch .cuda .graph (
769+ self .validation_graph
770+ ):
771+ with torch .cuda .amp .autocast (
772+ enabled = self .parameters .use_mixed_precision
773+ ):
774+ self .static_prediction_validation = self .network (
775+ self .static_input_validation
776+ )
777+ if (
778+ self .parameters_full .use_ddp
779+ ):
780+ self .static_loss_validation = self .network .module .calculate_loss (
781+ self .static_prediction_validation ,
782+ self .static_target_validation ,
783+ )
784+ else :
785+ self .static_loss_validation = self .network .calculate_loss (
786+ self .static_prediction_validation ,
787+ self .static_target_validation ,
788+ )
789+
790+ if self .validation_graph :
791+ self .static_input_validation .copy_ (x )
792+ self .static_target_validation .copy_ (y )
793+ self .validation_graph .replay ()
794+ validation_loss_sum += (
795+ self .static_loss_validation
796+ )
797+ else :
798+ with torch .cuda .amp .autocast (
799+ enabled = self .parameters .use_mixed_precision
800+ ):
801+ prediction = self .network (x )
802+ if self .parameters_full .use_ddp :
803+ loss = self .network .module .calculate_loss (
804+ prediction , y
805+ )
806+ else :
807+ loss = self .network .calculate_loss (
808+ prediction , y
809+ )
810+ validation_loss_sum += loss
811+ if (
812+ batchid != 0
813+ and (batchid + 1 ) % report_freq == 0
814+ ):
815+ torch .cuda .synchronize (
816+ self .parameters ._configuration [
817+ "device"
818+ ]
819+ )
820+ sample_time = time .time () - tsample
821+ avg_sample_time = (
822+ sample_time / report_freq
823+ )
824+ avg_sample_tput = (
825+ report_freq
826+ * x .shape [0 ]
827+ / sample_time
828+ )
829+ printout (
830+ f"batch { batchid + 1 } , " # /{total_samples}, "
831+ f"validation avg time: { avg_sample_time } "
832+ f"validation avg throughput: { avg_sample_tput } " ,
833+ min_verbosity = 2 ,
834+ )
835+ tsample = time .time ()
836+ batchid += 1
837+ torch .cuda .synchronize (
838+ self .parameters ._configuration ["device" ]
839+ )
840+ else :
841+ batchid = 0
842+ for loader in data_loaders :
843+ for x , y in loader :
844+ x = x .to (
845+ self .parameters ._configuration [
846+ "device"
847+ ]
848+ )
849+ y = y .to (
850+ self .parameters ._configuration [
851+ "device"
852+ ]
853+ )
854+ prediction = self .network (x )
855+ if self .parameters_full .use_ddp :
856+ validation_loss_sum += (
857+ self .network .module .calculate_loss (
858+ prediction , y
859+ ).item ()
860+ )
861+ else :
862+ validation_loss_sum += (
863+ self .network .calculate_loss (
864+ prediction , y
865+ ).item ()
866+ )
867+ batchid += 1
868+
869+ validation_loss = validation_loss_sum .item () / batchid
870+ errors [data_set_type ]["ldos" ] = validation_loss
691871
692- optimal_batch_size = self ._correct_batch_size (
693- grid_size , self .parameters .mini_batch_size
694- )
695- number_of_batches_per_snapshot = int (
696- grid_size / optimal_batch_size
697- )
872+ else :
873+ with torch .no_grad ():
874+ for snapshot_number in trange (
875+ offset_snapshots ,
876+ number_of_snapshots + offset_snapshots ,
877+ desc = "Validation" ,
878+ disable = self .parameters_full .verbosity < 2 ,
879+ ):
880+ # Get optimal batch size and number of batches per snapshotss
881+ grid_size = (
882+ self .data .parameters .snapshot_directories_list [
883+ snapshot_number
884+ ].grid_size
885+ )
698886
699- actual_outputs , predicted_outputs = (
700- self ._forward_entire_snapshot (
701- snapshot_number ,
702- data_sets [0 ],
703- data_set_type [0 :2 ],
704- number_of_batches_per_snapshot ,
705- optimal_batch_size ,
887+ optimal_batch_size = self ._correct_batch_size (
888+ grid_size , self .parameters .mini_batch_size
706889 )
707- )
708- calculated_errors = self ._calculate_errors (
709- actual_outputs ,
710- predicted_outputs ,
711- metrics ,
712- snapshot_number ,
713- )
714- for metric in metrics :
715- errors [data_set_type ][metric ].append (
716- calculated_errors [metric ]
890+ number_of_batches_per_snapshot = int (
891+ grid_size / optimal_batch_size
892+ )
893+
894+ actual_outputs , predicted_outputs = (
895+ self ._forward_entire_snapshot (
896+ snapshot_number ,
897+ data_sets [0 ],
898+ data_set_type [0 :2 ],
899+ number_of_batches_per_snapshot ,
900+ optimal_batch_size ,
901+ )
717902 )
903+ calculated_errors = self ._calculate_errors (
904+ actual_outputs ,
905+ predicted_outputs ,
906+ metrics ,
907+ snapshot_number ,
908+ )
909+ for metric in metrics :
910+ errors [data_set_type ][metric ].append (
911+ calculated_errors [metric ]
912+ )
718913 return errors
719914
720915 def __prepare_to_train (self , optimizer_dict ):
0 commit comments