Skip to content

Commit 420ebbe

Browse files
committed
Add multi-condition LOOCV plotting functionality and tests
- Implemented `plot_weights_and_deltaperby_odor` and `plot_condition_comparison` functions for visualizing LOOCV results. - Created a fixture CSV file for testing with multiple odor conditions. - Developed comprehensive tests for loading data, detecting odor columns, target construction, feature matrix building, LOOCV execution, weight aggregation, and full pipeline integration. - Ensured deterministic behavior across multiple runs and consistent feature shapes across conditions.
1 parent 840a29b commit 420ebbe

7 files changed

Lines changed: 1788 additions & 3 deletions

File tree

scripts/fit_plasticity_delta_weights.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,12 @@ def _parse_args(argv=None):
253253
p.add_argument("--door-cache", default="door_cache")
254254
p.add_argument("--mapping-csv", default="data/mappings/door_to_flywire_mapping.csv")
255255
p.add_argument("--feature-set", choices=["all", "union", "intersection"], default="intersection")
256-
p.add_argument("--activation-threshold", type=float, default=0.05)
256+
p.add_argument(
257+
"--activation-threshold",
258+
type=float,
259+
default=0.0,
260+
help="Activation threshold for feature selection (default: 0.0).",
261+
)
257262
p.add_argument("--agg", choices=["max", "mean", "sum"], default="max")
258263

259264
# Sparse-fit settings.

