Skip to content

Commit 6e60ff2

Browse files
committed
Fix remaining issues from code review
1 parent 1626bec commit 6e60ff2

6 files changed

Lines changed: 40 additions & 46 deletions

File tree

im2deep/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def cli(ctx, logging_level, profile, profile_name):
9494

9595
console.print(_build_credits())
9696

97-
97+
# TODO: Check that parameters match predict function in core
9898
# Implement psm_utils reading for calibration and prediction PSMLists
9999
@cli.command()
100100
@click.pass_context

im2deep/_io_helpers.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ def validate_psm_list(psm_list: PSMList, needs_target: bool = False) -> PSMList:
239239
# Filter missing and high charge states (IM2Deep predictions are not reliable for charges >6)
240240
original_size = len(psm_list)
241241
charges = np.array([psm.peptidoform.precursor_charge for psm in psm_list])
242-
psm_list_filtered = psm_list[charges != None] # noqa: E711
243-
psm_list_filtered = psm_list_filtered[charges <= 6]
242+
psm_list_filtered = psm_list[(charges != None) & (charges <= 6)] # noqa: E711
244243

245244
# TODO: Is deepcopy really necessary or can it be avoided?
246245
psm_list_filtered = deepcopy(psm_list_filtered)
@@ -263,25 +262,20 @@ def validate_psm_list(psm_list: PSMList, needs_target: bool = False) -> PSMList:
263262
for psm in psm_list_filtered
264263
)
265264

