Skip to content

Commit 375fde8

Browse files
Refactor sharpness filters and improve single_sided_correction (#573)
1 parent 18a566d commit 375fde8

9 files changed

Lines changed: 407 additions & 153 deletions

File tree

kwave/utils/filters.py

Lines changed: 33 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import numpy as np
55
import scipy
66
from scipy.fftpack import fft, fftshift, ifft, ifftshift
7-
from scipy.signal import convolve, lfilter
7+
from scipy.signal import lfilter
88

99
from kwave.utils.conversion import create_index_at_dim
1010

1111
from .checks import is_number
1212
from .data import scale_SI
13-
from .math import find_closest, gaussian, next_pow2, norm_var, sinc
13+
from .math import find_closest, gaussian, next_pow2, sinc
1414
from .matrix import num_dim, num_dim2
1515
from .signals import get_win
1616

@@ -27,30 +27,25 @@ def single_sided_correction(func_fft: np.ndarray, fft_len: int, dim: int) -> np.
2727
dim: The dimension along which to apply the correction.
2828
2929
Returns:
30-
The corrected FFT of the function.
30+
None, modifies the input array in place to have the corrected FFT of the function.
3131
"""
32-
# Determine the slice to use based on FFT length
33-
if fft_len % 2:
34-
# odd FFT length - multiply all elements except the first one by 2
35-
dim_slice = slice(1, None)
36-
else:
37-
# even FFT length - multiply all elements except the first and last ones by 2
38-
dim_slice = slice(1, -1)
39-
40-
# Create a slice tuple with the appropriate slice at the specified dimension
41-
idx_all = [slice(None)] * func_fft.ndim
42-
idx_all[dim] = dim_slice
43-
idx_tuple = tuple(idx_all)
32+
# Create a slice object for each dimension
33+
slices = [slice(None)] * func_fft.ndim
4434

45-
# Apply the correction
46-
func_fft[idx_tuple] = func_fft[idx_tuple] * 2
35+
if fft_len % 2: # odd FFT length
36+
# Set slice for the specified dimension to select all elements except the first
37+
slices[dim] = slice(1, None)
38+
else: # even FFT length
39+
# Set slice for the specified dimension to select all elements except first and last
40+
slices[dim] = slice(1, -1)
4741

48-
return func_fft
42+
# Apply the slicing and multiply by 2
43+
func_fft[tuple(slices)] *= 2
4944

5045

5146
def spect(
5247
func: np.ndarray,
53-
Fs: float,
48+
fs: float,
5449
dim: Optional[Union[int, str]] = "auto",
5550
fft_len: Optional[int] = 0,
5651
power_two: Optional[bool] = False,
@@ -62,7 +57,7 @@ def spect(
6257
6358
Args:
6459
func: The signal to analyse.
65-
Fs: The sampling frequency in Hz.
60+
fs: The sampling frequency in Hz.
6661
dim: The dimension over which the spectrum is calculated. Defaults to 'auto'.
6762
fft_len: The length of the FFT. If the set length is smaller than the signal length, the default value is used
6863
instead (default = signal length).
@@ -133,10 +128,10 @@ def spect(
133128
slicing[dim] = slice(0, num_unique_pts)
134129
func_fft = func_fft[tuple(slicing)]
135130

136-
func_fft = single_sided_correction(func_fft, fft_len, dim)
131+
single_sided_correction(func_fft, fft_len, dim)
137132

138133
# create the frequency axis variable
139-
f = np.arange(0, num_unique_pts) * Fs / fft_len
134+
f = np.arange(0, num_unique_pts) * fs / fft_len
140135

141136
# calculate the amplitude spectrum
142137
func_as = np.abs(func_fft)
@@ -152,7 +147,7 @@ def spect(
152147

153148

154149
def extract_amp_phase(
155-
data: np.ndarray, Fs: float, source_freq: float, dim: Tuple[str, int] = "auto", fft_padding: int = 3, window: str = "Hanning"
150+
data: np.ndarray, fs: float, source_freq: float, dim: Tuple[str, int] = "auto", fft_padding: int = 3, window: str = "Hanning"
156151
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
157152
"""
158153
Extract the amplitude and phase information at a specified frequency from a vector or matrix of time series data.
@@ -163,7 +158,7 @@ def extract_amp_phase(
163158
164159
Args:
165160
data: Matrix of time signals [s]
166-
Fs: Sampling frequency [Hz]
161+
fs: Sampling frequency [Hz]
167162
source_freq: Frequency at which the amplitude and phase should be extracted [Hz]
168163
dim: The time dimension of the input data. If 'auto', the highest non-singleton dimension is used.
169164
fft_padding: The amount of zero padding to apply to the FFT.
@@ -191,7 +186,8 @@ def extract_amp_phase(
191186
data = win * data
192187

193188
# compute amplitude and phase spectra
194-
f, func_as, func_ps = spect(data, Fs, fft_len=fft_padding * data.shape[dim], dim=dim)
189+
f, func_as, func_ps = spect(data, fs, fft_len=fft_padding * data.shape[dim], dim=dim)
190+
195191
# correct for coherent gain
196192
func_as = func_as / coherent_gain
197193

@@ -211,99 +207,6 @@ def extract_amp_phase(
211207
return amp.squeeze(), phase.squeeze(), f[f_index]
212208

213209

214-
def brenner_sharpness(im):
215-
num_dim = im.ndim
216-
if num_dim == 2:
217-
# compute metric
218-
bren_x = (im[:-2, :] - im[2:, :]) ** 2
219-
bren_y = (im[:, :-2] - im[:, 2:]) ** 2
220-
s = np.sum(bren_x) + np.sum(bren_y)
221-
elif num_dim == 3:
222-
# compute metric
223-
bren_x = (im[:-2, :, :] - im[2:, :, :]) ** 2
224-
bren_y = (im[:, :-2, :] - im[:, 2:, :]) ** 2
225-
bren_z = (im[:, :, :-2] - im[:, :, 2:]) ** 2
226-
s = np.sum(bren_x) + np.sum(bren_y) + np.sum(bren_z)
227-
return s
228-
229-
230-
def tenenbaum_sharpness(im):
231-
num_dim = im.ndim
232-
if num_dim == 2:
233-
# define the 2D sobel gradient operator
234-
sobel = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
235-
236-
# compute metric
237-
s = (convolve(sobel, im) ** 2 + convolve(sobel.T, im) ** 2).sum()
238-
elif num_dim == 3:
239-
# define the 3D sobel gradient operator
240-
sobel3D = np.zeros((3, 3, 3))
241-
sobel3D[:, :, 0] = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]])
242-
sobel3D[:, :, 2] = -sobel3D[:, :, 0]
243-
244-
# compute metric
245-
s = (
246-
convolve(im, sobel3D) ** 2
247-
+ convolve(im, np.transpose(sobel3D, (2, 0, 1))) ** 2
248-
+ convolve(im, np.transpose(sobel3D, (1, 2, 0))) ** 2
249-
).sum()
250-
return s
251-
252-
# TODO: get this passing the tests
253-
# NOTE: Walter thinks this is the proper way to do this, but it doesn't match the MATLAB version
254-
# num_dim = im.ndim
255-
# if num_dim == 2:
256-
# # compute metric
257-
# sx = sobel(im, axis=0, mode='constant')
258-
# sy = sobel(im, axis=1, mode='constant')
259-
# s = (sx ** 2) + (sy ** 2)
260-
# s = np.sum(s)
261-
#
262-
# elif num_dim == 3:
263-
# # compute metric
264-
# sx = sobel(im, axis=0, mode='constant')
265-
# sy = sobel(im, axis=1, mode='constant')
266-
# sz = sobel(im, axis=2, mode='constant')
267-
# s = (sx ** 2) + (sy ** 2) + (sz ** 2)
268-
# s = np.sum(s)
269-
# else:
270-
# raise ValueError("Invalid number of dimensions in im")
271-
272-
273-
def sharpness(im: np.ndarray, mode: Optional[str] = "Brenner") -> float:
274-
"""
275-
Returns a scalar metric related to the sharpness of a 2D or 3D image matrix.
276-
277-
Args:
278-
im: The image matrix.
279-
metric: The metric to use. Defaults to "Brenner".
280-
281-
Returns:
282-
A scalar sharpness metric.
283-
284-
Raises:
285-
AssertionError: If `im` is not a NumPy array.
286-
287-
References:
288-
B. E. Treeby, T. K. Varslot, E. Z. Zhang, J. G. Laufer, and P. C. Beard, "Automatic sound speed selection in
289-
photoacoustic image reconstruction using an autofocus approach," J. Biomed. Opt., vol. 16, no. 9, p. 090501, 2011.
290-
291-
"""
292-
293-
assert isinstance(im, np.ndarray), "Argument im must be of type numpy array"
294-
295-
if mode == "Brenner":
296-
metric = brenner_sharpness(im)
297-
elif mode == "Tenenbaum":
298-
metric = tenenbaum_sharpness(im)
299-
elif mode == "NormVariance":
300-
metric = norm_var(im)
301-
else:
302-
raise ValueError("Unrecognized sharpness metric passed. Valid values are ['Brenner', 'Tanenbaum', 'NormVariance']")
303-
304-
return metric
305-
306-
307210
def fwhm(f, x):
308211
"""
309212
fwhm calculates the Full Width at Half Maximum (FWHM) of a positive
@@ -341,7 +244,7 @@ def half_max_x(x, y):
341244

342245

343246
def gaussian_filter(
344-
signal: Union[np.ndarray, List[float]], Fs: float, frequency: float, bandwidth: float
247+
signal: Union[np.ndarray, List[float]], fs: float, frequency: float, bandwidth: float
345248
) -> Union[np.ndarray, List[float]]:
346249
"""
347250
Applies a frequency domain Gaussian filter with the
@@ -351,7 +254,7 @@ def gaussian_filter(
351254
352255
Args:
353256
signal: Signal to filter [channel, samples]
354-
Fs: Sampling frequency [Hz]
257+
fs: Sampling frequency [Hz]
355258
frequency: Center frequency of filter [Hz]
356259
bandwidth: Bandwidth of filter in percentage
357260
@@ -362,9 +265,9 @@ def gaussian_filter(
362265

363266
N = signal.shape[-1]
364267
if N % 2 == 0:
365-
f = np.arange(-N / 2, N / 2) * Fs / N
268+
f = np.arange(-N / 2, N / 2) * fs / N
366269
else:
367-
f = np.arange(-(N - 1) / 2, (N - 1) / 2 + 1) * Fs / N
270+
f = np.arange(-(N - 1) / 2, (N - 1) / 2 + 1) * fs / N
368271

369272
mean = frequency
370273
variance = (bandwidth / 100 * frequency / (2 * np.sqrt(2 * np.log(2)))) ** 2
@@ -456,7 +359,7 @@ def filter_time_series(
456359
assert not isinstance(kgrid.t_array, str) or kgrid.t_array != "auto", "kgrid.t_array must be explicitly defined."
457360

458361
# compute the sampling frequency
459-
Fs = 1 / kgrid.dt
362+
fs = 1 / kgrid.dt
460363

461364
# extract the minimum sound speed
462365
if medium.sound_speed is not None:
@@ -490,7 +393,7 @@ def filter_time_series(
490393
if ppw != 0:
491394
filtered_signal = apply_filter(
492395
signal,
493-
Fs,
396+
fs,
494397
float(filter_cutoff_f),
495398
"LowPass",
496399
zero_phase=zerophase,
@@ -532,7 +435,7 @@ def filter_time_series(
532435

533436
def apply_filter(
534437
signal: np.ndarray,
535-
Fs: float,
438+
fs: float,
536439
cutoff_f: float,
537440
filter_type: str,
538441
zero_phase: Optional[bool] = False,
@@ -545,7 +448,7 @@ def apply_filter(
545448
546449
Args:
547450
signal: The input signal.
548-
Fs: The sampling frequency of the signal.
451+
fs: The sampling frequency of the signal.
549452
cutoff_f: The cut-off frequency of the filter.
550453
filter_type: The type of filter to apply, either 'HighPass', 'LowPass' or 'BandPass'.
551454
zero_phase: Whether to apply a zero-phase filter. Defaults to False.
@@ -564,13 +467,13 @@ def apply_filter(
564467

565468
# apply the low pass filter
566469
func_filt_lp = apply_filter(
567-
signal, Fs, cutoff_f[1], "LowPass", stop_band_atten=stop_band_atten, transition_width=transition_width, zero_phase=zero_phase
470+
signal, fs, cutoff_f[1], "LowPass", stop_band_atten=stop_band_atten, transition_width=transition_width, zero_phase=zero_phase
568471
)
569472

570473
# apply the high pass filter
571474
filtered_signal = apply_filter(
572475
func_filt_lp,
573-
Fs,
476+
fs,
574477
cutoff_f[0],
575478
"HighPass",
576479
stop_band_atten=stop_band_atten,
@@ -584,7 +487,7 @@ def apply_filter(
584487
high_pass = False
585488
elif filter_type == "HighPass":
586489
high_pass = True
587-
cutoff_f = Fs / 2 - cutoff_f
490+
cutoff_f = fs / 2 - cutoff_f
588491
else:
589492
raise ValueError(f'Unknown filter type {filter_type}. Options are "LowPass, HighPass, BandPass"')
590493

@@ -602,7 +505,7 @@ def apply_filter(
602505
N = int(N)
603506

604507
# construct impulse response of ideal bandpass filter h(n), a sinc function
605-
fc = cutoff_f / Fs # normalised cut-off
508+
fc = cutoff_f / fs # normalised cut-off
606509
n = np.arange(-N / 2, N / 2)
607510
h = 2 * fc * sinc(2 * np.pi * fc * n)
608511

kwave/utils/math.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -357,22 +357,6 @@ def sinc(x: Union[int, float, np.ndarray]) -> Union[int, float, np.ndarray]:
357357
return np.sinc(x / np.pi)
358358

359359

360-
def norm_var(im: np.ndarray) -> float:
361-
"""
362-
Calculates the normalized variance of an array of values.
363-
364-
Args:
365-
im: The input array.
366-
367-
Returns:
368-
The normalized variance of im.
369-
370-
"""
371-
mu = np.mean(im)
372-
s = np.sum((im - mu) ** 2) / mu
373-
return s
374-
375-
376360
def gaussian(
377361
x: Union[int, float, np.ndarray],
378362
magnitude: Optional[Union[int, float]] = None,

0 commit comments

Comments
 (0)