@@ -25,25 +25,19 @@ def load_model(
2525) -> torch .nn .Module :
2626 """Load a model from a file or return a randomly initialized model if none is provided."""
2727 # If device is not specified, use the default device (GPU if available, else CPU)
28- selected_device = device or torch .device (
29- "cuda" if torch .cuda .is_available () else "cpu"
30- )
28+ selected_device = device or torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
3129
3230 # Load model from file if a path is provided
3331 if isinstance (model , str | Path ):
34- loaded_model = torch .load (
35- model , weights_only = False , map_location = selected_device
36- )
32+ loaded_model = torch .load (model , weights_only = False , map_location = selected_device )
3733 elif isinstance (model , torch .nn .Module ):
3834 loaded_model = model
3935 elif model is None :
4036 # Initialize a new model with default architecture
4137 loaded_model = DeepLCModel ()
4238 logger .debug ("Initialized new DeepLCModel with default architecture" )
4339 else :
44- raise TypeError (
45- f"Expected a PyTorch Module or a file path, got { type (model )} instead."
46- )
40+ raise TypeError (f"Expected a PyTorch Module or a file path, got { type (model )} instead." )
4741
4842 # Ensure the model is on the specified device
4943 loaded_model .to (selected_device )
@@ -151,9 +145,7 @@ def predict(
151145) -> torch .Tensor :
152146 """Predict using the model for the given dataset."""
153147 model = load_model (model , device )
154- data_loader = DataLoader (
155- data , batch_size = batch_size , shuffle = False , num_workers = num_workers
156- )
148+ data_loader = DataLoader (data , batch_size = batch_size , shuffle = False , num_workers = num_workers )
157149 predictions = _predict_epoch (model , data_loader , device )
158150 return predictions .cpu ().detach ()
159151
@@ -167,9 +159,7 @@ def evaluate(
167159) -> float :
168160 """Evaluate the model on the given dataset."""
169161 model = load_model (model , device )
170- data_loader = DataLoader (
171- data , batch_size = batch_size , shuffle = False , num_workers = num_workers
172- )
162+ data_loader = DataLoader (data , batch_size = batch_size , shuffle = False , num_workers = num_workers )
173163 loss_fn = torch .nn .L1Loss ()
174164 avg_loss = _validate_epoch (model , data_loader , loss_fn , device )
175165 return avg_loss
@@ -181,9 +171,7 @@ def _freeze_layers(model: torch.nn.Module, unfreeze_keyword: str) -> None:
181171 param .requires_grad = unfreeze_keyword in name
182172
183173
184- def _get_optimizer (
185- model : torch .nn .Module , learning_rate : float
186- ) -> torch .optim .Optimizer :
174+ def _get_optimizer (model : torch .nn .Module , learning_rate : float ) -> torch .optim .Optimizer :
187175 return torch .optim .Adam (
188176 filter (lambda p : p .requires_grad , model .parameters ()),
189177 lr = learning_rate ,
0 commit comments