265+
# TODO: Could be vectorized over all ion mobility values
266266
# If ion_mobility is present, convert to CCS
267267
for psm in psm_list_filtered:
268-
if (
269-
psm.ion_mobility is not None
270-
and psm.metadata is not None
271-
and psm.metadata.get("CCS") is None
272-
):
273-
psm.metadata["CCS"] = str(
274-
im2ccs(
275-
psm.ion_mobility,
276-
psm.peptidoform.theoretical_mz,
277-
psm.peptidoform.precursor_charge,
268+
if psm.ion_mobility is not None:
269+
if psm.metadata is None:
270+
psm.metadata = {}
271+
if "CCS" not in psm.metadata:
272+
psm.metadata["CCS"] = str(
273+
im2ccs(
274+
psm.ion_mobility,
275+
psm.peptidoform.theoretical_mz,
276+
psm.peptidoform.precursor_charge,
277+
)
278278
)
279-
)
280-
# Ensure CCS is always stored as float
281-
elif psm.metadata.get("CCS") is not None:
282-
ccs_value = psm.metadata["CCS"]
283-
if not isinstance(ccs_value, float):
284-
psm.metadata["CCS"] = float(ccs_value)
285279

286280
if needs_target and not all_has_targets:
287281
raise IM2DeepError("PSMList must contain 'ion_mobility' or 'CCS' metadata for all PSMs.")

im2deep/_model_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def predict(
130130
config=_get_model_config(multi=multi),
131131
criterion=_get_loss_function(multi=multi),
132132
)
133+
model.to(device)
133134
model.eval()
134135
LOGGER.debug(f"Model loaded on device: {device}")
135136

im2deep/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from psm_utils.psm_list import PSMList
1212

1313
from im2deep import _model_ops
14+
from im2deep._io_helpers import validate_psm_list
1415
from im2deep.calibration import Calibration, LinearCCSCalibration
1516
from im2deep.constants import DEFAULT_MODEL, DEFAULT_MULTI_MODEL
16-
from im2deep._io_helpers import validate_psm_list
1717

1818
LOGGER = logging.getLogger(__name__)
1919

tests/test_model_ops.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def test_load_model_from_path(self, temp_model_path):
2020
loaded_model = _model_ops.load_model(temp_model_path)
2121

2222
assert isinstance(loaded_model, torch.nn.Module)
23-
assert loaded_model.training is False # Should be in eval mode by default
23+
# load_model does not force eval mode; caller controls train/eval state.
24+
assert loaded_model.training is model.training
2425

2526
def test_load_model_from_module(self):
2627
"""Test loading model from existing module."""
@@ -38,7 +39,7 @@ def test_load_model_none(self):
3839
def test_load_model_invalid_type(self):
3940
"""Test loading model with invalid type raises TypeError."""
4041
with pytest.raises(TypeError):
41-
_model_ops.load_model(12345)
42+
_model_ops.load_model(12345) # type: ignore
4243

4344
def test_load_model_dict_checkpoint(self, temp_model_path):
4445
"""Test loading model from dict checkpoint."""
@@ -169,7 +170,7 @@ def test_predict_loop_single_output(self):
169170

170171
with patch("im2deep._model_ops.track", return_value=mock_data):
171172
predictions = _model_ops._predict_loop(
172-
model=model, data_loader=mock_data, device="cpu"
173+
model=model, data_loader=mock_data, device="cpu" # type: ignore
173174
)
174175

175176
assert isinstance(predictions, torch.Tensor)
@@ -190,7 +191,7 @@ def test_predict_loop_multi_output(self):
190191

191192
with patch("im2deep._model_ops.track", return_value=mock_data):
192193
predictions = _model_ops._predict_loop(
193-
model=model, data_loader=mock_data, device="cpu"
194+
model=model, data_loader=mock_data, device="cpu" # type: ignore
194195
)
195196

196197
assert isinstance(predictions, torch.Tensor)
@@ -211,7 +212,7 @@ def test_predict_loop_no_grad(self):
211212
# Mock track to return our mock data
212213
with patch("im2deep._model_ops.track", return_value=mock_data):
213214
predictions = _model_ops._predict_loop(
214-
model=model, data_loader=mock_data, device="cpu"
215+
model=model, data_loader=mock_data, device="cpu" # type: ignore
215216
)
216217

217218
assert not predictions.requires_grad
@@ -220,14 +221,14 @@ def test_predict_loop_no_grad(self):
220221
class TestGetArchitecture:
221222
"""Tests for _get_architecture function."""
222223

223-
@patch("im2deep._architecture.IM2Deep")
224+
@patch("im2deep._model_ops.IM2Deep")
224225
def test_get_architecture_single(self, mock_im2deep):
225226
"""Test getting single-output architecture."""
226227
arch = _model_ops._get_architecture(multi=False)
227228
# Should import IM2Deep
228229
assert arch is mock_im2deep
229230

230-
@patch("im2deep._architecture.IM2DeepMultiTransfer")
231+
@patch("im2deep._model_ops.IM2DeepMultiTransfer")
231232
def test_get_architecture_multi(self, mock_multi):
232233
"""Test getting multi-output architecture."""
233234
arch = _model_ops._get_architecture(multi=True)
@@ -257,7 +258,7 @@ def test_get_loss_function_single(self):
257258
loss = _model_ops._get_loss_function(multi=False)
258259
assert isinstance(loss, torch.nn.modules.loss._Loss)
259260

260-
@patch("im2deep._architecture.FlexibleLossSorted")
261+
@patch("im2deep._model_ops.FlexibleLossSorted")
261262
def test_get_loss_function_multi(self, mock_loss):
262263
"""Test getting multi-output loss function."""
263264
mock_instance = MagicMock()

tests/test_utils.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,9 @@
55
import pytest
66
from psm_utils import PSMList
77

8+
from im2deep._io_helpers import parse_input, validate_psm_list
89
from im2deep.exceptions import IM2DeepError
9-
from im2deep.utils import (
10-
ccs2im,
11-
im2ccs,
12-
parse_input,
13-
validate_psm_list,
14-
)
10+
from im2deep.utils import ccs2im, im2ccs
1511

1612

1713
class TestValidatePSMList:
@@ -28,6 +24,7 @@ def test_validate_psm_list_with_ccs(self, sample_psm_list_with_ccs):
2824
result = validate_psm_list(sample_psm_list_with_ccs, needs_target=True)
2925
assert isinstance(result, PSMList)
3026
for psm in result:
27+
assert psm.metadata is not None
3128
assert "CCS" in psm.metadata
3229
# CCS should always be stored as float
3330
assert isinstance(psm.metadata["CCS"], float)
@@ -46,7 +43,7 @@ def test_validate_psm_list_empty(self):
4643
def test_validate_psm_list_not_psm_list(self):
4744
"""Test validation fails with non-PSMList input."""
4845
with pytest.raises(IM2DeepError, match="PSMList"):
49-
validate_psm_list([1, 2, 3])
46+
validate_psm_list([1, 2, 3]) # type: ignore
5047

5148

5249
class TestParseInput:
@@ -131,6 +128,7 @@ def test_parse_input_legacy_format_detection(self, tmp_path):
131128
assert len(result) == 2
132129
# Check that CCS values are preserved in metadata
133130
for psm in result:
131+
assert psm.metadata is not None
134132
assert "CCS" in psm.metadata
135133

136134

@@ -142,7 +140,7 @@ def test_ccs2im_basic(self):
142140
ccs = 450.0
143141
charge = 2
144142
mz = 500.0
145-
im = ccs2im(ccs, charge, mz)
143+
im = ccs2im(ccs, mz, charge)
146144

147145
assert isinstance(im, float)
148146
assert im > 0
@@ -153,7 +151,7 @@ def test_ccs2im_array(self):
153151
charge = np.array([2, 3, 2])
154152
mz = np.array([500.0, 600.0, 550.0])
155153

156-
im = ccs2im(ccs, charge, mz)
154+
im = ccs2im(ccs, mz, charge)
157155

158156
assert isinstance(im, np.ndarray)
159157
assert len(im) == len(ccs)
@@ -164,7 +162,7 @@ def test_im2ccs_basic(self):
164162
im = 1.0
165163
charge = 2
166164
mz = 500.0
167-
ccs = im2ccs(im, charge, mz)
165+
ccs = im2ccs(im, mz, charge)
168166

169167
assert isinstance(ccs, float)
170168
assert ccs > 0
@@ -175,7 +173,7 @@ def test_im2ccs_array(self):
175173
charge = np.array([2, 3, 2])
176174
mz = np.array([500.0, 600.0, 550.0])
177175

178-
ccs = im2ccs(im, charge, mz)
176+
ccs = im2ccs(im, mz, charge)
179177

180178
assert isinstance(ccs, np.ndarray)
181179
assert len(ccs) == len(im)
@@ -187,34 +185,34 @@ def test_ccs2im_im2ccs_roundtrip(self):
187185
charge = 2
188186
mz = 500.0
189187

190-
im = ccs2im(ccs_original, charge, mz)
191-
ccs_roundtrip = im2ccs(im, charge, mz)
188+
im = ccs2im(ccs_original, mz, charge)
189+
ccs_roundtrip = im2ccs(im, mz, charge)
192190

193191
assert abs(ccs_roundtrip - ccs_original) < 0.01
194192

195193
def test_ccs2im_zero_values(self):
196194
"""Test handling of zero values."""
197195
with pytest.raises((ValueError, ZeroDivisionError)):
198-
ccs2im(0, 2, 500)
196+
ccs2im(0, 500, 2)
199197

200198
def test_im2ccs_zero_values(self):
201199
"""Test handling of zero values."""
202200
with pytest.raises((ValueError, ZeroDivisionError)):
203-
im2ccs(0, 2, 500)
201+
im2ccs(0, 500, 2)
204202

205203
def test_ccs2im_negative_values(self):
206204
"""Test handling of negative values."""
207205
# Function should raise ValueError for negative CCS values
208206
with pytest.raises(ValueError, match="CCS must be positive"):
209-
ccs2im(-450.0, 2, 500)
207+
ccs2im(-450.0, 500, 2)
210208

211209
def test_im2ccs_different_charges(self):
212210
"""Test conversions with different charge states."""
213211
im = 1.0
214212
mz = 500
215213

216-
ccs_z2 = im2ccs(im, 2, mz)
217-
ccs_z3 = im2ccs(im, 3, mz)
214+
ccs_z2 = im2ccs(im, mz, 2)
215+
ccs_z3 = im2ccs(im, mz, 3)
218216

219217
assert ccs_z2 != ccs_z3
220218
assert ccs_z2 > 0 and ccs_z3 > 0

0 commit comments

Comments
 (0)