Skip to content

Commit 8358e03

Browse files
Reintroduced old validation loss calculation, let's see if this fixes something
1 parent dddccd4 commit 8358e03

1 file changed

Lines changed: 231 additions & 36 deletions

File tree

mala/network/trainer.py

Lines changed: 231 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -675,46 +675,241 @@ def _validate_network(self, data_set_fractions, metrics):
675675
)
676676
loader_id += 1
677677
else:
678-
with torch.no_grad():
679-
for snapshot_number in trange(
680-
offset_snapshots,
681-
number_of_snapshots + offset_snapshots,
682-
desc="Validation",
683-
disable=self.parameters_full.verbosity < 2,
684-
):
685-
# Get optimal batch size and number of batches per snapshotss
686-
grid_size = (
687-
self.data.parameters.snapshot_directories_list[
688-
snapshot_number
689-
].grid_size
690-
)
678+
# If only the LDOS is in the validation metrics (as is the
679+
# case for, e.g., distributed network trainings), we can
680+
# use a faster (or at least better parallelizing) code
681+
if (
682+
len(self.parameters.validation_metrics) == 1
683+
and self.parameters.validation_metrics[0] == "ldos"
684+
):
685+
validation_loss_sum = torch.zeros(
686+
1, device=self.parameters._configuration["device"]
687+
)
688+
with torch.no_grad():
689+
if self.parameters._configuration["gpu"]:
690+
report_freq = self.parameters.training_log_interval
691+
torch.cuda.synchronize(
692+
self.parameters._configuration["device"]
693+
)
694+
tsample = time.time()
695+
batchid = 0
696+
for loader in data_loaders:
697+
for x, y in loader:
698+
x = x.to(
699+
self.parameters._configuration[
700+
"device"
701+
],
702+
non_blocking=True,
703+
)
704+
y = y.to(
705+
self.parameters._configuration[
706+
"device"
707+
],
708+
non_blocking=True,
709+
)
710+
711+
if (
712+
self.parameters.use_graphs
713+
and self.validation_graph is None
714+
):
715+
printout(
716+
"Capturing CUDA graph for validation.",
717+
min_verbosity=2,
718+
)
719+
s = torch.cuda.Stream(
720+
self.parameters._configuration[
721+
"device"
722+
]
723+
)
724+
s.wait_stream(
725+
torch.cuda.current_stream(
726+
self.parameters._configuration[
727+
"device"
728+
]
729+
)
730+
)
731+
# Warmup for graphs
732+
with torch.cuda.stream(s):
733+
for _ in range(20):
734+
with torch.cuda.amp.autocast(
735+
enabled=self.parameters.use_mixed_precision
736+
):
737+
prediction = self.network(
738+
x
739+
)
740+
if (
741+
self.parameters_full.use_ddp
742+
):
743+
loss = self.network.module.calculate_loss(
744+
prediction, y
745+
)
746+
else:
747+
loss = self.network.calculate_loss(
748+
prediction, y
749+
)
750+
torch.cuda.current_stream(
751+
self.parameters._configuration[
752+
"device"
753+
]
754+
).wait_stream(s)
755+
756+
# Create static entry point tensors to graph
757+
self.static_input_validation = (
758+
torch.empty_like(x)
759+
)
760+
self.static_target_validation = (
761+
torch.empty_like(y)
762+
)
763+
764+
# Capture graph
765+
self.validation_graph = (
766+
torch.cuda.CUDAGraph()
767+
)
768+
with torch.cuda.graph(
769+
self.validation_graph
770+
):
771+
with torch.cuda.amp.autocast(
772+
enabled=self.parameters.use_mixed_precision
773+
):
774+
self.static_prediction_validation = self.network(
775+
self.static_input_validation
776+
)
777+
if (
778+
self.parameters_full.use_ddp
779+
):
780+
self.static_loss_validation = self.network.module.calculate_loss(
781+
self.static_prediction_validation,
782+
self.static_target_validation,
783+
)
784+
else:
785+
self.static_loss_validation = self.network.calculate_loss(
786+
self.static_prediction_validation,
787+
self.static_target_validation,
788+
)
789+
790+
if self.validation_graph:
791+
self.static_input_validation.copy_(x)
792+
self.static_target_validation.copy_(y)
793+
self.validation_graph.replay()
794+
validation_loss_sum += (
795+
self.static_loss_validation
796+
)
797+
else:
798+
with torch.cuda.amp.autocast(
799+
enabled=self.parameters.use_mixed_precision
800+
):
801+
prediction = self.network(x)
802+
if self.parameters_full.use_ddp:
803+
loss = self.network.module.calculate_loss(
804+
prediction, y
805+
)
806+
else:
807+
loss = self.network.calculate_loss(
808+
prediction, y
809+
)
810+
validation_loss_sum += loss
811+
if (
812+
batchid != 0
813+
and (batchid + 1) % report_freq == 0
814+
):
815+
torch.cuda.synchronize(
816+
self.parameters._configuration[
817+
"device"
818+
]
819+
)
820+
sample_time = time.time() - tsample
821+
avg_sample_time = (
822+
sample_time / report_freq
823+
)
824+
avg_sample_tput = (
825+
report_freq
826+
* x.shape[0]
827+
/ sample_time
828+
)
829+
printout(
830+
f"batch {batchid + 1}, " # /{total_samples}, "
831+
f"validation avg time: {avg_sample_time} "
832+
f"validation avg throughput: {avg_sample_tput}",
833+
min_verbosity=2,
834+
)
835+
tsample = time.time()
836+
batchid += 1
837+
torch.cuda.synchronize(
838+
self.parameters._configuration["device"]
839+
)
840+
else:
841+
batchid = 0
842+
for loader in data_loaders:
843+
for x, y in loader:
844+
x = x.to(
845+
self.parameters._configuration[
846+
"device"
847+
]
848+
)
849+
y = y.to(
850+
self.parameters._configuration[
851+
"device"
852+
]
853+
)
854+
prediction = self.network(x)
855+
if self.parameters_full.use_ddp:
856+
validation_loss_sum += (
857+
self.network.module.calculate_loss(
858+
prediction, y
859+
).item()
860+
)
861+
else:
862+
validation_loss_sum += (
863+
self.network.calculate_loss(
864+
prediction, y
865+
).item()
866+
)
867+
batchid += 1
868+
869+
validation_loss = validation_loss_sum.item() / batchid
870+
errors[data_set_type]["ldos"] = validation_loss
691871

