Skip to content

Commit d24c050

Browse files
Optimize peptidoform feature processing: fix memory leak, add DIA-NN MS1 toggle, sequential processing
- Add min_occurrences config parameter (default 5) to filter peptides by PSM count - Add min_matched_peaks to test config (4 fragments minimum per PSM) - Fix DIA-NN feature generator memory leak: caches grew unbounded across peptidoforms because entries were never reused (unique fragment hashes per peptidoform) - Add enable_ms1_features toggle (default False) to skip slow MS1-based DIA-NN features (feature_ms1_accuracy_correlations took ~20ms/item due to iterating all MS1 scans) - Add prepare_ms1_dict() for 3.4x speedup when MS1 features are enabled (pre-converts mz arrays to sorted numpy, avoids np.asarray + sort check per scan) - Switch from ThreadPoolExecutor to sequential processing (GIL made threading 3-6x slower than single-threaded for CPU-bound numpy/pandas work: 2 it/s vs 13 it/s) - Use singleton DIA-NN generator to avoid 11k constructor/import calls - Fix quantification: derive stripped_peptide and proteins columns when missing - Fix create_model() to accept meta parameter from scikit-keras for dynamic input_dim - Remove unused CorrelationResults fields (8 commented-out matrix variants) - Add adaptive RT margin support in fragment intensity filtering - Reduce per-peptidoform logging noise (info -> debug) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 635c217 commit d24c050

9 files changed

Lines changed: 260 additions & 90 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ test_results.txt
1313
mumdia.egg-info/
1414
mzml_files/
1515
notebook_helpers/
16+
test_data/
1617

1718
# Byte-compiled / optimized / DLL files
1819
__pycache__/

config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def convert_legacy_config(legacy_data: Dict[str, Any]) -> Dict[str, Any]:
8181
"read_initial_search_pickle": "read_initial_search_pickle",
8282
"remove_intermediate_files": "remove_intermediate_files",
8383
"fdr_init_search": "fdr_init_search",
84+
"min_occurrences": "min_occurrences",
8485
}
8586

8687
if "mumdia" in legacy_data:
@@ -217,6 +218,9 @@ class MuMDIAConfig:
217218
read_full_search_pickle: bool = False
218219
read_initial_search_pickle: bool = False
219220

221+
# Filtering settings
222+
min_occurrences: int = 5 # Minimum PSMs per peptide to keep
223+
220224
# Processing settings
221225
remove_intermediate_files: bool = False
222226
dlc_transfer_learn: bool = False
@@ -233,6 +237,7 @@ class MuMDIAConfig:
233237
clean: bool = False
234238
sage_only: bool = False
235239
skip_mokapot: bool = False
240+
use_diann_features: bool = True
236241
verbose: bool = False
237242

238243
# Feature settings
@@ -382,7 +387,8 @@ def get_mumdia_config(self) -> Dict[str, Any]:
382387
"clean": self.clean,
383388
"sage_only": self.sage_only,
384389
"skip_mokapot": self.skip_mokapot,
385-
"verbose": self.verbose
390+
"verbose": self.verbose,
391+
"min_occurrences": self.min_occurrences
386392
}
387393

388394
def to_legacy_format(self) -> Dict[str, Any]:

data_structures.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,6 @@ class CorrelationResults:
2020
sum_pred_frag_intens: np.ndarray
2121
correlation_matrix_psm_ids: np.ndarray
2222
correlation_matrix_frag_ids: np.ndarray
23-
correlation_matrix_psm_ids_ignore_zeros: np.ndarray
24-
correlation_matrix_psm_ids_ignore_zeros_counts: np.ndarray
25-
correlation_matrix_psm_ids_missing: np.ndarray
26-
correlation_matrix_psm_ids_missing_zeros_counts: np.ndarray
27-
correlation_matrix_frag_ids_ignore_zeros: np.ndarray
28-
correlation_matrix_frag_ids_ignore_zeros_counts: np.ndarray
29-
correlation_matrix_frag_ids_missing: np.ndarray
30-
correlation_matrix_frag_ids_missing_zeros_counts: np.ndarray
3123
most_intens_cor: float
3224
most_intens_cos: float
3325
mse_avg_pred_intens: float

