Skip to content

Commit 37f4b86

Browse files
change default epochs in train
1 parent f93cb83 commit 37f4b86

1 file changed

Lines changed: 31 additions & 11 deletions

File tree

deeplc/_model_ops.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,26 @@ def load_model(
2121
) -> torch.nn.Module:
2222
"""Load a model from a file or return a randomly initialized model if none is provided."""
2323
# If device is not specified, use the default device (GPU if available, else CPU)
24-
selected_device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
24+
selected_device = device or torch.device(
25+
"cuda" if torch.cuda.is_available() else "cpu"
26+
)
2527

2628
# Load model from file if a path is provided
2729
if isinstance(model, str | Path):
28-
loaded_model = torch.load(model, weights_only=False, map_location=selected_device)
30+
loaded_model = torch.load(
31+
model, weights_only=False, map_location=selected_device
32+
)
2933
elif isinstance(model, torch.nn.Module):
3034
loaded_model = model
3135
elif model is None:
3236
# TODO: Implement randomly initialized model; requires model architecture definition
33-
raise NotImplementedError("Loading randomly initialized model is not implemented yet.")
37+
raise NotImplementedError(
38+
"Loading randomly initialized model is not implemented yet."
39+
)
3440
else:
35-
raise TypeError(f"Expected a PyTorch Module or a file path, got {type(model)} instead.")
41+
raise TypeError(
42+
f"Expected a PyTorch Module or a file path, got {type(model)} instead."
43+
)
3644

3745
# Ensure the model is on the specified device
3846
loaded_model.to(selected_device)
@@ -47,7 +55,7 @@ def train(
4755
validation_split: float = 0.1,
4856
device: str = "cpu",
4957
learning_rate: float = 0.001,
50-
epochs: int = 10,
58+
epochs: int = 25,
5159
batch_size: int = 512,
5260
patience: int = 5,
5361
trainable_layers: str | None = None,
@@ -93,9 +101,15 @@ def train(
93101
logger.debug(f"Frozen all layers except those containing '{trainable_layers}'")
94102

95103
# Parse dataset and split arguments; setup loaders
96-
train_dataset, val_dataset = split_datasets(train_data, validation_data, validation_split)
97-
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
98-
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
104+
train_dataset, val_dataset = split_datasets(
105+
train_data, validation_data, validation_split
106+
)
107+
train_loader = DataLoader(
108+
train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
109+
)
110+
val_loader = DataLoader(
111+
val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
112+
)
99113

100114
optimizer = _get_optimizer(model, learning_rate)
101115
loss_fn = torch.nn.L1Loss()
@@ -137,7 +151,9 @@ def predict(
137151
) -> torch.Tensor:
138152
"""Predict using the model for the given dataset."""
139153
model = load_model(model, device)
140-
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
154+
data_loader = DataLoader(
155+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
156+
)
141157
predictions = _predict_epoch(model, data_loader, device)
142158
return predictions.cpu().detach()
143159

@@ -151,7 +167,9 @@ def evaluate(
151167
) -> float:
152168
"""Evaluate the model on the given dataset."""
153169
model = load_model(model, device)
154-
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
170+
data_loader = DataLoader(
171+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
172+
)
155173
loss_fn = torch.nn.L1Loss()
156174
avg_loss = _validate_epoch(model, data_loader, loss_fn, device)
157175
return avg_loss
@@ -166,7 +184,9 @@ def _freeze_layers(model: torch.nn.Module, unfreeze_keyword: str) -> None:
166184
param.requires_grad = unfreeze_keyword in name
167185

168186

169-
def _get_optimizer(model: torch.nn.Module, learning_rate: float) -> torch.optim.Optimizer:
187+
def _get_optimizer(
188+
model: torch.nn.Module, learning_rate: float
189+
) -> torch.optim.Optimizer:
170190
return torch.optim.Adam(
171191
filter(lambda p: p.requires_grad, model.parameters()),
172192
lr=learning_rate,

0 commit comments

Comments
 (0)