Skip to content

Commit 74ff2cb

Browse files
committed
Ruff Ruff
1 parent 112c5c3 commit 74ff2cb

16 files changed

Lines changed: 104 additions & 102 deletions

im2deep/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
# Import main functionality for easier access
4242
from importlib.metadata import version
43-
from im2deep.utils import ccs2im, im2ccs
43+
4444
from im2deep.core import predict, predict_and_calibrate
4545
from im2deep.utils import ccs2im, im2ccs
4646

im2deep/__main__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,21 @@
4141

4242
from __future__ import annotations
4343

44-
import logging
4544
import cProfile
45+
import logging
46+
from pathlib import Path
4647

4748
import click
4849
from rich.console import Console
4950

5051
from im2deep import __version__, core
51-
from pathlib import Path
5252
from im2deep.utils import (
53-
setup_logging,
54-
parse_input,
53+
DefaultCommandGroup,
5554
build_credits,
56-
write_output,
5755
infer_output_name,
58-
DefaultCommandGroup,
56+
parse_input,
57+
setup_logging,
58+
write_output,
5959
)
6060

6161
console = Console()

im2deep/_architecture.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import sys
1+
import logging
22
from pathlib import Path
3+
4+
import lightning as L
35
import torch
46
import torch.nn as nn
57
import torch.nn.functional as F
6-
import lightning as L
7-
import logging
8+
from scipy.stats import pearsonr
89

910
try:
1011
import wandb
@@ -19,10 +20,12 @@
1920

2021
logger = logging.getLogger(__name__)
2122

23+
MAE = nn.L1Loss()
24+
2225

2326
class LogLowestMAE(L.Callback):
2427
def __init__(self, config):
25-
super(LogLowestMAE, self).__init__()
28+
super().__init__()
2629
self.bestMAE = float("inf")
2730
self.config = config
2831

@@ -40,7 +43,7 @@ def on_validation_end(self, trainer, pl_module):
4043

