Skip to content

Commit 9103fa5

Browse files
committed
Fix remaining typing issues
1 parent 50e39cf commit 9103fa5

5 files changed

Lines changed: 22 additions & 17 deletions

File tree

im2deep/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,13 @@ def _run_predict(*args, **kwargs):
215215

216216
# Parse input files
217217
LOGGER.info("Parsing input files...")
218-
psm_list = parse_input(kwargs.get("precursors"))
218+
psm_list = parse_input(Path(kwargs.get("precursors"))) # type: ignore[invalid-arg]
219219

220220
# Run prediction
221221
LOGGER.info("Running CCS prediction...")
222222
if kwargs.get("calibration_precursors"):
223223
LOGGER.info("Calibration file provided, performing calibration and prediction...")
224-
psm_list_cal = parse_input(kwargs.get("calibration_precursors"))
224+
psm_list_cal = parse_input(Path(kwargs.get("calibration_precursors"))) # type: ignore[invalid-arg]
225225
predictions = core.predict_and_calibrate(psm_list, psm_list_cal, *args, **kwargs)
226226
else:
227227
LOGGER.info(

im2deep/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def predict_and_calibrate(
5858
psm_list_cal: PSMList,
5959
psm_list_reference: PSMList | None = None,
6060
model: torch.nn.Module | PathLike | str | None = None,
61-
calibration: Calibration | None = None,
61+
calibration: LinearCCSCalibration | None = None,
6262
multi: bool = False,
6363
predict_kwargs: dict | None = None,
6464
**kwargs,
@@ -128,6 +128,11 @@ def predict_and_calibrate(
128128
)
129129

130130
if not calibration.is_fitted:
131+
if psm_df_reference is None:
132+
raise ValueError(
133+
"Reference PSM list must be provided for calibration fitting when using a custom" \
134+
"calibration object."
135+
)
131136
LOGGER.info("Fitting calibration...")
132137
if any(psm_list_cal["is_decoy"]):
133138
LOGGER.warning(

tests/test_cli.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def test_setup_logging_default(self):
196196
"""Test setup_logging with info level."""
197197
import logging
198198

199-
from im2deep.utils import setup_logging
199+
from im2deep._io_helpers import setup_logging
200200

201201
setup_logging("info")
202202

@@ -207,7 +207,7 @@ def test_setup_logging_debug(self):
207207
"""Test setup_logging with debug level."""
208208
import logging
209209

210-
from im2deep.utils import setup_logging
210+
from im2deep._io_helpers import setup_logging
211211

212212
setup_logging("debug")
213213

@@ -218,7 +218,7 @@ def test_setup_logging_warning(self):
218218
"""Test setup_logging with warning level."""
219219
import logging
220220

221-
from im2deep.utils import setup_logging
221+
from im2deep._io_helpers import setup_logging
222222

223223
setup_logging("warning")
224224

@@ -229,7 +229,7 @@ def test_setup_logging_affects_submodules(self):
229229
"""Test that setup_logging affects all im2deep submodules."""
230230
import logging
231231

232-
from im2deep.utils import setup_logging
232+
from im2deep._io_helpers import setup_logging
233233

234234
setup_logging("debug")
235235

@@ -245,6 +245,6 @@ def test_default_command_group_import(self):
245245
"""Test that DefaultCommandGroup can be imported."""
246246
import click
247247

248-
from im2deep.utils import DefaultCommandGroup
248+
from im2deep._io_helpers import DefaultCommandGroup
249249

250250
assert issubclass(DefaultCommandGroup, click.Group)

tests/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_predict_with_kwargs(self, mock_dataset, mock_predict, sample_psm_list):
7474
def test_predict_invalid_psm_list(self):
7575
"""Test prediction with invalid PSMList."""
7676
with pytest.raises(IM2DeepError):
77-
core.predict([1, 2, 3])
77+
core.predict([1, 2, 3]) # type: ignore[invalid-arg]
7878

7979

8080
class TestPredictAndCalibrate:
@@ -171,7 +171,7 @@ def transform(self, *args, **kwargs):
171171
predictions = core.predict_and_calibrate(
172172
psm_list=sample_psm_list,
173173
psm_list_cal=sample_psm_list_with_ccs,
174-
calibration=custom_calibration,
174+
calibration=custom_calibration, # type: ignore[invalid-arg]
175175
)
176176

177177
assert isinstance(predictions, np.ndarray)

tests/test_model_ops.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def test_predict_loop_single_output(self):
171171
with patch("im2deep._model_ops.track", return_value=mock_data):
172172
predictions = _model_ops._predict_loop(
173173
model=model,
174-
data_loader=mock_data,
175-
device="cpu", # type: ignore
174+
data_loader=mock_data, # type: ignore[unknown-arg]
175+
device="cpu",
176176
)
177177

178178
assert isinstance(predictions, torch.Tensor)
@@ -194,8 +194,8 @@ def test_predict_loop_multi_output(self):
194194
with patch("im2deep._model_ops.track", return_value=mock_data):
195195
predictions = _model_ops._predict_loop(
196196
model=model,
197-
data_loader=mock_data,
198-
device="cpu", # type: ignore
197+
data_loader=mock_data, # type: ignore[unknown-arg]
198+
device="cpu",
199199
)
200200

201201
assert isinstance(predictions, torch.Tensor)
@@ -217,8 +217,8 @@ def test_predict_loop_no_grad(self):
217217
with patch("im2deep._model_ops.track", return_value=mock_data):
218218
predictions = _model_ops._predict_loop(
219219
model=model,
220-
data_loader=mock_data,
221-
device="cpu", # type: ignore
220+
data_loader=mock_data, # type: ignore[unknown-arg]
221+
device="cpu",
222222
)
223223

224224
assert not predictions.requires_grad
@@ -262,7 +262,7 @@ class TestGetLossFunction:
262262
def test_get_loss_function_single(self):
263263
"""Test getting single-output loss function."""
264264
loss = _model_ops._get_loss_function(multi=False)
265-
assert isinstance(loss, torch.nn.modules.loss._Loss)
265+
assert isinstance(loss, torch.nn.modules.loss._Loss) # type: ignore[unresolved-attr]
266266

267267
@patch("im2deep._model_ops.FlexibleLossSorted")
268268
def test_get_loss_function_multi(self, mock_loss):

0 commit comments

Comments
 (0)