scripts/run_multicond_loocv.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Run multi-condition leave-one-odor-out LOOCV regression.
4+
5+
Fits one regression model per condition:
6+
- Control (opto_AIR): raw PER, mean-centered.
7+
- Trained conditions: ΔPER = trained − control, mean-centered.
8+
9+
Features are the intersection set across all 7 odors using the DoOR
10+
receptor feature builder.
11+
"""
12+
13+
import argparse
14+
import logging
15+
import sys
16+
from pathlib import Path
17+
from typing import List, Optional, Sequence, Tuple
18+
19+
import numpy as np
20+
import pandas as pd
21+
22+
# Ensure src/ is importable when running as a standalone script.
23+
_repo_root = Path(__file__).resolve().parent.parent
24+
if str(_repo_root / "src") not in sys.path:
25+
sys.path.insert(0, str(_repo_root / "src"))
26+
27+
from door_toolkit.encoder import DoOREncoder
28+
from door_toolkit.glomerulus_features import (
29+
build_design_matrix,
30+
load_receptor_to_glomerulus_mapping,
31+
)
32+
from door_toolkit.multicond_loocv import run_multicond_loocv
33+
from door_toolkit.multicond_loocv_plots import (
34+
plot_weights_and_deltaperby_odor,
35+
plot_condition_comparison,
36+
)
37+
38+
logger = logging.getLogger(__name__)
39+
40+
# ---------------------------------------------------------------------------
41+
# Odor-name mapping: CSV column names -> DoOR names
42+
# ---------------------------------------------------------------------------
43+
# The PER CSV uses short/lab names; DoOR expects canonical chemical names.
44+
# We map each of the 7 CSV columns to the DoOR name used by the encoder.
45+
46+
CSV_ODOR_TO_DOOR = {
47+
"3-Octonol": "3-octanol",
48+
"3-octonol": "3-octanol",
49+
"Apple_Cider_Vinegar": "acetic acid",
50+
"apple_cider_vinegar": "acetic acid",
51+
"Benzaldehyde": "benzaldehyde",
52+
"benzaldehyde": "benzaldehyde",
53+
"Citral": "citral",
54+
"citral": "citral",
55+
"Ethyl_Butyrate": "ethyl butyrate",
56+
"ethyl_butyrate": "ethyl butyrate",
57+
"Hexanol": "1-hexanol",
58+
"hexanol": "1-hexanol",
59+
"Linalool": "linalool",
60+
"linalool": "linalool",
61+
}
62+
63+
64+
def _resolve_door_name(csv_col: str) -> str:
65+
"""Map a CSV odor column name to a DoOR name."""
66+
if csv_col in CSV_ODOR_TO_DOOR:
67+
return CSV_ODOR_TO_DOOR[csv_col]
68+
# Fallback: try lowercase
69+
low = csv_col.lower().replace(" ", "_")
70+
if low in CSV_ODOR_TO_DOOR:
71+
return CSV_ODOR_TO_DOOR[low]
72+
# Last resort: use as-is (the encoder does its own fuzzy matching)
73+
return csv_col
74+
75+
76+
def _build_feature_builder(
77+
*,
78+
door_cache: str,
79+
mapping_csv: str,
80+
feature_set: str,
81+
activation_threshold: float,
82+
agg: str,
83+
):
84+
"""Create a feature builder callable for the multicond pipeline."""
85+
encoder = DoOREncoder(cache_path=door_cache, use_torch=False)
86+
mapping, mapping_meta = load_receptor_to_glomerulus_mapping(mapping_csv)
87+
logger.info(
88+
"Loaded DoOR mapping: %d receptors (adult_only=%s)",
89+
mapping_meta.get("n_receptors_mapped", -1),
90+
mapping_meta.get("adult_only", True),
91+
)
92+
93+
def _builder(
94+
csv_odors: List[str],
95+
) -> Tuple[np.ndarray, List[str], dict]:
96+
door_odors = [_resolve_door_name(o) for o in csv_odors]
97+
logger.info("CSV odors -> DoOR: %s", list(zip(csv_odors, door_odors)))
98+
X, feature_names, meta = build_design_matrix(
99+
door_odors,
100+
encoder,
101+
mapping,
102+
feature_set=feature_set,
103+
activation_threshold=activation_threshold,
104+
agg=agg,
105+
)
106+
return X, feature_names, meta
107+
108+
return _builder
109+
110+
111+
def _parse_alpha_grid(text: str) -> List[float]:
112+
if text.strip().lower() == "default":
113+
return list(np.logspace(-4, 1, 60))
114+
return [float(v.strip()) for v in text.split(",") if v.strip()]
115+
116+
117+
def _parse_conditions(text: str) -> List[str]:
118+
return [t.strip() for t in text.split(",") if t.strip()]
119+
120+
121+
def _parse_args(argv=None):
122+
p = argparse.ArgumentParser(
123+
description="Multi-condition leave-one-odor-out LOOCV regression."
124+
)
125+
p.add_argument(
126+
"--csv", required=True,
127+
help="Path to PER CSV (reaction_rates_summary_unordered.csv).",
128+
)
129+
p.add_argument(
130+
"--control-row", default="opto_AIR",
131+
help="Control condition row label.",
132+
)
133+
p.add_argument(
134+
"--conditions", required=True,
135+
help="Comma-separated conditions (including control if desired).",
136+
)
137+
p.add_argument(
138+
"--model", choices=["lasso", "elasticnet"], default="lasso",
139+
)
140+
p.add_argument("--outdir", default="out/multicond_loocv")
141+
142+
# Feature builder settings
143+
p.add_argument("--door-cache", default="door_cache")
144+
p.add_argument(
145+
"--mapping-csv",
146+
default="data/mappings/door_to_flywire_mapping.csv",
147+
)
148+
p.add_argument(
149+
"--feature-set",
150+
choices=["all", "union", "intersection", "no_blanks"],
151+
default="no_blanks",
152+
help="all=60 receptors; union=54 active; intersection=1; no_blanks=57 (excludes 3 all-zero receptors)",
153+
)
154+
p.add_argument("--activation-threshold", type=float, default=0.0)
155+
p.add_argument("--agg", choices=["max", "mean", "sum"], default="max")
156+
157+
# Sparse-fit settings
158+
p.add_argument(
159+
"--alpha-grid", default="default",
160+
help="Comma-separated alpha grid or 'default'.",
161+
)
162+
p.add_argument("--l1-ratio", type=float, default=0.5)
163+
p.add_argument("--seed", type=int, default=0)
164+
p.add_argument(
165+
"--no-standardize", dest="standardize", action="store_false",
166+
default=True,
167+
)
168+
p.add_argument("--zero-eps", type=float, default=1e-6)
169+
p.add_argument("--min-nonzero", type=int, default=1)
170+
171+
# Plotting options
172+
p.add_argument(
173+
"--plot", action="store_true",
174+
help="Generate per-odor baseline vs. delta weight plots.",
175+
)
176+
p.add_argument(
177+
"--plot-top-n", type=int, default=10,
178+
help="Number of top features to plot per odor (default: 10).",
179+
)
180+
p.add_argument(
181+
"--plot-outdir", default=None,
182+
help="Output directory for plots (default: <outdir>/plots).",
183+
)
184+
p.add_argument(
185+
"--plot-baseline-weights", default=None,
186+
help="Path to baseline weights CSV (feature, baseline_w columns).",
187+
)
188+
p.add_argument(
189+
"--plot-comparison", action="store_true",
190+
help="Also plot all conditions comparison across top features.",
191+
)
192+
193+
return p.parse_args(argv)
194+
195+
196+
def main(argv=None) -> int:
197+
args = _parse_args(argv)
198+
logging.basicConfig(
199+
level=logging.INFO,
200+
format="%(name)s %(levelname)s: %(message)s",
201+
)
202+
203+
conditions = _parse_conditions(args.conditions)
204+
alpha_grid = _parse_alpha_grid(args.alpha_grid)
205+
206+
feature_builder = _build_feature_builder(
207+
door_cache=args.door_cache,
208+
mapping_csv=args.mapping_csv,
209+
feature_set=args.feature_set,
210+
activation_threshold=args.activation_threshold,
211+
agg=args.agg,
212+
)
213+
214+
result = run_multicond_loocv(
215+
csv_path=args.csv,
216+
control_row=args.control_row,
217+
conditions=conditions,
218+
feature_builder=feature_builder,
219+
model=args.model,
220+
alpha_grid=alpha_grid,
221+
l1_ratio=args.l1_ratio,
222+
seed=args.seed,
223+
standardize=args.standardize,
224+
zero_eps=args.zero_eps,
225+
min_nonzero=args.min_nonzero,
226+
outdir=args.outdir,
227+
)
228+
229+
# Plotting
230+
if args.plot:
231+
plot_outdir = args.plot_outdir or str(Path(args.outdir) / "plots")
232+
233+
# Load baseline weights if provided
234+
baseline_df = None
235+
if args.plot_baseline_weights:
236+
baseline_df = pd.read_csv(args.plot_baseline_weights)
237+
if "receptor" in baseline_df.columns:
238+
baseline_df = baseline_df.rename(
239+
columns={"receptor": "feature"}
240+
)
241+
elif "feature" not in baseline_df.columns:
242+
raise ValueError(
243+
"Baseline weights CSV must have 'feature' or 'receptor' column"
244+
)
245+
246+
plots = plot_weights_and_deltaperby_odor(
247+
plot_outdir,
248+
odors=result["odors"],
249+
feature_names=result["feature_names"],
250+
condition_data=result["condition_data"],
251+
baseline_weights=baseline_df,
252+
top_n=args.plot_top_n,
253+
control_row=args.control_row,
254+
)
255+
print("Plots written ({0}):".format(len(plots)))
256+
for p in plots:
257+
print(" {0}".format(p))
258+
259+
if args.plot_comparison:
260+
comp_plots = plot_condition_comparison(
261+
plot_outdir,
262+
conditions=result["conditions"],
263+
feature_names=result["feature_names"],
264+
condition_data=result["condition_data"],
265+
top_n=args.plot_top_n,
266+
control_row=args.control_row,
267+
)
268+
print("Comparison plots written ({0}):".format(len(comp_plots)))
269+
for p in comp_plots:
270+
print(" {0}".format(p))
271+
272+
return 0
273+
274+
275+
if __name__ == "__main__":
276+
raise SystemExit(main())

src/door_toolkit/glomerulus_features.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def build_design_matrix(
178178
encoder: DoOREncoder,
179179
mapping: Dict[str, List[str]],
180180
*,
181-
feature_set: Literal["all", "union", "intersection"] = "union",
181+
feature_set: Literal["all", "union", "intersection", "no_blanks"] = "union",
182182
activation_threshold: float = 0.05,
183183
agg: Literal["max", "mean", "sum"] = "max",
184184
) -> Tuple[np.ndarray, List[str], dict]:
@@ -192,6 +192,7 @@ def build_design_matrix(
192192
- "all": every receptor in the mapping.
193193
- "union": receptors active for at least one odor.
194194
- "intersection": receptors active for all odors.
195+
- "no_blanks": receptors with at least one non-zero value (excludes all-zero receptors).
195196
activation_threshold: Response above this marks a receptor as active.
196197
agg: Aggregation (ignored, kept for API compatibility).
197198
@@ -230,8 +231,11 @@ def build_design_matrix(
230231
selected_mask = active_masks.any(axis=0)
231232
elif feature_set == "intersection":
232233
selected_mask = active_masks.all(axis=0)
234+
elif feature_set == "no_blanks":
235+
# Include receptors with at least one non-zero value (any absolute value > 0)
236+
selected_mask = np.any(np.abs(full_matrix) > 0.0, axis=0)
233237
else:
234-
raise ValueError(f"feature_set must be all/union/intersection, got '{feature_set}'")
238+
raise ValueError(f"feature_set must be all/union/intersection/no_blanks, got '{feature_set}'")
235239

236240
selected_indices = np.where(selected_mask)[0]
237241
selected_receptor_names = [all_receptors[i] for i in selected_indices]

0 commit comments

Comments
 (0)