Skip to content

Commit 0c21e00

Browse files
Add Rust-accelerated numerical functions (mumdia_rs) for 7.6x speedup
Phase 1 of Rust port: core numerical functions that replace Numba in the per-peptidoform hot path. Transparent fallback to Python/Numba when mumdia_rs is not installed. Rust crate (rust/mumdia_rs/): - percentile() and compute_percentiles() — quantile computation - compute_top() — top-k selection with zero padding - pearson_1d() — Pearson correlation with zero-variance guard - compute_correlations() — per-PSM row-wise Pearson vs prediction vector - 14 Rust unit tests Python integration: - mumdia.py: _RUST_BACKEND flag, dispatches compute_percentiles and compute_top to Rust in add_feature_columns_nb() - features_fragment_intensity.py: _RUST_CORRELATIONS flag, dispatches compute_correlations to Rust - Black + isort formatting applied to all modified files Testing: - tests/generate_rust_reference_data.py: generates 43 reference cases from Python implementations (percentile, top-k, correlation) - tests/rust_reference_data.json: serialized reference input/output pairs - tests/test_rust_equivalence.py: 9 tests validating Rust matches Python within 1e-12 tolerance, plus property-based tests Performance: 76ms -> 10ms per peptidoform (7.6x speedup on hot path) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d24c050 commit 0c21e00

14 files changed

Lines changed: 4686 additions & 224 deletions

config.py

Lines changed: 164 additions & 119 deletions
Large diffs are not rendered by default.

feature_generators/features_fragment_intensity.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,29 @@
2424
import polars as pl
2525
from numba import njit
2626
from rustyms import (
27+
CompoundPeptidoformIon,
2728
FragmentationModel,
2829
MassMode,
29-
RawSpectrum,
30-
CompoundPeptidoformIon,
3130
MatchingParameters,
31+
RawSpectrum,
3232
)
3333
from tqdm import tqdm
3434

3535
from data_structures import CorrelationResults, PickleConfig
3636
from utilities.logger import log_info
3737
from utilities.plotting import plot_XIC
3838

39+
# Optional Rust backend for compute_correlations
40+
try:
41+
import mumdia_rs
42+
43+
_RUST_CORRELATIONS = True
44+
except ImportError:
45+
_RUST_CORRELATIONS = False
46+
3947

4048
@njit
41-
def compute_correlations(intensity_matrix, pred_frag_intens):
49+
def _compute_correlations_numba(intensity_matrix, pred_frag_intens):
4250
"""
4351
Compute Pearson correlations between experimental and predicted intensities.
4452
@@ -76,6 +84,16 @@ def compute_correlations(intensity_matrix, pred_frag_intens):
7684
return correlations
7785

7886

87+
def compute_correlations(intensity_matrix, pred_frag_intens):
88+
"""Dispatch to Rust or Numba for per-PSM Pearson correlations."""
89+
if _RUST_CORRELATIONS:
90+
return mumdia_rs.compute_correlations(
91+
np.ascontiguousarray(intensity_matrix, dtype=np.float64),
92+
np.ascontiguousarray(pred_frag_intens, dtype=np.float64),
93+
)
94+
return _compute_correlations_numba(intensity_matrix, pred_frag_intens)
95+
96+
7997
def corrcoef_ignore_both_missing(data):
8098
"""
8199
Compute pairwise Pearson correlation coefficients between rows of the input
@@ -476,7 +494,6 @@ def match_fragments(
476494
mode=MassMode.Monoisotopic,
477495
)
478496

