Skip to content

Commit 20a06f2

Browse files
Refactored code internally
1 parent d3043e6 commit 20a06f2

1 file changed

Lines changed: 150 additions & 185 deletions

File tree

mala/network/trainer.py

Lines changed: 150 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -678,196 +678,17 @@ def _validate_network(self, data_set_fractions, metrics):
678678
# If only the LDOS is in the validation metrics (as is the
679679
# case for, e.g., distributed network trainings), we can
680680
# use a faster (or at least better parallelizing) code
681+
681682
if (
682683
len(self.parameters.validation_metrics) == 1
683684
and self.parameters.validation_metrics[0] == "ldos"
684685
):
685-
validation_loss_sum = torch.zeros(
686-
1, device=self.parameters._configuration["device"]
687-
)
688-
with torch.no_grad():
689-
if self.parameters._configuration["gpu"]:
690-
report_freq = self.parameters.training_log_interval
691-
torch.cuda.synchronize(
692-
self.parameters._configuration["device"]
693-
)
694-
tsample = time.time()
695-
batchid = 0
696-
for loader in data_loaders:
697-
for x, y in loader:
698-
x = x.to(
699-
self.parameters._configuration[
700-
"device"
701-
],
702-
non_blocking=True,
703-
)
704-
y = y.to(
705-
self.parameters._configuration[
706-
"device"
707-
],
708-
non_blocking=True,
709-
)
710686

711-
if (
712-
self.parameters.use_graphs
713-
and self._validation_graph is None
714-
):
715-
printout(
716-
"Capturing CUDA graph for validation.",
717-
min_verbosity=2,
718-
)
719-
s = torch.cuda.Stream(
720-
self.parameters._configuration[
721-
"device"
722-
]
723-
)
724-
s.wait_stream(
725-
torch.cuda.current_stream(
726-
self.parameters._configuration[
727-
"device"
728-
]
729-
)
730-
)
731-
# Warmup for graphs
732-
with torch.cuda.stream(s):
733-
for _ in range(20):
734-
with torch.cuda.amp.autocast(
735-
enabled=self.parameters.use_mixed_precision
736-
):
737-
prediction = self.network(
738-
x
739-
)
740-
if (
741-
self.parameters_full.use_ddp
742-
):
743-
loss = self.network.module.calculate_loss(
744-
prediction, y
745-
)
746-
else:
747-
loss = self.network.calculate_loss(
748-
prediction, y
749-
)
750-
torch.cuda.current_stream(
751-
self.parameters._configuration[
752-
"device"
753-
]
754-
).wait_stream(s)
755-
756-
# Create static entry point tensors to graph
757-
self.static_input_validation = (
758-
torch.empty_like(x)
759-
)
760-
self.static_target_validation = (
761-
torch.empty_like(y)
762-
)
763-
764-
# Capture graph
765-
self._validation_graph = (
766-
torch.cuda.CUDAGraph()
767-
)
768-
with torch.cuda.graph(
769-
self._validation_graph
770-
):
771-
with torch.cuda.amp.autocast(
772-
enabled=self.parameters.use_mixed_precision
773-
):
774-
self.static_prediction_validation = self.network(
775-
self.static_input_validation
776-
)
777-
if (
778-
self.parameters_full.use_ddp
779-
):
780-
self.static_loss_validation = self.network.module.calculate_loss(
781-
self.static_prediction_validation,
782-
self.static_target_validation,
783-
)
784-
else:
785-
self.static_loss_validation = self.network.calculate_loss(
786-
self.static_prediction_validation,
787-
self.static_target_validation,
788-
)
789-
790-
if self._validation_graph:
791-
self.static_input_validation.copy_(x)
792-
self.static_target_validation.copy_(y)
793-
self._validation_graph.replay()
794-
validation_loss_sum += (
795-
self.static_loss_validation
796-
)
797-
else:
798-
with torch.cuda.amp.autocast(
799-
enabled=self.parameters.use_mixed_precision
800-
):
801-
prediction = self.network(x)
802-
if self.parameters_full.use_ddp:
803-
loss = self.network.module.calculate_loss(
804-
prediction, y
805-
)
806-
else:
807-
loss = self.network.calculate_loss(
808-
prediction, y
809-
)
810-
validation_loss_sum += loss
811-
if (
812-
batchid != 0
813-
and (batchid + 1) % report_freq == 0
814-
):
815-
torch.cuda.synchronize(
816-
self.parameters._configuration[
817-
"device"
818-
]
819-
)
820-
sample_time = time.time() - tsample
821-
avg_sample_time = (
822-
sample_time / report_freq
823-
)
824-
avg_sample_tput = (
825-
report_freq
826-
* x.shape[0]
827-
/ sample_time
828-
)
829-
printout(
830-
f"batch {batchid + 1}, " # /{total_samples}, "
831-
f"validation avg time: {avg_sample_time} "
832-
f"validation avg throughput: {avg_sample_tput}",
833-
min_verbosity=2,
834-
)
835-
tsample = time.time()
836-
batchid += 1
837-
torch.cuda.synchronize(
838-
self.parameters._configuration["device"]
839-
)
840-
else:
841-
batchid = 0
842-
for loader in data_loaders:
843-
for x, y in loader:
844-
x = x.to(
845-
self.parameters._configuration[
846-
"device"
847-
]
848-
)
849-
y = y.to(
850-
self.parameters._configuration[
851-
"device"
852-
]
853-
)
854-
prediction = self.network(x)
855-
if self.parameters_full.use_ddp:
856-
validation_loss_sum += (
857-
self.network.module.calculate_loss(
858-
prediction, y
859-
).item()
860-
)
861-
else:
862-
validation_loss_sum += (
863-
self.network.calculate_loss(
864-
prediction, y
865-
).item()
866-
)
867-
batchid += 1
868-
869-
validation_loss = validation_loss_sum.item() / batchid
870-
errors[data_set_type]["ldos"] = validation_loss
687+
errors[data_set_type]["ldos"] = (
688+
self.__calculate_validation_error_ldos_only(
689+
data_loaders
690+
)
691+
)
871692

