from dataclasses import dataclass
import logging
import numpy as np
from kwave.kgrid import kWaveGrid
from kwave.utils.matrix import num_dim2
[docs]
@dataclass
class kSource(object):
_p0 = None
#: time varying pressure at each of the source positions given by source.p_mask
p = None
#: binary matrix specifying the positions of the time varying pressure source distribution
p_mask = None
#: optional input to control whether the input pressure is injected as a mass source or enforced
# as a dirichlet boundary condition; valid inputs are 'additive' (the default) or 'dirichlet'
p_mode = None
#: Pressure reference frequency
p_frequency_ref = None
#: time varying particle velocity in the x-direction at each of the source positions given by source.u_mask
ux = None
#: time varying particle velocity in the y-direction at each of the source positions given by source.u_mask
uy = None
#: time varying particle velocity in the z-direction at each of the source positions given by source.u_mask
uz = None
#: binary matrix specifying the positions of the time varying particle velocity distribution
u_mask = None
#: optional input to control whether the input velocity is applied as a force source or enforced as a dirichlet
# boundary condition; valid inputs are 'additive' (the default) or 'dirichlet'
u_mode = None
#: Velocity reference frequency
u_frequency_ref = None
sxx = None #: Stress source in x -> x direction
syy = None #: Stress source in y -> y direction
szz = None #: Stress source in z -> z direction
sxy = None #: Stress source in x -> y direction
sxz = None #: Stress source in x -> z direction
syz = None #: Stress source in y -> z direction
s_mask = None #: Stress source mask
s_mode = None #: Stress source mode
[docs]
def is_p0_empty(self) -> bool:
"""
Check if the `p0` field is set and not empty
"""
return self.p0 is None or len(self.p0) == 0 or (np.sum(self.p0 != 0) == 0)
@property
def p0(self):
"""
Initial pressure within the acoustic medium
"""
return self._p0
@p0.setter
def p0(self, val):
# check size and contents
if len(val) == 0 or np.sum(val != 0) == 0:
# if the initial pressure is empty, remove field
self._p0 = None
else:
self._p0 = val
[docs]
def validate(self, kgrid: kWaveGrid) -> None:
"""
Validate the object fields for correctness
Args:
kgrid: Instance of `~kwave.kgrid.kWaveGrid` class
Returns:
None
"""
if self.p0 is not None:
if self.p0.shape != kgrid.k.shape:
# throw an error if p0 is not the correct size
raise ValueError("source.p0 must be the same size as the computational grid.")
# if using the elastic code, reformulate source.p0 in terms of the
# stress source terms using the fact that source.p = [0.5 0.5] /
# (2*CFL) is the same as source.p0 = 1
# if self.elastic_code:
# raise NotImplementedError
# check for a time varying pressure source input
if self.p is not None:
# force p_mask to be given if p is given
assert self.p_mask is not None
# check mask is the correct size
# noinspection PyTypeChecker
if (num_dim2(self.p_mask) != kgrid.dim) or (self.p_mask.shape != kgrid.k.shape):
raise ValueError("source.p_mask must be the same size as the computational grid.")
# check mask is not empty
assert np.sum(self.p_mask) != 0, "source.p_mask must be a binary grid with at least one element set to 1."
# don't allow both source.p0 and source.p in the same simulation
# USERS: please contact us via http://www.k-wave.org/forum if this
# is a problem
assert self.p0 is None, "source.p0 and source.p can't be defined in the same simulation."
# check the source mode input is valid
if self.p_mode is not None:
assert self.p_mode in [
"additive",
"dirichlet",
"additive-no-correction",
], "source.p_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''."
# check if a reference frequency is defined
if self.p_frequency_ref is not None:
# check frequency is a scalar, positive number
assert np.isscalar(self.p_frequency_ref) and self.p_frequency_ref > 0
# check frequency is within range
assert self.p_frequency_ref <= kgrid.k_max_all * np.min(
self.medium.sound_speed / 2 * np.pi
), "source.p_frequency_ref is higher than the maximum frequency supported by the spatial grid."
# change source mode to no include k-space correction
self.p_mode = "additive-no-correction"
if len(self.p[0]) > kgrid.Nt:
logging.log(logging.WARN, " source.p has more time points than kgrid.Nt, remaining time points will not be used.")
# check if the mask is binary or labelled
p_unique = np.unique(self.p_mask)
# create a second indexing variable
if p_unique.size <= 2 and p_unique.sum() == 1:
# if more than one time series is given, check the number of time
# series given matches the number of source elements, or the number
# of labelled sources
if self.p.shape[0] > 1 and (len(self.p[:, 0]) != self.p_mask.sum()):
raise ValueError("The number of time series in source.p " "must match the number of source elements in source.p_mask.")
else:
# check the source labels are monotonic, and start from 1
if (sum(p_unique[1:] - p_unique[:-1]) != len(p_unique) - 1) or (not any(p_unique == 1)):
raise ValueError(
"If using a labelled source.p_mask, " "the source labels must be monotonically increasing and start from 1."
)
# make sure the correct number of input signals are given
if np.size(self.p, 1) != (np.size(p_unique) - 1):
raise ValueError(
"The number of time series in source.p " "must match the number of labelled source elements in source.p_mask."
)
# check for time varying velocity source input and set source flag
if any([(getattr(self, k) is not None) for k in ["ux", "uy", "uz", "u_mask"]]):
# force u_mask to be given
assert self.u_mask is not None
# check mask is the correct size
assert (
num_dim2(self.u_mask) == kgrid.dim and self.u_mask.shape == kgrid.k.shape
), "source.u_mask must be the same size as the computational grid."
# check mask is not empty
assert np.array(self.u_mask).sum() != 0, "source.u_mask must be a binary grid with at least one element set to 1."
# check the source mode input is valid
if self.u_mode is not None:
assert self.u_mode in [
"additive",
"dirichlet",
"additive-no-correction",
], "source.u_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''."
# check if a reference frequency is defined
if self.u_frequency_ref is not None:
# check frequency is a scalar, positive number
u_frequency_ref = self.u_frequency_ref
assert np.isscalar(u_frequency_ref) and u_frequency_ref > 0
# check frequency is within range
assert self.u_frequency_ref <= (
kgrid.k_max_all * np.min(self.medium.sound_speed) / 2 * np.pi
), "source.u_frequency_ref is higher than the maximum frequency supported by the spatial grid."
# change source mode to no include k-space correction
self.u_mode = "additive-no-correction"
if self.ux is not None:
if self.flag_ux > kgrid.Nt:
logging.log(logging.WARN, " source.ux has more time points than kgrid.Nt, " "remaining time points will not be used.")
if self.uy is not None:
if self.flag_uy > kgrid.Nt:
logging.log(logging.WARN, " source.uy has more time points than kgrid.Nt, " "remaining time points will not be used.")
if self.uz is not None:
if self.flag_uz > kgrid.Nt:
logging.log(logging.WARN, " source.uz has more time points than kgrid.Nt, " "remaining time points will not be used.")
# check if the mask is binary or labelled
u_unique = np.unique(self.u_mask)
# create a second indexing variable
if u_unique.size <= 2 and u_unique.sum() == 1:
# if more than one time series is given, check the number of time
# series given matches the number of source elements
ux_size = self.ux[:, 0].size
uy_size = self.uy[:, 0].size if (self.uy is not None) else None
uz_size = self.uz[:, 0].size if (self.uz is not None) else None
u_sum = np.sum(self.u_mask)
if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)):
if (
(self.flag_ux and (ux_size != u_sum))
and (self.flag_uy and (uy_size != u_sum))
or (self.flag_uz and (uz_size != u_sum))
):
raise ValueError(
"The number of time series in source.ux (etc) " "must match the number of source elements in source.u_mask."
)
# if more than one time series is given, check the number of time
# series given matches the number of source elements
if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)):
if (
(self.flag_ux and (ux_size != u_sum))
or (self.flag_uy and (uy_size != u_sum))
or (self.flag_uz and (uz_size != u_sum))
):
raise ValueError(
"The number of time series in source.ux (etc) " "must match the number of source elements in source.u_mask."
)
else:
raise NotImplementedError
# check the source labels are monotonic, and start from 1
# if (sum(u_unique(2:end) - u_unique(1:end-1)) != (numel(u_unique) - 1)) or (~any(u_unique == 1))
if eng.eval("(sum(u_unique(2:end) - " "u_unique(1:end-1)) ~= " "(numel(u_unique) - 1)) " "|| " "(~any(u_unique == 1))"):
raise ValueError(
"If using a labelled source.u_mask, " "the source labels must be monotonically increasing and start from 1."
)
# if more than one time series is given, check the number of time
# series given matches the number of source elements
# if (flgs.source_ux and (size(source.ux, 1) != (numel(u_unique) - 1))) or
# (flgs.source_uy and (size(source.uy, 1) != (numel(u_unique) - 1))) or
# (flgs.source_uz and (size(source.uz, 1) != (numel(u_unique) - 1)))
if eng.eval(
"(flgs.source_ux && (size(source.ux, 1) ~= (numel(u_unique) - 1))) "
"|| (flgs.source_uy && (size(source.uy, 1) ~= (numel(u_unique) - 1))) "
"|| "
"(flgs.source_uz && (size(source.uz, 1) ~= (numel(u_unique) - 1)))"
):
raise ValueError(
"The number of time series in source.ux (etc) "
"must match the number of labelled source elements in source.u_mask."
)
# check for time varying stress source input and set source flag
if any([(getattr(self, k) is not None) for k in ["sxx", "syy", "szz", "sxy", "sxz", "syz", "s_mask"]]):
# force s_mask to be given
enforce_fields(self, "s_mask")
# check mask is the correct size
# if (numDim(source.s_mask) != kgrid.dim) or (all(size(source.s_mask) != size(kgrid.k)))
if eng.eval("(numDim(source.s_mask) ~= kgrid.dim) || (all(size(source.s_mask) ~= size(kgrid.k)))"):
raise ValueError("source.s_mask must be the same size as the computational grid.")
# check mask is not empty
assert np.array(eng.getfield(source, "s_mask")) != 0, "source.s_mask must be a binary grid with at least one element set to 1."
# check the source mode input is valid
if eng.isfield(source, "s_mode"):
assert eng.getfield(source, "s_mode") in [
"additive",
"dirichlet",
], "source.s_mode must be set to ''additive'' or ''dirichlet''."
else:
eng.setfield(source, "s_mode", self.SOURCE_S_MODE_DEF)
# set source flgs to the length of the sources, this allows the
# inputs to be defined independently and be of any length
if self.sxx is not None and self_sxx > k_Nt:
logging.log(logging.WARN, " source.sxx has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.syy is not None and self_syy > k_Nt:
logging.log(logging.WARN, " source.syy has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.szz is not None and self_szz > k_Nt:
logging.log(logging.WARN, " source.szz has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.sxy is not None and self_sxy > k_Nt:
logging.log(logging.WARN, " source.sxy has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.sxz is not None and self_sxz > k_Nt:
logging.log(logging.WARN, " source.sxz has more time points than kgrid.Nt," " remaining time points will not be used.")
if self.syz is not None and self_syz > k_Nt:
logging.log(logging.WARN, " source.syz has more time points than kgrid.Nt," " remaining time points will not be used.")
# create an indexing variable corresponding to the location of all
# the source elements
raise NotImplementedError
# check if the mask is binary or labelled
"s_unique = unique(source.s_mask);"
# create a second indexing variable
if eng.eval("numel(s_unique) <= 2 && sum(s_unique) == 1"):
s_mask = eng.getfield(source, "s_mask")
s_mask_sum = np.array(s_mask).sum()
# if more than one time series is given, check the number of time
# series given matches the number of source elements
if (
(self.source_sxx and (eng.eval("length(source.sxx(:,1)) > 1))")))
or (self.source_syy and (eng.eval("length(source.syy(:,1)) > 1))")))
or (self.source_szz and (eng.eval("length(source.szz(:,1)) > 1))")))
or (self.source_sxy and (eng.eval("length(source.sxy(:,1)) > 1))")))
or (self.source_sxz and (eng.eval("length(source.sxz(:,1)) > 1))")))
or (self.source_syz and (eng.eval("length(source.syz(:,1)) > 1))")))
):
if (
(self.source_sxx and (eng.eval("length(source.sxx(:,1))") != s_mask_sum))
or (self.source_syy and (eng.eval("length(source.syy(:,1))") != s_mask_sum))
or (self.source_szz and (eng.eval("length(source.szz(:,1))") != s_mask_sum))
or (self.source_sxy and (eng.eval("length(source.sxy(:,1))") != s_mask_sum))
or (self.source_sxz and (eng.eval("length(source.sxz(:,1))") != s_mask_sum))
or (self.source_syz and (eng.eval("length(source.syz(:,1))") != s_mask_sum))
):
raise ValueError(
"The number of time series in source.sxx (etc) " "must match the number of source elements in source.s_mask."
)
else:
# check the source labels are monotonic, and start from 1
# if (sum(s_unique(2:end) - s_unique(1:end-1)) != (numel(s_unique) - 1)) or (~any(s_unique == 1))
if eng.eval("(sum(s_unique(2:end) - s_unique(1:end-1)) ~= " "(numel(s_unique) - 1)) || (~any(s_unique == 1))"):
raise ValueError(
"If using a labelled source.s_mask, " "the source labels must be monotonically increasing and start from 1."
)
numel_s_unique = eng.eval("numel(s_unique) - 1;")
# if more than one time series is given, check the number of time
# series given matches the number of source elements
if (
(self.source_sxx and (eng.eval("size(source.sxx, 1)") != numel_s_unique))
or (self.source_syy and (eng.eval("size(source.syy, 1)") != numel_s_unique))
or (self.source_szz and (eng.eval("size(source.szz, 1)") != numel_s_unique))
or (self.source_sxy and (eng.eval("size(source.sxy, 1)") != numel_s_unique))
or (self.source_sxz and (eng.eval("size(source.sxz, 1)") != numel_s_unique))
or (self.source_syz and (eng.eval("size(source.syz, 1)") != numel_s_unique))
):
raise ValueError(
"The number of time series in source.sxx (etc) "
"must match the number of labelled source elements in source.u_mask."
)
@property
def flag_ux(self):
"""
Get the length of the sources in X-direction, this allows the
inputs to be defined independently and be of any length
Returns:
Length of the sources
"""
return len(self.ux[0]) if self.ux is not None else 0
@property
def flag_uy(self):
"""
Get the length of the sources in X-direction, this allows the
inputs to be defined independently and be of any length
Returns:
Length of the sources
"""
return len(self.uy[0]) if self.uy is not None else 0
@property
def flag_uz(self):
"""
Get the length of the sources in X-direction, this allows the
inputs to be defined independently and be of any length
Returns:
Length of the sources
"""
return len(self.uz[0]) if self.uz is not None else 0