|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Annotated |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +import xarray as xa |
| 10 | + |
| 11 | +from openlifu.bf.delay_methods import DelayMethod |
| 12 | +from openlifu.geo import Point |
| 13 | +from openlifu.util.annotations import OpenLIFUFieldData |
| 14 | +from openlifu.util.units import getunitconversion |
| 15 | +from openlifu.xdc import Transducer |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class SimulationCorrected(DelayMethod): |
| 22 | + """Delay method using k-wave simulation with reciprocity for phase correction. |
| 23 | +
|
| 24 | + Places a virtual point source at the target and records pressure time series |
| 25 | + at all transducer element positions. The arrival time at each element encodes |
| 26 | + the true acoustic path through the heterogeneous skull model. Delays are |
| 27 | + computed as max(arrival_time) - arrival_time for each element. |
| 28 | +
|
| 29 | + Uses a single k-wave simulation (via reciprocity) instead of one per element. |
| 30 | + """ |
| 31 | + |
| 32 | + c0: Annotated[ |
| 33 | + float, |
| 34 | + OpenLIFUFieldData("Speed of Sound (m/s)", "Reference speed of sound in the medium (m/s)"), |
| 35 | + ] = 1500.0 |
| 36 | + """Reference speed of sound in the medium (m/s)""" |
| 37 | + |
| 38 | + cfl: Annotated[float, OpenLIFUFieldData("CFL Number", "Courant-Friedrichs-Lewy number for time stepping")] = 0.3 |
| 39 | + """Courant-Friedrichs-Lewy number for time stepping""" |
| 40 | + |
| 41 | + n_cycles: Annotated[int, OpenLIFUFieldData("Source Cycles", "Number of cycles in the source pulse")] = 3 |
| 42 | + """Number of cycles in the source pulse""" |
| 43 | + |
| 44 | + gpu: Annotated[bool, OpenLIFUFieldData("Use GPU", "Whether to attempt GPU-accelerated simulation")] = True |
| 45 | + """Whether to attempt GPU-accelerated simulation""" |
| 46 | + |
| 47 | + def __post_init__(self): |
| 48 | + if not isinstance(self.c0, int | float): |
| 49 | + raise TypeError("Speed of sound must be a number") |
| 50 | + if self.c0 <= 0: |
| 51 | + raise ValueError("Speed of sound must be greater than 0") |
| 52 | + self.c0 = float(self.c0) |
| 53 | + |
| 54 | + if not isinstance(self.cfl, int | float): |
| 55 | + raise TypeError("CFL must be a number") |
| 56 | + if self.cfl <= 0 or self.cfl >= 1: |
| 57 | + raise ValueError("CFL must be between 0 and 1 (exclusive)") |
| 58 | + self.cfl = float(self.cfl) |
| 59 | + |
| 60 | + if not isinstance(self.n_cycles, int): |
| 61 | + if isinstance(self.n_cycles, float) and self.n_cycles == int(self.n_cycles): |
| 62 | + self.n_cycles = int(self.n_cycles) |
| 63 | + else: |
| 64 | + raise TypeError("n_cycles must be an integer") |
| 65 | + if self.n_cycles < 1: |
| 66 | + raise ValueError("n_cycles must be at least 1") |
| 67 | + |
| 68 | + if not isinstance(self.gpu, bool): |
| 69 | + raise TypeError("gpu must be a boolean") |
| 70 | + |
| 71 | + def calc_delays(self, arr: Transducer, target: Point, params: xa.Dataset, transform: np.ndarray | None = None): |
| 72 | + """Calculate delays using k-wave simulation with reciprocity. |
| 73 | +
|
| 74 | + Fires a virtual point source at the target position and records the pressure |
| 75 | + time series at each transducer element location. Arrival times are extracted |
| 76 | + via the Hilbert envelope peak, and delays are computed so that all elements |
| 77 | + fire in phase at the target. |
| 78 | +
|
| 79 | + Falls back to Direct (geometric) delays if k-wave is not available or |
| 80 | + if the simulation fails for any reason. |
| 81 | +
|
| 82 | + Args: |
| 83 | + :param arr: The transducer array. |
| 84 | + :param target: The focal target point. |
| 85 | + :param params: Simulation grid dataset with sound_speed, density, attenuation fields. |
| 86 | + :param transform: Optional 4x4 affine transform for element positions. |
| 87 | + :returns: 1D numpy array of per-element delays in seconds. |
| 88 | + """ |
| 89 | + try: |
| 90 | + import importlib |
| 91 | + if importlib.util.find_spec("kwave") is None: |
| 92 | + raise ImportError("k-wave not installed") |
| 93 | + except ImportError: |
| 94 | + logger.warning("k-wave not available. Falling back to Direct delay method.") |
| 95 | + return self._fallback_delays(arr, target, params, transform) |
| 96 | + |
| 97 | + try: |
| 98 | + arrival_times = self._run_reciprocal_simulation(arr, target, params, transform) |
| 99 | + delays = np.max(arrival_times) - arrival_times |
| 100 | + return delays |
| 101 | + except (RuntimeError, ValueError, IndexError, OSError): |
| 102 | + logger.exception("Simulation-corrected delay calculation failed. Falling back to Direct method.") |
| 103 | + return self._fallback_delays(arr, target, params, transform) |
| 104 | + |
| 105 | + def _run_reciprocal_simulation( |
| 106 | + self, |
| 107 | + arr: Transducer, |
| 108 | + target: Point, |
| 109 | + params: xa.Dataset, |
| 110 | + transform: np.ndarray | None = None, |
| 111 | + ) -> np.ndarray: |
| 112 | + """Run the reciprocal k-wave simulation and extract arrival times. |
| 113 | +
|
| 114 | + Args: |
| 115 | + arr: The transducer array. |
| 116 | + target: The focal target point. |
| 117 | + params: Simulation grid dataset. |
| 118 | + transform: Optional 4x4 affine transform. |
| 119 | +
|
| 120 | + Returns: |
| 121 | + arrival_times: 1D array of arrival times (seconds) per element. |
| 122 | + """ |
| 123 | + from scipy.signal import hilbert |
| 124 | + |
| 125 | + from openlifu.sim.kwave_if import run_point_source_simulation |
| 126 | + |
| 127 | + # Get the reference sound speed from params if available |
| 128 | + if 'sound_speed' in params and 'ref_value' in params['sound_speed'].attrs: |
| 129 | + sound_speed_ref = params['sound_speed'].attrs['ref_value'] |
| 130 | + else: |
| 131 | + sound_speed_ref = self.c0 |
| 132 | + |
| 133 | + # Get frequency from the transducer |
| 134 | + freq = arr.frequency |
| 135 | + |
| 136 | + # Compute element positions in the simulation coordinate frame. |
| 137 | + # The simulation grid uses params.coords (typically in mm). |
| 138 | + # Element positions come from the transducer in its native units (typically m). |
| 139 | + # We need to convert element positions to the same units as the sim grid, |
| 140 | + # then find the nearest grid voxel for each element. |
| 141 | + coord_dims = list(params.coords.dims) |
| 142 | + coord_units = params[coord_dims[0]].attrs.get('units', 'mm') |
| 143 | + scl_to_grid = getunitconversion('m', coord_units) |
| 144 | + |
| 145 | + matrix = transform if transform is not None else np.eye(4) |
| 146 | + element_positions_m = np.array([ |
| 147 | + el.get_position(units="m", matrix=matrix) |
| 148 | + for el in arr.elements |
| 149 | + ]) |
| 150 | + # Convert to grid units |
| 151 | + element_positions_grid = element_positions_m * scl_to_grid |
| 152 | + |
| 153 | + # Get target position in grid units |
| 154 | + target_pos_grid = target.get_position(units=coord_units) |
| 155 | + |
| 156 | + # Build the sensor mask: find nearest grid indices for each element |
| 157 | + coord_arrays = [params.coords[dim].to_numpy() for dim in coord_dims] |
| 158 | + grid_shape = tuple(len(c) for c in coord_arrays) |
| 159 | + |
| 160 | + sensor_indices = [] |
| 161 | + out_of_grid = set() |
| 162 | + for el_i, epos in enumerate(element_positions_grid): |
| 163 | + idx = [] |
| 164 | + inside = True |
| 165 | + for dim_i, coord_vals in enumerate(coord_arrays): |
| 166 | + cmin, cmax = float(coord_vals[0]), float(coord_vals[-1]) |
| 167 | + half_step = abs(float(coord_vals[1] - coord_vals[0])) / 2 if len(coord_vals) > 1 else 0 |
| 168 | + if epos[dim_i] < cmin - half_step or epos[dim_i] > cmax + half_step: |
| 169 | + inside = False |
| 170 | + nearest_idx = int(np.argmin(np.abs(coord_vals - epos[dim_i]))) |
| 171 | + idx.append(nearest_idx) |
| 172 | + if not inside: |
| 173 | + out_of_grid.add(el_i) |
| 174 | + logger.warning( |
| 175 | + f"Element {el_i} at position {epos} is outside the simulation grid. " |
| 176 | + "Using geometric time-of-flight estimate for this element." |
| 177 | + ) |
| 178 | + sensor_indices.append(tuple(idx)) |
| 179 | + |
| 180 | + # Build sensor mask (3D binary) |
| 181 | + sensor_mask = np.zeros(grid_shape, dtype=int) |
| 182 | + for idx in sensor_indices: |
| 183 | + sensor_mask[idx] = 1 |
| 184 | + |
| 185 | + # Find the target voxel index |
| 186 | + target_idx = [] |
| 187 | + for dim_i, coord_vals in enumerate(coord_arrays): |
| 188 | + nearest_idx = int(np.argmin(np.abs(coord_vals - target_pos_grid[dim_i]))) |
| 189 | + target_idx.append(nearest_idx) |
| 190 | + target_idx = tuple(target_idx) |
| 191 | + |
| 192 | + # Build source mask (single voxel at target) |
| 193 | + source_mask = np.zeros(grid_shape, dtype=int) |
| 194 | + source_mask[target_idx] = 1 |
| 195 | + |
| 196 | + # Run the point source simulation |
| 197 | + sensor_data, dt = run_point_source_simulation( |
| 198 | + params=params, |
| 199 | + source_mask=source_mask, |
| 200 | + sensor_mask=sensor_mask, |
| 201 | + freq=freq, |
| 202 | + n_cycles=self.n_cycles, |
| 203 | + sound_speed_ref=sound_speed_ref, |
| 204 | + cfl=self.cfl, |
| 205 | + gpu=self.gpu, |
| 206 | + ) |
| 207 | + |
| 208 | + # sensor_data is (n_sensor_points, n_timesteps). |
| 209 | + # Multiple elements may map to the same voxel if the grid is coarse. |
| 210 | + # We need to map sensor data rows back to elements. |
| 211 | + |
| 212 | + # Build a lookup: grid index -> sensor_data row index. |
| 213 | + # The sensor mask was constructed by setting 1 at unique voxel locations. |
| 214 | + # k-wave returns data for each nonzero voxel in Fortran (column-major) order. |
| 215 | + nonzero_indices = list(zip(*np.nonzero(sensor_mask))) |
| 216 | + # k-wave returns sensor data in Fortran (column-major) order of the mask: |
| 217 | + # x varies fastest, then y, then z. np.nonzero returns C order (row-major), |
| 218 | + # so we sort by the Fortran linear index to match k-wave's output ordering. |
| 219 | + def fortran_linear_index(idx, shape): |
| 220 | + # For Fortran order: idx[0] + idx[1]*shape[0] + idx[2]*shape[0]*shape[1] |
| 221 | + lin = idx[0] |
| 222 | + stride = shape[0] |
| 223 | + for d in range(1, len(shape)): |
| 224 | + lin += idx[d] * stride |
| 225 | + stride *= shape[d] |
| 226 | + return lin |
| 227 | + |
| 228 | + nonzero_with_fortran = [(fortran_linear_index(idx, grid_shape), idx) for idx in nonzero_indices] |
| 229 | + nonzero_with_fortran.sort(key=lambda x: x[0]) |
| 230 | + sorted_nonzero = [item[1] for item in nonzero_with_fortran] |
| 231 | + |
| 232 | + voxel_to_row = {idx: row for row, idx in enumerate(sorted_nonzero)} |
| 233 | + |
| 234 | + # Extract arrival time for each element |
| 235 | + n_elements = len(arr.elements) |
| 236 | + arrival_times = np.zeros(n_elements) |
| 237 | + |
| 238 | + for el_i, sensor_idx in enumerate(sensor_indices): |
| 239 | + if el_i in out_of_grid: |
| 240 | + # Element is outside the simulation grid; use geometric fallback |
| 241 | + dist = np.linalg.norm(element_positions_m[el_i] - target.get_position(units="m")) |
| 242 | + arrival_times[el_i] = dist / self.c0 |
| 243 | + continue |
| 244 | + |
| 245 | + row = voxel_to_row[sensor_idx] |
| 246 | + time_series = sensor_data[row, :] |
| 247 | + # Compute the analytic signal envelope via the Hilbert transform |
| 248 | + analytic = hilbert(time_series) |
| 249 | + envelope = np.abs(analytic) |
| 250 | + # The arrival time is the time of the envelope peak |
| 251 | + peak_sample = int(np.argmax(envelope)) |
| 252 | + arrival_times[el_i] = peak_sample * dt |
| 253 | + |
| 254 | + return arrival_times |
| 255 | + |
| 256 | + def _fallback_delays(self, arr: Transducer, target: Point, params: xa.Dataset, transform: np.ndarray | None = None) -> np.ndarray: |
| 257 | + """Compute delays using the Direct (geometric) method as a fallback.""" |
| 258 | + from openlifu.bf.delay_methods.direct import Direct |
| 259 | + direct = Direct(c0=self.c0) |
| 260 | + return direct.calc_delays(arr, target, params, transform) |
| 261 | + |
| 262 | + def to_table(self) -> pd.DataFrame: |
| 263 | + """ |
| 264 | + Get a table of the delay method parameters |
| 265 | +
|
| 266 | + :returns: Pandas DataFrame of the delay method parameters |
| 267 | + """ |
| 268 | + records = [ |
| 269 | + {"Name": "Type", "Value": "SimulationCorrected", "Unit": ""}, |
| 270 | + {"Name": "Default Sound Speed", "Value": self.c0, "Unit": "m/s"}, |
| 271 | + {"Name": "CFL Number", "Value": self.cfl, "Unit": ""}, |
| 272 | + {"Name": "Source Cycles", "Value": self.n_cycles, "Unit": ""}, |
| 273 | + {"Name": "Use GPU", "Value": self.gpu, "Unit": ""}, |
| 274 | + ] |
| 275 | + return pd.DataFrame.from_records(records) |
0 commit comments