Skip to content

Commit 2db6b9c

Browse files
authored
Merge pull request #225 from ArcInstitute/cell-eval-0.7.0
Cell eval 0.7.0
2 parents 8b8e4a7 + e629a20 commit 2db6b9c

16 files changed

Lines changed: 247 additions & 176 deletions

File tree

.github/workflows/CI.yml

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ on: [push, pull_request]
55
jobs:
66
all_jobs:
77
runs-on: ubuntu-latest
8-
needs: [formatting, pytest, cli-test]
8+
needs: [formatting, typing, pytest, cli-test]
99
steps:
1010
- name: Complete
1111
run: echo "Complete"
@@ -50,7 +50,7 @@ jobs:
5050
run: |
5151
uv run ruff format --check
5252
53-
pytest:
53+
typing:
5454
runs-on: ubuntu-latest
5555

5656
needs: [install-job]
@@ -69,6 +69,33 @@ jobs:
6969
run: |
7070
uv sync --all-extras --dev
7171
72+
- name: run type checking
73+
run: |
74+
uv run ty check
75+
76+
pytest:
77+
runs-on: ubuntu-latest
78+
79+
needs: [install-job]
80+
81+
strategy:
82+
matrix:
83+
python-version: ["3.12", "3.13", "3.14"]
84+
85+
steps:
86+
- uses: actions/checkout@v4
87+
88+
- name: install uv
89+
uses: astral-sh/setup-uv@v5
90+
with:
91+
enable-cache: true
92+
cache-dependency-glob: "pyproject.toml"
93+
python-version: "${{ matrix.python-version }}"
94+
95+
- name: install dependencies
96+
run: |
97+
uv sync --all-extras --dev
98+
7299
- name: run pytest
73100
run: |
74101
uv run pytest -v

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.12
1+
3.14

CLAUDE.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
**cell-eval** is a Python package and CLI tool for evaluating the performance of models that predict cellular responses to perturbations at the single-cell level. Developed by the Arc Research Institute.
8+
9+
It generally revolves around a *real* anndata and a *predicted* anndata where it measures the general differences between the two across a variety of metrics.
10+
11+
- Python 3.11–3.12, managed with **UV** and built with **hatchling**
12+
- CLI entry point: `cell-eval` (defined in `src/cell_eval/__main__.py`)
13+
14+
## Common Commands
15+
16+
```bash
17+
# Install dependencies
18+
uv sync --all-extras --dev
19+
20+
# Run all tests
21+
uv run pytest -v
22+
23+
# Run a single test
24+
uv run pytest tests/test_eval.py::test_broken_adata_not_normlog -v
25+
26+
# Formatting (check / fix)
27+
uv run ruff format --check
28+
uv run ruff format
29+
30+
# Type checking
31+
uv run ty check
32+
33+
# Verify CLI works
34+
uv run cell-eval --help
35+
```
36+
37+
CI runs: formatting, typing, pytest, and cli-test (see `.github/workflows/CI.yml`).
38+
39+
## Architecture
40+
41+
### Core Data Flow
42+
43+
```
44+
AnnData inputs (predicted + real)
45+
→ MetricsEvaluator (validation, normalization, DE computation)
46+
→ MetricPipeline (profile-based metric selection + execution)
47+
→ metrics_registry (global MetricRegistry instance)
48+
→ individual metric functions
49+
→ polars DataFrames (per-perturbation + aggregated results)
50+
```
51+
52+
### Key Abstractions
53+
54+
- **`MetricsEvaluator`** (`src/cell_eval/_evaluator.py`) — Main programmatic entry point. Validates input AnnData objects, computes differential expression via `pdex`, and orchestrates the metric pipeline.
55+
56+
- **`MetricRegistry`** (`src/cell_eval/metrics/_registry.py`) — Global singleton `metrics_registry`. Metrics are registered with a name, type (`DE` or `ANNDATA_PAIR`), compute function, and best-value indicator. Supports both plain functions and class-based metrics requiring instantiation.
57+
58+
- **`MetricPipeline`** (`src/cell_eval/_pipeline/_runner.py`) — Selects and runs metrics based on a profile (`full`, `minimal`, `vcc`, `de`, `anndata`, `pds`). Collects per-perturbation results and aggregates them.
59+
60+
- **`Metric` protocol** (`src/cell_eval/metrics/base.py`) — All metric functions take either a `PerturbationAnndataPair` or `DEComparison` and return `float | dict[str, float]`.
61+
62+
- **Type system** (`src/cell_eval/_types/`) — Immutable dataclasses: `PerturbationAnndataPair`, `DEComparison`, plus enums `MetricType`, `MetricBestValue`, `DESortBy`.
63+
64+
### Metrics
65+
66+
Metrics are split into two categories registered in `src/cell_eval/metrics/_impl.py`:
67+
68+
- **AnnData metrics** (`_anndata.py`): pearson_delta, mse, mae, mse_delta, mae_delta, discrimination_score, clustering_agreement, edistance
69+
- **DE metrics** (`_de.py`): overlap/precision at N, spearman correlations, direction match, significant gene recall, ROC/PR AUC
70+
71+
### CLI
72+
73+
Subcommands in `src/cell_eval/_cli/`: `prep` (data preparation for VCC), `run` (evaluation), `baseline` (create baseline), `score` (normalize against baseline). CLI defaults are in `_cli/_const.py`.
74+
75+
### Test Data Utilities
76+
77+
`cell_eval.data` provides `build_random_anndata()` and `downsample_cells()` for generating synthetic AnnData objects in tests.
78+
79+
## Conventions
80+
81+
- Uses `polars` (not pandas) for DataFrames
82+
- Uses `match`/`case` statements (Python 3.10+ syntax)
83+
- Type hints throughout; PEP 561 `py.typed` marker present
84+
- Private modules prefixed with `_` (public API is re-exported from `__init__.py`)

