Skip to content

Commit 1c612a1

Browse files
author
Brandon Kirkland
committed
Remove unused spatial regularization from ThresholdMRI (#150)
Remove spatial_sigma_mm, mrf_beta, and _icm_iteration. Both regularization modes (MRF/ICM and Gaussian smoothing) were tested during development and found to have no benefit: Gaussian smoothing hurt all Dice scores, and MRF had negligible effect. The GMM now uses direct posterior argmax as the sole classification path. Also simplifies _classify_brain_gmm: removes redundant array extraction, moves label assignment out of the iteration loop, and replaces per-component mask copies with vectorized indexing.
1 parent cd2041b commit 1c612a1

2 files changed

Lines changed: 39 additions & 385 deletions

File tree

src/openlifu/seg/seg_methods/threshold_mri.py

Lines changed: 36 additions & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,8 @@ class ThresholdMRI(SegmentationMethod):
9393
is labeled as a single "tissue" material. The segmentation assigns four
9494
tissue types: water, skull, tissue, and air. When True, the brain
9595
interior is split into CSF, gray matter, and white matter using a
96-
3-component Gaussian Mixture Model (EM-GMM) on T1-weighted intensity,
97-
with optional spatial regularization. Two regularization modes are
98-
available: Markov Random Field (MRF) via Iterated Conditional Modes
99-
(ICM) when ``mrf_beta > 0``, or Gaussian smoothing of posterior
100-
probability maps when ``spatial_sigma_mm > 0``. Both are disabled by
101-
default. MRF regularization preserves tissue boundaries better than
102-
Gaussian smoothing by penalizing label disagreement among 6-connected
103-
neighbors without blurring probability maps. This yields a total of six assigned tissue types:
96+
3-component Gaussian Mixture Model (EM-GMM) on T1-weighted intensity.
97+
This yields a total of six assigned tissue types:
10498
water, skull, csf, gray_matter, white_matter, and air. Additional
10599
materials in the dict (e.g. standoff) are carried through but not
106100
assigned by the segmentation algorithm.
@@ -146,7 +140,7 @@ class ThresholdMRI(SegmentationMethod):
146140
OpenLIFUFieldData(
147141
"Classify brain tissues",
148142
"If True, classify brain interior into CSF, gray matter, and white matter "
149-
"using a spatially-regularized EM-GMM on T1-weighted intensity",
143+
"using an EM-GMM on T1-weighted intensity",
150144
),
151145
] = False
152146
"""If True, sub-classify brain interior into CSF, gray matter, and white matter."""
@@ -173,38 +167,6 @@ class ThresholdMRI(SegmentationMethod):
173167
"""If True, use SimpleITK N4BiasFieldCorrection for bias correction.
174168
Falls back to homomorphic if SimpleITK is not available."""
175169