feature_generators/diann_feature_generator.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class FeatureConfig:
4747
ms1_accuracy_factors: List[float] = None
4848
ms2_accuracy_factors: List[float] = None
4949

50+
# Feature toggles
51+
enable_ms1_features: bool = False # MS1-based features (groups 2-3); slow, disabled by default
52+
5053
# Parallelization settings
5154
n_jobs: int = -1 # -1 means use all available CPU cores
5255

@@ -101,8 +104,40 @@ def __init__(self, config: Optional[FeatureConfig] = None):
101104
self._pivot_cache = {}
102105
self._correlation_cache = {}
103106

107+
# Pre-processed MS1 data (set via prepare_ms1_dict)
108+
self._ms1_prepared = None # list of (rt, mz_array, intensity_array) sorted by RT
109+
104110
logger.info("Initialized DIANNFeatureGenerator with built-in optimizations")
105111

112+
def prepare_ms1_dict(self, ms1_dict: Dict[str, Dict[str, Any]]) -> None:
113+
"""Pre-convert ms1_dict to sorted numpy arrays for fast elution profile building.
114+
115+
Call once before processing peptidoforms. Converts each scan's mz/intensity
116+
lists to numpy arrays and sorts the scan list by RT. This avoids repeated
117+
np.asarray + sort checks in build_elution_profile (~20ms -> ~2ms per call).
118+
"""
119+
prepared = []
120+
for scan_dict in ms1_dict.values():
121+
mzs = scan_dict.get("mz", [])
122+
intensities = scan_dict.get("intensity", [])
123+
rt = scan_dict.get("retention_time", None)
124+
if rt is None or len(mzs) == 0:
125+
continue
126+
# Convert RT from seconds to minutes if needed
127+
if isinstance(rt, (int, float)) and rt > 1000:
128+
rt = rt / 60
129+
mz_arr = np.asarray(mzs)
130+
int_arr = np.asarray(intensities)
131+
# Ensure sorted by m/z
132+
if len(mz_arr) > 1 and mz_arr[0] > mz_arr[-1]:
133+
order = np.argsort(mz_arr)
134+
mz_arr = mz_arr[order]
135+
int_arr = int_arr[order]
136+
prepared.append((rt, mz_arr, int_arr))
137+
# Sort by RT for potential windowed access
138+
prepared.sort(key=lambda x: x[0])
139+
self._ms1_prepared = prepared
140+
106141
def _setup_parallelization(self):
107142
"""Set up parallelization parameters."""
108143
import os
@@ -158,7 +193,7 @@ def clear_cache(self):
158193
self._cache.clear()
159194
self._pivot_cache.clear()
160195
self._correlation_cache.clear()
161-
logger.info("Cleared all caches")
196+
logger.debug("Cleared all caches")
162197

