Skip to content

Commit d3043e6

Browse files
Forgot a renaming
1 parent 8358e03 commit d3043e6

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

mala/network/trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def _validate_network(self, data_set_fractions, metrics):
710710

711711
if (
712712
self.parameters.use_graphs
713-
and self.validation_graph is None
713+
and self._validation_graph is None
714714
):
715715
printout(
716716
"Capturing CUDA graph for validation.",
@@ -762,11 +762,11 @@ def _validate_network(self, data_set_fractions, metrics):
762762
)
763763

764764
# Capture graph
765-
self.validation_graph = (
765+
self._validation_graph = (
766766
torch.cuda.CUDAGraph()
767767
)
768768
with torch.cuda.graph(
769-
self.validation_graph
769+
self._validation_graph
770770
):
771771
with torch.cuda.amp.autocast(
772772
enabled=self.parameters.use_mixed_precision
@@ -787,10 +787,10 @@ def _validate_network(self, data_set_fractions, metrics):
787787
self.static_target_validation,
788788
)
789789

790-
if self.validation_graph:
790+
if self._validation_graph:
791791
self.static_input_validation.copy_(x)
792792
self.static_target_validation.copy_(y)
793-
self.validation_graph.replay()
793+
self._validation_graph.replay()
794794
validation_loss_sum += (
795795
self.static_loss_validation
796796
)

0 commit comments

Comments
 (0)