Skip to content

Commit d507fba

Browse files
authored
Manually added LRT logic from scverse#178
1 parent 11aed8c commit d507fba

2 files changed

Lines changed: 174 additions & 4 deletions

File tree

pydeseq2/ds.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
import warnings
44
from typing import Literal
5+
from typing import Tuple
56

67
# import anndata as ad
78
import numpy as np
@@ -14,7 +15,7 @@
1415
from pydeseq2.inference import Inference
1516
from pydeseq2.utils import lowess
1617
from pydeseq2.utils import make_MA_plot
17-
18+
from pydeseq2.utils import lrt_test
1819

1920
class DeseqStats:
2021
"""PyDESeq2 statistical tests for differential expression.
@@ -131,6 +132,7 @@ class DeseqStats:
131132
def __init__(
132133
self,
133134
dds: DeseqDataSet,
135+
test: Literal["wald", "LRT"] = "wald",
134136
contrast: list[str] | np.ndarray,
135137
alpha: float = 0.05,
136138
cooks_filter: bool = True,
@@ -150,6 +152,10 @@ def __init__(
150152

151153
self.dds = dds
152154

155+
if test not in ("wald", "LRT"):
156+
raise ValueError(f"Available tests are `wald` and `LRT`. Got: {test}.")
157+
self.test = test
158+
153159
self.alpha = alpha
154160
self.cooks_filter = cooks_filter
155161
self.independent_filter = independent_filter
@@ -213,12 +219,23 @@ def __init__(
213219
"refitted. Please run 'dds.refit()' first or set 'dds.refit_cooks' "
214220
"to False."
215221
)
222+
223+
self.p_values: pd.Series
224+
self.statistics: pd.Series
225+
self.SE: pd.Series
226+
227+
228+
229+
230+
231+
216232

217233
@property
218234
def variables(self):
219235
"""Get the names of the variables used in the model definition."""
220236
return self.dds.variables
221237

238+
222239
def summary(
223240
self,
224241
**kwargs,
@@ -262,6 +279,11 @@ def summary(
262279
rerun_summary = True
263280
self.run_wald_test()
264281

282+
if self.test == "wald":
283+
self.run_wald_test()
284+
else:
285+
self.run_likelihood_ratio_test()
286+
265287
if self.cooks_filter:
266288
# Filter p-values based on Cooks outliers
267289
self._cooks_filtering()
@@ -280,6 +302,8 @@ def summary(
280302
self.results_df["baseMean"] = self.base_mean
281303
self.results_df["log2FoldChange"] = self.LFC @ self.contrast_vector / np.log(2)
282304
self.results_df["lfcSE"] = self.SE / np.log(2)
305+
if self.test == "wald":
306+
self.results_df["lfcSE"] = self.SE / np.log(2)
283307
self.results_df["stat"] = self.statistics
284308
self.results_df["pvalue"] = self.p_values
285309
self.results_df["padj"] = self.padj
@@ -348,16 +372,85 @@ def run_wald_test(self) -> None:
348372
if not self.quiet:
349373
print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr)
350374

351-
self.p_values: pd.Series = pd.Series(pvals, index=self.dds.var_names)
352-
self.statistics: pd.Series = pd.Series(stats, index=self.dds.var_names)
353-
self.SE: pd.Series = pd.Series(se, index=self.dds.var_names)
375+
self.p_values = pd.Series(pvals, index=self.dds.var_names)
376+
self.statistics = pd.Series(stats, index=self.dds.var_names)
377+
self.SE = pd.Series(se, index=self.dds.var_names)
354378

355379
# Account for possible all_zeroes due to outlier refitting in DESeqDataSet
356380
if self.dds.refit_cooks and self.dds.varm["replaced"].sum() > 0:
357381
self.SE.loc[self.dds.new_all_zeroes_genes] = 0.0
358382
self.statistics.loc[self.dds.new_all_zeroes_genes] = 0.0
359383
self.p_values.loc[self.dds.new_all_zeroes_genes] = 1.0
360384

385+
def run_likelihood_ratio_test(self) -> None:
386+
"""Perform a Likelihood Ratio test.
387+
Get gene-wise p-values for gene over/under-expression.
388+
"""
389+
390+
num_genes = self.dds.n_vars
391+
num_vars = self.design_matrix.shape[1]
392+
393+
# XXX: Raise a warning if LFCs are shrunk.
394+
395+
def reduce(
396+
design_matrix: np.ndarray, ridge_factor: np.ndarray
397+
) -> Tuple[np.ndarray, np.ndarray]:
398+
indices = np.full(design_matrix.shape[1], True, dtype=bool)
399+
indices[self.contrast_idx] = False
400+
return design_matrix[:, indices], ridge_factor[:, indices][indices]
401+
402+
# Set regularization factors.
403+
if self.prior_LFC_var is not None:
404+
ridge_factor = np.diag(1 / self.prior_LFC_var**2)
405+
else:
406+
ridge_factor = np.diag(np.repeat(1e-6, num_vars))
407+
408+
design_matrix = self.design_matrix.values
409+
LFCs = self.LFC.values
410+
411+
reduced_design_matrix, reduced_ridge_factor = reduce(design_matrix, ridge_factor)
412+
self.dds.obsm["reduced_design_matrix"] = reduced_design_matrix
413+
414+
if not self.quiet:
415+
print("Running LRT tests...", file=sys.stderr)
416+
start = time.time()
417+
with parallel_backend("loky", inner_max_num_threads=1):
418+
res = Parallel(
419+
n_jobs=self.n_processes,
420+
verbose=self.joblib_verbosity,
421+
batch_size=self.batch_size,
422+
)(
423+
delayed(lrt_test)(
424+
counts=self.dds.X[:, i],
425+
design_matrix=design_matrix,
426+
reduced_design_matrix=reduced_design_matrix,
427+
size_factors=self.dds.obsm["size_factors"],
428+
disp=self.dds.varm["dispersions"][i],
429+
lfc=LFCs[i],
430+
min_mu=self.dds.min_mu,
431+
ridge_factor=ridge_factor,
432+
reduced_ridge_factor=reduced_ridge_factor,
433+
beta_tol=self.dds.beta_tol,
434+
)
435+
for i in range(num_genes)
436+
)
437+
end = time.time()
438+
if not self.quiet:
439+
print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr)
440+
441+
pvals, stats = zip(*res)
442+
443+
self.p_values = pd.Series(pvals, index=self.dds.var_names)
444+
self.statistics = pd.Series(stats, index=self.dds.var_names)
445+
446+
# Account for possible all_zeroes due to outlier refitting in DESeqDataSet
447+
if self.dds.refit_cooks and self.dds.varm["replaced"].sum() > 0:
448+
self.statistics.loc[self.dds.new_all_zeroes_genes] = 0.0
449+
self.p_values.loc[self.dds.new_all_zeroes_genes] = 1.0
450+
451+
452+
453+
361454
# TODO update this to reflect the new contrast format
362455
def lfc_shrink(self, coeff: str, adapt: bool = True) -> None:
363456
"""LFC shrinkage with an apeGLM prior :cite:p:`DeseqStats-zhu2019heavy`.

