|
| 1 | +"""Bias-corrected per-channel intensity histogram QC for SPIM data. |
| 2 | +
|
| 3 | +Samples random full-resolution patches from within the brain mask, applies |
| 4 | +bias field correction patch-wise by upsampling the downsampled correction map, |
| 5 | +and generates a four-panel intensity histogram of the corrected intensities. |
| 6 | +
|
| 7 | +This is a Snakemake script that expects the ``snakemake`` object to be |
| 8 | +available, which is automatically provided when executed as part of a |
| 9 | +Snakemake workflow. |
| 10 | +""" |
| 11 | + |
| 12 | +import logging |
| 13 | + |
| 14 | +import matplotlib |
| 15 | + |
| 16 | +matplotlib.use("agg") |
| 17 | +import matplotlib.pyplot as plt |
| 18 | +import numpy as np |
| 19 | +import pandas as pd |
| 20 | +from scipy.ndimage import zoom |
| 21 | + |
| 22 | +from dask_setup import get_dask_client |
| 23 | +from zarrnii import ZarrNii, ZarrNiiAtlas |
| 24 | + |
| 25 | +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") |
| 26 | + |
| 27 | + |
| 28 | +def main(): |
| 29 | + stain = snakemake.wildcards.stain |
| 30 | + n_patches = snakemake.params.n_patches |
| 31 | + patch_size = tuple(snakemake.params.patch_size) |
| 32 | + seed = snakemake.params.seed |
| 33 | + hist_bins = snakemake.params.hist_bins |
| 34 | + hist_range = snakemake.params.hist_range |
| 35 | + biasfield_zarr_level = snakemake.params.biasfield_zarr_level |
| 36 | + zarrnii_kwargs = { |
| 37 | + k: v for k, v in snakemake.params.zarrnii_kwargs.items() if v is not None |
| 38 | + } |
| 39 | + |
| 40 | + # Patch size for the biasfield at its downsampled level. |
| 41 | + # Since each pyramid level halves the voxel count per axis, a patch of |
| 42 | + # `patch_size` voxels at level 0 corresponds to |
| 43 | + # `patch_size / 2**biasfield_zarr_level` voxels at the downsampled level, |
| 44 | + # covering the same physical extent. |
| 45 | + biasfield_patch_size = tuple( |
| 46 | + max(1, p // (2**biasfield_zarr_level)) for p in patch_size |
| 47 | + ) |
| 48 | + |
| 49 | + with get_dask_client("threads", snakemake.threads): |
| 50 | + # Load brain mask as a ZarrNiiAtlas with a single "brain" label. |
| 51 | + brain_znii = ZarrNii.from_file( |
| 52 | + snakemake.input.brain_mask, **zarrnii_kwargs |
| 53 | + ) |
| 54 | + labels_df = pd.DataFrame( |
| 55 | + {"index": [1], "name": ["brain"], "abbreviation": ["brain"]} |
| 56 | + ) |
| 57 | + atlas = ZarrNiiAtlas.create_from_dseg(brain_znii, labels_df) |
| 58 | + |
| 59 | + # Sample patch centers uniformly within the brain mask (physical coords). |
| 60 | + logging.info(f"Sampling {n_patches} patch centers from brain mask ...") |
| 61 | + centers = atlas.sample_region_patches( |
| 62 | + n_patches=n_patches, |
| 63 | + region_ids=1, |
| 64 | + seed=seed, |
| 65 | + ) |
| 66 | + logging.info(f"Sampled {len(centers)} centers.") |
| 67 | + |
| 68 | + # Load raw SPIM at level 0 (full resolution) for patch extraction. |
| 69 | + znimg_raw = ZarrNii.from_ome_zarr( |
| 70 | + snakemake.input.spim, |
| 71 | + level=0, |
| 72 | + channel_labels=[stain], |
| 73 | + **zarrnii_kwargs, |
| 74 | + ) |
| 75 | + |
| 76 | + # Load biasfield at a downsampled pyramid level within the biasfield zarr. |
| 77 | + znimg_biasfield = ZarrNii.from_ome_zarr( |
| 78 | + snakemake.input.biasfield, |
| 79 | + level=biasfield_zarr_level, |
| 80 | + ) |
| 81 | + |
| 82 | + # Collect corrected intensities patch by patch. |
| 83 | + all_intensities = [] |
| 84 | + epsilon = np.finfo(np.float32).eps |
| 85 | + |
| 86 | + for i, center in enumerate(centers): |
| 87 | + try: |
| 88 | + # Extract full-resolution raw patch. |
| 89 | + raw_patch = znimg_raw.crop_centered([center], patch_size=patch_size) |
| 90 | + if not isinstance(raw_patch, list): |
| 91 | + raw_patch = [raw_patch] |
| 92 | + raw_np = np.squeeze(raw_patch[0].data.compute()).astype(np.float32) |
| 93 | + |
| 94 | + # Extract corresponding biasfield patch at the downsampled level. |
| 95 | + # Using the same physical center but a proportionally smaller voxel |
| 96 | + # count so both patches cover the same physical extent. |
| 97 | + bf_patch = znimg_biasfield.crop_centered( |
| 98 | + [center], patch_size=biasfield_patch_size |
| 99 | + ) |
| 100 | + if not isinstance(bf_patch, list): |
| 101 | + bf_patch = [bf_patch] |
| 102 | + bf_np = np.squeeze(bf_patch[0].data.compute()).astype(np.float32) |
| 103 | + |
| 104 | + # Upsample the biasfield patch to match the raw patch spatial shape. |
| 105 | + if raw_np.shape != bf_np.shape: |
| 106 | + zoom_factors = tuple(r / b for r, b in zip(raw_np.shape, bf_np.shape)) |
| 107 | + bf_np = zoom(bf_np, zoom_factors, order=1) |
| 108 | + |
| 109 | + # Apply bias field correction: corrected = raw / biasfield. |
| 110 | + corrected = raw_np / np.maximum(bf_np, epsilon) |
| 111 | + all_intensities.append(corrected.ravel()) |
| 112 | + |
| 113 | + logging.info(f"Processed patch {i + 1}/{len(centers)}") |
| 114 | + |
| 115 | + except Exception as e: |
| 116 | + logging.warning(f"Skipping patch {i + 1}: {e}") |
| 117 | + continue |
| 118 | + |
| 119 | + if not all_intensities: |
| 120 | + logging.warning("No valid patches collected; producing empty histogram.") |
| 121 | + all_intensities = [np.zeros(1, dtype=np.float32)] |
| 122 | + |
| 123 | + intensities = np.concatenate(all_intensities) |
| 124 | + |
| 125 | + # Compute histogram from the collected corrected intensities. |
| 126 | + hist_counts, bin_edges = np.histogram( |
| 127 | + intensities, bins=hist_bins, range=tuple(hist_range) |
| 128 | + ) |
| 129 | + hist_counts = hist_counts.astype(float) |
| 130 | + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 |
| 131 | + bin_width = bin_edges[1] - bin_edges[0] |
| 132 | + |
| 133 | + total_voxels = hist_counts.sum() |
| 134 | + max_range = hist_range[1] |
| 135 | + |
| 136 | + nonzero_mask = hist_counts > 0 |
| 137 | + disp_max = float(bin_centers[nonzero_mask][-1]) * 1.05 if nonzero_mask.any() else max_range |
| 138 | + sat_fraction = ( |
| 139 | + float(hist_counts[-1]) / total_voxels * 100 if total_voxels > 0 else 0.0 |
| 140 | + ) |
| 141 | + |
| 142 | + if total_voxels > 0: |
| 143 | + mean_val = float(np.sum(bin_centers * hist_counts) / total_voxels) |
| 144 | + cumsum_norm = np.cumsum(hist_counts) / total_voxels |
| 145 | + p50_val = float( |
| 146 | + bin_centers[min(np.searchsorted(cumsum_norm, 0.50), len(bin_centers) - 1)] |
| 147 | + ) |
| 148 | + p99_val = float( |
| 149 | + bin_centers[min(np.searchsorted(cumsum_norm, 0.99), len(bin_centers) - 1)] |
| 150 | + ) |
| 151 | + else: |
| 152 | + mean_val = p50_val = p99_val = 0.0 |
| 153 | + |
| 154 | + lin_xlim = p99_val * 1.05 if total_voxels > 0 else max_range |
| 155 | + visible = hist_counts[bin_centers <= lin_xlim] |
| 156 | + lin_ylim = float(visible.max()) * 1.05 if visible.size and visible.max() > 0 else 1.0 |
| 157 | + |
| 158 | + subject = snakemake.wildcards.subject |
| 159 | + |
| 160 | + fig, axes = plt.subplots(2, 2, figsize=(12, 8)) |
| 161 | + fig.suptitle( |
| 162 | + f"Bias-Corrected Intensity Histogram QC\n" |
| 163 | + f"Subject: {subject} | Stain: {stain} | " |
| 164 | + f"Patches: {len(centers)} | Correction: {snakemake.params.correction_method}", |
| 165 | + fontsize=12, |
| 166 | + fontweight="bold", |
| 167 | + ) |
| 168 | + |
| 169 | + # Panel 1: linear-scale histogram |
| 170 | + ax = axes[0, 0] |
| 171 | + ax.bar(bin_centers, hist_counts, width=bin_width, color="steelblue", alpha=0.75) |
| 172 | + ax.set_xlabel("Corrected intensity") |
| 173 | + ax.set_ylabel("Voxel count") |
| 174 | + ax.set_title("Linear-scale histogram") |
| 175 | + ax.set_xlim(0, lin_xlim) |
| 176 | + ax.set_ylim(0, lin_ylim) |
| 177 | + |
| 178 | + # Panel 2: log-scale histogram |
| 179 | + ax = axes[0, 1] |
| 180 | + log_counts = np.where(hist_counts > 0, np.log10(hist_counts), np.nan) |
| 181 | + ax.bar(bin_centers, log_counts, width=bin_width, color="darkorange", alpha=0.75) |
| 182 | + ax.set_xlabel("Corrected intensity") |
| 183 | + ax.set_ylabel("log\u2081\u2080(voxel count)") |
| 184 | + ax.set_title("Log-scale histogram") |
| 185 | + ax.set_xlim(0, disp_max) |
| 186 | + |
| 187 | + # Panel 3: cumulative distribution |
| 188 | + ax = axes[1, 0] |
| 189 | + if total_voxels > 0: |
| 190 | + cumsum_pct = cumsum_norm * 100 |
| 191 | + ax.plot(bin_centers, cumsum_pct, color="forestgreen", lw=1.5) |
| 192 | + ax.axvline( |
| 193 | + x=p50_val, |
| 194 | + color="purple", |
| 195 | + linestyle="--", |
| 196 | + alpha=0.7, |
| 197 | + label=f"Median ({p50_val:.1f})", |
| 198 | + ) |
| 199 | + ax.axvline( |
| 200 | + x=p99_val, |
| 201 | + color="red", |
| 202 | + linestyle="--", |
| 203 | + alpha=0.7, |
| 204 | + label=f"99th pctile ({p99_val:.1f})", |
| 205 | + ) |
| 206 | + ax.legend(fontsize=8) |
| 207 | + ax.set_xlabel("Corrected intensity") |
| 208 | + ax.set_ylabel("Cumulative voxels (%)") |
| 209 | + ax.set_title("Cumulative distribution") |
| 210 | + ax.set_ylim(0, 105) |
| 211 | + ax.set_xlim(0, disp_max) |
| 212 | + |
| 213 | + # Panel 4: summary statistics |
| 214 | + ax = axes[1, 1] |
| 215 | + ax.axis("off") |
| 216 | + summary_text = ( |
| 217 | + f"Sampled voxels: {int(total_voxels):>14,}\n" |
| 218 | + f"Patches: {len(centers):>14,}\n" |
| 219 | + f"Mean intensity: {mean_val:>14.2f}\n" |
| 220 | + f"Median (50th): {p50_val:>14.2f}\n" |
| 221 | + f"99th percentile: {p99_val:>14.2f}\n" |
| 222 | + f"Max range: {max_range:>14.1f}\n" |
| 223 | + f"Saturation frac.: {sat_fraction:>13.3f}%" |
| 224 | + ) |
| 225 | + ax.text( |
| 226 | + 0.1, |
| 227 | + 0.55, |
| 228 | + summary_text, |
| 229 | + transform=ax.transAxes, |
| 230 | + fontsize=11, |
| 231 | + verticalalignment="center", |
| 232 | + fontfamily="monospace", |
| 233 | + bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8), |
| 234 | + ) |
| 235 | + ax.set_title("Summary statistics") |
| 236 | + |
| 237 | + plt.tight_layout() |
| 238 | + plt.savefig(snakemake.output.png, dpi=150, bbox_inches="tight") |
| 239 | + plt.close() |
| 240 | + logging.info(f"Saved bias-corrected histogram QC to {snakemake.output.png}") |
| 241 | + |
| 242 | + |
| 243 | +if __name__ == "__main__": |
| 244 | + main() |
0 commit comments