Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 59 additions & 34 deletions pertpy/tools/_differential_gene_expression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import seaborn as sns
from matplotlib.pyplot import Figure
from matplotlib.ticker import MaxNLocator
from scverse_misc import Deprecation, deprecated_arg

from pertpy._doc import _doc_params, doc_common_plot_args
from pertpy._logger import logger
Expand Down Expand Up @@ -90,15 +91,21 @@ def compare_groups(
...

@_doc_params(common_plot_args=doc_common_plot_args)
@deprecated_arg(
"pval_thresh",
Deprecation("1.1.0", "Use `padj_threshold`."),
)
@deprecated_arg("pvalue_col", Deprecation("1.1.0", "Use `padj_col`."), stacklevel=2)
@deprecated_arg("log2fc_thresh", Deprecation("1.1.0", "Use `log2fc_threshold`."), stacklevel=3)
def plot_volcano( # pragma: no cover # noqa: D417
self,
data: pd.DataFrame | ad.AnnData,
*,
log2fc_threshold: float = 0.75,
padj_threshold: float = 0.05,
log2fc_col: str = "log_fc",
pvalue_col: str = "adj_p_value",
padj_col: str = "adj_p_value",
symbol_col: str = "variable",
pval_thresh: float = 0.05,
log2fc_thresh: float = 0.75,
to_label: int | list[str] = 5,
s_curve: bool | None = False,
colors: list[str] = None,
Expand All @@ -116,20 +123,23 @@ def plot_volcano( # pragma: no cover # noqa: D417
x_label: str | None = None,
y_label: str | None = None,
return_fig: bool = False,
log2fc_thresh: float | None = None,
pval_thresh: float | None = None,
pvalue_col: str | None = None,
**kwargs: int,
) -> Figure | None:
"""Creates a volcano plot from a pandas DataFrame or Anndata.

Args:
data: DataFrame or Anndata to plot.
log2fc_threshold: Threshold for log2 fold change significance.
padj_threshold: Adjusted p-values for significance.
log2fc_col: Column name of log2 Fold-Change values.
pvalue_col: Column name of the p values.
padj_col: Column name of adjusted p-values.
symbol_col: Column name of gene IDs.
varm_key: Key in Anndata.varm slot to use for plotting if an Anndata object was passed.
size_col: Column name to size points by.
point_sizes: Lower and upper bounds of point sizes.
pval_thresh: Threshold p value for significance.
log2fc_thresh: Threshold for log2 fold change significance.
to_label: Number of top genes or list of genes to label.
s_curve: Whether to use a reciprocal threshold for up and down gene determination.
color_dict: Dictionary for coloring dots by categories.
Expand All @@ -143,6 +153,9 @@ def plot_volcano( # pragma: no cover # noqa: D417
shape_order: Order of categories for shapes.
x_label: Label for the x-axis.
y_label: Label for the y-axis.
log2fc_thresh: Deprecated and will be removed in a future release. Use `log2fc_threshold`.
pval_thresh: Deprecated and will be removed in a future release. Use `padj_threshold`.
pvalue_col: Deprecated and will be removed in a future release. Use `padj_col`.
{common_plot_args}
**kwargs: Additional arguments for seaborn.scatterplot.

Expand All @@ -161,11 +174,18 @@ def plot_volcano( # pragma: no cover # noqa: D417
>>> res_df = edgr.test_contrasts(
... edgr.contrast(column="Treatment", baseline="Chemo", group_to_compare="Anti-PD-L1+Chemo")
... )
>>> edgr.plot_volcano(res_df, log2fc_thresh=0)
>>> edgr.plot_volcano(res_df, log2fc_threshold=0)

Preview:
.. image:: /_static/docstring_previews/de_volcano.png
"""
if pvalue_col is not None:
padj_col = pvalue_col
if pval_thresh is not None:
padj_threshold = pval_thresh
if log2fc_thresh is not None:
log2fc_threshold = log2fc_thresh

if colors is None:
colors = ["gray", "#D62728", "#1F77B4"]

Expand All @@ -174,7 +194,7 @@ def _pval_reciprocal(lfc: float) -> float:

Used for plotting the S-curve
"""
return pval_thresh / (lfc - log2fc_thresh)
return padj_threshold / (lfc - log2fc_threshold)

def _map_shape(symbol: str) -> str:
if shape_dict is not None:
Expand All @@ -188,8 +208,8 @@ def _map_genes_categories(
row: pd.Series,
log2fc_col: str,
nlog10_col: str,
log2fc_thresh: float,
pval_thresh: float = None,
log2fc_threshold: float,
padj_threshold: float = None,
s_curve: bool = False,
) -> str:
"""Map genes to categorize based on log2fc and pvalue.
Expand All @@ -203,16 +223,16 @@ def _map_genes_categories(
if s_curve:
# S-curve condition for Up or Down categorization
reciprocal_thresh = _pval_reciprocal(abs(log2fc))
if log2fc > log2fc_thresh and nlog10 > reciprocal_thresh:
if log2fc > log2fc_threshold and nlog10 > reciprocal_thresh:
return "Up"
elif log2fc < -log2fc_thresh and nlog10 > reciprocal_thresh:
elif log2fc < -log2fc_threshold and nlog10 > reciprocal_thresh:
return "Down"
else:
return "not DE"
# Standard condition for Up or Down categorization
elif log2fc > log2fc_thresh and nlog10 > pval_thresh:
elif log2fc > log2fc_threshold and nlog10 > padj_threshold:
return "Up"
elif log2fc < -log2fc_thresh and nlog10 > pval_thresh:
elif log2fc < -log2fc_threshold and nlog10 > padj_threshold:
return "Down"
else:
return "not DE"
Expand All @@ -221,8 +241,8 @@ def _map_genes_categories_highlight(
row: pd.Series,
log2fc_col: str,
nlog10_col: str,
log2fc_thresh: float,
pval_thresh: float = None,
log2fc_threshold: float,
padj_threshold: float = None,
s_curve: bool = False,
symbol_col: str = None,
) -> str:
Expand All @@ -242,12 +262,12 @@ def _map_genes_categories_highlight(

if s_curve:
# Use S-curve condition for filtering DE
if nlog10 > _pval_reciprocal(abs(log2fc)) and abs(log2fc) > log2fc_thresh:
if nlog10 > _pval_reciprocal(abs(log2fc)) and abs(log2fc) > log2fc_threshold:
return "DE"
return "not DE"
else:
# Use standard condition for filtering DE
if abs(log2fc) < log2fc_thresh or nlog10 < pval_thresh:
if abs(log2fc) < log2fc_threshold or nlog10 < padj_threshold:
return "not DE"
return "DE"

Expand All @@ -261,18 +281,18 @@ def _map_genes_categories_highlight(
df = data.copy(deep=True)

# clean and replace 0s as they would lead to -inf
if df[[log2fc_col, pvalue_col]].isnull().values.any():
if df[[log2fc_col, padj_col]].isnull().values.any():
print("NaNs encountered, dropping rows with NaNs")
df = df.dropna(subset=[log2fc_col, pvalue_col])
df = df.dropna(subset=[log2fc_col, padj_col])

if df[pvalue_col].min() == 0:
if df[padj_col].min() == 0:
print("0s encountered for p value, replacing with 1e-323")
df.loc[df[pvalue_col] == 0, pvalue_col] = 1e-323
df.loc[df[padj_col] == 0, padj_col] = 1e-323

# convert p value threshold to nlog10
pval_thresh = -np.log10(pval_thresh)
padj_threshold = -np.log10(padj_threshold)
# make nlog10 column
df["nlog10"] = -np.log10(df[pvalue_col])
df["nlog10"] = -np.log10(df[padj_col])
y_max = df["nlog10"].max() + 1
# make a column to pick top genes
df["top_genes"] = df["nlog10"] * df[log2fc_col]
Expand Down Expand Up @@ -307,8 +327,8 @@ def _map_genes_categories_highlight(
row,
log2fc_col=log2fc_col,
nlog10_col="nlog10",
log2fc_thresh=log2fc_thresh,
pval_thresh=pval_thresh,
log2fc_threshold=log2fc_threshold,
padj_threshold=padj_threshold,
s_curve=s_curve,
),
axis=1,
Expand All @@ -323,8 +343,8 @@ def _map_genes_categories_highlight(
row,
log2fc_col=log2fc_col,
nlog10_col="nlog10",
log2fc_thresh=log2fc_thresh,
pval_thresh=pval_thresh,
log2fc_threshold=log2fc_threshold,
padj_threshold=padj_threshold,
symbol_col=symbol_col,
s_curve=s_curve,
),
Expand Down Expand Up @@ -411,15 +431,15 @@ def _map_genes_categories_highlight(

# plot vertical and horizontal lines
if s_curve:
x = np.arange((log2fc_thresh + 0.000001), y_max, 0.01)
x = np.arange((log2fc_threshold + 0.000001), y_max, 0.01)
y = _pval_reciprocal(x)
ax.plot(x, y, zorder=1, c="k", lw=2, ls="--")
ax.plot(-x, y, zorder=1, c="k", lw=2, ls="--")

else:
ax.axhline(pval_thresh, zorder=1, c="k", lw=2, ls="--")
ax.axvline(log2fc_thresh, zorder=1, c="k", lw=2, ls="--")
ax.axvline(log2fc_thresh * -1, zorder=1, c="k", lw=2, ls="--")
ax.axhline(padj_threshold, zorder=1, c="k", lw=2, ls="--")
ax.axvline(log2fc_threshold, zorder=1, c="k", lw=2, ls="--")
ax.axvline(log2fc_threshold * -1, zorder=1, c="k", lw=2, ls="--")
plt.ylim(0, y_max)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

Expand Down Expand Up @@ -454,7 +474,7 @@ def _map_genes_categories_highlight(
if x_label is None:
x_label = log2fc_col
if y_label is None:
y_label = f"-$log_{{10}}$ {pvalue_col}"
y_label = f"-$log_{{10}}$ {padj_col}"

plt.xlabel(x_label, size=15)
plt.ylabel(y_label, size=15)
Expand Down Expand Up @@ -651,6 +671,8 @@ def plot_fold_change( # pragma: no cover # noqa: D417
*,
var_names: Sequence[str] = None,
n_top_vars: int = 15,
padj_threshold: float = 0.01,
padj_col: str = "adj_p_value",
log2fc_col: str = "log_fc",
symbol_col: str = "variable",
y_label: str = "Log2 fold change",
Expand All @@ -664,6 +686,8 @@ def plot_fold_change( # pragma: no cover # noqa: D417
results_df: DataFrame with results from DE analysis.
var_names: Variables to plot. If None, the top n_top_vars variables based on the log2 fold change are plotted.
n_top_vars: Number of top variables to plot. The top and bottom n_top_vars variables are plotted, respectively.
padj_threshold: Only variables with adjusted p-values below this threshold are included in the plot.
padj_col: Column name of adjusted p-values.
log2fc_col: Column name of log2 Fold-Change values.
symbol_col: Column name of gene IDs.
y_label: Label for the y-axis.
Expand Down Expand Up @@ -691,12 +715,13 @@ def plot_fold_change( # pragma: no cover # noqa: D417
Preview:
.. image:: /_static/docstring_previews/de_fold_change.png
"""
results_df = results_df[results_df[padj_col] < padj_threshold]
if var_names is None:
var_names = results_df.sort_values(log2fc_col, ascending=False).head(n_top_vars)[symbol_col].tolist()
var_names += results_df.sort_values(log2fc_col, ascending=True).head(n_top_vars)[symbol_col].tolist()
assert len(var_names) == 2 * n_top_vars

df = results_df[results_df[symbol_col].isin(var_names)]
df = results_df[results_df[symbol_col].isin(var_names)].copy()
df.sort_values(log2fc_col, ascending=False, inplace=True)

plt.figure(figsize=figsize)
Expand Down
3 changes: 3 additions & 0 deletions pertpy/tools/_differential_gene_expression/_pydeseq2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings

import numpy as np
import pandas as pd
Expand All @@ -12,6 +13,8 @@
from ._base import LinearModelBase
from ._checks import check_is_integer_matrix

warnings.filterwarnings("always", message=".*(pval_thresh|pvalue_col|log2fc_thresh).*")


class PyDESeq2(LinearModelBase):
"""Differential expression test using a PyDESeq2."""
Expand Down
16 changes: 13 additions & 3 deletions pertpy/tools/_enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scanpy.tools._score_genes import _sparse_nanmean
from scipy.sparse import issparse
from scipy.stats import hypergeom
from scverse_misc import Deprecation, deprecated_arg
from statsmodels.stats.multitest import multipletests

from pertpy._doc import _doc_params, doc_common_plot_args
Expand Down Expand Up @@ -147,13 +148,18 @@ def score(
adata.uns[f"{key_added}_genes"]["var"].loc[drug, "genes"] = "|".join(adata.var_names[targets[drug]])
adata.uns[f"{key_added}_all_genes"]["var"].loc[drug, "all_genes"] = "|".join(full_targets[drug])

@deprecated_arg(
"pvals_adj_thresh",
Deprecation("1.0.6", "Use `padj_threshold`."),
)
def hypergeometric(
self,
adata: AnnData,
targets: dict[str, list[str] | dict[str, list[str]]] | None = None,
nested: bool = False,
categories: str | list[str] | None = None,
pvals_adj_thresh: float = 0.05,
padj_threshold: float = 0.05,
pvals_adj_thresh: float | None = None,
direction: str = "both",
corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg",
):
Expand All @@ -170,16 +176,20 @@ def hypergeometric(
nested: Whether `targets` is a dictionary of dictionaries with group categories as keys.
categories: If `targets=None` or `nested=True`, this argument can be used to subset the gene groups to one or more categories (keys of the original dictionary).
In case of the ChEMBL drug targets, these are ATC level 1/level 2 category codes.
pvals_adj_thresh: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
padj_threshold: The `pvals_adj` cutoff to use on the `sc.tl.rank_genes_groups()` output to identify markers.
direction: Whether to seek out up/down-regulated genes for the groups, based on the values from `scores`.
Can be `up`, `down`, or `both` (for no selection).
corr_method: Which FDR correction to apply to the p-values of the hypergeometric test.
Can be `benjamini-hochberg` or `bonferroni`.
pvals_adj_thresh: Deprecated and will be removed in a future release. Use `padj_threshold`.

Returns:
Dictionary with clusters for which the original object markers were computed as the keys,
and data frames of test results sorted on q-value as the items.
"""
if pvals_adj_thresh is not None:
padj_threshold = pvals_adj_thresh

universe = set(adata.var_names)
targets = _prepare_targets(targets=targets, nested=nested, categories=categories) # type: ignore
for group in targets:
Expand All @@ -201,7 +211,7 @@ def hypergeometric(
"pvals_adj",
],
)
mask = adata.uns["rank_genes_groups"]["pvals_adj"][cluster] < pvals_adj_thresh
mask = adata.uns["rank_genes_groups"]["pvals_adj"][cluster] < padj_threshold
if direction == "up":
mask = mask & (adata.uns["rank_genes_groups"]["scores"][cluster] > 0)
elif direction == "down":
Expand Down
Loading
Loading