|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from unittest.mock import patch |
| 4 | + |
| 5 | +import numpy as np |
3 | 6 | import pytest |
4 | 7 |
|
5 | 8 | from openlifu.bf.delay_methods import DelayMethod, SimulationCorrected |
@@ -126,3 +129,58 @@ def test_to_table(self): |
126 | 129 | assert table.iloc[0]['Value'] == 'SimulationCorrected' |
127 | 130 | assert table.iloc[1]['Name'] == 'Default Sound Speed' |
128 | 131 | assert table.iloc[1]['Value'] == 1500.0 |
| 132 | + |
| 133 | + |
| 134 | +class TestSimulationCorrectedBehavior: |
| 135 | + """Test delay calculation logic by mocking the k-wave simulation layer.""" |
| 136 | + |
| 137 | + def test_delays_from_known_arrival_times(self): |
| 138 | + """When _run_reciprocal_simulation returns known arrival times, |
| 139 | + calc_delays should return max(arrival) - arrival for each element.""" |
| 140 | + arrival_times = np.array([0.001, 0.002, 0.003]) |
| 141 | + expected_delays = np.array([0.002, 0.001, 0.0]) |
| 142 | + |
| 143 | + method = SimulationCorrected() |
| 144 | + with patch.object( |
| 145 | + SimulationCorrected, |
| 146 | + "_run_reciprocal_simulation", |
| 147 | + return_value=arrival_times, |
| 148 | + ), patch("importlib.util.find_spec", return_value=True): |
| 149 | + delays = method.calc_delays( |
| 150 | + arr=None, target=None, params=None, transform=None |
| 151 | + ) |
| 152 | + np.testing.assert_allclose(delays, expected_delays) |
| 153 | + |
| 154 | + def test_fallback_on_simulation_failure(self): |
| 155 | + """When _run_reciprocal_simulation raises, calc_delays should |
| 156 | + fall back to Direct geometric delays without crashing.""" |
| 157 | + method = SimulationCorrected() |
| 158 | + with patch.object( |
| 159 | + SimulationCorrected, |
| 160 | + "_run_reciprocal_simulation", |
| 161 | + side_effect=RuntimeError("mocked failure"), |
| 162 | + ), patch("importlib.util.find_spec", return_value=True), patch.object( |
| 163 | + SimulationCorrected, |
| 164 | + "_fallback_delays", |
| 165 | + return_value=np.array([0.0, 0.0, 0.0]), |
| 166 | + ) as mock_fallback: |
| 167 | + delays = method.calc_delays( |
| 168 | + arr=None, target=None, params=None, transform=None |
| 169 | + ) |
| 170 | + mock_fallback.assert_called_once() |
| 171 | + np.testing.assert_array_equal(delays, [0.0, 0.0, 0.0]) |
| 172 | + |
| 173 | + def test_fallback_when_kwave_missing(self): |
| 174 | + """When k-wave is not installed, calc_delays should fall back |
| 175 | + to Direct geometric delays.""" |
| 176 | + method = SimulationCorrected() |
| 177 | + with patch("importlib.util.find_spec", return_value=None), patch.object( |
| 178 | + SimulationCorrected, |
| 179 | + "_fallback_delays", |
| 180 | + return_value=np.array([0.0, 0.0]), |
| 181 | + ) as mock_fallback: |
| 182 | + delays = method.calc_delays( |
| 183 | + arr=None, target=None, params=None, transform=None |
| 184 | + ) |
| 185 | + mock_fallback.assert_called_once() |
| 186 | + np.testing.assert_array_equal(delays, [0.0, 0.0]) |
0 commit comments