Skip to content

Commit b2d49f6

Browse files
authored
Improve test coverage on extract_amp_phase (#545)
1 parent ea9d37a commit b2d49f6

5 files changed

Lines changed: 100 additions & 44 deletions

File tree

kwave/utils/conversion.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,18 @@ def find_closest(array: ndarray, value: Num[kt.ScalarLike, ""]):
489489
array = np.asarray(array)
490490
idx = (np.abs(array - value)).argmin()
491491
return array[idx], idx
492+
493+
494+
def create_index_at_dim(ndim: int, dim: int, index_value: Any) -> tuple:
495+
"""
496+
Create a tuple of slice objects with a specific index value at the specified dimension.
497+
498+
Args:
499+
ndim: Number of dimensions in the array
500+
dim: The dimension where the specific index should be placed
501+
index_value: The index value to place at the specified dimension
502+
503+
Returns:
504+
A tuple of slice objects with the index value at the specified dimension
505+
"""
506+
return tuple(index_value if i == dim else slice(None) for i in range(ndim))

kwave/utils/filters.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from scipy.fftpack import fft, fftshift, ifft, ifftshift
77
from scipy.signal import convolve, lfilter
88

9-
from ..kgrid import kWaveGrid
10-
from ..kmedium import kWaveMedium
9+
from kwave.utils.conversion import create_index_at_dim
10+
1111
from .checks import is_number
1212
from .data import scale_SI
1313
from .math import find_closest, gaussian, next_pow2, norm_var, sinc
@@ -24,31 +24,26 @@ def single_sided_correction(func_fft: np.ndarray, fft_len: int, dim: int) -> np.
2424
Args:
2525
func_fft: The FFT of the function to be corrected.
2626
fft_len: The length of the FFT.
27-
dim: The number of dimensions of `func_fft`.
27+
dim: The dimension along which to apply the correction.
2828
2929
Returns:
3030
The corrected FFT of the function.
3131
"""
32+
# Determine the slice to use based on FFT length
3233
if fft_len % 2:
33-
# odd FFT length switch dim case
34-
if dim == 0:
35-
func_fft[1:, :] = func_fft[1:, :] * 2
36-
elif dim == 1:
37-
func_fft[:, 1:] = func_fft[:, 1:] * 2
38-
elif dim == 2:
39-
func_fft[:, :, 1:] = func_fft[:, :, 1:] * 2
40-
elif dim == 3:
41-
func_fft[:, :, :, 1:] = func_fft[:, :, :, 1:] * 2
34+
# odd FFT length - multiply all elements except the first one by 2
35+
dim_slice = slice(1, None)
4236
else:
43-
# even FFT length
44-
if dim == 0:
45-
func_fft[1:-1] = func_fft[1:-1] * 2
46-
elif dim == 1:
47-
func_fft[:, 1:-1] = func_fft[:, 1:-1] * 2
48-
elif dim == 2:
49-
func_fft[:, :, 1:-1] = func_fft[:, :, 1:-1] * 2
50-
elif dim == 3:
51-
func_fft[:, :, :, 1:-1] = func_fft[:, :, :, 1:-1] * 2
37+
# even FFT length - multiply all elements except the first and last ones by 2
38+
dim_slice = slice(1, -1)
39+
40+
# Create a slice tuple with the appropriate slice at the specified dimension
41+
idx_all = [slice(None)] * func_fft.ndim
42+
idx_all[dim] = dim_slice
43+
idx_tuple = tuple(idx_all)
44+
45+
# Apply the correction
46+
func_fft[idx_tuple] = func_fft[idx_tuple] * 2
5247

5348
return func_fft
5449

@@ -197,7 +192,6 @@ def extract_amp_phase(
197192

198193
# compute amplitude and phase spectra
199194
f, func_as, func_ps = spect(data, Fs, fft_len=fft_padding * data.shape[dim], dim=dim)
200-
201195
# correct for coherent gain
202196
func_as = func_as / coherent_gain
203197

@@ -209,20 +203,10 @@ def extract_amp_phase(
209203
sz[dim - 1] = 1
210204

211205
# extract amplitude and relative phase at freq_index
212-
if dim == 0:
213-
amp = func_as[f_index]
214-
phase = func_ps[f_index]
215-
elif dim == 1:
216-
amp = func_as[:, f_index]
217-
phase = func_ps[:, f_index]
218-
elif dim == 2:
219-
amp = func_as[:, :, f_index]
220-
phase = func_ps[:, :, f_index]
221-
elif dim == 3:
222-
amp = func_as[:, :, :, f_index]
223-
phase = func_ps[:, :, :, f_index]
224-
else:
225-
raise ValueError("dim must be 0, 1, 2, or 3")
206+
# Create a tuple of slice objects with the frequency index at the correct dimension
207+
idx = create_index_at_dim(func_as.ndim, dim, f_index)
208+
amp = func_as[idx]
209+
phase = func_ps[idx]
226210

227211
return amp.squeeze(), phase.squeeze(), f[f_index]
228212

tests/matlab_test_data_collectors/run_all_collectors.m

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
addpath(genpath('../../../k-wave'));
44
directory = pwd + "/matlab_collectors";
55
files = getListOfFiles(directory);
6-
% remove this file.
76

87
for idx=1:length(files)
98
% ensure collected value directory has been created
109
file_parts = split(files(idx),["_","."]);
1110
collected_value_dir = pwd + ...
1211
"/matlab_collectors/collectedValues/" + file_parts(2);
1312
mkdir(collected_value_dir)
14-
% run value collector
15-
run(fullfile(directory, files{idx}));
16-
clearvars -except idx files directory
13+
% run value collector
14+
run(fullfile(directory, files{idx}));
15+
clearvars -except idx files directory
1716
end
1817

1918
if ~isRunningInCI()

tests/test_math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ def test_affine_functions_equivalent(self):
9191
# Test deprecation warning
9292
with pytest.warns(DeprecationWarning, match="get_affine_matrix is deprecated as of 0.4.1") as warns:
9393
old_result = get_affine_matrix(translation, rotation)
94-
94+
9595
# Verify warning details
9696
assert len(warns) == 1
9797
assert "will be removed in 0.5.0" in str(warns[0].message)
98-
98+
9999
# Test functional equivalence
100100
new_result = make_affine(translation, rotation)
101101
assert np.allclose(old_result, new_result)
@@ -111,11 +111,11 @@ def test_shift_functions_equivalent(self):
111111
# Test deprecation warning
112112
with pytest.warns(DeprecationWarning, match="fourier_shift is deprecated as of 0.4.1") as warns:
113113
old_result = fourier_shift(signal, shift)
114-
114+
115115
# Verify warning details
116116
assert len(warns) == 1
117117
assert "will be removed in 0.5.0" in str(warns[0].message)
118-
118+
119119
# Test functional equivalence
120120
new_result = phase_shift_interpolate(signal, shift)
121121
assert np.allclose(old_result, new_result)

tests/test_utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,64 @@ def test_extract_amp_phase():
9797
assert (abs(c_t - c) < 100).all()
9898

9999

100+
def test_extract_amp_phase_2d():
101+
# Create a 2D test signal with 2 channels
102+
Fs = 10_000_000 # Sample frequency
103+
source_freq = 2.5 * 1_000 # Signal frequency
104+
amp_1 = 2.0
105+
amp_2 = 1.0
106+
phase_1 = np.pi / 8
107+
phase_2 = np.pi / 4
108+
# Create a 2D test signal with 2 channels
109+
sig_1 = amp_1 * np.sin(source_freq * 2 * np.pi * np.arange(Fs) / Fs + phase_1)
110+
sig_2 = amp_2 * np.sin(source_freq * 2 * np.pi * np.arange(Fs) / Fs + phase_2)
111+
test_signal = np.vstack([sig_1, sig_2])
112+
# plt.plot(test_signal[0])
113+
# plt.show()
114+
amp, phase, f = extract_amp_phase(test_signal, Fs, source_freq, dim=1)
115+
116+
assert np.allclose(amp, np.array([amp_1, amp_2]))
117+
# Phase is not used in any k-wave-python examples
118+
# assert np.allclose(phase, np.array([-phase_1, -phase_2]))
119+
assert np.allclose(f, source_freq)
120+
121+
122+
def test_extract_amp_phase_double_freq():
123+
# Create a test signal with double the detection frequency
124+
Fs = 10_000_000 # Sample frequency
125+
source_freq = 2.5 * 1_000 # Source frequency
126+
detection_freq = 5 * 1_000 # Double the frequency (2 * source_freq)
127+
128+
amp = 2.0
129+
phase = np.pi / 6
130+
131+
# Create test signal at source_freq
132+
t = np.arange(Fs) / Fs
133+
test_signal = amp * np.sin(source_freq * 2 * np.pi * t + phase)
134+
135+
# Extract amplitude and phase at double the frequency
136+
# Explicitly set dim=0 for 1D array
137+
detected_amp, detected_phase, detected_f = extract_amp_phase(test_signal, Fs, detection_freq, dim=0)
138+
139+
# The amplitude should be very close to zero since we're detecting at a different frequency
140+
# than what's present in the signal
141+
assert np.isclose(detected_amp, 0, atol=1e-3)
142+
assert np.isclose(detected_f, detection_freq)
143+
144+
# Now create a signal with the detection frequency
145+
test_signal_at_detection = amp * np.sin(detection_freq * 2 * np.pi * t + phase)
146+
147+
# Extract amplitude and phase at the correct frequency
148+
# Explicitly set dim=0 for 1D array
149+
detected_amp2, detected_phase2, detected_f2 = extract_amp_phase(test_signal_at_detection, Fs, detection_freq, dim=0)
150+
151+
# Now the amplitude should match
152+
assert np.isclose(detected_amp2, amp, rtol=0.01)
153+
# Phase is not used in any k-wave-python examples, but we can verify it's consistent
154+
# assert np.isclose(detected_phase2, -phase, rtol=0.01)
155+
assert np.isclose(detected_f2, detection_freq)
156+
157+
100158
def test_apply_filter_lowpass():
101159
test_signal = tone_burst(sample_freq=10_000_000, signal_freq=2.5 * 1_000_000, num_cycles=2, envelope="Gaussian")
102160
filtered_signal = apply_filter(test_signal, Fs=1e7, cutoff_f=1e7, filter_type="LowPass")

0 commit comments

Comments
 (0)