Skip to content

Commit c5d8bd1

Browse files
committed
Improve API flexibility, input validation, and robustness
- Accept lists of PSMs, Peptidoforms, or strings as input for predict and calibrate functions, not just PSMList objects - Rename psm_list to psm_list_reference in finetune/train for clarity - Add num_threads parameter to train, predict, and evaluate - Add automatic device detection if device is set to `None` - Add input validation for empty data loaders and dataset splitting - Handle empty predictions gracefully by returning empty tensor - Fix target tensor shape in dataset when no target is available - Rename _exceptions module to exceptions (public API) - Add unit tests for edge cases in model ops and data splitting
1 parent c6d0e9f commit c5d8bd1

6 files changed

Lines changed: 138 additions & 23 deletions

File tree

deeplc/_model_ops.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def load_model(
3232
selected_device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
3333

3434
# Load model from file if a path is provided
35-
if isinstance(model, str | Path):
35+
if isinstance(model, (str, PathLike, Path)):
3636
loaded_model = torch.load(model, weights_only=False, map_location=selected_device)
3737
elif isinstance(model, torch.nn.Module):
3838
loaded_model = model
@@ -54,8 +54,9 @@ def train(
5454
model: torch.nn.Module | PathLike | str | None,
5555
train_dataset: DeepLCDataset | Subset[DeepLCDataset],
5656
validation_dataset: DeepLCDataset | Subset[DeepLCDataset],
57-
device: str = "cpu",
57+
device: str | None = None,
5858
num_workers: int = 0,
59+
num_threads: int | None = None,
5960
learning_rate: float = 0.001,
6061
epochs: int = 25,
6162
batch_size: int = 512,
@@ -77,6 +78,8 @@ def train(
7778
Device to train on ('cpu' or 'cuda').
7879
num_workers
7980
Number of worker processes for data loading.
81+
num_threads
82+
Number of threads for model operations on CPU (ignored if using GPU).
8083
learning_rate
8184
Learning rate for optimizer.
8285
epochs
@@ -94,6 +97,8 @@ def train(
9497
Trained model.
9598
9699
"""
100+
torch.set_num_threads(num_threads or torch.get_num_threads())
101+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
97102
model = load_model(model, device)
98103

99104
# Parse datasets; setup loaders
@@ -107,6 +112,13 @@ def train(
107112
num_workers=num_workers,
108113
)
109114

115+
if len(train_loader) == 0:
116+
raise ValueError("Training data loader is empty. Provide at least one training sample.")
117+
if len(val_loader) == 0:
118+
raise ValueError(
119+
"Validation data loader is empty. Adjust validation data or validation_split."
120+
)
121+
110122
optimizer = _get_optimizer(model, learning_rate)
111123
loss_fn = torch.nn.L1Loss()
112124

@@ -145,12 +157,15 @@ def train(
145157
def predict(
146158
model: torch.nn.Module | PathLike | str | None,
147159
data: Dataset,
148-
device: str = "cpu",
160+
device: str | None = None,
149161
batch_size: int = 512,
150162
num_workers: int = 0,
163+
num_threads: int | None = None,
151164
show_progress: bool = True,
152165
) -> torch.Tensor:
153166
"""Predict using the model for the given dataset."""
167+
torch.set_num_threads(num_threads or torch.get_num_threads())
168+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
154169
model = load_model(model, device)
155170
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
156171
predictions = _predict_epoch(model, data_loader, device, show_progress=show_progress)
@@ -160,11 +175,14 @@ def predict(
160175
def evaluate(
161176
model: torch.nn.Module | PathLike | str | None,
162177
data: Dataset,
163-
device: str = "cpu",
178+
device: str | None = None,
164179
batch_size: int = 512,
165180
num_workers: int = 0,
181+
num_threads: int | None = None,
166182
) -> float:
167183
"""Evaluate the model on the given dataset."""
184+
torch.set_num_threads(num_threads or torch.get_num_threads())
185+
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
168186
model = load_model(model, device)
169187
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
170188
loss_fn = torch.nn.L1Loss()
@@ -235,6 +253,8 @@ def _predict_epoch(
235253
features = [feature_tensor.to(device) for feature_tensor in features]
236254
outputs = model(*features)
237255
predictions.append(outputs.cpu())
256+
if not predictions:
257+
return torch.empty(0, dtype=torch.float32)
238258
return torch.cat(predictions, dim=0).squeeze()
239259

240260

deeplc/calibration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.pipeline import Pipeline, make_pipeline # type: ignore[import]
1212
from sklearn.preprocessing import SplineTransformer # type: ignore[import]
1313

14-
from deeplc._exceptions import CalibrationError
14+
from deeplc.exceptions import CalibrationError
1515

1616
LOGGER = logging.getLogger(__name__)
1717

deeplc/core.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010
import torch
11-
from psm_utils.psm_list import PSMList
11+
from psm_utils import PSM, Peptidoform, PSMList
1212

1313
from deeplc import _model_ops
1414
from deeplc.calibration import (
@@ -25,7 +25,7 @@
2525

2626

2727
def predict(
28-
psm_list: PSMList,
28+
psm_list: PSMList | list[PSM | Peptidoform | str],
2929
model: torch.nn.Module | PathLike | str | None = None,
3030
predict_kwargs: dict | None = None,
3131
) -> np.ndarray:
@@ -49,7 +49,7 @@ def predict(
4949
"""
5050
return _model_ops.predict(
5151
model=model or DEFAULT_MODEL,
52-
data=DeepLCDataset.from_psm_list(psm_list),
52+
data=DeepLCDataset.from_psm_list(_parse_psms(psm_list)),
5353
**(predict_kwargs or {}),
5454
).numpy()
5555

@@ -116,7 +116,7 @@ def calibrate(
116116

117117

118118
def predict_and_calibrate(
119-
psm_list: PSMList,
119+
psm_list: PSMList | list[PSM | Peptidoform | str],
120120
psm_list_reference: PSMList,
121121
model: torch.nn.Module | PathLike | str | None = None,
122122
calibration: Calibration | None = None,
@@ -147,7 +147,7 @@ def predict_and_calibrate(
147147
# Predict initial retention times
148148
LOGGER.info("Predicting retention times...")
149149
predicted_rt = predict(
150-
psm_list=psm_list,
150+
psm_list=_parse_psms(psm_list),
151151
model=model,
152152
predict_kwargs=predict_kwargs,
153153
)
@@ -175,7 +175,7 @@ def predict_and_calibrate(
175175

176176

177177
def finetune_and_predict(
178-
psm_list: PSMList,
178+
psm_list: PSMList | list[PSM | Peptidoform | str],
179179
psm_list_reference: PSMList,
180180
model: torch.nn.Module | PathLike | str | None = None,
181181
train_kwargs: dict | None = None,
@@ -205,15 +205,15 @@ def finetune_and_predict(
205205
"""
206206
# Fine-tune the model
207207
finetuned_model = finetune(
208-
psm_list=psm_list_reference,
208+
psm_list_reference=psm_list_reference,
209209
model=model,
210210
train_kwargs=train_kwargs,
211211
)
212212

213213
# Predict retention times with fine-tuned model
214214
LOGGER.info("Predicting retention times with fine-tuned model...")
215215
predicted_rt = predict(
216-
psm_list=psm_list,
216+
psm_list=_parse_psms(psm_list),
217217
model=finetuned_model,
218218
predict_kwargs=predict_kwargs,
219219
)
@@ -233,7 +233,7 @@ def finetune_and_predict(
233233

234234

235235
def finetune(
236-
psm_list: PSMList,
236+
psm_list_reference: PSMList,
237237
psm_list_validation: PSMList | None = None,
238238
validation_split: float = 0.1,
239239
model: torch.nn.Module | PathLike | str | None = None,
@@ -244,7 +244,7 @@ def finetune(
244244
245245
Parameters
246246
----------
247-
psm_list
247+
psm_list_reference
248248
List of PSMs to use as reference for fine-tuning.
249249
psm_list_validation
250250
List of PSMs to use for validation during fine-tuning. If None, a split from psm_list is
@@ -261,10 +261,10 @@ def finetune(
261261
262262
"""
263263
LOGGER.info("Fine-tuning model...")
264-
if any(psm_list["is_decoy"]):
264+
if any(psm_list_reference["is_decoy"]):
265265
# TODO: Move to reusable validation step?
266266
LOGGER.warning("PSM list contains decoy PSMs. These will be used for fine tuning.")
267-
training_data = DeepLCDataset.from_psm_list(psm_list)
267+
training_data = DeepLCDataset.from_psm_list(psm_list_reference)
268268
validation_data = (
269269
DeepLCDataset.from_psm_list(psm_list_validation) if psm_list_validation else None
270270
)
@@ -281,7 +281,7 @@ def finetune(
281281

282282

283283
def train(
284-
psm_list: PSMList,
284+
psm_list_reference: PSMList,
285285
psm_list_validation: PSMList | None = None,
286286
validation_split: float = 0.1,
287287
train_kwargs: dict | None = None,
@@ -291,8 +291,8 @@ def train(
291291
292292
Parameters
293293
----------
294-
psm_list
295-
List of PSMs to use for training.
294+
psm_list_reference
295+
List of PSMs to use as reference for fine-tuning.
296296
psm_list_validation
297297
List of PSMs to use for validation. If None, a split from psm_list is used.
298298
validation_split
@@ -306,7 +306,7 @@ def train(
306306
Trained model.
307307
308308
"""
309-
training_data = DeepLCDataset.from_psm_list(psm_list)
309+
training_data = DeepLCDataset.from_psm_list(psm_list_reference)
310310
validation_data = (
311311
DeepLCDataset.from_psm_list(psm_list_validation) if psm_list_validation else None
312312
)
@@ -321,3 +321,29 @@ def train(
321321
**(train_kwargs or {}),
322322
)
323323
return trained_model
324+
325+
326+
def _parse_psms(psm_list: PSMList | list[PSM | Peptidoform | str]) -> PSMList:
327+
"""
328+
Parse a list of PSMs, Peptidoforms, or strings into a PSMList.
329+
330+
Note that this function can only be used for inputs that do not require additional data,
331+
such as retention times or decoy status. It cannot be used for reference or validation
332+
data sets that require observed retention times for calibration or training.
333+
334+
"""
335+
if isinstance(psm_list, PSMList):
336+
return psm_list
337+
elif isinstance(psm_list, list):
338+
if all(isinstance(psm, PSM) for psm in psm_list):
339+
return PSMList(psm_list=psm_list)
340+
elif all(isinstance(psm, Peptidoform) for psm in psm_list) or all(
341+
isinstance(psm, str) for psm in psm_list
342+
):
343+
return PSMList(
344+
psm_list=[PSM(spectrum_id=i, peptidoform=pf) for i, pf in enumerate(psm_list)]
345+
)
346+
else:
347+
raise ValueError("List must contain either PSMs, Peptidoforms, or strings.")
348+
else:
349+
raise ValueError("Input must be a PSMList or a list of PSMs, Peptidoforms, or strings.")

deeplc/data.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __getitem__(self, idx) -> tuple:
7474
targets = (
7575
self.target_retention_times[idx]
7676
if self.target_retention_times is not None
77-
else torch.full_like(feature_tuples[0], fill_value=float("nan"), dtype=torch.float32)
77+
else torch.tensor(float("nan"), dtype=torch.float32)
7878
)
7979
return feature_tuples, targets
8080

@@ -160,10 +160,19 @@ def split_datasets(
160160
"""
161161
# TODO: Implement stratified splitting based on stripped sequence
162162
if validation_data is None:
163+
if not 0 < validation_split < 1:
164+
raise ValueError(
165+
f"validation_split must be between 0 and 1 (exclusive), got {validation_split}."
166+
)
163167
if not hasattr(train_data, "__len__"):
164168
raise ValueError("Dataset must implement __len__ method for automatic splitting")
165169
dataset_len = len(train_data) # type: ignore[arg-type]
166-
val_size = int(dataset_len * validation_split)
170+
if dataset_len < 2:
171+
raise ValueError(
172+
"Need at least 2 samples in train_data when validation_data is not provided."
173+
)
174+
val_size = max(1, int(dataset_len * validation_split))
175+
val_size = min(val_size, dataset_len - 1)
167176
train_size = dataset_len - val_size
168177
train_dataset, val_dataset = torch.utils.data.random_split(
169178
train_data, [train_size, val_size]

tests/test_model_ops.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
import torch
5+
from torch.utils.data import Dataset
6+
7+
from deeplc import _model_ops
8+
from deeplc.data import split_datasets
9+
10+
11+
class _TinyDeepLCDataset(Dataset):
12+
def __init__(self, length: int):
13+
self.length = length
14+
15+
def __len__(self) -> int:
16+
return self.length
17+
18+
def __getitem__(self, index: int):
19+
features = (
20+
torch.zeros((60, 6), dtype=torch.float32),
21+
torch.zeros((30, 6), dtype=torch.float32),
22+
torch.zeros((55,), dtype=torch.float32),
23+
torch.zeros((60, 20), dtype=torch.float32),
24+
)
25+
target = torch.tensor(0.0, dtype=torch.float32)
26+
return features, target
27+
28+
29+
class _DummyModel(torch.nn.Module):
30+
def forward(self, matrix, matrix_sum, matrix_global, matrix_hc): # noqa: ARG002
31+
batch_size = matrix.shape[0]
32+
return torch.zeros((batch_size, 1), dtype=torch.float32)
33+
34+
35+
def test_predict_returns_empty_tensor_for_empty_dataset():
36+
empty_data = _TinyDeepLCDataset(length=0)
37+
preds = _model_ops.predict(model=_DummyModel(), data=empty_data, show_progress=False)
38+
assert isinstance(preds, torch.Tensor)
39+
assert preds.numel() == 0
40+
41+
42+
def test_split_datasets_rejects_too_small_dataset_without_validation_data():
43+
with pytest.raises(ValueError, match="Need at least 2 samples"):
44+
split_datasets(
45+
train_data=_TinyDeepLCDataset(length=1),
46+
validation_data=None,
47+
validation_split=0.1,
48+
)
49+
50+
51+
def test_train_rejects_empty_validation_loader():
52+
with pytest.raises(ValueError, match="Validation data loader is empty"):
53+
_model_ops.train(
54+
model=_DummyModel(),
55+
train_dataset=_TinyDeepLCDataset(length=2),
56+
validation_dataset=_TinyDeepLCDataset(length=0),
57+
epochs=1,
58+
batch_size=2,
59+
show_progress=False,
60+
)

0 commit comments

Comments
 (0)