692-
optimal_batch_size = self._correct_batch_size(
693-
grid_size, self.parameters.mini_batch_size
694-
)
695-
number_of_batches_per_snapshot = int(
696-
grid_size / optimal_batch_size
697-
)
872+
else:
873+
with torch.no_grad():
874+
for snapshot_number in trange(
875+
offset_snapshots,
876+
number_of_snapshots + offset_snapshots,
877+
desc="Validation",
878+
disable=self.parameters_full.verbosity < 2,
879+
):
880+
# Get optimal batch size and number of batches per snapshotss
881+
grid_size = (
882+
self.data.parameters.snapshot_directories_list[
883+
snapshot_number
884+
].grid_size
885+
)
698886

699-
actual_outputs, predicted_outputs = (
700-
self._forward_entire_snapshot(
701-
snapshot_number,
702-
data_sets[0],
703-
data_set_type[0:2],
704-
number_of_batches_per_snapshot,
705-
optimal_batch_size,
887+
optimal_batch_size = self._correct_batch_size(
888+
grid_size, self.parameters.mini_batch_size
706889
)
707-
)
708-
calculated_errors = self._calculate_errors(
709-
actual_outputs,
710-
predicted_outputs,
711-
metrics,
712-
snapshot_number,
713-
)
714-
for metric in metrics:
715-
errors[data_set_type][metric].append(
716-
calculated_errors[metric]
890+
number_of_batches_per_snapshot = int(
891+
grid_size / optimal_batch_size
892+
)
893+
894+
actual_outputs, predicted_outputs = (
895+
self._forward_entire_snapshot(
896+
snapshot_number,
897+
data_sets[0],
898+
data_set_type[0:2],
899+
number_of_batches_per_snapshot,
900+
optimal_batch_size,
901+
)
717902
)
903+
calculated_errors = self._calculate_errors(
904+
actual_outputs,
905+
predicted_outputs,
906+
metrics,
907+
snapshot_number,
908+
)
909+
for metric in metrics:
910+
errors[data_set_type][metric].append(
911+
calculated_errors[metric]
912+
)
718913
return errors
719914

720915
def __prepare_to_train(self, optimizer_dict):

0 commit comments

Comments
 (0)