11import pytest
22import torch
33
4- from connectomics .config import from_dict
5- from connectomics .models .build import build_model
4+ from connectomics .config import from_dict , validate_config
65from connectomics .models .architectures .mednext_models import MedNeXtMultiHeadWrapper
7-
6+ from connectomics . models . build import build_model
87
98nnunet_mednext = pytest .importorskip ("nnunet_mednext" )
109from nnunet_mednext import MedNeXt # noqa: E402
@@ -32,7 +31,7 @@ def test_mednext_multi_head_wrapper_returns_named_outputs():
3231 model = MedNeXtMultiHeadWrapper (
3332 trunk ,
3433 {
35- "affinity" : {"out_channels" : 9 , "num_blocks" : 1 },
34+ "affinity" : {"out_channels" : 9 , "num_blocks" : 1 , "hidden_channels" : 2 },
3635 "sdt" : {"out_channels" : 1 , "num_blocks" : 0 },
3736 },
3837 )
@@ -51,7 +50,21 @@ def test_mednext_multi_head_wrapper_returns_named_outputs():
5150 assert outputs ["output" ]["affinity" ].shape == (1 , 9 , 32 , 32 , 32 )
5251 assert outputs ["output" ]["sdt" ].shape == (1 , 1 , 32 , 32 , 32 )
5352 assert model .head_specs ["affinity" ]["num_blocks" ] == 1
53+ assert model .head_specs ["affinity" ]["hidden_channels" ] == 2
54+ assert isinstance (model .heads ["affinity" ].input_projection , torch .nn .Conv3d )
55+ assert model .heads ["affinity" ].projection .in_channels == 2
5456 assert model .head_specs ["sdt" ]["num_blocks" ] == 0
57+ assert model .head_specs ["sdt" ]["hidden_channels" ] == 4
58+
59+
60+ def test_mednext_multi_head_wrapper_rejects_hidden_channels_above_trunk_width ():
61+ trunk = _build_tiny_mednext (deep_supervision = False )
62+
63+ with pytest .raises (ValueError , match = "must not exceed the shared feature width" ):
64+ MedNeXtMultiHeadWrapper (
65+ trunk ,
66+ {"affinity" : {"out_channels" : 9 , "num_blocks" : 1 , "hidden_channels" : 8 }},
67+ )
5568
5669
5770def test_mednext_multi_head_wrapper_rejects_deep_supervision_trunk ():
@@ -70,7 +83,7 @@ def test_build_model_creates_mednext_multi_head_wrapper_from_config():
7083 "out_channels" : 10 ,
7184 "primary_head" : "affinity" ,
7285 "heads" : {
73- "affinity" : {"out_channels" : 9 , "num_blocks" : 1 },
86+ "affinity" : {"out_channels" : 9 , "num_blocks" : 1 , "hidden_channels" : 2 },
7487 "sdt" : {"out_channels" : 1 , "num_blocks" : 0 },
7588 },
7689 "mednext" : {
@@ -91,10 +104,25 @@ def test_build_model_creates_mednext_multi_head_wrapper_from_config():
91104 model = build_model (cfg )
92105 assert isinstance (model , MedNeXtMultiHeadWrapper )
93106 assert model .primary_head == "affinity"
107+ assert model .head_specs ["affinity" ]["hidden_channels" ] == 2
94108
95109 x = torch .randn (1 , 1 , 32 , 32 , 32 )
96110 with torch .no_grad ():
97111 outputs = model (x )
98112
99113 assert outputs ["output" ]["affinity" ].shape == (1 , 9 , 32 , 32 , 32 )
100114 assert outputs ["output" ]["sdt" ].shape == (1 , 1 , 32 , 32 , 32 )
115+
116+
117+ def test_validate_config_rejects_nonpositive_head_hidden_channels ():
118+ cfg = from_dict (
119+ {
120+ "model" : {
121+ "arch" : {"type" : "mednext_custom" },
122+ "heads" : {"affinity" : {"out_channels" : 3 , "num_blocks" : 0 , "hidden_channels" : 0 }},
123+ }
124+ }
125+ )
126+
127+ with pytest .raises (ValueError , match = "model.heads.affinity.hidden_channels must be positive" ):
128+ validate_config (cfg )
0 commit comments