Skip to content

Commit 4ea5897

Browse files
author
Donglai Wei
committed
Add MedNeXt head hidden channels and update tutorials
1 parent ec4f5cf commit 4ea5897

5 files changed

Lines changed: 85 additions & 17 deletions

File tree

connectomics/config/pipeline/config_io.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ def validate_config(cfg: Config) -> None:
304304
for head_name, head_cfg in model_heads.items():
305305
head_out_channels = int(getattr(head_cfg, "out_channels", 0))
306306
head_num_blocks = int(getattr(head_cfg, "num_blocks", 0))
307+
head_hidden_channels = getattr(head_cfg, "hidden_channels", None)
308+
if head_hidden_channels is not None:
309+
head_hidden_channels = int(head_hidden_channels)
307310
if head_out_channels <= 0:
308311
raise ValueError(
309312
f"model.heads.{head_name}.out_channels must be positive "
@@ -314,6 +317,11 @@ def validate_config(cfg: Config) -> None:
314317
f"model.heads.{head_name}.num_blocks must be non-negative "
315318
f"(got {head_num_blocks})"
316319
)
320+
if head_hidden_channels is not None and head_hidden_channels <= 0:
321+
raise ValueError(
322+
f"model.heads.{head_name}.hidden_channels must be positive "
323+
f"(got {head_hidden_channels})"
324+
)
317325