pyproject.toml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
[project]
22
name = "cell-eval"
3-
version = "0.6.8"
3+
version = "0.7.0"
44
description = "Evaluation metrics for single-cell perturbation predictions"
55
readme = "README.md"
66
authors = [
77
{ name = "Noam Teyssier", email = "noam.teyssier@arcinstitute.org" },
88
{ name = "Abhinav Adduri", email = "abhinav.adduri@arcinstitute.org" },
99
{ name = "Yusuf Roohani", email = "yusuf.roohani@arcinstitute.org" },
1010
]
11-
requires-python = ">=3.10,<3.13"
11+
requires-python = ">=3.11"
1212
dependencies = [
1313
"igraph>=0.11.8",
14-
"pdex>=0.1.26",
14+
"pdex>=0.2.0",
1515
"polars>=1.30.0",
1616
"pyyaml>=6.0.2",
1717
"scanpy>=1.10.3",
1818
"pyarrow>=18.0.0",
1919
"tqdm>=4.67.1",
20+
"anndata>=0.12.10",
2021
]
2122

2223
[build-system]
2324
requires = ["hatchling"]
2425
build-backend = "hatchling.build"
2526

2627
[dependency-groups]
27-
dev = ["ipykernel>=6.29.5", "pytest>=8.3.5", "ruff>=0.11.8"]
28+
dev = [
29+
"ipykernel>=6.29.5",
30+
"pytest>=8.3.5",
31+
"ruff>=0.11.8",
32+
"ty>=0.0.19",
33+
]
2834

2935
[project.scripts]
3036
cell-eval = "cell_eval.__main__:main"
31-
32-
[tool.pyright]
33-
venvPath = "."
34-
venv = ".venv"

ruff.toml

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/cell_eval/_baseline.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
2-
from typing import Any
2+
from typing import Any, cast
33

44
import anndata as ad
55
import numpy as np
6+
import pandas as pd
67
import polars as pl
78
from numpy.typing import NDArray
8-
from pdex import parallel_differential_expression
9+
from pdex import pdex
910
from scipy.sparse import issparse
1011

1112
from ._evaluator import _build_pdex_kwargs, _convert_to_normlog
@@ -23,9 +24,7 @@ def build_base_mean_adata(
2324
allow_discrete: bool = False,
2425
output_path: str | None = None,
2526
output_de_path: str | None = None,
26-
batch_size: int = 1000,
2727
num_threads: int = 1,
28-
de_method: str = "wilcoxon",
2928
pdex_kwargs: dict[str, Any] = {},
3029
) -> ad.AnnData:
3130
if isinstance(adata, str):
@@ -67,7 +66,7 @@ def build_base_mean_adata(
6766
(int(counts[counts_col].sum()), baseline.size),
6867
baseline,
6968
),
70-
var=adata.var,
69+
var=cast(pd.DataFrame, adata.var),
7170
obs=obs,
7271
)
7372

