Skip to content

Commit 35bc90f

Browse files
committed
Fix formatting
1 parent bce8675 commit 35bc90f

2 files changed

Lines changed: 7 additions & 18 deletions

File tree

deeplc/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# TODO: Add CLI functionality
99

10+
1011
def _setup_logging(passed_level):
1112
log_mapping = {
1213
"critical": logging.CRITICAL,

deeplc/_model_ops.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,19 @@ def load_model(
2525
) -> torch.nn.Module:
2626
"""Load a model from a file or return a randomly initialized model if none is provided."""
2727
# If device is not specified, use the default device (GPU if available, else CPU)
28-
selected_device = device or torch.device(
29-
"cuda" if torch.cuda.is_available() else "cpu"
30-
)
28+
selected_device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
3129

3230
# Load model from file if a path is provided
3331
if isinstance(model, str | Path):
34-
loaded_model = torch.load(
35-
model, weights_only=False, map_location=selected_device
36-
)
32+
loaded_model = torch.load(model, weights_only=False, map_location=selected_device)
3733
elif isinstance(model, torch.nn.Module):
3834
loaded_model = model
3935
elif model is None:
4036
# Initialize a new model with default architecture
4137
loaded_model = DeepLCModel()
4238
logger.debug("Initialized new DeepLCModel with default architecture")
4339
else:
44-
raise TypeError(
45-
f"Expected a PyTorch Module or a file path, got {type(model)} instead."
46-
)
40+
raise TypeError(f"Expected a PyTorch Module or a file path, got {type(model)} instead.")
4741

4842
# Ensure the model is on the specified device
4943
loaded_model.to(selected_device)
@@ -151,9 +145,7 @@ def predict(
151145
) -> torch.Tensor:
152146
"""Predict using the model for the given dataset."""
153147
model = load_model(model, device)
154-
data_loader = DataLoader(
155-
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
156-
)
148+
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
157149
predictions = _predict_epoch(model, data_loader, device)
158150
return predictions.cpu().detach()
159151

@@ -167,9 +159,7 @@ def evaluate(
167159
) -> float:
168160
"""Evaluate the model on the given dataset."""
169161
model = load_model(model, device)
170-
data_loader = DataLoader(
171-
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
172-
)
162+
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
173163
loss_fn = torch.nn.L1Loss()
174164
avg_loss = _validate_epoch(model, data_loader, loss_fn, device)
175165
return avg_loss
@@ -181,9 +171,7 @@ def _freeze_layers(model: torch.nn.Module, unfreeze_keyword: str) -> None:
181171
param.requires_grad = unfreeze_keyword in name
182172

183173

184-
def _get_optimizer(
185-
model: torch.nn.Module, learning_rate: float
186-
) -> torch.optim.Optimizer:
174+
def _get_optimizer(model: torch.nn.Module, learning_rate: float) -> torch.optim.Optimizer:
187175
return torch.optim.Adam(
188176
filter(lambda p: p.requires_grad, model.parameters()),
189177
lr=learning_rate,

0 commit comments

Comments
 (0)