import logging
from typing import List, Optional, Tuple, Union
import numpy as np
import scipy
from scipy.fftpack import fft, fftshift, ifft, ifftshift
from scipy.signal import lfilter
from kwave.utils.conversion import create_index_at_dim
from .checks import is_number
from .data import scale_SI
from .math import find_closest, gaussian, next_pow2, sinc
from .matrix import num_dim, num_dim2
from .signals import get_win
[docs]
def single_sided_correction(func_fft: np.ndarray, fft_len: int, dim: int) -> np.ndarray:
"""Correct the single-sided magnitude by multiplying the symmetric points by 2.
The DC and Nyquist components are unique and are not multiplied by 2.
The Nyquist component only exists for even numbered FFT lengths.
Args:
func_fft: The FFT of the function to be corrected.
fft_len: The length of the FFT.
dim: The dimension along which to apply the correction.
Returns:
None, modifies the input array in place to have the corrected FFT of the function.
"""
# Create a slice object for each dimension
slices = [slice(None)] * func_fft.ndim
if fft_len % 2: # odd FFT length
# Set slice for the specified dimension to select all elements except the first
slices[dim] = slice(1, None)
else: # even FFT length
# Set slice for the specified dimension to select all elements except first and last
slices[dim] = slice(1, -1)
# Apply the slicing and multiply by 2
func_fft[tuple(slices)] *= 2
[docs]
def spect(
func: np.ndarray,
fs: float,
dim: Optional[Union[int, str]] = "auto",
fft_len: Optional[int] = 0,
power_two: Optional[bool] = False,
unwrap_phase: Optional[bool] = False,
window: Optional[str] = "Rectangular",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Calculates the spectrum of a signal.
Args:
func: The signal to analyse.
fs: The sampling frequency in Hz.
dim: The dimension over which the spectrum is calculated. Defaults to 'auto'.
fft_len: The length of the FFT. If the set length is smaller than the signal length, the default value is used
instead (default = signal length).
power_two: Whether the FFT length is forced to be the next highest power of 2 (default = False).
unwrap_phase: Whether to unwrap the phase spectrum (default = False).
window: (str) The window type used to filter the signal before the FFT is taken (default = 'Rectangular'). Any valid
input types for get_win may be used.
Returns:
f: Frequency array
func_as: Single-sided amplitude spectrum
func_ps: Single-sided phase spectrum
Raises:
ValueError: If the input signal is scalar or has more than 4 dimensions.
"""
# check the size of the input
sz = func.shape
# check input isn't scalar
if np.size(func) == 1:
raise ValueError("Input signal cannot be scalar.")
# check input doesn't have more than 4 dimensions
if len(sz) > 4:
raise ValueError("Input signal must have 1, 2, 3, or 4 dimensions.")
# automatically set dimension to first non - singleton dimension
if dim == "auto":
dim = np.argmax(np.array(sz) > 1)
if sz[dim] <= 1:
raise ValueError("All dimensions are singleton; unable to determine valid dimension.")
# assign the number of points being analysed
func_length = sz[dim]
# set the length of the FFT
if fft_len <= 0 or fft_len < func_length:
if power_two:
# find an appropriate FFT length of the form 2 ^ N that is equal to or
# larger than the length of the input signal
fft_len = 2 ** (next_pow2(func_length))
else:
# set the FFT length to the function length
fft_len = func_length
# window the signal, reshaping the window to be in the correct direction
win, coherent_gain = get_win(func_length, type_=window, symmetric=False)
win_shape = [1] * len(sz)
win_shape[dim] = func_length
win = np.reshape(win, tuple(win_shape))
func = win * func
# compute the fft using the defined FFT length, if fft_len >
# func_length, the input signal is padded with zeros
func_fft = np.fft.fft(func, n=fft_len, axis=dim)
# correct for the magnitude scaling of the FFT and the coherent gain of the
# window(note that the correction is equal to func_length NOT fft_len)
epsilon = 1e-10 # Small value to prevent division by zero
func_fft = func_fft / (func_length * coherent_gain + epsilon)
# reduce to a single sided spectrum where the number of unique points for
# even numbered FFT lengths is given by N / 2 + 1, and for odd(N + 1) / 2
num_unique_pts = int(np.ceil((fft_len + 1) / 2))
slicing = [slice(None)] * len(sz)
slicing[dim] = slice(0, num_unique_pts)
func_fft = func_fft[tuple(slicing)]
single_sided_correction(func_fft, fft_len, dim)
# create the frequency axis variable
f = np.arange(0, num_unique_pts) * fs / fft_len
# calculate the amplitude spectrum
func_as = np.abs(func_fft)
# calculate the phase spectrum
func_ps = np.angle(func_fft)
# unwrap the phase spectrum if required
if unwrap_phase:
func_ps = np.unwrap(func_ps, axis=dim)
return f, func_as, func_ps
[docs]
def fwhm(f, x):
"""
fwhm calculates the Full Width at Half Maximum (FWHM) of a positive
1D input function f(x) with spacing given by x.
Args:
f: f(x)
x: x
Returns:
FWHM of f(x) along with the position of the leading and trailing edges as a tuple
"""
# ensure f is numpy array
f = np.array(f)
if len(f.squeeze().shape) != 1:
raise ValueError("Input function must be 1-dimensional.")
def lin_interp(x, y, i, half):
return x[i] + (x[i + 1] - x[i]) * ((half - y[i]) / (y[i + 1] - y[i]))
def half_max_x(x, y):
half = max(y) / 2.0
signs = np.sign(np.add(y, -half))
zero_crossings = signs[0:-2] != signs[1:-1]
zero_crossings_i = np.where(zero_crossings)[0]
return [lin_interp(x, y, zero_crossings_i[0], half), lin_interp(x, y, zero_crossings_i[1], half)]
hmx = half_max_x(x, f)
fwhm_val = hmx[1] - hmx[0]
return fwhm_val, tuple(hmx)
[docs]
def gaussian_filter(
signal: Union[np.ndarray, List[float]], fs: float, frequency: float, bandwidth: float
) -> Union[np.ndarray, List[float]]:
"""
Applies a frequency domain Gaussian filter with the
specified center frequency and percentage bandwidth to the input
signal. If the input signal is given as a matrix, the filter is
applied to each matrix row.
Args:
signal: Signal to filter [channel, samples]
fs: Sampling frequency [Hz]
frequency: Center frequency of filter [Hz]
bandwidth: Bandwidth of filter in percentage
Returns:
The filtered signal
"""
N = signal.shape[-1]
if N % 2 == 0:
f = np.arange(-N / 2, N / 2) * fs / N
else:
f = np.arange(-(N - 1) / 2, (N - 1) / 2 + 1) * fs / N
mean = frequency
variance = (bandwidth / 100 * frequency / (2 * np.sqrt(2 * np.log(2)))) ** 2
magnitude = 1
# create double-sided Gaussain filter
gfilter = np.fmax(gaussian(f, magnitude, mean, variance), gaussian(f, magnitude, -mean, variance))
# add dimensions to filter to be broadcastable to signal shape
if len(signal.shape) == 2:
gfilter = gfilter[np.newaxis, :]
# apply filter
signal = np.real(ifft(ifftshift(gfilter * fftshift(fft(signal)))))
return signal
[docs]
def filter_time_series(
kgrid: "kWaveGrid",
medium: "kWaveMedium",
signal: np.ndarray,
ppw: Optional[int] = 3,
rppw: Optional[int] = 0,
stop_band_atten: Optional[int] = 60,
transition_width: Optional[float] = 0.1,
zerophase: Optional[bool] = False,
plot_spectrums: Optional[bool] = False,
plot_signals: Optional[bool] = False,
) -> np.ndarray:
"""
Filters a time-domain signal using the Kaiser windowing method.
The filter is designed to attenuate high-frequency noise in the signal while preserving
the signal's important features. The filter design parameters can be adjusted to trade off
between the amount of noise reduction and the amount of signal distortion.
Args:
kgrid: The kWaveGrid grid.
medium: The kWavemedium.
signal: The time-domain signal to filter.
ppw: The minimum number of points per wavelength in the signal. This determines the
minimum frequency that will be passed through the filter. Higher values of ppw
result in a lower cut-off frequency and more noise reduction, but may also result
in more signal distortion. Defaults to 3.
rppw: The number of points per wavelength in the smoothing ramp applied to the beginning
of the signal. This can be used to reduce ringing artifacts caused by the sudden
transition from the filtered signal to the unfiltered signal. Defaults to 0.
stop_band_atten: The stop-band attenuation in dB. This determines the steepness of the
filter's transition from the pass-band to the stop-band. Higher values result in a
steeper transition and more noise reduction, but may also result in more signal
distortion. Defaults to 60.
transition_width: The transition width as a proportion of the sampling frequency. This
determines the width of the transition region between the pass-band and the stop-band.
Smaller values result in a narrower transition and more noise reduction, but may also
result in more signal distortion. Defaults to 0.1.
zerophase: Whether to implement the filter as a zero-phase filter. Zero-phase filtering
can be used to preserve the phase information in the signal, which can be important
for some applications. However, it may also result in more signal distortion.
Defaults to False.
plot_spectrums: Whether to plot the spectrums of the input and filtered signals.
Defaults to False.
plot_signals: Whether to plot the input and filtered signals. Defaults to False.
Raises:
ValueError: Checks correctness of passed arguments.
NotImplementedError: Cannot currently plot anything.
Returns:
The filtered signal.
"""
# check the input is a row vector
if num_dim2(signal) == 1:
m, n = signal.shape
if n == 1:
signal = signal.T
rotate_signal = True
else:
rotate_signal = False
else:
raise TypeError("Input signal must be a vector.")
# update the command line status
logging.log(logging.INFO, "Filtering input signal...")
# extract the time step
assert not isinstance(kgrid.t_array, str) or kgrid.t_array != "auto", "kgrid.t_array must be explicitly defined."
# compute the sampling frequency
fs = 1 / kgrid.dt
# extract the minimum sound speed
if medium.sound_speed is not None:
# for the fluid code, use medium.sound_speed
c0 = medium.sound_speed.min()
elif all(medium.is_defined("sound_speed_compression", "sound_speed_shear")): # pragma: no cover
# for the elastic code, combine the shear and compression sound speeds and remove zeros values
ss = np.hstack([medium.sound_speed_compression, medium.sound_speed_shear])
ss[ss == 0] = np.nan
c0 = np.nanmin(ss)
# cleanup unused variables
del ss
else:
raise ValueError(
"The input fields medium.sound_speed or medium.sound_speed_compression and medium.sound_speed_shear must " "be defined."
)
# extract the maximum supported frequency (two points per wavelength)
f_max = kgrid.k_max_all * c0 / (2 * np.pi)
# calculate the filter cut-off frequency
filter_cutoff_f = 2 * f_max / ppw
# calculate the wavelength of the filter cut-off frequency as a number of time steps
filter_wavelength = (2 * np.pi / filter_cutoff_f) / kgrid.dt
# filter the signal if required
if ppw != 0:
filtered_signal = apply_filter(
signal,
fs,
float(filter_cutoff_f),
"LowPass",
zero_phase=zerophase,
stop_band_atten=float(stop_band_atten),
transition_width=transition_width,
)
# add a start-up ramp if required
if rppw != 0:
# calculate the length of the ramp in time steps
ramp_length = round(rppw * filter_wavelength / (2 * ppw))
# create the ramp
ramp = (-np.cos(np.arange(0, ramp_length - 1 + 1) * np.pi / ramp_length) + 1) / 2
# apply the ramp
filtered_signal[1:ramp_length] = filtered_signal[1:ramp_length] * ramp
# restore the original vector orientation if modified
if rotate_signal:
filtered_signal = filtered_signal.T
# update the command line status
logging.log(logging.INFO, f" maximum frequency supported by kgrid: {scale_SI(f_max)}Hz (2 PPW)")
if ppw != 0:
logging.log(logging.INFO, f" filter cutoff frequency: {scale_SI(filter_cutoff_f)}Hz ({ppw} PPW)")
if rppw != 0:
logging.log(
logging.INFO, f" ramp frequency: {scale_SI(2 * np.pi / (2 * ramp_length * kgrid.dt))}Hz (ramp_points_per_wavelength PPW)"
)
logging.log(logging.INFO, " computation complete.")
# plot signals if required
if plot_signals or plot_spectrums:
raise NotImplementedError
return filtered_signal
[docs]
def apply_filter(
signal: np.ndarray,
fs: float,
cutoff_f: float,
filter_type: str,
zero_phase: Optional[bool] = False,
transition_width: Optional[float] = 0.1,
stop_band_atten: Optional[int] = 60,
) -> np.ndarray:
"""
Filters an input signal using a FIR filter with Kaiser window coefficients based on the specified cut-off frequency and filter type.
Both causal and zero phase filters can be applied.
Args:
signal: The input signal.
fs: The sampling frequency of the signal.
cutoff_f: The cut-off frequency of the filter.
filter_type: The type of filter to apply, either 'HighPass', 'LowPass' or 'BandPass'.
zero_phase: Whether to apply a zero-phase filter. Defaults to False.
transition_width: The transition width of the filter, as a proportion of the sampling frequency. Defaults to 0.1.
stop_band_atten: The stop-band attenuation of the filter in dB. Defaults to 60.
Returns:
The filtered signal.
"""
# for a bandpass filter, use applyFilter recursively
if filter_type == "BandPass":
assert isinstance(cutoff_f, list), "List of two frequencies required as for filter type 'BandPass'"
assert len(cutoff_f) == 2, "List of two frequencies required as for filter type 'BandPass'"
# apply the low pass filter
func_filt_lp = apply_filter(
signal, fs, cutoff_f[1], "LowPass", stop_band_atten=stop_band_atten, transition_width=transition_width, zero_phase=zero_phase
)
# apply the high pass filter
filtered_signal = apply_filter(
func_filt_lp,
fs,
cutoff_f[0],
"HighPass",
stop_band_atten=stop_band_atten,
transition_width=transition_width,
zero_phase=zero_phase,
)
else:
# check filter type
if filter_type == "LowPass":
high_pass = False
elif filter_type == "HighPass":
high_pass = True
cutoff_f = fs / 2 - cutoff_f
else:
raise ValueError(f'Unknown filter type {filter_type}. Options are "LowPass, HighPass, BandPass"')
# make sure input is the correct way around
m, n = signal.shape
if m > n:
signal = signal.T
# correct the stopband attenuation if a zero phase filter is being used
if zero_phase:
stop_band_atten = stop_band_atten / 2
# decide the filter order
N = np.ceil((stop_band_atten - 7.95) / (2.285 * (transition_width * np.pi)))
N = int(N)
# construct impulse response of ideal bandpass filter h(n), a sinc function
fc = cutoff_f / fs # normalised cut-off
n = np.arange(-N / 2, N / 2)
h = 2 * fc * sinc(2 * np.pi * fc * n)
# if no window is given, use a Kaiser window
# TODO: there is no window argument
if "w" not in locals():
# compute Kaiser window parameter beta
if stop_band_atten > 50:
beta = 0.1102 * (stop_band_atten - 8.7)
elif stop_band_atten >= 21:
beta = 0.5842 * (stop_band_atten - 21) ** 0.4 + 0.07886 * (stop_band_atten - 21)
else:
beta = 0
# construct the Kaiser smoothing window w(n)
m = np.arange(0, N)
w = np.real(scipy.special.iv(0, np.pi * beta * np.sqrt(1 - (2 * m / N - 1) ** 2))) / np.real(scipy.special.iv(0, np.pi * beta))
# window the ideal impulse response with Kaiser window to obtain the FIR filter coefficients hw(n)
hw = w * h
# modify to make a high_pass filter
if high_pass:
hw = (-1 * np.ones((1, len(hw))) ** (np.arange(1, len(hw) + 1))) * hw
# add some zeros to allow the reverse (zero phase) filtering room to work
L = signal.size # length of original input signal
filtered_signal = np.hstack([np.zeros((1, N)), signal]).squeeze()
# apply the filter
filtered_signal = lfilter(hw.squeeze(), 1, filtered_signal)
if zero_phase:
filtered_signal = np.fliplr(lfilter(hw.squeeze(), 1, filtered_signal[np.arange(L + N, 1, -1)]))
# remove the part of the signal corresponding to the added zeros
filtered_signal = filtered_signal[N:]
return filtered_signal[np.newaxis]
[docs]
def smooth(a: np.ndarray, restore_max: Optional[bool] = False, window_type: Optional[str] = "Blackman") -> np.ndarray:
"""
Smooths a matrix.
Args:
a: The spatial distribution to smooth.
restore_max: Boolean controlling whether the maximum value is restored after smoothing. Defaults to False.
window_type: Shape of the smoothing window. Any valid inputs to get_win are supported. Defaults to 'Blackman'.
Returns:
a_sm: The smoothed matrix.
"""
DEF_USE_ROTATION = True
if a.dtype == bool:
a = a.astype(int)
assert is_number(a) and np.all(~np.isinf(a))
assert isinstance(restore_max, bool)
assert isinstance(window_type, str)
# get the grid size, squeezing singleton dimensions to match MATLAB's
# grid_size(grid_size == 1) = [] convention
original_shape = a.shape
grid_size = a.shape
if num_dim2(a) != len(grid_size):
a = np.squeeze(a)
grid_size = a.shape # tuple, possibly (N,) for 1D
# use a symmetric filter for odd grid sizes, and a non-symmetric filter for
# even grid sizes to ensure the DC component of the window has a value of
# unity
gs_arr = np.atleast_1d(np.array(grid_size))
window_symmetry = (gs_arr % 2).astype(bool)
if window_symmetry.size == 1:
window_symmetry = bool(window_symmetry)
# get_win expects int for 1D, tuple for 2D/3D
if len(grid_size) == 1:
grid_size = grid_size[0]
# get the window, taking the absolute value to discard machine precision
# negative values
from .signals import get_win
use_rotation = DEF_USE_ROTATION and np.ndim(a) > 1
win, _ = get_win(grid_size, type_=window_type, rotation=use_rotation, symmetric=window_symmetry)
win = np.abs(win)
# ensure window matches squeezed array dimensions
if win.ndim != np.ndim(a):
win = win.reshape(a.shape)
# apply the filter
a_sm = np.real(np.fft.ifftn(np.fft.fftn(a) * np.fft.ifftshift(win)))
# restore original shape if input had singleton dimensions
if a_sm.shape != original_shape:
a_sm = a_sm.reshape(original_shape)
# restore magnitude if required
if restore_max:
a_sm = (np.abs(a).max() / np.abs(a_sm).max()) * a_sm
return a_sm