@@ -21,18 +21,26 @@ def load_model(
2121) -> torch .nn .Module :
2222 """Load a model from a file or return a randomly initialized model if none is provided."""
2323 # If device is not specified, use the default device (GPU if available, else CPU)
24- selected_device = device or torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
24+ selected_device = device or torch .device (
25+ "cuda" if torch .cuda .is_available () else "cpu"
26+ )
2527
2628 # Load model from file if a path is provided
2729 if isinstance (model , str | Path ):
28- loaded_model = torch .load (model , weights_only = False , map_location = selected_device )
30+ loaded_model = torch .load (
31+ model , weights_only = False , map_location = selected_device
32+ )
2933 elif isinstance (model , torch .nn .Module ):
3034 loaded_model = model
3135 elif model is None :
3236 # TODO: Implement randomly initialized model; requires model architecture definition
33- raise NotImplementedError ("Loading randomly initialized model is not implemented yet." )
37+ raise NotImplementedError (
38+ "Loading randomly initialized model is not implemented yet."
39+ )
3440 else :
35- raise TypeError (f"Expected a PyTorch Module or a file path, got { type (model )} instead." )
41+ raise TypeError (
42+ f"Expected a PyTorch Module or a file path, got { type (model )} instead."
43+ )
3644
3745 # Ensure the model is on the specified device
3846 loaded_model .to (selected_device )
@@ -47,7 +55,7 @@ def train(
4755 validation_split : float = 0.1 ,
4856 device : str = "cpu" ,
4957 learning_rate : float = 0.001 ,
50- epochs : int = 10 ,
58+ epochs : int = 25 ,
5159 batch_size : int = 512 ,
5260 patience : int = 5 ,
5361 trainable_layers : str | None = None ,
@@ -93,9 +101,15 @@ def train(
93101 logger .debug (f"Frozen all layers except those containing '{ trainable_layers } '" )
94102
95103 # Parse dataset and split arguments; setup loaders
96- train_dataset , val_dataset = split_datasets (train_data , validation_data , validation_split )
97- train_loader = DataLoader (train_dataset , batch_size = batch_size , shuffle = True , num_workers = 0 )
98- val_loader = DataLoader (val_dataset , batch_size = batch_size , shuffle = False , num_workers = 0 )
104+ train_dataset , val_dataset = split_datasets (
105+ train_data , validation_data , validation_split
106+ )
107+ train_loader = DataLoader (
108+ train_dataset , batch_size = batch_size , shuffle = True , num_workers = 0
109+ )
110+ val_loader = DataLoader (
111+ val_dataset , batch_size = batch_size , shuffle = False , num_workers = 0
112+ )
99113
100114 optimizer = _get_optimizer (model , learning_rate )
101115 loss_fn = torch .nn .L1Loss ()
@@ -137,7 +151,9 @@ def predict(
137151) -> torch .Tensor :
138152 """Predict using the model for the given dataset."""
139153 model = load_model (model , device )
140- data_loader = DataLoader (data , batch_size = batch_size , shuffle = False , num_workers = num_workers )
154+ data_loader = DataLoader (
155+ data , batch_size = batch_size , shuffle = False , num_workers = num_workers
156+ )
141157 predictions = _predict_epoch (model , data_loader , device )
142158 return predictions .cpu ().detach ()
143159
@@ -151,7 +167,9 @@ def evaluate(
151167) -> float :
152168 """Evaluate the model on the given dataset."""
153169 model = load_model (model , device )
154- data_loader = DataLoader (data , batch_size = batch_size , shuffle = False , num_workers = num_workers )
170+ data_loader = DataLoader (
171+ data , batch_size = batch_size , shuffle = False , num_workers = num_workers
172+ )
155173 loss_fn = torch .nn .L1Loss ()
156174 avg_loss = _validate_epoch (model , data_loader , loss_fn , device )
157175 return avg_loss
@@ -166,7 +184,9 @@ def _freeze_layers(model: torch.nn.Module, unfreeze_keyword: str) -> None:
166184 param .requires_grad = unfreeze_keyword in name
167185
168186
169- def _get_optimizer (model : torch .nn .Module , learning_rate : float ) -> torch .optim .Optimizer :
187+ def _get_optimizer (
188+ model : torch .nn .Module , learning_rate : float
189+ ) -> torch .optim .Optimizer :
170190 return torch .optim .Adam (
171191 filter (lambda p : p .requires_grad , model .parameters ()),
172192 lr = learning_rate ,
0 commit comments