4144
class LRelu_with_saturation(nn.Module):
4245
def __init__(self, negative_slope, saturation):
43-
super(LRelu_with_saturation, self).__init__()
46+
super().__init__()
4447
self.negative_slope = negative_slope
4548
self.saturation = saturation
4649
self.leaky_relu = nn.LeakyReLU(self.negative_slope)
@@ -61,7 +64,7 @@ def __init__(
6164
negative_slope,
6265
saturation,
6366
):
64-
super(Conv1dActivation, self).__init__()
67+
super().__init__()
6568
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
6669
self.initializer = initializer
6770
self.activation = LRelu_with_saturation(
@@ -76,7 +79,7 @@ def forward(self, x):
7679

7780
class DenseActivation(nn.Module):
7881
def __init__(self, in_features, out_features, initializer, negative_slope, saturation):
79-
super(DenseActivation, self).__init__()
82+
super().__init__()
8083
self.linear = nn.Linear(in_features, out_features)
8184
self.initializer = initializer
8285
self.activation = LRelu_with_saturation(
@@ -91,7 +94,7 @@ def forward(self, x):
9194

9295
class SelfAttention(nn.Module):
9396
def __init__(self, feature_dim, heads=1):
94-
super(SelfAttention, self).__init__()
97+
super().__init__()
9598
self.feature_dim = feature_dim
9699
self.heads = heads
97100
# self.padded_dim = self.feature_dim + (self.feature_dim % self.heads)
@@ -152,7 +155,7 @@ def forward(self, x):
152155

153156
class Branch(nn.Module):
154157
def __init__(self, input_size, output_size, add_layer=1, dropout_rate=0.0):
155-
super(Branch, self).__init__()
158+
super().__init__()
156159
self.add_layer = add_layer
157160
if self.add_layer:
158161
self.fc1 = nn.Linear(input_size, output_size)
@@ -172,7 +175,7 @@ def forward(self, x):
172175

173176
class IM2Deep(L.LightningModule):
174177
def __init__(self, config, criterion):
175-
super(IM2Deep, self).__init__()
178+
super().__init__()
176179
self.config = config
177180
self.criterion = criterion
178181
self.mae = nn.L1Loss()
@@ -628,7 +631,7 @@ def configure_init(self):
628631

629632
class IM2DeepMulti(L.LightningModule):
630633
def __init__(self, config, criterion):
631-
super(IM2DeepMulti, self).__init__()
634+
super().__init__()
632635
self.config = config
633636
self.criterion = criterion
634637

@@ -1106,7 +1109,7 @@ def configure_init(self):
11061109

11071110
class IM2DeepMultiTransfer(L.LightningModule):
11081111
def __init__(self, config, criterion):
1109-
super(IM2DeepMultiTransfer, self).__init__()
1112+
super().__init__()
11101113
# TODO: config should be adapted in config file
11111114
self.config = config
11121115
self.criterion = criterion
@@ -1123,7 +1126,7 @@ def __init__(self, config, criterion):
11231126
self.ConvGlobal = self.backbone.ConvGlobal
11241127
self.OneHot = self.backbone.OneHot
11251128

1126-
if self.config.get("add_X_mol", False) == True:
1129+
if self.config.get("add_X_mol", False):
11271130
self.MolDesc = self.backbone.MolDesc
11281131

11291132
self.concat = list(self.backbone.Concat.children())[:-1]
@@ -1294,7 +1297,7 @@ def configure_optimizers(self):
12941297

12951298
class IM2DeepTransfer(L.LightningModule):
12961299
def __init__(self, config, criterion):
1297-
super(IM2DeepTransfer, self).__init__()
1300+
super().__init__()
12981301

12991302
self.config = config
13001303
self.criterion = criterion
@@ -1312,7 +1315,7 @@ def __init__(self, config, criterion):
13121315
self.ConvGlobal = self.backbone.ConvGlobal
13131316
self.OneHot = self.backbone.OneHot
13141317

1315-
if self.config.get("add_X_mol", False) == True:
1318+
if self.config.get("add_X_mol", False):
13161319
self.MolDesc = self.backbone.MolDesc
13171320

13181321
self.concat = self.backbone.Concat
@@ -1439,7 +1442,7 @@ def configure_optimizers(self):
14391442

14401443
class FlexibleLossSorted(nn.Module):
14411444
def __init__(self, diversity_weight=0.1):
1442-
super(FlexibleLossSorted, self).__init__()
1445+
super().__init__()
14431446
self.diversity_weight = diversity_weight
14441447

14451448
def forward(self, y1, y2, y_hat1, y_hat2):
@@ -1479,7 +1482,7 @@ def forward(self, y1, y2, y_hat1, y_hat2):
14791482

14801483
class FlexibleLoss(nn.Module):
14811484
def __init__(self, diversity_weight=0.1):
1482-
super(FlexibleLoss, self).__init__()
1485+
super().__init__()
14831486
self.diversity_weight = diversity_weight
14841487

14851488
def forward(self, y1, y2, y_hat1, y_hat2):

im2deep/_model_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
"""Training, predicting, and evaluating using IM2Deep (PyTorch)."""
33

44
from __future__ import annotations
5-
import copy
5+
66
import logging
77
import warnings
88
from os import PathLike
99
from pathlib import Path
1010

11+
import lightning as L
1112
import torch
1213
from rich.progress import track
1314
from torch.utils.data import DataLoader, Dataset
14-
import lightning as L
1515

1616
# Suppress PyTorch padding warning for conv1d with even kernels and odd dilation
1717
warnings.filterwarnings(

im2deep/calibration.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@
88

99
import logging
1010
from abc import ABC, abstractmethod
11-
from typing import cast
1211

13-
import pandas as pd
1412
import numpy as np
15-
from psm_utils import PSMList, Peptidoform
13+
import pandas as pd
14+
from psm_utils import Peptidoform, PSMList
1615

1716
from im2deep._exceptions import CalibrationError
18-
from im2deep.utils import parse_input
19-
from im2deep.constants import DEFAULT_REFERENCE_DATASET_PATH, DEFAULT_MULTI_REFERENCE_DATASET_PATH
17+
from im2deep.constants import DEFAULT_MULTI_REFERENCE_DATASET_PATH, DEFAULT_REFERENCE_DATASET_PATH
2018

2119
LOGGER = logging.getLogger(__name__)
2220

@@ -110,7 +108,7 @@ def fit(
110108
LOGGER.warning(
111109
f"Could not calculate charge-specific shift factors: {e}. Using 0.0 as fallback."
112110
)
113-
self.charge_shifts = {charge: 0.0 for charge in range(1, 7)}
111+
self.charge_shifts = dict.fromkeys(range(1, 7), 0.0)
114112

115113
# Set general shift as the mean of calculated charge shifts or charge 2 if available
116114
if 2 in self.charge_shifts and self.charge_shifts[2] != 0.0:
@@ -153,7 +151,7 @@ def fit(
153151
f"Could not calculate general shift factor: {e}. Using 0.0 as fallback."
154152
)
155153
self.general_shift = 0.0
156-
self.charge_shifts = {charge: self.general_shift for charge in range(1, 7)}
154+
self.charge_shifts = dict.fromkeys(range(1, 7), self.general_shift)
157155

158156
self.used_charges = set(self.charge_shifts.keys())
159157
self.fitted = True
@@ -172,7 +170,7 @@ def transform(
172170
if "peptidoform" not in psm_df.columns:
173171
raise CalibrationError("Input DataFrame must contain 'peptidoform' column.")
174172

175-
if not "predicted_CCS_uncalibrated" in psm_df.columns and "metadata" in psm_df.columns:
173+
if "predicted_CCS_uncalibrated" not in psm_df.columns and "metadata" in psm_df.columns:
176174
psm_df["predicted_CCS_uncalibrated"] = psm_df["metadata"].apply(
177175
lambda x: (
178176
x["predicted_CCS_uncalibrated"]

im2deep/core.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@
44

55
import logging
66
from os import PathLike
7-
from pathlib import Path
87

98
import numpy as np
10-
from psm_utils.psm_list import PSMList
119
import torch
1210
from deeplc.data import DeepLCDataset
11+
from psm_utils.psm_list import PSMList
1312

14-
from im2deep.utils import validate_psm_list
1513
from im2deep import _model_ops
16-
from im2deep.calibration import LinearCCSCalibration, Calibration
14+
from im2deep.calibration import Calibration, LinearCCSCalibration
1715
from im2deep.constants import DEFAULT_MODEL, DEFAULT_MULTI_MODEL
16+
from im2deep.utils import validate_psm_list
1817

1918
LOGGER = logging.getLogger(__name__)
2019

im2deep/utils.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,25 @@
1515

1616
from __future__ import annotations
1717

18-
import sys
19-
from pathlib import Path
20-
from typing import Any
2118
import logging
22-
from rich.text import Text
23-
import gzip
19+
from pathlib import Path
2420

2521
import click
2622
import numpy as np
27-
import psm_utils.io
2823
import pandas as pd
24+
import psm_utils.io
25+
from psm_utils.psm import PSM
26+
from psm_utils.psm_list import PSMList
2927
from rich.console import Console
3028
from rich.logging import RichHandler
31-
from psm_utils.psm_list import PSMList
32-
from psm_utils.psm import PSM
29+
from rich.text import Text
3330

3431
from im2deep._exceptions import IM2DeepError
3532
from im2deep.constants import (
36-
SUMMARY_CONSTANT,
3733
MASS_GAS_N2,
38-
TEMP,
34+
SUMMARY_CONSTANT,
3935
T_DIFF,
36+
TEMP,
4037
)
4138

4239
console = Console()
@@ -286,7 +283,7 @@ def parse_input(
286283
is_legacy_format = False
287284
try:
288285
# Read first line to check column names
289-
with open(input_file, "r") as f:
286+
with open(input_file) as f:
290287
first_line = f.readline().strip()
291288

292289
# Check if it has legacy format columns
@@ -307,7 +304,7 @@ def parse_input(
307304
# Try to parse with psm_utils
308305
try:
309306
psm_list = psm_utils.io.read_file(input_file, filetype=filetype or "infer")
310-
LOGGER.debug(f"Successfully read file using psm_utils.")
307+
LOGGER.debug("Successfully read file using psm_utils.")
311308
except Exception as e:
312309
# If psm_utils fails, try legacy format as fallback
313310
LOGGER.warning(f"Failed to read PSM file using psm_utils: {e}")
@@ -345,7 +342,7 @@ def _parse_legacy_format(input_file: str | Path) -> PSMList:
345342
df = pd.read_csv(input_file, sep=None, engine="python")
346343
df = df.fillna("") # Replace NaN with empty strings
347344
except Exception as e:
348-
raise IM2DeepError(f"Failed to read file as delimited text: {e}")
345+
raise IM2DeepError(f"Failed to read file as delimited text: {e}") from e
349346

350347
required_cols_legacy = ["seq", "modifications", "charge"]
351348
missing_cols = set(required_cols_legacy) - set(df.columns)

tests/conftest.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Pytest configuration and fixtures for IM2Deep tests."""
22

3-
import pytest
3+
44
import numpy as np
55
import pandas as pd
6-
from pathlib import Path
7-
from psm_utils import PSM, PSMList, Peptidoform
6+
import pytest
7+
from psm_utils import PSM, Peptidoform, PSMList
88

99

1010
@pytest.fixture
@@ -51,7 +51,7 @@ def sample_psm_list_with_ccs():
5151
is_decoy=False,
5252
score=0.95,
5353
retention_time=100.5,
54-
metadata={"CCS": float(450.5)},
54+
metadata={"CCS": 450.5},
5555
),
5656
PSM(
5757
peptidoform=Peptidoform("SEQUENCE/3"),
@@ -61,7 +61,7 @@ def sample_psm_list_with_ccs():
6161
is_decoy=False,
6262
score=0.92,
6363
retention_time=120.3,
64-
metadata={"CCS": float(520.8)},
64+
metadata={"CCS": 520.8},
6565
),
6666
PSM(
6767
peptidoform=Peptidoform("TESTPEPTIDE/2"),
@@ -71,7 +71,7 @@ def sample_psm_list_with_ccs():
7171
is_decoy=False,
7272
score=0.88,
7373
retention_time=135.7,
74-
metadata={"CCS": float(480.2)},
74+
metadata={"CCS": 480.2},
7575
),
7676
PSM(
7777
peptidoform=Peptidoform("ANOTHER/3"),
@@ -81,7 +81,7 @@ def sample_psm_list_with_ccs():
8181
is_decoy=False,
8282
score=0.90,
8383
retention_time=142.1,
84-
metadata={"CCS": float(510.5)},
84+
metadata={"CCS": 510.5},
8585
),
8686
]
8787
return PSMList(psm_list=psms)

tests/test_calibration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Tests for calibration module."""
22

3-
import pytest
43
import numpy as np
54
import pandas as pd
6-
from psm_utils import Peptidoform, PSM, PSMList
5+
import pytest
6+
from psm_utils import Peptidoform
77

8-
from im2deep.calibration import LinearCCSCalibration, get_default_reference
98
from im2deep._exceptions import CalibrationError
9+
from im2deep.calibration import LinearCCSCalibration, get_default_reference
1010

1111

1212
class TestLinearCCSCalibration:

0 commit comments

Comments
 (0)