318326
primary_head = getattr(cfg.model, "primary_head", None)
319327
if primary_head is not None and primary_head not in model_heads:
@@ -326,7 +334,11 @@ def validate_config(cfg: Config) -> None:
326334
f"inference.head='{inference_head}' is not present in model.heads "
327335
f"({sorted(model_heads.keys())})."
328336
)
329-
if visualization_head is not None and visualization_head != "all" and visualization_head not in model_heads:
337+
if (
338+
visualization_head is not None
339+
and visualization_head != "all"
340+
and visualization_head not in model_heads
341+
):
330342
raise ValueError(
331343
f"monitor.logging.images.head='{visualization_head}' is not present in "
332344
f"model.heads ({sorted(model_heads.keys())})."

connectomics/models/architectures/mednext_models.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525
try:
2626
from nnunet_mednext import MedNeXt as MedNeXtBase
27-
from nnunet_mednext import MedNeXtBlock
28-
from nnunet_mednext import create_mednext_v1
27+
from nnunet_mednext import MedNeXtBlock, create_mednext_v1
2928

3029
MEDNEXT_AVAILABLE = True
3130
except ImportError:
@@ -137,6 +136,7 @@ def __init__(
137136
in_channels: int,
138137
out_channels: int,
139138
num_blocks: int,
139+
hidden_channels: int | None = None,
140140
*,
141141
exp_r: int,
142142
kernel_size: int,
@@ -150,17 +150,33 @@ def __init__(
150150
raise ValueError(f"MedNeXt task head num_blocks must be >= 0, got {num_blocks}")
151151
if out_channels <= 0:
152152
raise ValueError(f"MedNeXt task head out_channels must be positive, got {out_channels}")
153+
if hidden_channels is None:
154+
hidden_channels = in_channels
155+
if hidden_channels <= 0:
156+
raise ValueError(
157+
f"MedNeXt task head hidden_channels must be positive, got {hidden_channels}"
158+
)
159+
if hidden_channels > in_channels:
160+
raise ValueError(
161+
"MedNeXt task head hidden_channels must not exceed the shared feature width "
162+
f"({hidden_channels} > {in_channels})"
163+
)
153164
if dim == "2d":
154165
conv = nn.Conv2d
155166
elif dim == "3d":
156167
conv = nn.Conv3d
157168
else:
158169
raise ValueError(f"MedNeXt task head dim must be '2d' or '3d', got {dim}")
159170

171+
self.input_projection = (
172+
conv(in_channels, hidden_channels, kernel_size=1)
173+
if hidden_channels != in_channels
174+
else nn.Identity()
175+
)
160176
blocks = [
161177
MedNeXtBlock(
162-
in_channels=in_channels,
163-
out_channels=in_channels,
178+
in_channels=hidden_channels,
179+
out_channels=hidden_channels,
164180
exp_r=exp_r,
165181
kernel_size=kernel_size,
166182
do_res=do_res,
@@ -171,9 +187,11 @@ def __init__(
171187
for _ in range(num_blocks)
172188
]
173189
self.blocks = nn.Sequential(*blocks) if blocks else nn.Identity()
174-
self.projection = conv(in_channels, out_channels, kernel_size=1)
190+
self.projection = conv(hidden_channels, out_channels, kernel_size=1)
191+
self.hidden_channels = hidden_channels
175192

176193
def forward(self, x: torch.Tensor) -> torch.Tensor:
194+
x = self.input_projection(x)
177195
x = self.blocks(x)
178196
return self.projection(x)
179197

@@ -219,15 +237,19 @@ def __init__(
219237
for head_name, head_cfg in heads.items():
220238
out_channels = int(_cfg_value(head_cfg, "out_channels", head_cfg))
221239
num_blocks = int(_cfg_value(head_cfg, "num_blocks", 0))
240+
hidden_channels = _cfg_value(head_cfg, "hidden_channels", None)
241+
hidden_channels = int(hidden_channels) if hidden_channels is not None else None
222242
task_heads[head_name] = MedNeXtTaskHead(
223243
in_channels=self.feature_channels,
224244
out_channels=out_channels,
225245
num_blocks=num_blocks,
246+
hidden_channels=hidden_channels,
226247
**self.head_block_kwargs,
227248
)
228249
head_specs[head_name] = {
229250
"out_channels": out_channels,
230251
"num_blocks": num_blocks,
252+
"hidden_channels": hidden_channels or self.feature_channels,
231253
}
232254

233255
self.heads = nn.ModuleDict(task_heads)

tests/unit/test_mednext_multi_head_wrapper.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import pytest
22
import torch
33

4-
from connectomics.config import from_dict
5-
from connectomics.models.build import build_model
4+
from connectomics.config import from_dict, validate_config
65
from connectomics.models.architectures.mednext_models import MedNeXtMultiHeadWrapper
7-
6+
from connectomics.models.build import build_model
87

98
nnunet_mednext = pytest.importorskip("nnunet_mednext")
109
from nnunet_mednext import MedNeXt # noqa: E402
@@ -32,7 +31,7 @@ def test_mednext_multi_head_wrapper_returns_named_outputs():
3231
model = MedNeXtMultiHeadWrapper(
3332
trunk,
3433
{
35-
"affinity": {"out_channels": 9, "num_blocks": 1},
34+
"affinity": {"out_channels": 9, "num_blocks": 1, "hidden_channels": 2},
3635
"sdt": {"out_channels": 1, "num_blocks": 0},
3736
},
3837
)
@@ -51,7 +50,21 @@ def test_mednext_multi_head_wrapper_returns_named_outputs():
5150
assert outputs["output"]["affinity"].shape == (1, 9, 32, 32, 32)
5251
assert outputs["output"]["sdt"].shape == (1, 1, 32, 32, 32)
5352
assert model.head_specs["affinity"]["num_blocks"] == 1
53+
assert model.head_specs["affinity"]["hidden_channels"] == 2
54+
assert isinstance(model.heads["affinity"].input_projection, torch.nn.Conv3d)
55+
assert model.heads["affinity"].projection.in_channels == 2
5456
assert model.head_specs["sdt"]["num_blocks"] == 0
57+
assert model.head_specs["sdt"]["hidden_channels"] == 4
58+
59+
60+
def test_mednext_multi_head_wrapper_rejects_hidden_channels_above_trunk_width():
61+
trunk = _build_tiny_mednext(deep_supervision=False)
62+
63+
with pytest.raises(ValueError, match="must not exceed the shared feature width"):
64+
MedNeXtMultiHeadWrapper(
65+
trunk,
66+
{"affinity": {"out_channels": 9, "num_blocks": 1, "hidden_channels": 8}},
67+
)
5568

5669

5770
def test_mednext_multi_head_wrapper_rejects_deep_supervision_trunk():
@@ -70,7 +83,7 @@ def test_build_model_creates_mednext_multi_head_wrapper_from_config():
7083
"out_channels": 10,
7184
"primary_head": "affinity",
7285
"heads": {
73-
"affinity": {"out_channels": 9, "num_blocks": 1},
86+
"affinity": {"out_channels": 9, "num_blocks": 1, "hidden_channels": 2},
7487
"sdt": {"out_channels": 1, "num_blocks": 0},
7588
},
7689
"mednext": {
@@ -91,10 +104,25 @@ def test_build_model_creates_mednext_multi_head_wrapper_from_config():
91104
model = build_model(cfg)
92105
assert isinstance(model, MedNeXtMultiHeadWrapper)
93106
assert model.primary_head == "affinity"
107+
assert model.head_specs["affinity"]["hidden_channels"] == 2
94108

95109
x = torch.randn(1, 1, 32, 32, 32)
96110
with torch.no_grad():
97111
outputs = model(x)
98112

99113
assert outputs["output"]["affinity"].shape == (1, 9, 32, 32, 32)
100114
assert outputs["output"]["sdt"].shape == (1, 1, 32, 32, 32)
115+
116+
117+
def test_validate_config_rejects_nonpositive_head_hidden_channels():
118+
cfg = from_dict(
119+
{
120+
"model": {
121+
"arch": {"type": "mednext_custom"},
122+
"heads": {"affinity": {"out_channels": 3, "num_blocks": 0, "hidden_channels": 0}},
123+
}
124+
}
125+
)
126+
127+
with pytest.raises(ValueError, match="model.heads.affinity.hidden_channels must be positive"):
128+
validate_config(cfg)

tutorials/mito_betaseg.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@ default:
1818
aff_r1:
1919
out_channels: 3
2020
num_blocks: 1
21+
hidden_channels: 8
2122
target_slice: "0:3"
2223
aff_r5:
2324
out_channels: 3
2425
num_blocks: 1
26+
hidden_channels: 8
2527
target_slice: "3:6"
2628
sdt:
2729
out_channels: 1
2830
num_blocks: 1
31+
hidden_channels: 8
2932
target_slice: "6:7"
3033
loss:
3134
losses:
@@ -60,12 +63,10 @@ default:
6063
bg_value: -1.0
6164
dataloader:
6265
patch_size: [128, 128, 128]
63-
batch_size: 4
66+
batch_size: 2
6467
use_lazy_zarr: true
6568
image_transform:
66-
normalize: "0-1"
67-
clip_percentile_low: 0.005
68-
clip_percentile_high: 0.995
69+
normalize: "divide-255"
6970
augmentation:
7071
profile: aug_em_neuron
7172
rotate:
@@ -123,7 +124,7 @@ train:
123124
channel_mode: all
124125
checkpoint:
125126
save_top_k: 3
126-
monitor: val_loss_total_epoch
127+
monitor: val_loss_total
127128
mode: min
128129

129130
test:

tutorials/neuron_snemi_sdt_multitask.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ test:
148148
label: test-labels.h5
149149
resolution: [30, 6, 6]
150150
inference:
151-
head: aff_r5
151+
head: aff_r1
152152
decoding:
153153
- template: decoding_waterz
154154
kwargs:
@@ -160,6 +160,11 @@ test:
160160
dust_merge_size: 800
161161
dust_merge_affinity: 0.3
162162
dust_remove_size: 600
163+
branch_merge: true
164+
branch_iou_threshold: 0.5
165+
branch_best_buddy: true
166+
branch_one_sided_threshold: 0.8
167+
branch_one_sided_min_size: 100
163168

164169
# ============================================================================
165170
# Parameter tuning for waterz agglomeration thresholds (--mode tune)

0 commit comments

Comments
 (0)