872693
else:
873694
with torch.no_grad():
@@ -912,6 +733,150 @@ def _validate_network(self, data_set_fractions, metrics):
912733
)
913734
return errors
914735

736+
def __calculate_validation_error_ldos_only(self, data_loaders):
737+
validation_loss_sum = torch.zeros(
738+
1, device=self.parameters._configuration["device"]
739+
)
740+
with torch.no_grad():
741+
if self.parameters._configuration["gpu"]:
742+
report_freq = self.parameters.training_log_interval
743+
torch.cuda.synchronize(
744+
self.parameters._configuration["device"]
745+
)
746+
tsample = time.time()
747+
batchid = 0
748+
for loader in data_loaders:
749+
for x, y in loader:
750+
x = x.to(
751+
self.parameters._configuration["device"],
752+
non_blocking=True,
753+
)
754+
y = y.to(
755+
self.parameters._configuration["device"],
756+
non_blocking=True,
757+
)
758+
759+
if (
760+
self.parameters.use_graphs
761+
and self._validation_graph is None
762+
):
763+
printout(
764+
"Capturing CUDA graph for validation.",
765+
min_verbosity=2,
766+
)
767+
s = torch.cuda.Stream(
768+
self.parameters._configuration["device"]
769+
)
770+
s.wait_stream(
771+
torch.cuda.current_stream(
772+
self.parameters._configuration["device"]
773+
)
774+
)
775+
# Warmup for graphs
776+
with torch.cuda.stream(s):
777+
for _ in range(20):
778+
with torch.cuda.amp.autocast(
779+
enabled=self.parameters.use_mixed_precision
780+
):
781+
prediction = self.network(x)
782+
if self.parameters_full.use_ddp:
783+
loss = self.network.module.calculate_loss(
784+
prediction, y
785+
)
786+
else:
787+
loss = self.network.calculate_loss(
788+
prediction, y
789+
)
790+
torch.cuda.current_stream(
791+
self.parameters._configuration["device"]
792+
).wait_stream(s)
793+
794+
# Create static entry point tensors to graph
795+
self.static_input_validation = torch.empty_like(x)
796+
self.static_target_validation = torch.empty_like(y)
797+
798+
# Capture graph
799+
self._validation_graph = torch.cuda.CUDAGraph()
800+
with torch.cuda.graph(self._validation_graph):
801+
with torch.cuda.amp.autocast(
802+
enabled=self.parameters.use_mixed_precision
803+
):
804+
self.static_prediction_validation = (
805+
self.network(
806+
self.static_input_validation
807+
)
808+
)
809+
if self.parameters_full.use_ddp:
810+
self.static_loss_validation = self.network.module.calculate_loss(
811+
self.static_prediction_validation,
812+
self.static_target_validation,
813+
)
814+
else:
815+
self.static_loss_validation = self.network.calculate_loss(
816+
self.static_prediction_validation,
817+
self.static_target_validation,
818+
)
819+
820+
if self._validation_graph:
821+
self.static_input_validation.copy_(x)
822+
self.static_target_validation.copy_(y)
823+
self._validation_graph.replay()
824+
validation_loss_sum += self.static_loss_validation
825+
else:
826+
with torch.cuda.amp.autocast(
827+
enabled=self.parameters.use_mixed_precision
828+
):
829+
prediction = self.network(x)
830+
if self.parameters_full.use_ddp:
831+
loss = self.network.module.calculate_loss(
832+
prediction, y
833+
)
834+
else:
835+
loss = self.network.calculate_loss(
836+
prediction, y
837+
)
838+
validation_loss_sum += loss
839+
if batchid != 0 and (batchid + 1) % report_freq == 0:
840+
torch.cuda.synchronize(
841+
self.parameters._configuration["device"]
842+
)
843+
sample_time = time.time() - tsample
844+
avg_sample_time = sample_time / report_freq
845+
avg_sample_tput = (
846+
report_freq * x.shape[0] / sample_time
847+
)
848+
printout(
849+
f"batch {batchid + 1}, " # /{total_samples}, "
850+
f"validation avg time: {avg_sample_time} "
851+
f"validation avg throughput: {avg_sample_tput}",
852+
min_verbosity=2,
853+
)
854+
tsample = time.time()
855+
batchid += 1
856+
torch.cuda.synchronize(
857+
self.parameters._configuration["device"]
858+
)
859+
else:
860+
batchid = 0
861+
for loader in data_loaders:
862+
for x, y in loader:
863+
x = x.to(self.parameters._configuration["device"])
864+
y = y.to(self.parameters._configuration["device"])
865+
prediction = self.network(x)
866+
if self.parameters_full.use_ddp:
867+
validation_loss_sum += (
868+
self.network.module.calculate_loss(
869+
prediction, y
870+
).item()
871+
)
872+
else:
873+
validation_loss_sum += self.network.calculate_loss(
874+
prediction, y
875+
).item()
876+
batchid += 1
877+
878+
return validation_loss_sum.item() / batchid
879+
915880
def __prepare_to_train(self, optimizer_dict):
916881
"""Prepare everything for training."""
917882
# Configure keyword arguments for DataSampler.

0 commit comments

Comments
 (0)