@@ -93,14 +93,8 @@ class ThresholdMRI(SegmentationMethod):
9393 is labeled as a single "tissue" material. The segmentation assigns four
9494 tissue types: water, skull, tissue, and air. When True, the brain
9595 interior is split into CSF, gray matter, and white matter using a
96- 3-component Gaussian Mixture Model (EM-GMM) on T1-weighted intensity,
97- with optional spatial regularization. Two regularization modes are
98- available: Markov Random Field (MRF) via Iterated Conditional Modes
99- (ICM) when ``mrf_beta > 0``, or Gaussian smoothing of posterior
100- probability maps when ``spatial_sigma_mm > 0``. Both are disabled by
101- default. MRF regularization preserves tissue boundaries better than
102- Gaussian smoothing by penalizing label disagreement among 6-connected
103- neighbors without blurring probability maps. This yields a total of six assigned tissue types:
96+ 3-component Gaussian Mixture Model (EM-GMM) on T1-weighted intensity.
97+ This yields a total of six assigned tissue types:
10498 water, skull, csf, gray_matter, white_matter, and air. Additional
10599 materials in the dict (e.g. standoff) are carried through but not
106100 assigned by the segmentation algorithm.
@@ -146,7 +140,7 @@ class ThresholdMRI(SegmentationMethod):
146140 OpenLIFUFieldData (
147141 "Classify brain tissues" ,
148142 "If True, classify brain interior into CSF, gray matter, and white matter "
149- "using a spatially-regularized EM-GMM on T1-weighted intensity" ,
143+ "using an EM-GMM on T1-weighted intensity" ,
150144 ),
151145 ] = False
152146 """If True, sub-classify brain interior into CSF, gray matter, and white matter."""
@@ -173,38 +167,6 @@ class ThresholdMRI(SegmentationMethod):
173167 """If True, use SimpleITK N4BiasFieldCorrection for bias correction.
174168 Falls back to homomorphic if SimpleITK is not available."""
175169
176- spatial_sigma_mm : Annotated [
177- float ,
178- OpenLIFUFieldData (
179- "Spatial regularization sigma (mm)" ,
180- "Gaussian smoothing sigma for spatial regularization of brain tissue "
181- "probability maps. Set to 0 to disable. Only used when classify_brain_tissues is True "
182- "and mrf_beta is 0." ,
183- ),
184- ] = 0.0
185- """Gaussian smoothing sigma (mm) for spatial regularization.
186- Defaults to 0 (disabled). Non-zero values apply masked Gaussian
187- smoothing to posterior probability maps before label assignment.
188- Smoothing is confined to the parenchyma mask to prevent boundary leakage.
189- Ignored when mrf_beta > 0 (MRF regularization takes precedence)."""
190-
191- mrf_beta : Annotated [
192- float ,
193- OpenLIFUFieldData (
194- "MRF regularization strength" ,
195- "Markov Random Field regularization parameter (beta) for brain tissue "
196- "classification. Controls the penalty for label disagreement among "
197- "6-connected neighbors via Iterated Conditional Modes (ICM). "
198- "Set to 0 to disable MRF and fall back to Gaussian smoothing. "
199- "Typical values are 0.1-0.3. Only used when classify_brain_tissues is True." ,
200- ),
201- ] = 0.0
202- """MRF regularization strength (beta).
203- Defaults to 0 (disabled). When positive, Iterated Conditional Modes (ICM) is used
204- for spatial regularization instead of Gaussian smoothing. Higher values
205- produce smoother label maps but may over-regularize thin structures.
206- Set to 0 to disable MRF and use Gaussian smoothing (spatial_sigma_mm) instead."""
207-
208170 brain_extraction_margin_mm : Annotated [
209171 float ,
210172 OpenLIFUFieldData (
@@ -273,12 +235,6 @@ def __post_init__(self) -> None:
273235 if self .bias_correction_sigma_mm < 0 :
274236 msg = f"bias_correction_sigma_mm must be non-negative, got { self .bias_correction_sigma_mm } ."
275237 raise ValueError (msg )
276- if self .spatial_sigma_mm < 0 :
277- msg = f"spatial_sigma_mm must be non-negative, got { self .spatial_sigma_mm } ."
278- raise ValueError (msg )
279- if self .mrf_beta < 0 :
280- msg = f"mrf_beta must be non-negative, got { self .mrf_beta } ."
281- raise ValueError (msg )
282238 if self .brain_extraction_margin_mm < 0 :
283239 msg = f"brain_extraction_margin_mm must be non-negative, got { self .brain_extraction_margin_mm } ."
284240 raise ValueError (msg )
@@ -745,59 +701,6 @@ def _fit_gmm_1d(
745701 order = np .argsort (means )
746702 return means [order ], stds [order ], weights [order ]
747703
748- @staticmethod
749- def _icm_iteration (
750- seg_arr : np .ndarray ,
751- log_likelihoods : np .ndarray ,
752- parenchyma_mask : np .ndarray ,
753- label_indices : list [int ],
754- beta : float ,
755- ) -> np .ndarray :
756- """One pass of Iterated Conditional Modes (ICM) for MRF regularization.
757-
758- For each voxel in the parenchyma mask, computes a posterior that
759- combines the GMM log-likelihood with a spatial prior that counts
760- how many of the 6-connected (face) neighbors share the same label.
761- The voxel is assigned to the class that maximizes this posterior.
762-
763- Uses vectorized numpy operations (np.roll along each axis) rather
764- than per-voxel loops for performance.
765-
766- :param seg_arr: Current integer label array (modified in-place)
767- :param log_likelihoods: Log-likelihood for each class, shape (n_classes, *vol_shape)
768- :param parenchyma_mask: Boolean mask restricting ICM updates
769- :param label_indices: Integer label values for each class (same order as log_likelihoods)
770- :param beta: MRF regularization strength
771- :returns: Updated seg_arr (same object, modified in-place)
772- """
773- n_classes = log_likelihoods .shape [0 ]
774-
775- # For each class k, count how many of the 6 face-neighbors share
776- # label k. np.roll wraps at boundaries, but boundary voxels are
777- # typically outside the parenchyma mask and do not affect the result.
778- neighbor_agreement = np .zeros (
779- (n_classes , * seg_arr .shape ), dtype = np .float32
780- )
781- for k , label_val in enumerate (label_indices ):
782- same_label = (seg_arr == label_val ).astype (np .float32 )
783- agreement = np .zeros_like (same_label )
784- for axis in range (3 ):
785- agreement += np .roll (same_label , 1 , axis = axis )
786- agreement += np .roll (same_label , - 1 , axis = axis )
787- neighbor_agreement [k ] = agreement # Range 0..6
788-
789- # Posterior = log_likelihood + beta * neighbor_agreement.
790- # Higher agreement means more neighbors share this label, which is
791- # rewarded. This is equivalent to penalizing disagreement.
792- posterior = log_likelihoods + beta * neighbor_agreement
793-
794- # Assign each voxel to the class with the highest posterior.
795- best_class = np .argmax (posterior , axis = 0 )
796- for k , label_val in enumerate (label_indices ):
797- seg_arr [parenchyma_mask & (best_class == k )] = label_val
798-
799- return seg_arr
800-
801704 def _classify_brain (
802705 self ,
803706 data : np .ndarray ,
@@ -811,9 +714,9 @@ def _classify_brain(
811714
812715 Modifies ``seg`` in-place. Optionally applies bias field correction,
813716 then fits a 3-component Gaussian Mixture Model (EM-GMM) to the
814- (possibly tighter) GMM parenchyma mask. Posterior probability maps are
815- spatially smoothed with a Gaussian kernel for regularization, and the
816- process is iterated twice to refine the classification .
717+ (possibly tighter) GMM parenchyma mask. Labels are assigned via
718+ argmax of the posterior probabilities, iterated twice to refine
719+ the GMM parameters .
817720
818721 After GMM classification within the tight mask, labels are expanded
819722 to cover the full parenchyma mask via nearest-neighbor assignment
@@ -845,7 +748,7 @@ def _classify_brain(
845748
846749 try :
847750 self ._classify_brain_gmm (
848- classify_data , gmm_parenchyma_mask , seg , material_idx , spacing
751+ classify_data , gmm_parenchyma_mask , seg , material_idx
849752 )
850753 except (RuntimeError , ValueError , np .linalg .LinAlgError ):
851754 logging .warning (
@@ -875,19 +778,12 @@ def _classify_brain_gmm(
875778 parenchyma_mask : np .ndarray ,
876779 seg : np .ndarray ,
877780 material_idx : dict [str , int ],
878- spacing : np .ndarray ,
879781 ) -> None :
880- """Core GMM classification with spatial regularization.
881-
882- Fits a 3-component GMM, computes full-volume log-likelihood maps,
883- then applies spatial regularization. When ``mrf_beta > 0``, uses
884- Markov Random Field regularization via Iterated Conditional Modes
885- (ICM), which penalizes label disagreement among 6-connected
886- neighbors while preserving tissue boundaries. When ``mrf_beta == 0``
887- and ``spatial_sigma_mm > 0``, falls back to Gaussian smoothing of
888- posterior probability maps. The outer loop iterates twice: after
889- each regularization pass, GMM parameters are re-estimated from the
890- current labels/posteriors.
782+ """Core GMM classification.
783+
784+ Fits a 3-component GMM to parenchyma voxel intensities, computes
785+ posterior probabilities, assigns labels via argmax, then iterates
786+ twice to refine GMM parameters from the current assignment.
891787
892788 Called by ``_classify_brain``; raises on failure so the caller can
893789 apply the gray-matter fallback.
@@ -896,15 +792,11 @@ def _classify_brain_gmm(
896792 :param parenchyma_mask: Boolean mask of brain parenchyma
897793 :param seg: Integer label array to modify in-place
898794 :param material_idx: Mapping from material name to integer label
899- :param spacing: Voxel spacing in mm per axis
900795 """
901- use_mrf = self .mrf_beta > 0
902- sigma_voxels = self .spatial_sigma_mm / spacing
903796 label_keys = ["csf" , "gray_matter" , "white_matter" ] # sorted by T1 mean
904797 label_indices = [material_idx [k ] for k in label_keys ]
905798 n_components = len (label_keys )
906- n_outer = 2 # total spatial-regularization iterations
907- n_icm = 5 # ICM iterations per outer loop (when using MRF)
799+ n_outer = 2
908800
909801 # Extract parenchyma intensities for the initial fit.
910802 parenchyma_vals = classify_data [parenchyma_mask ]
@@ -923,144 +815,49 @@ def _classify_brain_gmm(
923815 # Precompute constant for inline logpdf (avoids scipy overhead).
924816 _LOG_2PI_HALF = 0.5 * np .log (2.0 * np .pi )
925817
926- # Determine whether we need full 3D probability maps.
927- # MRF needs 3D log-likelihoods for ICM; Gaussian smoothing needs 3D
928- # prob_maps for the convolution. When neither is active, we can
929- # compute likelihoods only at parenchyma voxels (flat arrays).
930- needs_3d = use_mrf or self .spatial_sigma_mm > 0
931-
932818 for _iteration in range (n_outer ):
933819 safe_stds = np .maximum (stds , 1e-10 )
934820 log_weights = np .log (weights + 1e-300 )
935821
936- if needs_3d :
937- # Build full-volume log-likelihood maps (only inside parenchyma).
938- log_likelihoods = np .full (
939- (n_components , * classify_data .shape ), - np .inf , dtype = np .float64
940- )
941- prob_maps = np .zeros (
942- (n_components , * classify_data .shape ), dtype = np .float64
943- )
944- for k in range (n_components ):
945- z = (classify_data - means [k ]) / safe_stds [k ]
946- log_prob = (
947- - 0.5 * z * z
948- - np .log (safe_stds [k ])
949- - _LOG_2PI_HALF
950- + log_weights [k ]
951- )
952- log_likelihoods [k ][parenchyma_mask ] = log_prob [parenchyma_mask ]
953- prob_maps [k ] = np .exp (log_prob )
954- # Zero out non-parenchyma to avoid leakage.
955- prob_maps [k ][~ parenchyma_mask ] = 0.0
956- else :
957- # Fast path: compute likelihoods only at parenchyma voxels.
958- # shapes: parenchyma_vals = (M,), means = (K,)
959- z_flat = (parenchyma_vals [None , :] - means [:, None ]) / safe_stds [:, None ]
960- log_ll_flat = (
961- - 0.5 * z_flat * z_flat
962- - np .log (safe_stds [:, None ])
963- - _LOG_2PI_HALF
964- + log_weights [:, None ]
965- )
966- prob_flat = np .exp (log_ll_flat )
967-
968- if use_mrf :
969- # MRF regularization via ICM.
970- # Initialize seg labels from the argmax of the raw likelihoods
971- # (only within the parenchyma mask) before ICM iterations.
972- init_labels = np .argmax (
973- log_likelihoods [:, parenchyma_mask ], axis = 0
974- )
975- for k , label_val in enumerate (label_indices ):
976- mask_k = parenchyma_mask .copy ()
977- mask_k [parenchyma_mask ] = init_labels == k
978- seg [mask_k ] = label_val
979-
980- # Run ICM iterations to refine labels using spatial context.
981- for _icm_iter in range (n_icm ):
982- self ._icm_iteration (
983- seg , log_likelihoods , parenchyma_mask ,
984- label_indices , self .mrf_beta ,
985- )
986-
987- # Build soft responsibilities from the final labels for
988- # GMM re-estimation. Use the normalized likelihoods (posteriors)
989- # so the M-step is consistent with the E-step.
990- prob_sum = np .sum (prob_maps , axis = 0 )
991- prob_sum [prob_sum < 1e-300 ] = 1e-300
992- for k in range (n_components ):
993- prob_maps [k ] /= prob_sum
994-
995- # Extract labels from seg for final output.
996- labels_3d = np .zeros (int (np .sum (parenchyma_mask )), dtype = int )
997- for k , label_val in enumerate (label_indices ):
998- labels_3d [seg [parenchyma_mask ] == label_val ] = k
999-
1000- elif self .spatial_sigma_mm > 0 :
1001- # Gaussian smoothing regularization path.
1002- mask_float = parenchyma_mask .astype (np .float64 )
1003- smoothed_mask = gaussian_filter (mask_float , sigma = sigma_voxels )
1004- safe = parenchyma_mask & (smoothed_mask > 1e-6 )
1005- for k in range (n_components ):
1006- smoothed_prob = gaussian_filter (
1007- prob_maps [k ] * mask_float , sigma = sigma_voxels
1008- )
1009- result = np .zeros_like (prob_maps [k ])
1010- result [safe ] = smoothed_prob [safe ] / smoothed_mask [safe ]
1011- prob_maps [k ] = result
1012-
1013- # Normalize posteriors within parenchyma.
1014- prob_sum = np .sum (prob_maps , axis = 0 )
1015- prob_sum [prob_sum < 1e-300 ] = 1e-300
1016- for k in range (n_components ):
1017- prob_maps [k ] /= prob_sum
1018-
1019- # Assign labels via argmax of smoothed posteriors.
1020- labels_3d = np .argmax (prob_maps [:, parenchyma_mask ], axis = 0 )
822+ # Compute likelihoods only at parenchyma voxels (flat arrays).
823+ z_flat = (parenchyma_vals [None , :] - means [:, None ]) / safe_stds [:, None ]
824+ log_ll_flat = (
825+ - 0.5 * z_flat * z_flat
826+ - np .log (safe_stds [:, None ])
827+ - _LOG_2PI_HALF
828+ + log_weights [:, None ]
829+ )
830+ prob_flat = np .exp (log_ll_flat )
1021831
1022- else :
1023- # No spatial regularization: work entirely with flat arrays.
1024- # Normalize posteriors at parenchyma voxels only.
1025- prob_sum_flat = np .sum (prob_flat , axis = 0 )
1026- prob_sum_flat [prob_sum_flat < 1e-300 ] = 1e-300
1027- for k in range (n_components ):
1028- prob_flat [k ] /= prob_sum_flat
1029-
1030- # Assign labels via argmax of posteriors.
1031- labels_3d = np .argmax (prob_flat , axis = 0 )
1032-
1033- # Re-estimate GMM parameters from the current soft assignment
1034- # (for the next iteration). Use the posteriors as weights.
1035- parenchyma_intensities = classify_data [parenchyma_mask ]
832+ # Normalize posteriors.
833+ prob_sum_flat = np .sum (prob_flat , axis = 0 )
834+ prob_sum_flat [prob_sum_flat < 1e-300 ] = 1e-300
835+ prob_flat /= prob_sum_flat
836+
837+ # Re-estimate GMM parameters from the current soft assignment.
1036838 new_means = np .empty (n_components )
1037839 new_stds = np .empty (n_components )
1038840 new_weights = np .empty (n_components )
1039841 for k in range (n_components ):
1040- if needs_3d :
1041- resp_k = prob_maps [k ][parenchyma_mask ]
1042- else :
1043- resp_k = prob_flat [k ]
842+ resp_k = prob_flat [k ]
1044843 nk = np .sum (resp_k )
1045844 if nk < 1e-10 :
1046845 new_means [k ] = means [k ]
1047846 new_stds [k ] = stds [k ]
1048847 new_weights [k ] = weights [k ]
1049848 continue
1050- new_means [k ] = np .dot (resp_k , parenchyma_intensities ) / nk
1051- diff = parenchyma_intensities - new_means [k ]
849+ new_means [k ] = np .dot (resp_k , parenchyma_vals ) / nk
850+ diff = parenchyma_vals - new_means [k ]
1052851 new_stds [k ] = np .sqrt (np .dot (resp_k , diff * diff ) / nk )
1053852 new_stds [k ] = max (new_stds [k ], 1e-10 )
1054- new_weights [k ] = nk / len (parenchyma_intensities )
853+ new_weights [k ] = nk / len (parenchyma_vals )
1055854
1056855 means , stds , weights = new_means , new_stds , new_weights
1057856
1058- # Write final labels into seg. Components are sorted by mean:
1059- # index 0 = CSF (lowest), 1 = gray matter, 2 = white matter (highest).
1060- for k , key in enumerate (label_keys ):
1061- voxel_mask = parenchyma_mask .copy ()
1062- voxel_mask [parenchyma_mask ] = labels_3d == k
1063- seg [voxel_mask ] = material_idx [key ]
857+ # Assign final labels from the last iteration's posteriors.
858+ # Components are sorted by mean: 0=CSF, 1=gray matter, 2=white matter.
859+ labels_flat = np .argmax (prob_flat , axis = 0 )
860+ seg [parenchyma_mask ] = np .array (label_indices )[labels_flat ]
1064861
1065862 def to_table (self ) -> pd .DataFrame :
1066863 """
@@ -1075,8 +872,6 @@ def to_table(self) -> pd.DataFrame:
1075872 {"Name" : "Classify Brain Tissues" , "Value" : self .classify_brain_tissues , "Unit" : "" },
1076873 {"Name" : "Bias Correction Sigma" , "Value" : self .bias_correction_sigma_mm , "Unit" : "mm" },
1077874 {"Name" : "Use N4" , "Value" : self .use_n4 , "Unit" : "" },
1078- {"Name" : "Spatial Regularization Sigma" , "Value" : self .spatial_sigma_mm , "Unit" : "mm" },
1079- {"Name" : "MRF Beta" , "Value" : self .mrf_beta , "Unit" : "" },
1080875 {"Name" : "Brain Extraction Margin" , "Value" : self .brain_extraction_margin_mm , "Unit" : "mm" },
1081876 {"Name" : "Refine Skull Intensity" , "Value" : self .refine_skull_intensity , "Unit" : "" },
1082877 {"Name" : "Reference Material" , "Value" : self .ref_material , "Unit" : "" },
0 commit comments