66from scipy .fftpack import fft , fftshift , ifft , ifftshift
77from scipy .signal import convolve , lfilter
88
9- from .. kgrid import kWaveGrid
10- from .. kmedium import kWaveMedium
9+ from kwave . utils . conversion import create_index_at_dim
10+
1111from .checks import is_number
1212from .data import scale_SI
1313from .math import find_closest , gaussian , next_pow2 , norm_var , sinc
@@ -24,31 +24,26 @@ def single_sided_correction(func_fft: np.ndarray, fft_len: int, dim: int) -> np.
2424 Args:
2525 func_fft: The FFT of the function to be corrected.
2626 fft_len: The length of the FFT.
27- dim: The number of dimensions of `func_fft` .
27+ dim: The dimension along which to apply the correction .
2828
2929 Returns:
3030 The corrected FFT of the function.
3131 """
32+ # Determine the slice to use based on FFT length
3233 if fft_len % 2 :
33- # odd FFT length switch dim case
34- if dim == 0 :
35- func_fft [1 :, :] = func_fft [1 :, :] * 2
36- elif dim == 1 :
37- func_fft [:, 1 :] = func_fft [:, 1 :] * 2
38- elif dim == 2 :
39- func_fft [:, :, 1 :] = func_fft [:, :, 1 :] * 2
40- elif dim == 3 :
41- func_fft [:, :, :, 1 :] = func_fft [:, :, :, 1 :] * 2
34+ # odd FFT length - multiply all elements except the first one by 2
35+ dim_slice = slice (1 , None )
4236 else :
43- # even FFT length
44- if dim == 0 :
45- func_fft [1 :- 1 ] = func_fft [1 :- 1 ] * 2
46- elif dim == 1 :
47- func_fft [:, 1 :- 1 ] = func_fft [:, 1 :- 1 ] * 2
48- elif dim == 2 :
49- func_fft [:, :, 1 :- 1 ] = func_fft [:, :, 1 :- 1 ] * 2
50- elif dim == 3 :
51- func_fft [:, :, :, 1 :- 1 ] = func_fft [:, :, :, 1 :- 1 ] * 2
37+ # even FFT length - multiply all elements except the first and last ones by 2
38+ dim_slice = slice (1 , - 1 )
39+
40+ # Create a slice tuple with the appropriate slice at the specified dimension
41+ idx_all = [slice (None )] * func_fft .ndim
42+ idx_all [dim ] = dim_slice
43+ idx_tuple = tuple (idx_all )
44+
45+ # Apply the correction
46+ func_fft [idx_tuple ] = func_fft [idx_tuple ] * 2
5247
5348 return func_fft
5449
@@ -197,7 +192,6 @@ def extract_amp_phase(
197192
198193 # compute amplitude and phase spectra
199194 f , func_as , func_ps = spect (data , Fs , fft_len = fft_padding * data .shape [dim ], dim = dim )
200-
201195 # correct for coherent gain
202196 func_as = func_as / coherent_gain
203197
@@ -209,20 +203,10 @@ def extract_amp_phase(
209203 sz [dim - 1 ] = 1
210204
211205 # extract amplitude and relative phase at freq_index
212- if dim == 0 :
213- amp = func_as [f_index ]
214- phase = func_ps [f_index ]
215- elif dim == 1 :
216- amp = func_as [:, f_index ]
217- phase = func_ps [:, f_index ]
218- elif dim == 2 :
219- amp = func_as [:, :, f_index ]
220- phase = func_ps [:, :, f_index ]
221- elif dim == 3 :
222- amp = func_as [:, :, :, f_index ]
223- phase = func_ps [:, :, :, f_index ]
224- else :
225- raise ValueError ("dim must be 0, 1, 2, or 3" )
206+ # Create a tuple of slice objects with the frequency index at the correct dimension
207+ idx = create_index_at_dim (func_as .ndim , dim , f_index )
208+ amp = func_as [idx ]
209+ phase = func_ps [idx ]
226210
227211 return amp .squeeze (), phase .squeeze (), f [f_index ]
228212
0 commit comments