Skip to content

Commit f1637fc

Browse files
Copilotakhanf
andauthored
Add per-subject bias-corrected intensity histogram QC rule
Agent-Logs-Url: https://github.com/khanlab/SPIMquant/sessions/81efff74-dbee-4b61-bae9-8536884fc75c Co-authored-by: akhanf <11492701+akhanf@users.noreply.github.com>
1 parent 934adc6 commit f1637fc

5 files changed

Lines changed: 336 additions & 0 deletions

File tree

spimquant/workflow/Snakefile

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,20 @@ rule all_qc:
677677
),
678678
stain=stains,
679679
),
680+
# Bias-corrected intensity histograms (per stain, requires segmentation)
681+
inputs["spim"].expand(
682+
bids(
683+
root=root,
684+
datatype="qc",
685+
stain="{stain}",
686+
desc="biascorrected",
687+
suffix="histogram.png",
688+
**inputs["spim"].wildcards,
689+
),
690+
stain=stains_for_seg,
691+
)
692+
if do_seg
693+
else [],
680694
# Segmentation overview figures (per stain, per seg method)
681695
inputs["spim"].expand(
682696
bids(

spimquant/workflow/rules/qc.smk

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,68 @@ saturation/clip fraction (percentage of voxels at the maximum bin).
7373
"../scripts/qc_intensity_histogram.py"
7474

7575

76+
rule qc_bias_corrected_histogram:
77+
"""Per-channel bias-corrected intensity histogram QC.
78+
79+
Samples random full-resolution patches from within the brain mask using
80+
ZarrNiiAtlas patch sampling, applies bias field correction patch-wise by
81+
upsampling the downsampled correction map to each patch, and generates a
82+
four-panel intensity histogram of the corrected intensities.
83+
84+
Inputs:
85+
- Raw SPIM OME-Zarr (full-resolution patches are extracted at level 0)
86+
- Downsampled bias field OME-Zarr (loaded at registration_level within the
87+
zarr pyramid for efficient patch extraction; upsampled per-patch via
88+
scipy.ndimage.zoom)
89+
- Brain mask NIfTI (used as a single-label ZarrNiiAtlas to draw random
90+
patch centre coordinates within the brain)
91+
"""
92+
input:
93+
spim=inputs["spim"].path,
94+
biasfield=bids(
95+
root=work,
96+
datatype="seg",
97+
stain="{stain}",
98+
level=config["segmentation_level"],
99+
desc=config["correction_method"],
100+
suffix="biasfield.ome.zarr",
101+
**inputs["spim"].wildcards,
102+
),
103+
brain_mask=bids(
104+
root=root,
105+
datatype="micr",
106+
stain=stain_for_reg,
107+
level=config["registration_level"],
108+
desc="brain",
109+
suffix="mask.nii.gz",
110+
**inputs["spim"].wildcards,
111+
),
112+
output:
113+
png=bids(
114+
root=root,
115+
datatype="qc",
116+
stain="{stain}",
117+
desc="biascorrected",
118+
suffix="histogram.png",
119+
**inputs["spim"].wildcards,
120+
),
121+
threads: 8
122+
resources:
123+
mem_mb=32000,
124+
runtime=60,
125+
params:
126+
n_patches=config.get("n_patches_per_label", 5),
127+
patch_size=config.get("patch_size", [256, 256, 256]),
128+
seed=config.get("patch_seed", 42),
129+
hist_bins=500,
130+
hist_range=[0, 65535],
131+
biasfield_zarr_level=config["registration_level"],
132+
correction_method=config["correction_method"],
133+
zarrnii_kwargs={"orientation": config["orientation"]},
134+
script:
135+
"../scripts/qc_bias_corrected_histogram.py"
136+
137+
76138
rule qc_segmentation_overview:
77139
"""Segmentation overview slice montage QC.
78140

spimquant/workflow/rules/segmentation.smk

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,24 @@ rule n4_biasfield:
9292
),
9393
group_jobs=True,
9494
),
95+
biasfield=temp(
96+
directory(
97+
bids(
98+
root=work,
99+
datatype="seg",
100+
stain="{stain}",
101+
level="{level}",
102+
desc="n4",
103+
suffix="biasfield.ome.zarr",
104+
**inputs["spim"].wildcards,
105+
)
106+
),
107+
group_jobs=True,
108+
),
95109
threads: 128 if config["dask_scheduler"] == "distributed" else 32
96110
resources:
97111
mem_mb=500000 if config["dask_scheduler"] == "distributed" else 250000,
112+
disk_mb=2097152,
98113
runtime=180,
99114
script:
100115
"../scripts/n4_biasfield.py"

spimquant/workflow/scripts/n4_biasfield.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
znimg_corrected = znimg.apply_scaled_processing(
3131
N4BiasFieldCorrection(shrink_factor=snakemake.params.shrink_factor),
3232
downsample_factor=adjusted_downsample_factor,
33+
upsampled_ome_zarr_path=snakemake.output.biasfield,
3334
)
3435

3536
# write to ome_zarr
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""Bias-corrected per-channel intensity histogram QC for SPIM data.
2+
3+
Samples random full-resolution patches from within the brain mask, applies
4+
bias field correction patch-wise by upsampling the downsampled correction map,
5+
and generates a four-panel intensity histogram of the corrected intensities.
6+
7+
This is a Snakemake script that expects the ``snakemake`` object to be
8+
available, which is automatically provided when executed as part of a
9+
Snakemake workflow.
10+
"""
11+
12+
import logging
13+
14+
import matplotlib
15+
16+
matplotlib.use("agg")
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
import pandas as pd
20+
from scipy.ndimage import zoom
21+
22+
from dask_setup import get_dask_client
23+
from zarrnii import ZarrNii, ZarrNiiAtlas
24+
25+
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
26+
27+
28+
def main():
29+
stain = snakemake.wildcards.stain
30+
n_patches = snakemake.params.n_patches
31+
patch_size = tuple(snakemake.params.patch_size)
32+
seed = snakemake.params.seed
33+
hist_bins = snakemake.params.hist_bins
34+
hist_range = snakemake.params.hist_range
35+
biasfield_zarr_level = snakemake.params.biasfield_zarr_level
36+
zarrnii_kwargs = {
37+
k: v for k, v in snakemake.params.zarrnii_kwargs.items() if v is not None
38+
}
39+
40+
# Patch size for the biasfield at its downsampled level.
41+
# Since each pyramid level halves the voxel count per axis, a patch of
42+
# `patch_size` voxels at level 0 corresponds to
43+
# `patch_size / 2**biasfield_zarr_level` voxels at the downsampled level,
44+
# covering the same physical extent.
45+
biasfield_patch_size = tuple(
46+
max(1, p // (2**biasfield_zarr_level)) for p in patch_size
47+
)
48+
49+
with get_dask_client("threads", snakemake.threads):
50+
# Load brain mask as a ZarrNiiAtlas with a single "brain" label.
51+
brain_znii = ZarrNii.from_file(
52+
snakemake.input.brain_mask, **zarrnii_kwargs
53+
)
54+
labels_df = pd.DataFrame(
55+
{"index": [1], "name": ["brain"], "abbreviation": ["brain"]}
56+
)
57+
atlas = ZarrNiiAtlas.create_from_dseg(brain_znii, labels_df)
58+
59+
# Sample patch centers uniformly within the brain mask (physical coords).
60+
logging.info(f"Sampling {n_patches} patch centers from brain mask ...")
61+
centers = atlas.sample_region_patches(
62+
n_patches=n_patches,
63+
region_ids=1,
64+
seed=seed,
65+
)
66+
logging.info(f"Sampled {len(centers)} centers.")
67+
68+
# Load raw SPIM at level 0 (full resolution) for patch extraction.
69+
znimg_raw = ZarrNii.from_ome_zarr(
70+
snakemake.input.spim,
71+
level=0,
72+
channel_labels=[stain],
73+
**zarrnii_kwargs,
74+
)
75+
76+
# Load biasfield at a downsampled pyramid level within the biasfield zarr.
77+
znimg_biasfield = ZarrNii.from_ome_zarr(
78+
snakemake.input.biasfield,
79+
level=biasfield_zarr_level,
80+
)
81+
82+
# Collect corrected intensities patch by patch.
83+
all_intensities = []
84+
epsilon = np.finfo(np.float32).eps
85+
86+
for i, center in enumerate(centers):
87+
try:
88+
# Extract full-resolution raw patch.
89+
raw_patch = znimg_raw.crop_centered([center], patch_size=patch_size)
90+
if not isinstance(raw_patch, list):
91+
raw_patch = [raw_patch]
92+
raw_np = np.squeeze(raw_patch[0].data.compute()).astype(np.float32)
93+
94+
# Extract corresponding biasfield patch at the downsampled level.
95+
# Using the same physical center but a proportionally smaller voxel
96+
# count so both patches cover the same physical extent.
97+
bf_patch = znimg_biasfield.crop_centered(
98+
[center], patch_size=biasfield_patch_size
99+
)
100+
if not isinstance(bf_patch, list):
101+
bf_patch = [bf_patch]
102+
bf_np = np.squeeze(bf_patch[0].data.compute()).astype(np.float32)
103+
104+
# Upsample the biasfield patch to match the raw patch spatial shape.
105+
if raw_np.shape != bf_np.shape:
106+
zoom_factors = tuple(r / b for r, b in zip(raw_np.shape, bf_np.shape))
107+
bf_np = zoom(bf_np, zoom_factors, order=1)
108+
109+
# Apply bias field correction: corrected = raw / biasfield.
110+
corrected = raw_np / np.maximum(bf_np, epsilon)
111+
all_intensities.append(corrected.ravel())
112+
113+
logging.info(f"Processed patch {i + 1}/{len(centers)}")
114+
115+
except Exception as e:
116+
logging.warning(f"Skipping patch {i + 1}: {e}")
117+
continue
118+
119+
if not all_intensities:
120+
logging.warning("No valid patches collected; producing empty histogram.")
121+
all_intensities = [np.zeros(1, dtype=np.float32)]
122+
123+
intensities = np.concatenate(all_intensities)
124+
125+
# Compute histogram from the collected corrected intensities.
126+
hist_counts, bin_edges = np.histogram(
127+
intensities, bins=hist_bins, range=tuple(hist_range)
128+
)
129+
hist_counts = hist_counts.astype(float)
130+
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
131+
bin_width = bin_edges[1] - bin_edges[0]
132+
133+
total_voxels = hist_counts.sum()
134+
max_range = hist_range[1]
135+
136+
nonzero_mask = hist_counts > 0
137+
disp_max = float(bin_centers[nonzero_mask][-1]) * 1.05 if nonzero_mask.any() else max_range
138+
sat_fraction = (
139+
float(hist_counts[-1]) / total_voxels * 100 if total_voxels > 0 else 0.0
140+
)
141+
142+
if total_voxels > 0:
143+
mean_val = float(np.sum(bin_centers * hist_counts) / total_voxels)
144+
cumsum_norm = np.cumsum(hist_counts) / total_voxels
145+
p50_val = float(
146+
bin_centers[min(np.searchsorted(cumsum_norm, 0.50), len(bin_centers) - 1)]
147+
)
148+
p99_val = float(
149+
bin_centers[min(np.searchsorted(cumsum_norm, 0.99), len(bin_centers) - 1)]
150+
)
151+
else:
152+
mean_val = p50_val = p99_val = 0.0
153+
154+
lin_xlim = p99_val * 1.05 if total_voxels > 0 else max_range
155+
visible = hist_counts[bin_centers <= lin_xlim]
156+
lin_ylim = float(visible.max()) * 1.05 if visible.size and visible.max() > 0 else 1.0
157+
158+
subject = snakemake.wildcards.subject
159+
160+
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
161+
fig.suptitle(
162+
f"Bias-Corrected Intensity Histogram QC\n"
163+
f"Subject: {subject} | Stain: {stain} | "
164+
f"Patches: {len(centers)} | Correction: {snakemake.params.correction_method}",
165+
fontsize=12,
166+
fontweight="bold",
167+
)
168+
169+
# Panel 1: linear-scale histogram
170+
ax = axes[0, 0]
171+
ax.bar(bin_centers, hist_counts, width=bin_width, color="steelblue", alpha=0.75)
172+
ax.set_xlabel("Corrected intensity")
173+
ax.set_ylabel("Voxel count")
174+
ax.set_title("Linear-scale histogram")
175+
ax.set_xlim(0, lin_xlim)
176+
ax.set_ylim(0, lin_ylim)
177+
178+
# Panel 2: log-scale histogram
179+
ax = axes[0, 1]
180+
log_counts = np.where(hist_counts > 0, np.log10(hist_counts), np.nan)
181+
ax.bar(bin_centers, log_counts, width=bin_width, color="darkorange", alpha=0.75)
182+
ax.set_xlabel("Corrected intensity")
183+
ax.set_ylabel("log\u2081\u2080(voxel count)")
184+
ax.set_title("Log-scale histogram")
185+
ax.set_xlim(0, disp_max)
186+
187+
# Panel 3: cumulative distribution
188+
ax = axes[1, 0]
189+
if total_voxels > 0:
190+
cumsum_pct = cumsum_norm * 100
191+
ax.plot(bin_centers, cumsum_pct, color="forestgreen", lw=1.5)
192+
ax.axvline(
193+
x=p50_val,
194+
color="purple",
195+
linestyle="--",
196+
alpha=0.7,
197+
label=f"Median ({p50_val:.1f})",
198+
)
199+
ax.axvline(
200+
x=p99_val,
201+
color="red",
202+
linestyle="--",
203+
alpha=0.7,
204+
label=f"99th pctile ({p99_val:.1f})",
205+
)
206+
ax.legend(fontsize=8)
207+
ax.set_xlabel("Corrected intensity")
208+
ax.set_ylabel("Cumulative voxels (%)")
209+
ax.set_title("Cumulative distribution")
210+
ax.set_ylim(0, 105)
211+
ax.set_xlim(0, disp_max)
212+
213+
# Panel 4: summary statistics
214+
ax = axes[1, 1]
215+
ax.axis("off")
216+
summary_text = (
217+
f"Sampled voxels: {int(total_voxels):>14,}\n"
218+
f"Patches: {len(centers):>14,}\n"
219+
f"Mean intensity: {mean_val:>14.2f}\n"
220+
f"Median (50th): {p50_val:>14.2f}\n"
221+
f"99th percentile: {p99_val:>14.2f}\n"
222+
f"Max range: {max_range:>14.1f}\n"
223+
f"Saturation frac.: {sat_fraction:>13.3f}%"
224+
)
225+
ax.text(
226+
0.1,
227+
0.55,
228+
summary_text,
229+
transform=ax.transAxes,
230+
fontsize=11,
231+
verticalalignment="center",
232+
fontfamily="monospace",
233+
bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8),
234+
)
235+
ax.set_title("Summary statistics")
236+
237+
plt.tight_layout()
238+
plt.savefig(snakemake.output.png, dpi=150, bbox_inches="tight")
239+
plt.close()
240+
logging.info(f"Saved bias-corrected histogram QC to {snakemake.output.png}")
241+
242+
243+
if __name__ == "__main__":
244+
main()

0 commit comments

Comments
 (0)