Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cell-eval"
version = "0.6.6"
version = "0.6.7"
description = "Evaluation metrics for single-cell perturbation predictions"
readme = "README.md"
authors = [
Expand Down
30 changes: 30 additions & 0 deletions src/cell_eval/_cli/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def parse_args_run(parser: ap.ArgumentParser):
type=str,
help="Metrics to skip (comma-separated for multiple) (see docs for more details)",
)
parser.add_argument(
"--fdr-threshold",
type=float,
default=0.05,
help="FDR threshold for DE significance [default: %(default)s]",
)
parser.add_argument(
"--version",
action="version",
Expand Down Expand Up @@ -142,6 +148,30 @@ def run_evaluation(args: ap.Namespace):
else {}
)

# Add fdr_threshold to all DE metrics that accept it
de_metrics_with_fdr = [
"de_spearman_sig",
"de_direction_match",
"de_spearman_lfc_sig",
"de_sig_genes_recall",
"de_nsig_counts",
"pr_auc",
"roc_auc",
# overlap/precision metrics
"overlap_at_N",
"overlap_at_50",
"overlap_at_100",
"overlap_at_200",
"overlap_at_500",
"precision_at_N",
"precision_at_50",
"precision_at_100",
"precision_at_200",
"precision_at_500",
]
for metric_name in de_metrics_with_fdr:
metric_kwargs.setdefault(metric_name, {})["fdr_threshold"] = args.fdr_threshold
Comment on lines +151 to +173
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block duplicates logic for setting the fdr_threshold that is now handled within the MetricsEvaluator class (in _build_de_metric_configs). To avoid code duplication and improve maintainability, this block should be removed. The fdr_threshold from the command-line arguments should be passed directly to the MetricsEvaluator constructor instead.

After removing this, you'll need to update the MetricsEvaluator instantiations in this file to include fdr_threshold=args.fdr_threshold. I cannot suggest this change directly as it is outside the diff.


skip_metrics = args.skip_metrics.split(",") if args.skip_metrics else None

if args.celltype_col is not None:
Expand Down
35 changes: 34 additions & 1 deletion src/cell_eval/_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class MetricsEvaluator:
pdex_kwargs: dict[str, Any] | None = None
Keyword arguments for parallel_differential_expression.
These will overwrite arguments passed to MetricsEvaluator.__init__ if they conflict.
fdr_threshold: float = 0.05
FDR threshold for DE significance used in DE metrics.
"""

def __init__(
Expand All @@ -71,6 +73,7 @@ def __init__(
prefix: str | None = None,
pdex_kwargs: dict[str, Any] | None = None,
skip_de: bool = False,
fdr_threshold: float = 0.05,
):
# Enable a global string cache for categorical columns
pl.enable_string_cache()
Expand Down Expand Up @@ -107,6 +110,7 @@ def __init__(

self.outdir = outdir
self.prefix = prefix
self.fdr_threshold = fdr_threshold

def compute(
self,
Expand All @@ -117,9 +121,13 @@ def compute(
write_csv: bool = True,
break_on_error: bool = False,
) -> tuple[pl.DataFrame, pl.DataFrame]:
# Inject fdr_threshold into DE metric configs
de_metric_configs = _build_de_metric_configs(self.fdr_threshold)
merged_configs = {**de_metric_configs, **(metric_configs or {})}

pipeline = MetricPipeline(
profile=profile,
metric_configs=metric_configs,
metric_configs=merged_configs,
break_on_error=break_on_error,
)
if skip_metrics is not None:
Expand Down Expand Up @@ -156,6 +164,31 @@ def compute(
return results, agg_results


def _build_de_metric_configs(fdr_threshold: float) -> dict[str, dict[str, Any]]:
"""Build metric configs with fdr_threshold for all DE metrics that accept it."""
de_metrics_with_fdr = [
"de_spearman_sig",
"de_direction_match",
"de_spearman_lfc_sig",
"de_sig_genes_recall",
"de_nsig_counts",
"pr_auc",
"roc_auc",
# overlap/precision metrics
"overlap_at_N",
"overlap_at_50",
"overlap_at_100",
"overlap_at_200",
"overlap_at_500",
"precision_at_N",
"precision_at_50",
"precision_at_100",
"precision_at_200",
"precision_at_500",
]
return {metric: {"fdr_threshold": fdr_threshold} for metric in de_metrics_with_fdr}
Comment on lines +167 to +189
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of hardcoding the list of metrics that accept fdr_threshold, you can generate it dynamically by inspecting the signatures of the registered DE metrics. This would make the code more robust and easier to maintain, as you wouldn't need to update this list manually when adding or modifying metrics.

Here's a suggested implementation that uses the inspect module. You'll also need to add the following imports at the top of the file:

import inspect
from .metrics import metrics_registry
from ._types import MetricType
def _build_de_metric_configs(fdr_threshold: float) -> dict[str, dict[str, Any]]:
    """Build metric configs with fdr_threshold for all DE metrics that accept it."""
    de_metrics_with_fdr = []
    for metric_name in metrics_registry.list_metrics(MetricType.DE):
        metric_info = metrics_registry.get_metric(metric_name)
        func_to_inspect = metric_info.func
        if metric_info.is_class:
            func_to_inspect = func_to_inspect.__init__

        sig = inspect.signature(func_to_inspect)
        if "fdr_threshold" in sig.parameters:
            de_metrics_with_fdr.append(metric_name)

    return {metric: {"fdr_threshold": fdr_threshold} for metric in de_metrics_with_fdr}



def _build_anndata_pair(
real: ad.AnnData | str,
pred: ad.AnnData | str,
Expand Down
15 changes: 10 additions & 5 deletions src/cell_eval/metrics/_de.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,24 @@ def __call__(self, data: DEComparison) -> dict[str, dict[str, int]]:
return counts


def compute_pr_auc(data: DEComparison) -> dict[str, float]:
def compute_pr_auc(
data: DEComparison, fdr_threshold: float = 0.05
) -> dict[str, float]:
"""Compute precision-recall AUC per perturbation for significant recovery."""
return compute_generic_auc(data, method="pr")
return compute_generic_auc(data, method="pr", fdr_threshold=fdr_threshold)


def compute_roc_auc(data: DEComparison) -> dict[str, float]:
def compute_roc_auc(
data: DEComparison, fdr_threshold: float = 0.05
) -> dict[str, float]:
"""Compute ROC AUC per perturbation for significant recovery."""
return compute_generic_auc(data, method="roc")
return compute_generic_auc(data, method="roc", fdr_threshold=fdr_threshold)


def compute_generic_auc(
data: DEComparison,
method: Literal["pr", "roc"] = "pr",
fdr_threshold: float = 0.05,
) -> dict[str, float]:
"""Compute AUC values for significant recovery per perturbation."""

Expand All @@ -221,7 +226,7 @@ def compute_generic_auc(
pred_fdr_col = data.pred.fdr_col

labeled_real = data.real.data.with_columns(
(pl.col(real_fdr_col) < 0.05).cast(pl.Float32).alias("label")
(pl.col(real_fdr_col) < fdr_threshold).cast(pl.Float32).alias("label")
).select([target_col, feature_col, "label"])

merged = (
Expand Down
Loading