479-
480497
# Filter annotated peaks to keep only singly-charged b and y ions.
481498
# RustyMS annotations are accessed via repr() strings, so regex is
482499
# used to extract the ion type (e.g. "b3", "y7") from the annotation
@@ -577,7 +594,6 @@ def match_fragments(
577594
└──────────┴─────────────┴─────────────┴──────────────┴────────────┴─────────────┴─────────────┴────────────┘
578595
"""
579596

580-
581597
# Max-normalize MS2PIP predictions to [0, 1] range. This is necessary because
582598
# MS2PIP outputs raw predicted intensities on an arbitrary scale, while the
583599
# experimental intensities will also be max-normalized per PSM later (line ~722).
@@ -608,7 +624,6 @@ def match_fragments(
608624
]
609625
)
610626

611-
612627
"""
613628
Get pearson and cosine similarity of spectrum with highest intensity
614629
"""
@@ -617,13 +632,11 @@ def match_fragments(
617632
pred_frag_intens_individual, most_abundant_frag_psm["fragment_intensity"]
618633
)[0][1]
619634

620-
621635
# Compute cosine similarity between predicted and observed intensities for the apex spectrum
622636
most_intens_cos = cosine_similarity(
623637
pred_frag_intens_individual, most_abundant_frag_psm["fragment_intensity"]
624638
)
625639

626-
627640
"""
628641
Get the intensity matrix of observations
629642
"""
@@ -637,7 +650,6 @@ def match_fragments(
637650
[ms2pip_predictions.get(fid, 0.0) for fid in fragment_names]
638651
)
639652

640-
641653
# Collect predictions for keys not listed in fragment_names (i.e., fragments predicted but not observed)
642654
non_matched_predictions = np.array(
643655
[v for k, v in ms2pip_predictions.items() if k not in fragment_names]
@@ -649,7 +661,6 @@ def match_fragments(
649661
sum([ms2pip_predictions.get(fid, 0.0) for fid in fragment_names])
650662
)
651663

652-
653664
# Ensure data types are consistent for downstream calculations
654665
intensity_matrix = intensity_matrix.astype(np.float32)
655666
pred_frag_intens = pred_frag_intens.astype(np.float32)
@@ -700,7 +711,6 @@ def match_fragments(
700711
.ravel() # Flatten the array to 1D
701712
)
702713

703-
704714
# Compute mean squared error between normalized observed and predicted intensities (per PSM, then averaged)
705715
mse_avg_pred_intens = (
706716
abs(intensity_matrix_normalized - pred_frag_intens).sum(axis=1)
@@ -712,7 +722,6 @@ def match_fragments(
712722
+ sum(non_matched_predictions)
713723
) / intensity_matrix_normalized.shape[0]
714724

715-
716725
# Compute correlation matrix for PSM IDs (rows of intensity matrix)
717726
if intensity_matrix_normalized.shape[0] > 1: # Ensure there are multiple PSMs
718727
correlation_matrix_psm_ids = np.corrcoef(
@@ -729,15 +738,13 @@ def match_fragments(
729738
# NOTE: this converts r to R², unlike the fragment correlation matrix below.
730739
correlation_matrix_psm_ids = np.sort(correlation_matrix_psm_ids**2)
731740
else:
732-
733741
# If only one PSM, set all correlation matrices to empty
734742
correlation_matrix_psm_ids = np.array([])
735743

736744
# Compute correlation matrix for fragment IDs (columns of intensity matrix)
737745
if intensity_matrix_normalized.shape[1] > 1:
738746
correlation_matrix_frag_ids = np.corrcoef(intensity_matrix_normalized.T)
739747

740-
741748
# Remove diagonal elements (self-correlation) and flatten to 1D
742749
correlation_matrix_frag_ids = correlation_matrix_frag_ids[
743750
~np.eye(correlation_matrix_frag_ids.shape[0], dtype=bool)
@@ -747,12 +754,9 @@ def match_fragments(
747754
# preserves the sign, allowing detection of anti-correlated fragment pairs.
748755
correlation_matrix_frag_ids = np.sort(correlation_matrix_frag_ids)
749756
else:
750-
751757
# If only one fragment, set all correlation matrices to empty
752758
correlation_matrix_frag_ids = np.array([])
753759

754-
755-
756760
return CorrelationResults(
757761
correlations=correlation_result, # Pearson correlation between predicted and observed intensities
758762
correlations_count=correlation_result_counts, # Count of non-zero fragments entries per PSM
@@ -818,15 +822,13 @@ def get_features_fragment_intensity(
818822
(pl.col("peptide") + "/" + pl.col("charge").cast(pl.Utf8)).alias("precursor")
819823
)
820824

821-
822825
precursor_to_rt_max = dict(
823826
zip(
824827
df_fragment_max_peptide["precursor"].to_list(),
825828
df_fragment_max_peptide["rt"].to_list(),
826829
)
827830
)
828831

829-
830832
df_precursor_rt = pl.DataFrame(
831833
{
832834
"precursor": list(precursor_to_rt_max.keys()),
@@ -844,7 +846,10 @@ def get_features_fragment_intensity(
844846
# If calibrated RT margins are available (rt_lower_margin / rt_higher_margin),
845847
# use them for per-peptidoform adaptive windows. Otherwise fall back to the
846848
# fixed ±filter_max_apex_rt seconds window.
847-
if "rt_lower_margin" in df_fragment.columns and "rt_higher_margin" in df_fragment.columns:
849+
if (
850+
"rt_lower_margin" in df_fragment.columns
851+
and "rt_higher_margin" in df_fragment.columns
852+
):
848853
df_fragment = df_fragment.filter(
849854
(pl.col("rt_max_peptide_sub").is_not_null())
850855
& (
@@ -855,17 +860,22 @@ def get_features_fragment_intensity(
855860
& (pl.col("rt") <= pl.col("rt_higher_margin"))
856861
)
857862
.otherwise(
858-
abs(pl.col("rt") - pl.col("rt_max_peptide_sub")) < filter_max_apex_rt
863+
abs(pl.col("rt") - pl.col("rt_max_peptide_sub"))
864+
< filter_max_apex_rt
859865
)
860866
)
861867
)
862-
log_info("Fragment filtering: using calibrated RT margins (with fixed fallback)")
868+
log_info(
869+
"Fragment filtering: using calibrated RT margins (with fixed fallback)"
870+
)
863871
else:
864872
df_fragment = df_fragment.filter(
865873
(pl.col("rt_max_peptide_sub").is_not_null())
866874
& (abs(pl.col("rt") - pl.col("rt_max_peptide_sub")) < filter_max_apex_rt)
867875
)
868-
log_info(f"Fragment filtering: using fixed ±{filter_max_apex_rt}s window (no margins available)")
876+
log_info(
877+
f"Fragment filtering: using fixed ±{filter_max_apex_rt}s window (no margins available)"
878+
)
869879

870880
for (peptidoform, charge), df_fragment_sub_peptidoform in tqdm(
871881
df_fragment.group_by(["peptide", "charge"])

0 commit comments

Comments
 (0)