@@ -78,21 +77,20 @@ def build_base_mean_adata(
7877

7978
if output_path is not None:
8079
logger.info(f"Saving baseline data to {output_path}")
81-
baseline_adata.write_h5ad(output_path)
80+
baseline_adata.write_h5ad(output_path) # type: ignore[invalid-argument-type]
8281

8382
if output_de_path is not None:
8483
logger.info("Calculating differential expression")
8584
pdex_kwargs = _build_pdex_kwargs(
86-
groupby_key=pert_col,
85+
groupby=pert_col,
8786
reference=control_pert,
88-
num_workers=num_threads,
89-
metric=de_method,
90-
batch_size=batch_size,
87+
threads=num_threads,
9188
allow_discrete=allow_discrete,
9289
pdex_kwargs=pdex_kwargs,
9390
)
94-
frame = parallel_differential_expression(
91+
frame = pdex(
9592
adata=baseline_adata,
93+
mode="ref",
9694
**pdex_kwargs,
9795
)
9896
logger.info(f"Saving differential expression results to {output_de_path}")
@@ -137,9 +135,9 @@ def _build_counts_df_from_adata(
137135
raise ValueError(
138136
f"Column '{pert_col}' not found in adata.obs: {adata.obs.columns}"
139137
)
140-
if control_pert not in adata.obs[pert_col].unique():
138+
if control_pert not in cast(pd.Series, adata.obs[pert_col]).unique():
141139
raise ValueError(
142-
f"Control pert '{control_pert}' not found in adata.obs[{pert_col}]: {adata.obs[pert_col].unique()}"
140+
f"Control pert '{control_pert}' not found in adata.obs[{pert_col}]: {cast(pd.Series, adata.obs[pert_col]).unique()}"
143141
)
144142
logger.info("Building counts DataFrame from adata")
145143
return (
@@ -161,7 +159,7 @@ def _build_pert_baseline(
161159
raise ValueError(
162160
f"Column '{pert_col}' not found in adata.obs: {adata.obs.columns}"
163161
)
164-
unique_perts = adata.obs[pert_col].unique()
162+
unique_perts = cast(pd.Series, adata.obs[pert_col]).unique()
165163
if control_pert not in unique_perts:
166164
raise ValueError(
167165
f"Control pert '{control_pert}' not found in unique_perts: {unique_perts}"

src/cell_eval/_cli/_prep.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import shutil
66
import subprocess
77
from tempfile import TemporaryDirectory
8+
from typing import cast
89

910
import anndata as ad
1011
import numpy as np
@@ -136,9 +137,9 @@ def strip_anndata(
136137
raise ValueError(
137138
f"Provided celltype column: '{celltype_col}' missing from anndata: {adata.obs.columns}"
138139
)
139-
if ntc_name not in adata.obs[pert_col].unique():
140+
if ntc_name not in cast(pd.Series, adata.obs[pert_col]).unique():
140141
raise ValueError(
141-
f"Provided negative control name: '{ntc_name}' missing from anndata: {adata.obs[pert_col].unique()}"
142+
f"Provided negative control name: '{ntc_name}' missing from anndata: {cast(pd.Series, adata.obs[pert_col]).unique()}"
142143
)
143144

144145
# Check if expected dimension is provided and matches the length of the genelist
@@ -196,11 +197,11 @@ def strip_anndata(
196197

197198
logger.info("Simplifying obs dataframe")
198199
new_obs = pd.DataFrame(
199-
{output_pert_col: adata.obs[pert_col].values},
200+
{output_pert_col: cast(pd.Series, adata.obs[pert_col]).values},
200201
index=np.arange(adata.shape[0]).astype(str),
201202
)
202203
if celltype_col:
203-
new_obs[output_celltype_col] = adata.obs[celltype_col].values
204+
new_obs[output_celltype_col] = cast(pd.Series, adata.obs[celltype_col]).values
204205

205206
logger.info("Simplifying var dataframe")
206207
new_var = pd.DataFrame(
@@ -225,7 +226,7 @@ def strip_anndata(
225226

226227
# Write the h5ad file
227228
logger.info(f"Writing h5ad output to {tmp_h5ad}")
228-
minimal.write_h5ad(tmp_h5ad)
229+
minimal.write_h5ad(tmp_h5ad) # type: ignore[invalid-argument-type]
229230

230231
# Zstd compress the h5ad file (will create pred.h5ad.zst)
231232
logger.info(f"Zstd compressing {tmp_h5ad}")

src/cell_eval/_cli/_run.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,6 @@ def parse_args_run(parser: ap.ArgumentParser):
7979
default=1,
8080
help="Number of threads to use for parallel processing [default: %(default)s]",
8181
)
82-
parser.add_argument(
83-
"--batch-size",
84-
type=int,
85-
default=100,
86-
help="Batch size for parallel processing [default: %(default)s]",
87-
)
88-
parser.add_argument(
89-
"--de-method",
90-
type=str,
91-
default="wilcoxon",
92-
help="Method to use for differential expression analysis [default: %(default)s]",
93-
)
9482
parser.add_argument(
9583
"--allow-discrete",
9684
action="store_true",
@@ -166,9 +154,7 @@ def run_evaluation(args: ap.Namespace):
166154
de_real=args.de_real,
167155
control_pert=args.control_pert,
168156
pert_col=args.pert_col,
169-
de_method=args.de_method,
170157
num_threads=args.num_threads,
171-
batch_size=args.batch_size,
172158
outdir=args.outdir,
173159
allow_discrete=args.allow_discrete,
174160
prefix=ct,
@@ -189,9 +175,7 @@ def run_evaluation(args: ap.Namespace):
189175
de_real=args.de_real,
190176
control_pert=args.control_pert,
191177
pert_col=args.pert_col,
192-
de_method=args.de_method,
193178
num_threads=args.num_threads,
194-
batch_size=args.batch_size,
195179
outdir=args.outdir,
196180
allow_discrete=args.allow_discrete,
197181
skip_de=args.profile == "pds",

0 commit comments

Comments
 (0)