pydeseq2/utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from scipy.optimize import minimize # type: ignore
1313
from scipy.special import gammaln # type: ignore
1414
from scipy.special import polygamma # type: ignore
15+
from scipy.stats import chi2 # type: ignore
1516
from scipy.stats import norm # type: ignore
1617
from sklearn.linear_model import LinearRegression # type: ignore
1718

@@ -811,6 +812,82 @@ def less_abs(lfc_null):
811812

812813
return wald_p_value, wald_statistic, wald_se
813814

815+
def lrt_test(
816+
counts: np.ndarray,
817+
design_matrix: np.ndarray,
818+
reduced_design_matrix: np.ndarray,
819+
size_factors: np.ndarray,
820+
disp: float,
821+
lfc: np.ndarray,
822+
min_mu: float,
823+
ridge_factor: np.ndarray,
824+
reduced_ridge_factor: np.ndarray,
825+
beta_tol: float,
826+
) -> Tuple[float, float]:
827+
"""Run likelihood ratio test for differential expression.
828+
Compute likelihood ratio test statistics and p-values from
829+
dispersion and LFC estimates.
830+
Parameters
831+
----------
832+
counts : ndarray
833+
Raw counts for a given gene.
834+
design_matrix : ndarray
835+
Design matrix.
836+
reduced_design_matrix : ndarray
837+
Reduced design matrix.
838+
size_factors : ndarray
839+
DESeq2 normalization factors.
840+
disp : float
841+
Dispersion estimate.
842+
lfc : ndarray
843+
Log-fold change estimate (in natural log scale).
844+
min_mu : float
845+
Lower bound on estimated means, to ensure numerical stability.
846+
(default: ``0.5``).
847+
ridge_factor : ndarray
848+
Regularization factors.
849+
reduced_ridge_factor : ndarray
850+
Reduced regularization factors.
851+
beta_tol : float
852+
Stopping criterion for IRWLS:
853+
:math:`\vert dev - dev_{old}\vert / \vert dev + 0.1 \vert < \beta_{tol}`.
854+
(default: ``1e-8``).
855+
Returns
856+
-------
857+
lrt_p_value : float
858+
Estimated p-value.
859+
lrt_statistic : float
860+
LRT statistic.
861+
"""
862+
863+
def reg_nb_nll(
864+
beta: np.ndarray, design_matrix: np.ndarray, ridge_factor: np.ndarray
865+
) -> float:
866+
# closure to minimize
867+
mu_ = np.maximum(size_factors * np.exp(design_matrix @ beta), min_mu)
868+
val = nb_nll(counts, mu_, disp) + 0.5 * (ridge_factor @ beta**2).sum()
869+
return -1.0 * val # maximize the likelihood
870+
871+
beta_reduced, *_ = irls_solver(
872+
counts=counts,
873+
size_factors=size_factors,
874+
design_matrix=reduced_design_matrix,
875+
disp=disp,
876+
min_mu=min_mu,
877+
beta_tol=beta_tol,
878+
)
879+
880+
reduced_ll = reg_nb_nll(beta_reduced, reduced_design_matrix, reduced_ridge_factor)
881+
full_ll = reg_nb_nll(lfc, design_matrix, ridge_factor)
882+
883+
lrt_statistic = 2 * (full_ll - reduced_ll)
884+
# df = 1 since contrast_idx is the only variable removed
885+
lrt_p_value = chi2.sf(lrt_statistic, df=1)
886+
887+
print(lrt_p_value)
888+
print(lrt_statistic)
889+
890+
return lrt_p_value, lrt_statistic
814891

815892
def fit_rough_dispersions(
816893
normed_counts: np.ndarray, design_matrix: pd.DataFrame

0 commit comments

Comments
 (0)