22import time
33import warnings
44from typing import Literal
5+ from typing import Tuple
56
67# import anndata as ad
78import numpy as np
1415from pydeseq2 .inference import Inference
1516from pydeseq2 .utils import lowess
1617from pydeseq2 .utils import make_MA_plot
17-
18+ from pydeseq2 . utils import lrt_test
1819
1920class DeseqStats :
2021 """PyDESeq2 statistical tests for differential expression.
@@ -131,6 +132,7 @@ class DeseqStats:
131132 def __init__ (
132133 self ,
133134 dds : DeseqDataSet ,
135+ test : Literal ["wald" , "LRT" ] = "wald" ,
134136 contrast : list [str ] | np .ndarray ,
135137 alpha : float = 0.05 ,
136138 cooks_filter : bool = True ,
@@ -150,6 +152,10 @@ def __init__(
150152
151153 self .dds = dds
152154
155+ if test not in ("wald" , "LRT" ):
156+ raise ValueError (f"Available tests are `wald` and `LRT`. Got: { test } ." )
157+ self .test = test
158+
153159 self .alpha = alpha
154160 self .cooks_filter = cooks_filter
155161 self .independent_filter = independent_filter
@@ -213,12 +219,23 @@ def __init__(
213219 "refitted. Please run 'dds.refit()' first or set 'dds.refit_cooks' "
214220 "to False."
215221 )
222+
223+ self .p_values : pd .Series
224+ self .statistics : pd .Series
225+ self .SE : pd .Series
226+
227+
228+
229+
230+
231+
216232
217233 @property
218234 def variables (self ):
219235 """Get the names of the variables used in the model definition."""
220236 return self .dds .variables
221237
238+
222239 def summary (
223240 self ,
224241 ** kwargs ,
@@ -262,6 +279,11 @@ def summary(
262279 rerun_summary = True
263280 self .run_wald_test ()
264281
282+ if self .test == "wald" :
283+ self .run_wald_test ()
284+ else :
285+ self .run_likelihood_ratio_test ()
286+
265287 if self .cooks_filter :
266288 # Filter p-values based on Cooks outliers
267289 self ._cooks_filtering ()
@@ -280,6 +302,8 @@ def summary(
280302 self .results_df ["baseMean" ] = self .base_mean
281303 self .results_df ["log2FoldChange" ] = self .LFC @ self .contrast_vector / np .log (2 )
282304 self .results_df ["lfcSE" ] = self .SE / np .log (2 )
305+ if self .test == "wald" :
306+ self .results_df ["lfcSE" ] = self .SE / np .log (2 )
283307 self .results_df ["stat" ] = self .statistics
284308 self .results_df ["pvalue" ] = self .p_values
285309 self .results_df ["padj" ] = self .padj
@@ -348,16 +372,85 @@ def run_wald_test(self) -> None:
348372 if not self .quiet :
349373 print (f"... done in { end - start :.2f} seconds.\n " , file = sys .stderr )
350374
351- self .p_values : pd . Series = pd .Series (pvals , index = self .dds .var_names )
352- self .statistics : pd . Series = pd .Series (stats , index = self .dds .var_names )
353- self .SE : pd . Series = pd .Series (se , index = self .dds .var_names )
375+ self .p_values = pd .Series (pvals , index = self .dds .var_names )
376+ self .statistics = pd .Series (stats , index = self .dds .var_names )
377+ self .SE = pd .Series (se , index = self .dds .var_names )
354378
355379 # Account for possible all_zeroes due to outlier refitting in DESeqDataSet
356380 if self .dds .refit_cooks and self .dds .varm ["replaced" ].sum () > 0 :
357381 self .SE .loc [self .dds .new_all_zeroes_genes ] = 0.0
358382 self .statistics .loc [self .dds .new_all_zeroes_genes ] = 0.0
359383 self .p_values .loc [self .dds .new_all_zeroes_genes ] = 1.0
360384
385+ def run_likelihood_ratio_test (self ) -> None :
386+ """Perform a Likelihood Ratio test.
387+ Get gene-wise p-values for gene over/under-expression.
388+ """
389+
390+ num_genes = self .dds .n_vars
391+ num_vars = self .design_matrix .shape [1 ]
392+
393+ # XXX: Raise a warning if LFCs are shrunk.
394+
395+ def reduce (
396+ design_matrix : np .ndarray , ridge_factor : np .ndarray
397+ ) -> Tuple [np .ndarray , np .ndarray ]:
398+ indices = np .full (design_matrix .shape [1 ], True , dtype = bool )
399+ indices [self .contrast_idx ] = False
400+ return design_matrix [:, indices ], ridge_factor [:, indices ][indices ]
401+
402+ # Set regularization factors.
403+ if self .prior_LFC_var is not None :
404+ ridge_factor = np .diag (1 / self .prior_LFC_var ** 2 )
405+ else :
406+ ridge_factor = np .diag (np .repeat (1e-6 , num_vars ))
407+
408+ design_matrix = self .design_matrix .values
409+ LFCs = self .LFC .values
410+
411+ reduced_design_matrix , reduced_ridge_factor = reduce (design_matrix , ridge_factor )
412+ self .dds .obsm ["reduced_design_matrix" ] = reduced_design_matrix
413+
414+ if not self .quiet :
415+ print ("Running LRT tests..." , file = sys .stderr )
416+ start = time .time ()
417+ with parallel_backend ("loky" , inner_max_num_threads = 1 ):
418+ res = Parallel (
419+ n_jobs = self .n_processes ,
420+ verbose = self .joblib_verbosity ,
421+ batch_size = self .batch_size ,
422+ )(
423+ delayed (lrt_test )(
424+ counts = self .dds .X [:, i ],
425+ design_matrix = design_matrix ,
426+ reduced_design_matrix = reduced_design_matrix ,
427+ size_factors = self .dds .obsm ["size_factors" ],
428+ disp = self .dds .varm ["dispersions" ][i ],
429+ lfc = LFCs [i ],
430+ min_mu = self .dds .min_mu ,
431+ ridge_factor = ridge_factor ,
432+ reduced_ridge_factor = reduced_ridge_factor ,
433+ beta_tol = self .dds .beta_tol ,
434+ )
435+ for i in range (num_genes )
436+ )
437+ end = time .time ()
438+ if not self .quiet :
439+ print (f"... done in { end - start :.2f} seconds.\n " , file = sys .stderr )
440+
441+ pvals , stats = zip (* res )
442+
443+ self .p_values = pd .Series (pvals , index = self .dds .var_names )
444+ self .statistics = pd .Series (stats , index = self .dds .var_names )
445+
446+ # Account for possible all_zeroes due to outlier refitting in DESeqDataSet
447+ if self .dds .refit_cooks and self .dds .varm ["replaced" ].sum () > 0 :
448+ self .statistics .loc [self .dds .new_all_zeroes_genes ] = 0.0
449+ self .p_values .loc [self .dds .new_all_zeroes_genes ] = 1.0
450+
451+
452+
453+
361454 # TODO update this to reflect the new contrast format
362455 def lfc_shrink (self , coeff : str , adapt : bool = True ) -> None :
363456 """LFC shrinkage with an apeGLM prior :cite:p:`DeseqStats-zhu2019heavy`.
0 commit comments