Source code for kwave.options.simulation_execution_options

import os
import warnings
from logging import getLogger
from pathlib import Path
from typing import Optional, Union

from kwave import BINARY_DIR, PLATFORM
from kwave.ksensor import kSensor

logger = getLogger(__name__)


[docs] class SimulationExecutionOptions: """ A class to manage and configure the execution options for k-Wave simulations. Attributes: backend: Execution backend to use. Options: - "OMP": C++ OpenMP binary (default for CPU) - "CUDA": C++ CUDA binary (default for GPU) - "python": Pure Python/CuPy solver (no external binaries required) """
[docs] def __init__( self, is_gpu_simulation: bool = False, binary_path: Optional[str] = None, binary_dir: Optional[str] = None, binary_name: Optional[str] = None, kwave_function_name: Optional[str] = "kspaceFirstOrder3D", delete_data: bool = True, device_num: Optional[int] = None, num_threads: Optional[int] = None, thread_binding: Optional[bool] = None, system_call: Optional[str] = None, verbose_level: int = 0, auto_chunking: Optional[bool] = True, show_sim_log: bool = True, checkpoint_interval: Optional[int] = None, # [seconds] checkpoint_timesteps: Optional[int] = None, # [timestep integer] checkpoint_file: Optional[Path | str] = None, # [path to hdf5 file] backend: Optional[str] = None, # "OMP", "CUDA", or "python" ): self._num_threads_explicit = num_threads is not None self.backend = backend self.is_gpu_simulation = is_gpu_simulation self._binary_path = binary_path self._binary_name = binary_name self._binary_dir = binary_dir self.kwave_function_name = kwave_function_name self.delete_data = delete_data self.device_num = device_num self.num_threads = num_threads self.thread_binding = thread_binding self.system_call = system_call self.verbose_level = verbose_level self.auto_chunking = auto_chunking self.show_sim_log = show_sim_log self.checkpoint_interval = checkpoint_interval self.checkpoint_timesteps = checkpoint_timesteps self.checkpoint_file = checkpoint_file if self.checkpoint_file is not None: if self.checkpoint_interval is None and self.checkpoint_timesteps is None: raise ValueError("One of checkpoint_interval or checkpoint_timesteps must be set when checkpoint_file is set.")
@property def num_threads_explicit(self) -> bool: """True if user explicitly set num_threads (vs auto-detected cpu_count).""" return self._num_threads_explicit @property def num_threads(self) -> Union[int, str]: return self._num_threads @num_threads.setter def num_threads(self, value: Union[int, str]): cpu_count = os.cpu_count() if cpu_count is None: raise RuntimeError("Unable to determine the number of CPUs on this system. Please specify the number of threads explicitly.") if value == "all": warnings.warn( "The 'all' option is deprecated. The value of None sets the maximal number of threads (excluding Windows).", DeprecationWarning, ) value = cpu_count if value is None: value = cpu_count if not isinstance(value, int): raise ValueError("Got {value}. Number of threads must be 'all' or a positive integer") if value <= 0 or value > cpu_count: raise ValueError(f"Number of threads {value} must be a positive integer and less than total threads on the system {cpu_count}.") self._num_threads = value @property def verbose_level(self) -> int: return self._verbose_level @verbose_level.setter def verbose_level(self, value: int): if not (isinstance(value, int) and 0 <= value <= 2): raise ValueError("Verbose level must be between 0 and 2") self._verbose_level = value @property def backend(self) -> Optional[str]: return self._backend @backend.setter def backend(self, value: Optional[str]): valid_backends = [None, "OMP", "CUDA", "python"] if value not in valid_backends: raise ValueError(f"Backend must be one of {valid_backends}, got '{value}'") self._backend = value @property def is_python_backend(self) -> bool: return self._backend == "python" @property def is_gpu_simulation(self) -> Optional[bool]: return self._is_gpu_simulation @is_gpu_simulation.setter def is_gpu_simulation(self, value: Optional[bool]): "Set the flag to enable default GPU simulation. This option will supersede custom binary paths." self._is_gpu_simulation = value # Automatically update the binary name based on the GPU simulation flag if value is not None: self._binary_name = None @property def binary_name(self) -> str: valid_binary_names = ["kspaceFirstOrder-OMP", "kspaceFirstOrder-CUDA"] if self._binary_name is None: # set default binary name based on GPU simulation value if self.is_gpu_simulation is None: raise ValueError("`is_gpu_simulation` must be set to either True or False before determining the binary name.") if self.is_gpu_simulation: self._binary_name = "kspaceFirstOrder-CUDA" else: self._binary_name = "kspaceFirstOrder-OMP" if PLATFORM == "windows": self._binary_name += ".exe" valid_binary_names = [name + ".exe" for name in valid_binary_names] elif self._binary_name not in valid_binary_names: warnings.warn("Custom binary name set. Ignoring `is_gpu_simulation` state.") return self._binary_name @binary_name.setter def binary_name(self, value: str): self._binary_name = value @property def binary_path(self) -> Path: if self._binary_path is not None: return self._binary_path binary_dir = BINARY_DIR if self._binary_dir is None else self._binary_dir if binary_dir is None: raise ValueError("Binary directory is not specified.") path = Path(binary_dir) / self.binary_name if PLATFORM == "windows" and not path.name.endswith(".exe"): path = path.with_suffix(".exe") return path @binary_path.setter def binary_path(self, value: str): # check if the binary path is a valid path if not os.path.exists(value): raise FileNotFoundError( f"Binary path {value} does not exist. If you are trying to set `binary_dir`, use the `binary_dir` attribute instead." ) self._binary_path = value @property def binary_dir(self) -> str: return BINARY_DIR if self._binary_dir is None else self._binary_dir @binary_dir.setter def binary_dir(self, value: str): # check if binary_dir is a directory if not os.path.isdir(value): raise NotADirectoryError( f"{value} is not a directory. If you are trying to set the `binary_path`, use the `binary_path` attribute instead." ) self._binary_dir = Path(value) @property def device_num(self) -> Optional[int]: return self._device_num @device_num.setter def device_num(self, value: Optional[int]): if value is not None and value < 0: raise ValueError("Device number must be non-negative") self._device_num = value @property def checkpoint_interval(self) -> Optional[int]: return self._checkpoint_interval @checkpoint_interval.setter def checkpoint_interval(self, value: Optional[int]): if value is not None: if not isinstance(value, int) or value < 0: raise ValueError("Checkpoint interval must be a positive integer") self._checkpoint_interval = value @property def checkpoint_timesteps(self) -> Optional[int]: return self._checkpoint_timesteps @checkpoint_timesteps.setter def checkpoint_timesteps(self, value: Optional[int]): if value is not None: if not isinstance(value, int) or value < 0: raise ValueError("Checkpoint timesteps must be a positive integer") self._checkpoint_timesteps = value @property def checkpoint_file(self) -> Optional[Path]: if self._checkpoint_file is None: return None return self._checkpoint_file @checkpoint_file.setter def checkpoint_file(self, value: Optional[Path | str]): if value is not None: if not isinstance(value, (str, Path)): raise ValueError("Checkpoint file must be a string or Path object.") if isinstance(value, str): value = Path(value) if not value.parent.is_dir(): raise FileNotFoundError(f"Checkpoint folder {value.parent} does not exist.") if value.suffix != ".h5": raise ValueError(f"Checkpoint file {value} must have .h5 extension.") self._checkpoint_file = value
[docs] def as_list(self, sensor: kSensor) -> list[str]: options_list = [] if self.device_num is not None: options_list.append("-g") options_list.append(str(self.device_num)) if self._num_threads is not None and PLATFORM != "windows": options_list.append("-t") options_list.append(str(self._num_threads)) if self.verbose_level > 0: options_list.append("--verbose") options_list.append(str(self.verbose_level)) if (self.checkpoint_interval is not None or self.checkpoint_timesteps is not None) and self.checkpoint_file is not None: if self.checkpoint_timesteps is not None: options_list.append("--checkpoint_timesteps") options_list.append(str(self.checkpoint_timesteps)) if self.checkpoint_interval is not None: options_list.append("--checkpoint_interval") options_list.append(str(self.checkpoint_interval)) options_list.append("--checkpoint_file") options_list.append(str(self.checkpoint_file)) record_options_map = { "p": "p_raw", "p_max": "p_max", "p_min": "p_min", "p_rms": "p_rms", "p_max_all": "p_max_all", "p_min_all": "p_min_all", "p_final": "p_final", "u": "u_raw", "u_max": "u_max", "u_min": "u_min", "u_rms": "u_rms", "u_max_all": "u_max_all", "u_min_all": "u_min_all", "u_final": "u_final", } if sensor.record is not None: matching_keys = sorted(set(sensor.record).intersection(record_options_map.keys())) options_list.extend([f"--{record_options_map[key]}" for key in matching_keys]) if "u_non_staggered" in sensor.record or "I_avg" in sensor.record or "I" in sensor.record: options_list.append("--u_non_staggered_raw") if ("I_avg" in sensor.record or "I" in sensor.record) and ("p" not in sensor.record): options_list.append("--p_raw") else: options_list.append("--p_raw") if sensor.record_start_index is not None: options_list.append("-s") options_list.append(f"{sensor.record_start_index}") return options_list
[docs] def get_options_string(self, sensor: kSensor) -> str: # raise a deprecation warning warnings.warn("This method is deprecated. Use `as_list` method instead.", DeprecationWarning) options_list = self.as_list(sensor) return " ".join(options_list)
@property def env_vars(self) -> dict: env = os.environ if PLATFORM != "darwin": env.update({"OMP_PLACES": "cores"}) if self.thread_binding is not None: if PLATFORM == "darwin": raise ValueError("Thread binding is not supported in MacOS.") # read the parameters and update the system options if self.thread_binding: env.update({"OMP_PROC_BIND": "SPREAD"}) else: env.update({"OMP_PROC_BIND": "CLOSE"}) else: if PLATFORM != "darwin": env.update({"OMP_PROC_BIND": "SPREAD"}) return env