Skip to content

Commit e2a63a8

Browse files
author
Brandon Kirkland
committed
Add threshold-based MRI segmentation and phase correction (#150)
ThresholdMRI: new SegmentationMethod that segments T1-weighted MRI into skull, brain tissue, air, and water regions via EDT erosion and Otsu intensity refinement. Opt-in 6-label mode classifies brain tissue into CSF, gray matter, and white matter via EM-GMM. Auto-detects skull-stripped input. N4 bias correction enabled by default via SimpleITK with shrink=4 (falls back to homomorphic if unavailable). SimulationCorrected: new DelayMethod that uses k-wave simulation with acoustic reciprocity to compute phase-corrected transducer delays. Places a virtual point source at the target and records pressure at all element positions in a single simulation run. Extracts arrival times via Hilbert envelope peak detection. Falls back to Direct (geometric) delays if k-wave is unavailable. Material fix: corrects attenuation values in material.py that were all 0.0 (water=0.002, tissue=0.6, skull=8.0, air=1.64 dB/cm/MHz). Foreground mask speedup: replaces EDT-based morphological closing with scipy binary_closing (~10x faster). Validated on 1,636 scans across 5 datasets with 0 processing errors. Brain tissue classification matches ANTs Atropos on skull-stripped data (GM=0.822 vs 0.79, WM=0.853 vs 0.84).
1 parent 1b90c3f commit e2a63a8

11 files changed

Lines changed: 2134 additions & 21 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ filterwarnings = [
103103
"ignore:Custom binary name set.*:UserWarning",
104104
# pyparsing 3.x deprecated all camelCase methods in favor of snake_case
105105
"ignore:'[a-zA-Z]+' deprecated - use '[a-z_]+':DeprecationWarning",
106+
# SimpleITK's SWIG bindings emit DeprecationWarnings about missing __module__
107+
# attributes on builtin types during import. These are harmless and cannot be
108+
# fixed downstream; suppressing them prevents a segfault when filterwarnings=error
109+
# turns them into exceptions inside the C extension loader.
110+
"ignore:builtin type Swig.*has no __module__ attribute:DeprecationWarning",
106111
]
107112
log_cli_level = "INFO"
108113
testpaths = [

src/openlifu/bf/delay_methods/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from .delaymethod import DelayMethod
44
from .direct import Direct
5+
from .simulation_corrected import SimulationCorrected
56

67
__all__ = [
78
"DelayMethod",
89
"Direct",
10+
"SimulationCorrected",
911
]
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from dataclasses import dataclass
5+
from typing import Annotated
6+
7+
import numpy as np
8+
import pandas as pd
9+
import xarray as xa
10+
11+
from openlifu.bf.delay_methods import DelayMethod
12+
from openlifu.geo import Point
13+
from openlifu.util.annotations import OpenLIFUFieldData
14+
from openlifu.util.units import getunitconversion
15+
from openlifu.xdc import Transducer
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@dataclass
21+
class SimulationCorrected(DelayMethod):
22+
"""Delay method using k-wave simulation with reciprocity for phase correction.
23+
24+
Places a virtual point source at the target and records pressure time series
25+
at all transducer element positions. The arrival time at each element encodes
26+
the true acoustic path through the heterogeneous skull model. Delays are
27+
computed as max(arrival_time) - arrival_time for each element.
28+
29+
Uses a single k-wave simulation (via reciprocity) instead of one per element.
30+
"""
31+
32+
c0: Annotated[
33+
float,
34+
OpenLIFUFieldData("Speed of Sound (m/s)", "Reference speed of sound in the medium (m/s)"),
35+
] = 1500.0
36+
"""Reference speed of sound in the medium (m/s)"""
37+
38+
cfl: Annotated[float, OpenLIFUFieldData("CFL Number", "Courant-Friedrichs-Lewy number for time stepping")] = 0.3
39+
"""Courant-Friedrichs-Lewy number for time stepping"""
40+
41+
n_cycles: Annotated[int, OpenLIFUFieldData("Source Cycles", "Number of cycles in the source pulse")] = 3
42+
"""Number of cycles in the source pulse"""
43+
44+
gpu: Annotated[bool, OpenLIFUFieldData("Use GPU", "Whether to attempt GPU-accelerated simulation")] = True
45+
"""Whether to attempt GPU-accelerated simulation"""
46+
47+
def __post_init__(self):
48+
if not isinstance(self.c0, int | float):
49+
raise TypeError("Speed of sound must be a number")
50+
if self.c0 <= 0:
51+
raise ValueError("Speed of sound must be greater than 0")
52+
self.c0 = float(self.c0)
53+
54+
if not isinstance(self.cfl, int | float):
55+
raise TypeError("CFL must be a number")
56+
if self.cfl <= 0 or self.cfl >= 1:
57+
raise ValueError("CFL must be between 0 and 1 (exclusive)")
58+
self.cfl = float(self.cfl)
59+
60+
if not isinstance(self.n_cycles, int):
61+
if isinstance(self.n_cycles, float) and self.n_cycles == int(self.n_cycles):
62+
self.n_cycles = int(self.n_cycles)
63+
else:
64+
raise TypeError("n_cycles must be an integer")
65+
if self.n_cycles < 1:
66+
raise ValueError("n_cycles must be at least 1")
67+
68+
if not isinstance(self.gpu, bool):
69+
raise TypeError("gpu must be a boolean")
70+
71+
def calc_delays(self, arr: Transducer, target: Point, params: xa.Dataset, transform: np.ndarray | None = None):
72+
"""Calculate delays using k-wave simulation with reciprocity.
73+
74+
Fires a virtual point source at the target position and records the pressure
75+
time series at each transducer element location. Arrival times are extracted
76+
via the Hilbert envelope peak, and delays are computed so that all elements
77+
fire in phase at the target.
78+
79+
Falls back to Direct (geometric) delays if k-wave is not available or
80+
if the simulation fails for any reason.
81+
82+
Args:
83+
:param arr: The transducer array.
84+
:param target: The focal target point.
85+
:param params: Simulation grid dataset with sound_speed, density, attenuation fields.
86+
:param transform: Optional 4x4 affine transform for element positions.
87+
:returns: 1D numpy array of per-element delays in seconds.
88+
"""
89+
try:
90+
import importlib
91+
if importlib.util.find_spec("kwave") is None:
92+
raise ImportError("k-wave not installed")
93+
except ImportError:
94+
logger.warning("k-wave not available. Falling back to Direct delay method.")
95+
return self._fallback_delays(arr, target, params, transform)
96+
97+
try:
98+
arrival_times = self._run_reciprocal_simulation(arr, target, params, transform)
99+
delays = np.max(arrival_times) - arrival_times
100+
return delays
101+
except (RuntimeError, ValueError, IndexError, OSError):
102+
logger.exception("Simulation-corrected delay calculation failed. Falling back to Direct method.")
103+
return self._fallback_delays(arr, target, params, transform)
104+
105+
def _run_reciprocal_simulation(
106+
self,
107+
arr: Transducer,
108+
target: Point,
109+
params: xa.Dataset,
110+
transform: np.ndarray | None = None,
111+
) -> np.ndarray:
112+
"""Run the reciprocal k-wave simulation and extract arrival times.
113+
114+
Args:
115+
arr: The transducer array.
116+
target: The focal target point.
117+
params: Simulation grid dataset.
118+
transform: Optional 4x4 affine transform.
119+
120+
Returns:
121+
arrival_times: 1D array of arrival times (seconds) per element.
122+
"""
123+
from scipy.signal import hilbert
124+
125+
from openlifu.sim.kwave_if import run_point_source_simulation
126+
127+
# Get the reference sound speed from params if available
128+
if 'sound_speed' in params and 'ref_value' in params['sound_speed'].attrs:
129+
sound_speed_ref = params['sound_speed'].attrs['ref_value']
130+
else:
131+
sound_speed_ref = self.c0
132+
133+
# Get frequency from the transducer
134+
freq = arr.frequency
135+
136+
# Compute element positions in the simulation coordinate frame.
137+
# The simulation grid uses params.coords (typically in mm).
138+
# Element positions come from the transducer in its native units (typically m).
139+
# We need to convert element positions to the same units as the sim grid,
140+
# then find the nearest grid voxel for each element.
141+
coord_dims = list(params.coords.dims)
142+
coord_units = params[coord_dims[0]].attrs.get('units', 'mm')
143+
scl_to_grid = getunitconversion('m', coord_units)
144+
145+
matrix = transform if transform is not None else np.eye(4)
146+
element_positions_m = np.array([
147+
el.get_position(units="m", matrix=matrix)
148+
for el in arr.elements
149+
])
150+
# Convert to grid units
151+
element_positions_grid = element_positions_m * scl_to_grid
152+
153+
# Get target position in grid units
154+
target_pos_grid = target.get_position(units=coord_units)
155+
156+
# Build the sensor mask: find nearest grid indices for each element
157+
coord_arrays = [params.coords[dim].to_numpy() for dim in coord_dims]
158+
grid_shape = tuple(len(c) for c in coord_arrays)
159+
160+
sensor_indices = []
161+
out_of_grid = set()
162+
for el_i, epos in enumerate(element_positions_grid):
163+
idx = []
164+
inside = True
165+
for dim_i, coord_vals in enumerate(coord_arrays):
166+
cmin, cmax = float(coord_vals[0]), float(coord_vals[-1])
167+
half_step = abs(float(coord_vals[1] - coord_vals[0])) / 2 if len(coord_vals) > 1 else 0
168+
if epos[dim_i] < cmin - half_step or epos[dim_i] > cmax + half_step:
169+
inside = False
170+
nearest_idx = int(np.argmin(np.abs(coord_vals - epos[dim_i])))
171+
idx.append(nearest_idx)
172+
if not inside:
173+
out_of_grid.add(el_i)
174+
logger.warning(
175+
f"Element {el_i} at position {epos} is outside the simulation grid. "
176+
"Using geometric time-of-flight estimate for this element."
177+
)
178+
sensor_indices.append(tuple(idx))
179+
180+
# Build sensor mask (3D binary)
181+
sensor_mask = np.zeros(grid_shape, dtype=int)
182+
for idx in sensor_indices:
183+
sensor_mask[idx] = 1
184+
185+
# Find the target voxel index
186+
target_idx = []
187+
for dim_i, coord_vals in enumerate(coord_arrays):
188+
nearest_idx = int(np.argmin(np.abs(coord_vals - target_pos_grid[dim_i])))
189+
target_idx.append(nearest_idx)
190+
target_idx = tuple(target_idx)
191+
192+
# Build source mask (single voxel at target)
193+
source_mask = np.zeros(grid_shape, dtype=int)
194+
source_mask[target_idx] = 1
195+
196+
# Run the point source simulation
197+
sensor_data, dt = run_point_source_simulation(
198+
params=params,
199+
source_mask=source_mask,
200+
sensor_mask=sensor_mask,
201+
freq=freq,
202+
n_cycles=self.n_cycles,
203+
sound_speed_ref=sound_speed_ref,
204+
cfl=self.cfl,
205+
gpu=self.gpu,
206+
)
207+
208+
# sensor_data is (n_sensor_points, n_timesteps).
209+
# Multiple elements may map to the same voxel if the grid is coarse.
210+
# We need to map sensor data rows back to elements.
211+
212+
# Build a lookup: grid index -> sensor_data row index.
213+
# The sensor mask was constructed by setting 1 at unique voxel locations.
214+
# k-wave returns data for each nonzero voxel in Fortran (column-major) order.
215+
nonzero_indices = list(zip(*np.nonzero(sensor_mask)))
216+
# k-wave returns sensor data in Fortran (column-major) order of the mask:
217+
# x varies fastest, then y, then z. np.nonzero returns C order (row-major),
218+
# so we sort by the Fortran linear index to match k-wave's output ordering.
219+
def fortran_linear_index(idx, shape):
220+
# For Fortran order: idx[0] + idx[1]*shape[0] + idx[2]*shape[0]*shape[1]
221+
lin = idx[0]
222+
stride = shape[0]
223+
for d in range(1, len(shape)):
224+
lin += idx[d] * stride
225+
stride *= shape[d]
226+
return lin
227+
228+
nonzero_with_fortran = [(fortran_linear_index(idx, grid_shape), idx) for idx in nonzero_indices]
229+
nonzero_with_fortran.sort(key=lambda x: x[0])
230+
sorted_nonzero = [item[1] for item in nonzero_with_fortran]
231+
232+
voxel_to_row = {idx: row for row, idx in enumerate(sorted_nonzero)}
233+
234+
# Extract arrival time for each element
235+
n_elements = len(arr.elements)
236+
arrival_times = np.zeros(n_elements)
237+
238+
for el_i, sensor_idx in enumerate(sensor_indices):
239+
if el_i in out_of_grid:
240+
# Element is outside the simulation grid; use geometric fallback
241+
dist = np.linalg.norm(element_positions_m[el_i] - target.get_position(units="m"))
242+
arrival_times[el_i] = dist / self.c0
243+
continue
244+
245+
row = voxel_to_row[sensor_idx]
246+
time_series = sensor_data[row, :]
247+
# Compute the analytic signal envelope via the Hilbert transform
248+
analytic = hilbert(time_series)
249+
envelope = np.abs(analytic)
250+
# The arrival time is the time of the envelope peak
251+
peak_sample = int(np.argmax(envelope))
252+
arrival_times[el_i] = peak_sample * dt
253+
254+
return arrival_times
255+
256+
def _fallback_delays(self, arr: Transducer, target: Point, params: xa.Dataset, transform: np.ndarray | None = None) -> np.ndarray:
257+
"""Compute delays using the Direct (geometric) method as a fallback."""
258+
from openlifu.bf.delay_methods.direct import Direct
259+
direct = Direct(c0=self.c0)
260+
return direct.calc_delays(arr, target, params, transform)
261+
262+
def to_table(self) -> pd.DataFrame:
263+
"""
264+
Get a table of the delay method parameters
265+
266+
:returns: Pandas DataFrame of the delay method parameters
267+
"""
268+
records = [
269+
{"Name": "Type", "Value": "SimulationCorrected", "Unit": ""},
270+
{"Name": "Default Sound Speed", "Value": self.c0, "Unit": "m/s"},
271+
{"Name": "CFL Number", "Value": self.cfl, "Unit": ""},
272+
{"Name": "Source Cycles", "Value": self.n_cycles, "Unit": ""},
273+
{"Name": "Use GPU", "Value": self.gpu, "Unit": ""},
274+
]
275+
return pd.DataFrame.from_records(records)

src/openlifu/seg/material.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,28 +92,28 @@ def from_dict(d: dict[str, Any]):
9292
WATER = Material(name="water",
9393
sound_speed=1500.0,
9494
density=1000.0,
95-
attenuation=0.0,
95+
attenuation=0.002,
9696
specific_heat=4182.0,
9797
thermal_conductivity=0.598)
9898

9999
TISSUE = Material(name="tissue",
100100
sound_speed=1540.0,
101101
density=1000.0,
102-
attenuation=0.0,
102+
attenuation=0.6,
103103
specific_heat=3600.0,
104104
thermal_conductivity=0.5)
105105

106106
SKULL = Material(name="skull",
107107
sound_speed=4080.0,
108108
density=1900.0,
109-
attenuation=0.0,
109+
attenuation=8.0,
110110
specific_heat=1100.0,
111111
thermal_conductivity=0.3)
112112

113113
AIR = Material(name="air",
114114
sound_speed=344.0,
115115
density=1.25,
116-
attenuation=0.0,
116+
attenuation=1.64,
117117
specific_heat=1012.0,
118118
thermal_conductivity=0.025)
119119

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3+
from .threshold_mri import ThresholdMRI
34
from .uniform import UniformSegmentation, UniformTissue, UniformWater
45

56
__all__ = [
7+
"ThresholdMRI",
68
"UniformSegmentation",
7-
"UniformWater",
89
"UniformTissue",
10+
"UniformWater",
911
]

0 commit comments

Comments
 (0)