Skip to content

Commit 03f6b96

Browse files
Merge pull request mala-project#617 from RandomDefaultUser/fix_ddp_validation
Recovering DDP scalability
2 parents dddccd4 + 20a06f2 commit 03f6b96

1 file changed

Lines changed: 193 additions & 33 deletions

File tree

mala/network/trainer.py

Lines changed: 193 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -675,47 +675,207 @@ def _validate_network(self, data_set_fractions, metrics):
675675
)
676676
loader_id += 1
677677
else:
678-
with torch.no_grad():
679-
for snapshot_number in trange(
680-
offset_snapshots,
681-
number_of_snapshots + offset_snapshots,
682-
desc="Validation",
683-
disable=self.parameters_full.verbosity < 2,
684-
):
685-
# Get optimal batch size and number of batches per snapshotss
686-
grid_size = (
687-
self.data.parameters.snapshot_directories_list[
688-
snapshot_number
689-
].grid_size
690-
)
678+
# If only the LDOS is in the validation metrics (as is the
679+
# case for, e.g., distributed network trainings), we can
680+
# use a faster (or at least better parallelizing) code
691681

692-
optimal_batch_size = self._correct_batch_size(
693-
grid_size, self.parameters.mini_batch_size
694-
)
695-
number_of_batches_per_snapshot = int(
696-
grid_size / optimal_batch_size
682+
if (
683+
len(self.parameters.validation_metrics) == 1
684+
and self.parameters.validation_metrics[0] == "ldos"
685+
):
686+
687+
errors[data_set_type]["ldos"] = (
688+
self.__calculate_validation_error_ldos_only(
689+
data_loaders
697690
)
691+
)
698692

699-
actual_outputs, predicted_outputs = (
700-
self._forward_entire_snapshot(
693+
else:
694+
with torch.no_grad():
695+
for snapshot_number in trange(
696+
offset_snapshots,
697+
number_of_snapshots + offset_snapshots,
698+
desc="Validation",
699+
disable=self.parameters_full.verbosity < 2,
700+
):
701+
# Get optimal batch size and number of batches per snapshotss
702+
grid_size = (
703+
self.data.parameters.snapshot_directories_list[
704+
snapshot_number
705+
].grid_size
706+
)
707+
708+
optimal_batch_size = self._correct_batch_size(
709+
grid_size, self.parameters.mini_batch_size
710+
)
711+
number_of_batches_per_snapshot = int(
712+
grid_size / optimal_batch_size
713+
)
714+
715+
actual_outputs, predicted_outputs = (
716+
self._forward_entire_snapshot(
717+
snapshot_number,
718+
data_sets[0],
719+
data_set_type[0:2],
720+
number_of_batches_per_snapshot,
721+
optimal_batch_size,
722+
)
723+
)
724+
calculated_errors = self._calculate_errors(
725+
actual_outputs,
726+
predicted_outputs,
727+
metrics,
701728
snapshot_number,
702-
data_sets[0],
703-
data_set_type[0:2],
704-
number_of_batches_per_snapshot,
705-
optimal_batch_size,
706729
)
730+
for metric in metrics:
731+
errors[data_set_type][metric].append(
732+
calculated_errors[metric]
733+
)
734+
return errors
735+
736+
def __calculate_validation_error_ldos_only(self, data_loaders):
737+
validation_loss_sum = torch.zeros(
738+
1, device=self.parameters._configuration["device"]
739+
)
740+
with torch.no_grad():
741+
if self.parameters._configuration["gpu"]:
742+
report_freq = self.parameters.training_log_interval
743+
torch.cuda.synchronize(
744+
self.parameters._configuration["device"]
745+
)
746+
tsample = time.time()
747+
batchid = 0
748+
for loader in data_loaders:
749+
for x, y in loader:
750+
x = x.to(
751+
self.parameters._configuration["device"],
752+
non_blocking=True,
707753
)
708-
calculated_errors = self._calculate_errors(
709-
actual_outputs,
710-
predicted_outputs,
711-
metrics,
712-
snapshot_number,
754+
y = y.to(
755+
self.parameters._configuration["device"],
756+
non_blocking=True,
713757
)
714-
for metric in metrics:
715-
errors[data_set_type][metric].append(
716-
calculated_errors[metric]
758+
759+
if (
760+
self.parameters.use_graphs
761+
and self._validation_graph is None
762+
):
763+
printout(
764+
"Capturing CUDA graph for validation.",
765+
min_verbosity=2,
717766
)
718-
return errors
767+
s = torch.cuda.Stream(
768+
self.parameters._configuration["device"]
769+
)
770+
s.wait_stream(
771+
torch.cuda.current_stream(
772+
self.parameters._configuration["device"]
773+
)
774+
)
775+
# Warmup for graphs
776+
with torch.cuda.stream(s):
777+
for _ in range(20):
778+
with torch.cuda.amp.autocast(
779+
enabled=self.parameters.use_mixed_precision
780+
):
781+
prediction = self.network(x)
782+
if self.parameters_full.use_ddp:
783+
loss = self.network.module.calculate_loss(
784+
prediction, y
785+
)
786+
else:
787+
loss = self.network.calculate_loss(
788+
prediction, y
789+
)
790+
torch.cuda.current_stream(
791+
self.parameters._configuration["device"]
792+
).wait_stream(s)
793+
794+
# Create static entry point tensors to graph
795+
self.static_input_validation = torch.empty_like(x)
796+
self.static_target_validation = torch.empty_like(y)
797+
798+
# Capture graph
799+
self._validation_graph = torch.cuda.CUDAGraph()
800+
with torch.cuda.graph(self._validation_graph):
801+
with torch.cuda.amp.autocast(
802+
enabled=self.parameters.use_mixed_precision
803+
):
804+
self.static_prediction_validation = (
805+
self.network(
806+
self.static_input_validation
807+
)
808+
)
809+
if self.parameters_full.use_ddp:
810+
self.static_loss_validation = self.network.module.calculate_loss(
811+
self.static_prediction_validation,
812+
self.static_target_validation,
813+
)
814+
else:
815+
self.static_loss_validation = self.network.calculate_loss(
816+
self.static_prediction_validation,
817+
self.static_target_validation,
818+
)
819+
820+
if self._validation_graph:
821+
self.static_input_validation.copy_(x)
822+
self.static_target_validation.copy_(y)
823+
self._validation_graph.replay()
824+
validation_loss_sum += self.static_loss_validation
825+
else:
826+
with torch.cuda.amp.autocast(
827+
enabled=self.parameters.use_mixed_precision
828+
):
829+
prediction = self.network(x)
830+
if self.parameters_full.use_ddp:
831+
loss = self.network.module.calculate_loss(
832+
prediction, y
833+
)
834+
else:
835+
loss = self.network.calculate_loss(
836+
prediction, y
837+
)
838+
validation_loss_sum += loss
839+
if batchid != 0 and (batchid + 1) % report_freq == 0:
840+
torch.cuda.synchronize(
841+
self.parameters._configuration["device"]
842+
)
843+
sample_time = time.time() - tsample
844+
avg_sample_time = sample_time / report_freq
845+
avg_sample_tput = (
846+
report_freq * x.shape[0] / sample_time
847+
)
848+
printout(
849+
f"batch {batchid + 1}, " # /{total_samples}, "
850+
f"validation avg time: {avg_sample_time} "
851+
f"validation avg throughput: {avg_sample_tput}",
852+
min_verbosity=2,
853+
)
854+
tsample = time.time()
855+
batchid += 1
856+
torch.cuda.synchronize(
857+
self.parameters._configuration["device"]
858+
)
859+
else:
860+
batchid = 0
861+
for loader in data_loaders:
862+
for x, y in loader:
863+
x = x.to(self.parameters._configuration["device"])
864+
y = y.to(self.parameters._configuration["device"])
865+
prediction = self.network(x)
866+
if self.parameters_full.use_ddp:
867+
validation_loss_sum += (
868+
self.network.module.calculate_loss(
869+
prediction, y
870+
).item()
871+
)
872+
else:
873+
validation_loss_sum += self.network.calculate_loss(
874+
prediction, y
875+
).item()
876+
batchid += 1
877+
878+
return validation_loss_sum.item() / batchid
719879

720880
def __prepare_to_train(self, optimizer_dict):
721881
"""Prepare everything for training."""

0 commit comments

Comments
 (0)