176-
spatial_sigma_mm: Annotated[
177-
float,
178-
OpenLIFUFieldData(
179-
"Spatial regularization sigma (mm)",
180-
"Gaussian smoothing sigma for spatial regularization of brain tissue "
181-
"probability maps. Set to 0 to disable. Only used when classify_brain_tissues is True "
182-
"and mrf_beta is 0.",
183-
),
184-
] = 0.0
185-
"""Gaussian smoothing sigma (mm) for spatial regularization.
186-
Defaults to 0 (disabled). Non-zero values apply masked Gaussian
187-
smoothing to posterior probability maps before label assignment.
188-
Smoothing is confined to the parenchyma mask to prevent boundary leakage.
189-
Ignored when mrf_beta > 0 (MRF regularization takes precedence)."""
190-
191-
mrf_beta: Annotated[
192-
float,
193-
OpenLIFUFieldData(
194-
"MRF regularization strength",
195-
"Markov Random Field regularization parameter (beta) for brain tissue "
196-
"classification. Controls the penalty for label disagreement among "
197-
"6-connected neighbors via Iterated Conditional Modes (ICM). "
198-
"Set to 0 to disable MRF and fall back to Gaussian smoothing. "
199-
"Typical values are 0.1-0.3. Only used when classify_brain_tissues is True.",
200-
),
201-
] = 0.0
202-
"""MRF regularization strength (beta).
203-
Defaults to 0 (disabled). When positive, Iterated Conditional Modes (ICM) is used
204-
for spatial regularization instead of Gaussian smoothing. Higher values
205-
produce smoother label maps but may over-regularize thin structures.
206-
Set to 0 to disable MRF and use Gaussian smoothing (spatial_sigma_mm) instead."""
207-
208170
brain_extraction_margin_mm: Annotated[
209171
float,
210172
OpenLIFUFieldData(
@@ -273,12 +235,6 @@ def __post_init__(self) -> None:
273235
if self.bias_correction_sigma_mm < 0:
274236
msg = f"bias_correction_sigma_mm must be non-negative, got {self.bias_correction_sigma_mm}."
275237
raise ValueError(msg)
276-
if self.spatial_sigma_mm < 0:
277-
msg = f"spatial_sigma_mm must be non-negative, got {self.spatial_sigma_mm}."
278-
raise ValueError(msg)
279-
if self.mrf_beta < 0:
280-
msg = f"mrf_beta must be non-negative, got {self.mrf_beta}."
281-
raise ValueError(msg)
282238
if self.brain_extraction_margin_mm < 0:
283239
msg = f"brain_extraction_margin_mm must be non-negative, got {self.brain_extraction_margin_mm}."
284240
raise ValueError(msg)
@@ -745,59 +701,6 @@ def _fit_gmm_1d(
745701
order = np.argsort(means)
746702
return means[order], stds[order], weights[order]
747703

748-
@staticmethod
749-
def _icm_iteration(
750-
seg_arr: np.ndarray,
751-
log_likelihoods: np.ndarray,
752-
parenchyma_mask: np.ndarray,
753-
label_indices: list[int],
754-
beta: float,
755-
) -> np.ndarray:
756-
"""One pass of Iterated Conditional Modes (ICM) for MRF regularization.
757-
758-
For each voxel in the parenchyma mask, computes a posterior that
759-
combines the GMM log-likelihood with a spatial prior that counts
760-
how many of the 6-connected (face) neighbors share the same label.
761-
The voxel is assigned to the class that maximizes this posterior.
762-
763-
Uses vectorized numpy operations (np.roll along each axis) rather
764-
than per-voxel loops for performance.
765-
766-
:param seg_arr: Current integer label array (modified in-place)
767-
:param log_likelihoods: Log-likelihood for each class, shape (n_classes, *vol_shape)
768-
:param parenchyma_mask: Boolean mask restricting ICM updates
769-
:param label_indices: Integer label values for each class (same order as log_likelihoods)
770-
:param beta: MRF regularization strength
771-
:returns: Updated seg_arr (same object, modified in-place)
772-
"""
773-
n_classes = log_likelihoods.shape[0]
774-
775-
# For each class k, count how many of the 6 face-neighbors share
776-
# label k. np.roll wraps at boundaries, but boundary voxels are
777-
# typically outside the parenchyma mask and do not affect the result.
778-
neighbor_agreement = np.zeros(
779-
(n_classes, *seg_arr.shape), dtype=np.float32
780-
)
781-
for k, label_val in enumerate(label_indices):
782-
same_label = (seg_arr == label_val).astype(np.float32)
783-
agreement = np.zeros_like(same_label)
784-
for axis in range(3):
785-
agreement += np.roll(same_label, 1, axis=axis)
786-
agreement += np.roll(same_label, -1, axis=axis)
787-
neighbor_agreement[k] = agreement # Range 0..6
788-
789-
# Posterior = log_likelihood + beta * neighbor_agreement.
790-
# Higher agreement means more neighbors share this label, which is
791-
# rewarded. This is equivalent to penalizing disagreement.
792-
posterior = log_likelihoods + beta * neighbor_agreement
793-
794-
# Assign each voxel to the class with the highest posterior.
795-
best_class = np.argmax(posterior, axis=0)
796-
for k, label_val in enumerate(label_indices):
797-
seg_arr[parenchyma_mask & (best_class == k)] = label_val
798-
799-
return seg_arr
800-
801704
def _classify_brain(
802705
self,
803706
data: np.ndarray,
@@ -811,9 +714,9 @@ def _classify_brain(
811714
812715
Modifies ``seg`` in-place. Optionally applies bias field correction,
813716
then fits a 3-component Gaussian Mixture Model (EM-GMM) to the
814-
(possibly tighter) GMM parenchyma mask. Posterior probability maps are
815-
spatially smoothed with a Gaussian kernel for regularization, and the
816-
process is iterated twice to refine the classification.
717+
(possibly tighter) GMM parenchyma mask. Labels are assigned via
718+
argmax of the posterior probabilities, iterated twice to refine
719+
the GMM parameters.
817720
818721
After GMM classification within the tight mask, labels are expanded
819722
to cover the full parenchyma mask via nearest-neighbor assignment
@@ -845,7 +748,7 @@ def _classify_brain(
845748

846749
try:
847750
self._classify_brain_gmm(
848-
classify_data, gmm_parenchyma_mask, seg, material_idx, spacing
751+
classify_data, gmm_parenchyma_mask, seg, material_idx
849752
)
850753
except (RuntimeError, ValueError, np.linalg.LinAlgError):
851754
logging.warning(
@@ -875,19 +778,12 @@ def _classify_brain_gmm(
875778
parenchyma_mask: np.ndarray,
876779
seg: np.ndarray,
877780
material_idx: dict[str, int],
878-
spacing: np.ndarray,
879781
) -> None:
880-
"""Core GMM classification with spatial regularization.
881-
882-
Fits a 3-component GMM, computes full-volume log-likelihood maps,
883-
then applies spatial regularization. When ``mrf_beta > 0``, uses
884-
Markov Random Field regularization via Iterated Conditional Modes
885-
(ICM), which penalizes label disagreement among 6-connected
886-
neighbors while preserving tissue boundaries. When ``mrf_beta == 0``
887-
and ``spatial_sigma_mm > 0``, falls back to Gaussian smoothing of
888-
posterior probability maps. The outer loop iterates twice: after
889-
each regularization pass, GMM parameters are re-estimated from the
890-
current labels/posteriors.
782+
"""Core GMM classification.
783+
784+
Fits a 3-component GMM to parenchyma voxel intensities, computes
785+
posterior probabilities, assigns labels via argmax, then iterates
786+
twice to refine GMM parameters from the current assignment.
891787
892788
Called by ``_classify_brain``; raises on failure so the caller can
893789
apply the gray-matter fallback.
@@ -896,15 +792,11 @@ def _classify_brain_gmm(
896792
:param parenchyma_mask: Boolean mask of brain parenchyma
897793
:param seg: Integer label array to modify in-place
898794
:param material_idx: Mapping from material name to integer label
899-
:param spacing: Voxel spacing in mm per axis
900795
"""
901-
use_mrf = self.mrf_beta > 0
902-
sigma_voxels = self.spatial_sigma_mm / spacing
903796
label_keys = ["csf", "gray_matter", "white_matter"] # sorted by T1 mean
904797
label_indices = [material_idx[k] for k in label_keys]
905798
n_components = len(label_keys)
906-
n_outer = 2 # total spatial-regularization iterations
907-
n_icm = 5 # ICM iterations per outer loop (when using MRF)
799+
n_outer = 2
908800

909801
# Extract parenchyma intensities for the initial fit.
910802
parenchyma_vals = classify_data[parenchyma_mask]
@@ -923,144 +815,49 @@ def _classify_brain_gmm(
923815
# Precompute constant for inline logpdf (avoids scipy overhead).
924816
_LOG_2PI_HALF = 0.5 * np.log(2.0 * np.pi)
925817

926-
# Determine whether we need full 3D probability maps.
927-
# MRF needs 3D log-likelihoods for ICM; Gaussian smoothing needs 3D
928-
# prob_maps for the convolution. When neither is active, we can
929-
# compute likelihoods only at parenchyma voxels (flat arrays).
930-
needs_3d = use_mrf or self.spatial_sigma_mm > 0
931-
932818
for _iteration in range(n_outer):
933819
safe_stds = np.maximum(stds, 1e-10)
934820
log_weights = np.log(weights + 1e-300)
935821

936-
if needs_3d:
937-
# Build full-volume log-likelihood maps (only inside parenchyma).
938-
log_likelihoods = np.full(
939-
(n_components, *classify_data.shape), -np.inf, dtype=np.float64
940-
)
941-
prob_maps = np.zeros(
942-
(n_components, *classify_data.shape), dtype=np.float64
943-
)
944-
for k in range(n_components):
945-
z = (classify_data - means[k]) / safe_stds[k]
946-
log_prob = (
947-
-0.5 * z * z
948-
- np.log(safe_stds[k])
949-
- _LOG_2PI_HALF
950-
+ log_weights[k]
951-
)
952-
log_likelihoods[k][parenchyma_mask] = log_prob[parenchyma_mask]
953-
prob_maps[k] = np.exp(log_prob)
954-
# Zero out non-parenchyma to avoid leakage.
955-
prob_maps[k][~parenchyma_mask] = 0.0
956-
else:
957-
# Fast path: compute likelihoods only at parenchyma voxels.
958-
# shapes: parenchyma_vals = (M,), means = (K,)
959-
z_flat = (parenchyma_vals[None, :] - means[:, None]) / safe_stds[:, None]
960-
log_ll_flat = (
961-
-0.5 * z_flat * z_flat
962-
- np.log(safe_stds[:, None])
963-
- _LOG_2PI_HALF
964-
+ log_weights[:, None]
965-
)
966-
prob_flat = np.exp(log_ll_flat)
967-
968-
if use_mrf:
969-
# MRF regularization via ICM.
970-
# Initialize seg labels from the argmax of the raw likelihoods
971-
# (only within the parenchyma mask) before ICM iterations.
972-
init_labels = np.argmax(
973-
log_likelihoods[:, parenchyma_mask], axis=0
974-
)
975-
for k, label_val in enumerate(label_indices):
976-
mask_k = parenchyma_mask.copy()
977-
mask_k[parenchyma_mask] = init_labels == k
978-
seg[mask_k] = label_val
979-
980-
# Run ICM iterations to refine labels using spatial context.
981-
for _icm_iter in range(n_icm):
982-
self._icm_iteration(
983-
seg, log_likelihoods, parenchyma_mask,
984-
label_indices, self.mrf_beta,
985-
)
986-
987-
# Build soft responsibilities from the final labels for
988-
# GMM re-estimation. Use the normalized likelihoods (posteriors)
989-
# so the M-step is consistent with the E-step.
990-
prob_sum = np.sum(prob_maps, axis=0)
991-
prob_sum[prob_sum < 1e-300] = 1e-300
992-
for k in range(n_components):
993-
prob_maps[k] /= prob_sum
994-
995-
# Extract labels from seg for final output.
996-
labels_3d = np.zeros(int(np.sum(parenchyma_mask)), dtype=int)
997-
for k, label_val in enumerate(label_indices):
998-
labels_3d[seg[parenchyma_mask] == label_val] = k
999-
1000-
elif self.spatial_sigma_mm > 0:
1001-
# Gaussian smoothing regularization path.
1002-
mask_float = parenchyma_mask.astype(np.float64)
1003-
smoothed_mask = gaussian_filter(mask_float, sigma=sigma_voxels)
1004-
safe = parenchyma_mask & (smoothed_mask > 1e-6)
1005-
for k in range(n_components):
1006-
smoothed_prob = gaussian_filter(
1007-
prob_maps[k] * mask_float, sigma=sigma_voxels
1008-
)
1009-
result = np.zeros_like(prob_maps[k])
1010-
result[safe] = smoothed_prob[safe] / smoothed_mask[safe]
1011-
prob_maps[k] = result
1012-
1013-
# Normalize posteriors within parenchyma.
1014-
prob_sum = np.sum(prob_maps, axis=0)
1015-
prob_sum[prob_sum < 1e-300] = 1e-300
1016-
for k in range(n_components):
1017-
prob_maps[k] /= prob_sum
1018-
1019-
# Assign labels via argmax of smoothed posteriors.
1020-
labels_3d = np.argmax(prob_maps[:, parenchyma_mask], axis=0)
822+
# Compute likelihoods only at parenchyma voxels (flat arrays).
823+
z_flat = (parenchyma_vals[None, :] - means[:, None]) / safe_stds[:, None]
824+
log_ll_flat = (
825+
-0.5 * z_flat * z_flat
826+
- np.log(safe_stds[:, None])
827+
- _LOG_2PI_HALF
828+
+ log_weights[:, None]
829+
)
830+
prob_flat = np.exp(log_ll_flat)
1021831

1022-
else:
1023-
# No spatial regularization: work entirely with flat arrays.
1024-
# Normalize posteriors at parenchyma voxels only.
1025-
prob_sum_flat = np.sum(prob_flat, axis=0)
1026-
prob_sum_flat[prob_sum_flat < 1e-300] = 1e-300
1027-
for k in range(n_components):
1028-
prob_flat[k] /= prob_sum_flat
1029-
1030-
# Assign labels via argmax of posteriors.
1031-
labels_3d = np.argmax(prob_flat, axis=0)
1032-
1033-
# Re-estimate GMM parameters from the current soft assignment
1034-
# (for the next iteration). Use the posteriors as weights.
1035-
parenchyma_intensities = classify_data[parenchyma_mask]
832+
# Normalize posteriors.
833+
prob_sum_flat = np.sum(prob_flat, axis=0)
834+
prob_sum_flat[prob_sum_flat < 1e-300] = 1e-300
835+
prob_flat /= prob_sum_flat
836+
837+
# Re-estimate GMM parameters from the current soft assignment.
1036838
new_means = np.empty(n_components)
1037839
new_stds = np.empty(n_components)
1038840
new_weights = np.empty(n_components)
1039841
for k in range(n_components):
1040-
if needs_3d:
1041-
resp_k = prob_maps[k][parenchyma_mask]
1042-
else:
1043-
resp_k = prob_flat[k]
842+
resp_k = prob_flat[k]
1044843
nk = np.sum(resp_k)
1045844
if nk < 1e-10:
1046845
new_means[k] = means[k]
1047846
new_stds[k] = stds[k]
1048847
new_weights[k] = weights[k]
1049848
continue
1050-
new_means[k] = np.dot(resp_k, parenchyma_intensities) / nk
1051-
diff = parenchyma_intensities - new_means[k]
849+
new_means[k] = np.dot(resp_k, parenchyma_vals) / nk
850+
diff = parenchyma_vals - new_means[k]
1052851
new_stds[k] = np.sqrt(np.dot(resp_k, diff * diff) / nk)
1053852
new_stds[k] = max(new_stds[k], 1e-10)
1054-
new_weights[k] = nk / len(parenchyma_intensities)
853+
new_weights[k] = nk / len(parenchyma_vals)
1055854

1056855
means, stds, weights = new_means, new_stds, new_weights
1057856

1058-
# Write final labels into seg. Components are sorted by mean:
1059-
# index 0 = CSF (lowest), 1 = gray matter, 2 = white matter (highest).
1060-
for k, key in enumerate(label_keys):
1061-
voxel_mask = parenchyma_mask.copy()
1062-
voxel_mask[parenchyma_mask] = labels_3d == k
1063-
seg[voxel_mask] = material_idx[key]
857+
# Assign final labels from the last iteration's posteriors.
858+
# Components are sorted by mean: 0=CSF, 1=gray matter, 2=white matter.
859+
labels_flat = np.argmax(prob_flat, axis=0)
860+
seg[parenchyma_mask] = np.array(label_indices)[labels_flat]
1064861

1065862
def to_table(self) -> pd.DataFrame:
1066863
"""
@@ -1075,8 +872,6 @@ def to_table(self) -> pd.DataFrame:
1075872
{"Name": "Classify Brain Tissues", "Value": self.classify_brain_tissues, "Unit": ""},
1076873
{"Name": "Bias Correction Sigma", "Value": self.bias_correction_sigma_mm, "Unit": "mm"},
1077874
{"Name": "Use N4", "Value": self.use_n4, "Unit": ""},
1078-
{"Name": "Spatial Regularization Sigma", "Value": self.spatial_sigma_mm, "Unit": "mm"},
1079-
{"Name": "MRF Beta", "Value": self.mrf_beta, "Unit": ""},
1080875
{"Name": "Brain Extraction Margin", "Value": self.brain_extraction_margin_mm, "Unit": "mm"},
1081876
{"Name": "Refine Skull Intensity", "Value": self.refine_skull_intensity, "Unit": ""},
1082877
{"Name": "Reference Material", "Value": self.ref_material, "Unit": ""},

0 commit comments

Comments
 (0)