@@ -372,6 +372,7 @@ class ModelWrapper(torch.nn.Module):
372372 Wraps the encoder, decoder, and neural ODE into a single model.
373373 Chooses architecture based on the config.model_version flag.
374374 """
375+
375376 def __init__ (self , config , n_quantities : int ):
376377 super ().__init__ ()
377378 self .config = config
@@ -383,7 +384,9 @@ def __init__(self, config, n_quantities: int):
383384 self .encoder = OldEncoder (
384385 in_features = n_quantities ,
385386 latent_features = config .latent_features ,
386- layers_factor = getattr (config , "layers_factor" , 8 ), # for backward compatibility
387+ layers_factor = getattr (
388+ config , "layers_factor" , 8
389+ ), # for backward compatibility
387390 activation = config .activation ,
388391 )
389392 self .decoder = OldDecoder (
@@ -640,7 +643,8 @@ def __init__(
640643
641644 def forward (self , x : torch .Tensor ) -> torch .Tensor :
642645 """Forward pass to encode the input into the latent space."""
643- return self .mlp (x )
646+ out = self .mlp (x )
647+ return out
644648
645649
646650class Decoder (torch .nn .Module ):
@@ -683,7 +687,8 @@ def __init__(
683687
684688 def forward (self , x : torch .Tensor ) -> torch .Tensor :
685689 """Forward pass to decode the latent representation into output features."""
686- return self .mlp (x )
690+ out = self .mlp (x )
691+ return out
687692
688693
689694class OldEncoder (torch .nn .Module ):
0 commit comments