Skip to content

Commit b2da92b

Browse files
Copilotakhanf
andauthored
Remove k wildcard from atropos_seg, add k-retry logic with gmm_init_k/gmm_min_k config
Agent-Logs-Url: https://github.com/khanlab/SPIMquant/sessions/e025f01c-8e6d-495f-b403-90c8c2aa6710 Co-authored-by: akhanf <11492701+akhanf@users.noreply.github.com>
1 parent 7303300 commit b2da92b

5 files changed

Lines changed: 60 additions & 20 deletions

File tree

spimquant/config/.ipynb_checkpoints/snakebids-checkpoint.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ templatereg:
236236
- 40
237237

238238
masking:
239-
gmm_k: 9
239+
gmm_init_k: 9
240+
gmm_min_k: 2
240241
gmm_bg_class: 1
241242
pre_atropos_downsampling: '50%'
242243

spimquant/config/snakebids.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ templatereg:
352352
- 40
353353

354354
masking:
355-
gmm_k: 9
355+
gmm_init_k: 9
356+
gmm_min_k: 2
356357
gmm_bg_class: 1
357358
pre_atropos_downsampling: '50%'
358359

spimquant/workflow/rules/masking.smk

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,22 @@ rule pre_atropos:
7474

7575
rule atropos_seg:
7676
"""Perform tissue classification using Atropos (k-class GMM).
77-
77+
7878
Uses ANTs Atropos to classify tissue into k intensity classes via Gaussian
7979
mixture modeling with Markov random field (MRF) spatial smoothing. Outputs
8080
a discrete segmentation and posterior probability maps for each class.
81+
82+
Automatically decrements k from init_k down to min_k if Atropos fails,
83+
handling cases where the image lacks enough distinct intensity classes.
8184
"""
8285
input:
8386
downsampled=rules.pre_atropos.output.downsampled,
8487
mask=rules.pre_atropos.output.mask,
8588
params:
8689
mrf_smoothing=0.3,
8790
mrf_radius="2x2x2",
91+
init_k=config["masking"]["gmm_init_k"],
92+
min_k=config["masking"]["gmm_min_k"],
8893
output:
8994
dseg=temp(
9095
bids(
@@ -93,7 +98,6 @@ rule atropos_seg:
9398
stain="{stain}",
9499
level="{level}",
95100
desc="dsAtropos",
96-
k="{k}",
97101
suffix="dseg.nii",
98102
**inputs["spim"].wildcards,
99103
),
@@ -106,27 +110,19 @@ rule atropos_seg:
106110
stain="{stain}",
107111
level="{level}",
108112
desc="Atropos",
109-
k="{k}",
110113
suffix="posteriors",
111114
**inputs["spim"].wildcards,
112115
)
113116
),
114117
),
115118
conda:
116119
"../envs/ants.yaml"
117-
shadow:
118-
"minimal"
119120
threads: 16
120121
resources:
121122
mem_mb=32000,
122123
runtime=45,
123-
shell:
124-
"mkdir -p {output.posteriors_dir} && "
125-
"ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} "
126-
"Atropos -v -d 3 --initialization KMeans[{wildcards.k}] "
127-
" --intensity-image {input.downsampled} "
128-
" --output [{output.dseg},{output.posteriors_dir}/class-%02d.nii] "
129-
" --mask-image {input.mask} --mrf [{params.mrf_smoothing},{params.mrf_radius}]"
124+
script:
125+
"../scripts/atropos_seg.py"
130126

131127

132128
rule post_atropos:
@@ -148,7 +144,6 @@ rule post_atropos:
148144
stain="{stain}",
149145
level="{level}",
150146
desc="Atropos",
151-
k="{k}",
152147
suffix="dseg.nii",
153148
**inputs["spim"].wildcards,
154149
),
@@ -273,7 +268,6 @@ rule create_mask_from_gmm_and_prior:
273268
stain="{stain}",
274269
level="{level}",
275270
desc="Atropos",
276-
k=config["masking"]["gmm_k"],
277271
suffix="dseg.nii",
278272
**inputs["spim"].wildcards,
279273
),
@@ -285,8 +279,6 @@ rule create_mask_from_gmm_and_prior:
285279
suffix="mask.nii.gz",
286280
**inputs["spim"].wildcards,
287281
),
288-
params:
289-
k=config["masking"]["gmm_k"],
290282
output:
291283
mask=bids(
292284
root=root,
@@ -313,7 +305,6 @@ rule create_mask_from_gmm:
313305
stain="{stain}",
314306
level="{level}",
315307
desc="Atropos",
316-
k=config["masking"]["gmm_k"],
317308
suffix="dseg.nii",
318309
**inputs["spim"].wildcards,
319310
),
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
Run Atropos segmentation with automatic k-retry logic.
3+
4+
Tries GMM segmentation with init_k classes, and decrements k if Atropos fails,
5+
down to min_k. This handles cases where the image doesn't have enough distinct
6+
intensity classes to support the initial k value.
7+
"""
8+
9+
import os
10+
import shutil
11+
import subprocess
12+
13+
init_k = snakemake.params.init_k
14+
min_k = snakemake.params.min_k
15+
threads = snakemake.threads
16+
mrf_smoothing = snakemake.params.mrf_smoothing
17+
mrf_radius = snakemake.params.mrf_radius
18+
19+
downsampled = snakemake.input.downsampled
20+
mask = snakemake.input.mask
21+
dseg = snakemake.output.dseg
22+
posteriors_dir = snakemake.output.posteriors_dir
23+
24+
os.makedirs(posteriors_dir, exist_ok=True)
25+
26+
for k in range(init_k, min_k - 1, -1):
27+
cmd = (
28+
f"ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS={threads} "
29+
f"Atropos -v -d 3 --initialization KMeans[{k}] "
30+
f" --intensity-image {downsampled} "
31+
f" --output [{dseg},{posteriors_dir}/class-%02d.nii] "
32+
f" --mask-image {mask} --mrf [{mrf_smoothing},{mrf_radius}]"
33+
)
34+
result = subprocess.run(cmd, shell=True)
35+
if result.returncode == 0:
36+
break
37+
elif k == min_k:
38+
raise subprocess.CalledProcessError(result.returncode, cmd)
39+
else:
40+
# Clean up and retry with lower k
41+
if os.path.exists(dseg):
42+
os.remove(dseg)
43+
shutil.rmtree(posteriors_dir, ignore_errors=True)
44+
os.makedirs(posteriors_dir, exist_ok=True)

spimquant/workflow/scripts/create_mask_from_gmm_and_prior.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313

1414
out_mask = np.zeros(template_mask_vol.shape)
1515

16-
for i in range(1, snakemake.params.k + 1):
16+
# Detect the number of tissue classes from the segmentation labels
17+
tissue_labels = sorted(np.unique(tissue_vol[tissue_vol > 0]).astype(int))
18+
19+
for i in tissue_labels:
1720

1821
# if more voxels in foreground than in background, we assign it to the mask
1922
nvox_fg = (fg_tissue == i).sum()

0 commit comments

Comments
 (0)