163198
def get_cache_stats(self) -> Dict[str, int]:
164199
"""Get cache statistics for monitoring."""
@@ -657,9 +692,27 @@ def build_elution_profile(
657692
if tolerance_ppm is None:
658693
tolerance_ppm = self.config.precursor_mass_tolerance
659694

660-
elution_profile = {}
661695
tol_mz = target_mz * tolerance_ppm / 1e6 * acc_factor
662696

697+
# Fast path: use pre-processed arrays (avoids np.asarray + sort per scan)
698+
if self._ms1_prepared is not None:
699+
elution_profile = {}
700+
for rt, mz_arr, int_arr in self._ms1_prepared:
701+
idx = np.searchsorted(mz_arr, target_mz)
702+
best_idx = None
703+
best_diff = tol_mz
704+
for check_idx in (idx - 1, idx, idx + 1):
705+
if 0 <= check_idx < len(mz_arr):
706+
diff = abs(mz_arr[check_idx] - target_mz)
707+
if diff < best_diff:
708+
best_diff = diff
709+
best_idx = check_idx
710+
if best_idx is not None:
711+
elution_profile[rt] = int_arr[best_idx]
712+
return elution_profile
713+
714+
# Slow fallback: original dict-based path
715+
elution_profile = {}
663716
for scan, scan_dict in ms1_dict.items():
664717
mzs = scan_dict.get("mz", [])
665718
intensities = scan_dict.get("intensity", [])
@@ -668,8 +721,7 @@ def build_elution_profile(
668721
if rt is None or len(mzs) == 0 or len(intensities) == 0:
669722
continue
670723

671-
# Convert RT from seconds to minutes if needed
672-
if isinstance(rt, (int, float)) and rt > 1000: # Likely in seconds
724+
if isinstance(rt, (int, float)) and rt > 1000:
673725
rt = rt / 60
674726

675727
best_idx, best_val = self._search_sorted_with_tolerance(
@@ -3272,7 +3324,7 @@ def _calculate_all_features_parallel(
32723324
except Exception as e:
32733325
logger.error(f"Error calculating {feature_name}: {e}")
32743326

3275-
logger.info(f"Calculated {len(features)} feature groups in parallel")
3327+
logger.debug(f"Calculated {len(features)} feature groups in parallel")
32763328
return features
32773329

32783330
def _safe_feature_calculation(self, func, args):
@@ -3317,7 +3369,7 @@ def _calculate_all_features_sequential(
33173369
"""
33183370
features = {}
33193371

3320-
logger.info("Calculating DIA-NN features...")
3372+
logger.debug("Calculating DIA-NN features...")
33213373

33223374
# Group 1: Ion co-elution (MS2 level)
33233375
try:
@@ -3436,7 +3488,7 @@ def _calculate_all_features_sequential(
34363488
except Exception as e:
34373489
logger.error(f"Error in group 10 features: {e}")
34383490

3439-
logger.info(f"Calculated {len(features)} feature groups")
3491+
logger.debug(f"Calculated {len(features)} feature groups")
34403492
return features
34413493

34423494

feature_generators/features_fragment_intensity.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -840,10 +840,32 @@ def get_features_fragment_intensity(
840840

841841
df_fragment = df_fragment.join(df_precursor_rt, on="precursor", how="left")
842842

843-
df_fragment = df_fragment.filter(
844-
(pl.col("rt_max_peptide_sub").is_not_null())
845-
& (abs(pl.col("rt") - pl.col("rt_max_peptide_sub")) < filter_max_apex_rt)
846-
)
843+
# Filter fragments to the retention time window around the apex.
844+
# If calibrated RT margins are available (rt_lower_margin / rt_higher_margin),
845+
# use them for per-peptidoform adaptive windows. Otherwise fall back to the
846+
# fixed ±filter_max_apex_rt seconds window.
847+
if "rt_lower_margin" in df_fragment.columns and "rt_higher_margin" in df_fragment.columns:
848+
df_fragment = df_fragment.filter(
849+
(pl.col("rt_max_peptide_sub").is_not_null())
850+
& (
851+
# Use calibrated margins where available, fall back to fixed window where NaN
852+
pl.when(pl.col("rt_lower_margin").is_not_null())
853+
.then(
854+
(pl.col("rt") >= pl.col("rt_lower_margin"))
855+
& (pl.col("rt") <= pl.col("rt_higher_margin"))
856+
)
857+
.otherwise(
858+
abs(pl.col("rt") - pl.col("rt_max_peptide_sub")) < filter_max_apex_rt
859+
)
860+
)
861+
)
862+
log_info("Fragment filtering: using calibrated RT margins (with fixed fallback)")
863+
else:
864+
df_fragment = df_fragment.filter(
865+
(pl.col("rt_max_peptide_sub").is_not_null())
866+
& (abs(pl.col("rt") - pl.col("rt_max_peptide_sub")) < filter_max_apex_rt)
867+
)
868+
log_info(f"Fragment filtering: using fixed ±{filter_max_apex_rt}s window (no margins available)")
847869

848870
for (peptidoform, charge), df_fragment_sub_peptidoform in tqdm(
849871
df_fragment.group_by(["peptide", "charge"])

0 commit comments

Comments
 (0)