From a772291d3764809f5dcbf9fccbb64971cc24cd34 Mon Sep 17 00:00:00 2001 From: henrikjacobsenfys Date: Mon, 30 Mar 2026 14:32:00 +0200 Subject: [PATCH 1/3] create bases classes --- src/easydynamics/base_classes/__init__.py | 10 ++ .../base_classes/easydynamics_base.py | 52 ++++++++++ .../base_classes/easydynamics_modelbase.py | 87 +++++++++++++++++ .../convolution/analytical_convolution.py | 37 +++---- .../convolution/convolution_base.py | 55 +++++++---- .../convolution/numerical_convolution_base.py | 78 ++++++++------- src/easydynamics/experiment/experiment.py | 4 +- .../components/model_component.py | 53 ++-------- .../diffusion_model/diffusion_model_base.py | 33 +------ src/easydynamics/sample_model/model_base.py | 96 +++++++------------ src/easydynamics/utils/utils.py | 29 +++--- tests/unit/easydynamics/utils/test_utils.py | 52 +++++----- 12 files changed, 340 insertions(+), 246 deletions(-) create mode 100644 src/easydynamics/base_classes/__init__.py create mode 100644 src/easydynamics/base_classes/easydynamics_base.py create mode 100644 src/easydynamics/base_classes/easydynamics_modelbase.py diff --git a/src/easydynamics/base_classes/__init__.py b/src/easydynamics/base_classes/__init__.py new file mode 100644 index 00000000..69134fc3 --- /dev/null +++ b/src/easydynamics/base_classes/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from .easydynamics_base import EasyDynamicsBase +from .easydynamics_modelbase import EasyDynamicsModelBase + +__all__ = [ + 'EasyDynamicsModelBase', + 'EasyDynamicsBase', +] diff --git a/src/easydynamics/base_classes/easydynamics_base.py b/src/easydynamics/base_classes/easydynamics_base.py new file mode 100644 index 00000000..d70c4016 --- /dev/null +++ b/src/easydynamics/base_classes/easydynamics_base.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from easyscience.base_classes.new_base import NewBase + + +class EasyDynamicsBase(NewBase): + """Base class for all EasyDynamics classes.""" + + def __init__( + self, + name: str | None = 'MyEasyDynamicsModel', + display_name: str | None = 'MyEasyDynamicsModel', + unique_name: str | None = None, + ) -> None: + """Initialize the EasyDynamicsBase. + + Args: + name (str | None, default="MyEasyDynamicsModel"): Name of the model. + display_name (str, default="MyEasyDynamicsModel"): Display name of the model. + unique_name (str | None, default=None): Unique name of the model. If None, + a unique name will be generated. + + Raises: + TypeError: If name is not a string or None. + """ + super().__init__(display_name=display_name, unique_name=unique_name) + + if name is not None and not isinstance(name, str): + raise TypeError('Name must be a string or None.') + self._name = name + + @property + def name(self) -> str | None: + """Get the name of the model. + + Returns: + str | None: The name of the model. + """ + return self._name + + @name.setter + def name(self, name_str: str | None) -> None: + """Set the name of the model. + + Args: + name_str (str | None): The new name to set. + """ + + if name_str is not None and not isinstance(name_str, str): + raise TypeError('Name must be a string or None.') + self._name = name_str diff --git a/src/easydynamics/base_classes/easydynamics_modelbase.py b/src/easydynamics/base_classes/easydynamics_modelbase.py new file mode 100644 index 00000000..6d334b88 --- /dev/null +++ b/src/easydynamics/base_classes/easydynamics_modelbase.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import scipp as sc +from easyscience.base_classes import ModelBase + +from easydynamics.utils.utils import _validate_unit + + +class EasyDynamicsModelBase(ModelBase): + """Base class for all EasyDynamics models.""" + + def __init__( + self, + unit: str | sc.Unit = "meV", + name: str | None = "MyEasyDynamicsModel", + display_name: str | None = "MyEasyDynamicsModel", + unique_name: str | None = None, + ) -> None: + """Initialize the EasyDynamicsModelBase. + + Args: + unit (str | sc.Unit, default="meV"): Unit of the model. + name (str | None, default="MyEasyDynamicsModel"): Name of the model. + display_name (str | None, default="MyEasyDynamicsModel"): Display name of the model. + unique_name (str | None, default=None): Unique name of the model. If None, + a unique name will be generated. + + Raises: + TypeError: If name is not a string or None. + """ + super().__init__(display_name=display_name, unique_name=unique_name) + self._unit = _validate_unit(unit) + + if name is not None and not isinstance(name, str): + raise TypeError("Name must be a string or None.") + self._name = name + + @property + def unit(self) -> str | sc.Unit | None: + """Get the unit of the model. + + Returns: + str | sc.Unit | None: The unit of the model. + """ + + return self._unit + + @unit.setter + def unit(self, unit_str: str) -> None: + """Unit is read-only and cannot be set directly. + + Args: + unit_str (str): The new unit to set (ignored). + + Raises: + AttributeError: Always raised to indicate that the unit is + read-only. + """ + raise AttributeError( + f"Unit is read-only. Use convert_unit to change the unit between allowed types " + f"or create a new {self.__class__.__name__} with the desired unit." + ) # noqa: E501 + + @property + def name(self) -> str | None: + """Get the name of the model. + + Returns: + str | None: The name of the model. + """ + return self._name + + @name.setter + def name(self, name_str: str) -> None: + """Set the name of the model. + + Args: + name_str (str): The new name to set. + + Raises: + TypeError: If name_str is not a string or None. + """ + + if name_str is not None and not isinstance(name_str, str): + raise TypeError("Name must be a string or None.") + self._name = name_str diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py index 761f1cec..86e4cae7 100644 --- a/src/easydynamics/convolution/analytical_convolution.py +++ b/src/easydynamics/convolution/analytical_convolution.py @@ -27,18 +27,18 @@ class AnalyticalConvolution(ConvolutionBase): # Mapping of supported component type pairs to convolution methods. # Delta functions are handled separately. _CONVOLUTIONS = { - ('Gaussian', 'Gaussian'): '_convolute_gaussian_gaussian', - ('Gaussian', 'Lorentzian'): '_convolute_gaussian_lorentzian', - ('Gaussian', 'Voigt'): '_convolute_gaussian_voigt', - ('Lorentzian', 'Lorentzian'): '_convolute_lorentzian_lorentzian', - ('Lorentzian', 'Voigt'): '_convolute_lorentzian_voigt', - ('Voigt', 'Voigt'): '_convolute_voigt_voigt', + ("Gaussian", "Gaussian"): "_convolute_gaussian_gaussian", + ("Gaussian", "Lorentzian"): "_convolute_gaussian_lorentzian", + ("Gaussian", "Voigt"): "_convolute_gaussian_voigt", + ("Lorentzian", "Lorentzian"): "_convolute_lorentzian_lorentzian", + ("Lorentzian", "Voigt"): "_convolute_lorentzian_voigt", + ("Voigt", "Voigt"): "_convolute_voigt_voigt", } def __init__( self, energy: np.ndarray | sc.Variable, - energy_unit: str | sc.Unit = 'meV', + energy_unit: str | sc.Unit = "meV", sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, energy_offset: Numeric | Parameter = 0.0, @@ -59,7 +59,7 @@ def __init__( """ super().__init__( energy=energy, - energy_unit=energy_unit, + unit=energy_unit, sample_components=sample_components, resolution_components=resolution_components, energy_offset=energy_offset, @@ -146,8 +146,8 @@ def _convolute_analytic_pair( if isinstance(resolution_component, DeltaFunction): raise ValueError( - 'Analytical convolution with a delta function \ - in the resolution model is not supported.' + "Analytical convolution with a delta function \ + in the resolution model is not supported." ) # Delta function + anything --> @@ -173,8 +173,8 @@ def _convolute_analytic_pair( if func_name is None: raise ValueError( - f'Analytical convolution not supported for component pair: ' - f'{type(sample_component).__name__}, {type(resolution_component).__name__}' + f"Analytical convolution not supported for component pair: " + f"{type(sample_component).__name__}, {type(resolution_component).__name__}" ) # Call the corresponding method @@ -225,7 +225,9 @@ def _convolute_gaussian_gaussian( np.ndarray: The evaluated convolution values at self.energy. """ - width = np.sqrt(sample_component.width.value**2 + resolution_component.width.value**2) + width = np.sqrt( + sample_component.width.value**2 + resolution_component.width.value**2 + ) area = sample_component.area.value * resolution_component.area.value @@ -284,7 +286,8 @@ def _convolute_gaussian_voigt( center = sample_component.center.value + resolution_component.center.value gaussian_width = np.sqrt( - sample_component.width.value**2 + resolution_component.gaussian_width.value**2 + sample_component.width.value**2 + + resolution_component.gaussian_width.value**2 ) lorentzian_width = resolution_component.lorentzian_width.value @@ -386,11 +389,13 @@ def _convolute_voigt_voigt( center = sample_component.center.value + resolution_component.center.value gaussian_width = np.sqrt( - sample_component.gaussian_width.value**2 + resolution_component.gaussian_width.value**2 + sample_component.gaussian_width.value**2 + + resolution_component.gaussian_width.value**2 ) lorentzian_width = ( - sample_component.lorentzian_width.value + resolution_component.lorentzian_width.value + sample_component.lorentzian_width.value + + resolution_component.lorentzian_width.value ) return self._voigt_eval( area=area, diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py index f9cfc9fd..5ecfc346 100644 --- a/src/easydynamics/convolution/convolution_base.py +++ b/src/easydynamics/convolution/convolution_base.py @@ -5,12 +5,13 @@ import scipp as sc from easyscience.variable import Parameter +from easydynamics.base_classes import EasyDynamicsBase from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.utils.utils import Numeric -class ConvolutionBase: +class ConvolutionBase(EasyDynamicsBase): """Base class for convolutions of sample and resolution models. This base class has no convolution functionality. @@ -21,7 +22,7 @@ def __init__( energy: np.ndarray | sc.Variable, sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, - energy_unit: str | sc.Unit = 'meV', + energy_unit: str | sc.Unit = "meV", energy_offset: Numeric | Parameter = 0.0, ) -> None: """Initialize the ConvolutionBase. @@ -49,21 +50,23 @@ def __init__( energy = np.array([float(energy)]) if not isinstance(energy, (np.ndarray, sc.Variable)): - raise TypeError(f'Energy must be a numpy ndarray or a scipp Variable. Got {energy}') + raise TypeError( + f"Energy must be a numpy ndarray or a scipp Variable. Got {energy}" + ) if not isinstance(energy_unit, (str, sc.Unit)): - raise TypeError('Energy_unit must be a string or sc.Unit.') + raise TypeError("Energy_unit must be a string or sc.Unit.") if isinstance(energy, np.ndarray): - energy = sc.array(dims=['energy'], values=energy, unit=energy_unit) + energy = sc.array(dims=["energy"], values=energy, unit=energy_unit) if isinstance(energy_offset, Numeric): energy_offset = Parameter( - name='energy_offset', value=float(energy_offset), unit=energy_unit + name="energy_offset", value=float(energy_offset), unit=energy_unit ) if not isinstance(energy_offset, Parameter): - raise TypeError('Energy_offset must be a number or a Parameter.') + raise TypeError("Energy_offset must be a number or a Parameter.") self._energy = energy self._energy_unit = energy_unit @@ -73,7 +76,7 @@ def __init__( isinstance(sample_components, (ComponentCollection, ModelComponent)) ): raise TypeError( - f'`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 + f"`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501 ) if isinstance(sample_components, ModelComponent): sample_components = ComponentCollection(components=[sample_components]) @@ -83,10 +86,12 @@ def __init__( isinstance(resolution_components, (ComponentCollection, ModelComponent)) ): raise TypeError( - f'`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 + f"`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501 ) if isinstance(resolution_components, ModelComponent): - resolution_components = ComponentCollection(components=[resolution_components]) + resolution_components = ComponentCollection( + components=[resolution_components] + ) self._resolution_components = resolution_components @property @@ -110,7 +115,7 @@ def energy_offset(self, energy_offset: Numeric | Parameter) -> None: TypeError: If energy_offset is not a number or a Parameter. """ if not isinstance(energy_offset, Parameter | Numeric): - raise TypeError('Energy_offset must be a number or a Parameter.') + raise TypeError("Energy_offset must be a number or a Parameter.") if isinstance(energy_offset, Numeric): self._energy_offset.value = float(energy_offset) @@ -142,7 +147,7 @@ def energy_with_offset(self, value: sc.Variable) -> None: read-only. """ raise AttributeError( - 'Energy with offset is a read-only property derived from energy and energy_offset.' + "Energy with offset is a read-only property derived from energy and energy_offset." ) @property @@ -173,10 +178,14 @@ def energy(self, energy: np.ndarray | sc.Variable) -> None: energy = np.array([float(energy)]) if not isinstance(energy, (np.ndarray, sc.Variable)): - raise TypeError('Energy must be a Number, a numpy ndarray or a scipp Variable.') + raise TypeError( + "Energy must be a Number, a numpy ndarray or a scipp Variable." + ) if isinstance(energy, np.ndarray): - self._energy = sc.array(dims=['energy'], values=energy, unit=self._energy.unit) + self._energy = sc.array( + dims=["energy"], values=energy, unit=self._energy.unit + ) if isinstance(energy, sc.Variable): self._energy = energy @@ -194,8 +203,8 @@ def energy_unit(self) -> str: @energy_unit.setter def energy_unit(self, unit_str: str) -> None: raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' - f'or create a new {self.__class__.__name__} with the desired unit.' + f"Unit is read-only. Use convert_unit to change the unit between allowed types " + f"or create a new {self.__class__.__name__} with the desired unit." ) # noqa: E501 def convert_energy_unit(self, energy_unit: str | sc.Unit) -> None: @@ -210,7 +219,7 @@ def convert_energy_unit(self, energy_unit: str | sc.Unit) -> None: unit. """ if not isinstance(energy_unit, (str, sc.Unit)): - raise TypeError('Energy unit must be a string or scipp unit.') + raise TypeError("Energy unit must be a string or scipp unit.") old_energy = self.energy.copy() try: @@ -239,7 +248,9 @@ def sample_components(self) -> ComponentCollection | ModelComponent: return self._sample_components @sample_components.setter - def sample_components(self, sample_components: ComponentCollection | ModelComponent) -> None: + def sample_components( + self, sample_components: ComponentCollection | ModelComponent + ) -> None: """Set the sample model. Args: @@ -252,7 +263,7 @@ def sample_components(self, sample_components: ComponentCollection | ModelCompon """ if not isinstance(sample_components, (ComponentCollection, ModelComponent)): raise TypeError( - f'`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 + f"`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501 ) if isinstance(sample_components, ModelComponent): @@ -286,9 +297,11 @@ def resolution_components( """ if not isinstance(resolution_components, (ComponentCollection, ModelComponent)): raise TypeError( - f'`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 + f"`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501 ) if isinstance(resolution_components, ModelComponent): - resolution_components = ComponentCollection(components=[resolution_components]) + resolution_components = ComponentCollection( + components=[resolution_components] + ) self._resolution_components = resolution_components diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py index 913726f6..ffbd3736 100644 --- a/src/easydynamics/convolution/numerical_convolution_base.py +++ b/src/easydynamics/convolution/numerical_convolution_base.py @@ -42,8 +42,8 @@ def __init__( upsample_factor: Numeric | None = 5, extension_factor: Numeric | None = 0.2, temperature: Parameter | Numeric | None = None, - temperature_unit: str | sc.Unit = 'K', - energy_unit: str | sc.Unit = 'meV', + temperature_unit: str | sc.Unit = "K", + energy_unit: str | sc.Unit = "meV", normalize_detailed_balance: bool = True, ) -> None: """Initialize the NumericalConvolutionBase. @@ -79,15 +79,17 @@ def __init__( energy=energy, sample_components=sample_components, resolution_components=resolution_components, - energy_unit=energy_unit, + unit=energy_unit, energy_offset=energy_offset, ) - if temperature is not None and not isinstance(temperature, (Numeric, Parameter)): - raise TypeError('Temperature must be None, a number or a Parameter.') + if temperature is not None and not isinstance( + temperature, (Numeric, Parameter) + ): + raise TypeError("Temperature must be None, a number or a Parameter.") if not isinstance(temperature_unit, (str, sc.Unit)): - raise TypeError('Temperature_unit must be a string or sc.Unit.') + raise TypeError("Temperature_unit must be a string or sc.Unit.") self._temperature_unit = temperature_unit self._temperature = None self.temperature = temperature @@ -140,10 +142,10 @@ def upsample_factor(self, factor: Numeric | None) -> None: return if not isinstance(factor, Numeric): - raise TypeError('Upsample factor must be a numerical value or None.') + raise TypeError("Upsample factor must be a numerical value or None.") factor = float(factor) if factor <= 1.0: - raise ValueError('Upsample factor must be greater than 1.') + raise ValueError("Upsample factor must be greater than 1.") self._upsample_factor = factor @@ -183,9 +185,9 @@ def extension_factor(self, factor: Numeric) -> None: """ if not isinstance(factor, Numeric): - raise TypeError('Extension factor must be a number.') + raise TypeError("Extension factor must be a number.") if factor < 0.0: - raise ValueError('Extension factor must be non-negative.') + raise ValueError("Extension factor must be non-negative.") self._extension_factor = float(factor) # Recreate dense grid when extension factor is updated @@ -225,7 +227,7 @@ def temperature(self, temp: Parameter | Numeric | None) -> None: self._temperature.value = float(temp) else: self._temperature = Parameter( - name='temperature', + name="temperature", value=float(temp), unit=self._temperature_unit, fixed=True, @@ -233,7 +235,7 @@ def temperature(self, temp: Parameter | Numeric | None) -> None: elif isinstance(temp, Parameter): self._temperature = temp else: - raise TypeError('Temperature must be None, a float or a Parameter.') + raise TypeError("Temperature must be None, a float or a Parameter.") @property def normalize_detailed_balance(self) -> bool: @@ -262,7 +264,7 @@ def normalize_detailed_balance(self, normalize: bool) -> None: """ if not isinstance(normalize, bool): - raise TypeError('normalize_detailed_balance must be True or False.') + raise TypeError("normalize_detailed_balance must be True or False.") self._normalize_detailed_balance = normalize @@ -290,7 +292,7 @@ def _create_energy_grid( is_uniform = np.allclose(energy_diff, energy_diff[0]) if not is_uniform: raise ValueError( - 'Input array `energy` must be uniformly spaced if upsample_factor is not given.' # noqa: E501 + "Input array `energy` must be uniformly spaced if upsample_factor is not given." # noqa: E501 ) energy_dense = self.energy.values @@ -307,7 +309,7 @@ def _create_energy_grid( energy_span_dense = extended_max - extended_min if len(energy_dense) < 2: - raise ValueError('Energy array must have at least two points.') + raise ValueError("Energy array must have at least two points.") energy_dense_step = energy_dense[1] - energy_dense[0] # Handle offset for even length of energy_dense in convolution. @@ -320,7 +322,9 @@ def _create_energy_grid( # select the 4 central points we either get # indices [2,3,4,5] or [1,2,3,4], both of which are offset by # 0.5*dx from the true center at index 3.5. - energy_even_length_offset = -0.5 * energy_dense_step if len(energy_dense) % 2 == 0 else 0.0 + energy_even_length_offset = ( + -0.5 * energy_dense_step if len(energy_dense) % 2 == 0 else 0.0 + ) # Handle the case when energy_dense is not symmetric around 0. # The resolution is still centered around zero (or close to it), @@ -362,27 +366,35 @@ def _check_width_thresholds( """ # Handle ComponentCollection or ModelComponent - components = model.components if isinstance(model, ComponentCollection) else [model] + components = ( + model.components if isinstance(model, ComponentCollection) else [model] + ) for comp in components: - if hasattr(comp, 'width'): - if comp.width.value > LARGE_WIDTH_THRESHOLD * self._energy_grid.energy_span_dense: + if hasattr(comp, "width"): + if ( + comp.width.value + > LARGE_WIDTH_THRESHOLD * self._energy_grid.energy_span_dense + ): warnings.warn( f"The width of the {model_name} component '{comp.unique_name}' \ ({comp.width.value}) is large compared to the span of the input " - f'array ({self._energy_grid.energy_span_dense}). \ + f"array ({self._energy_grid.energy_span_dense}). \ This may lead to inaccuracies in the convolution. \ - Increase extension_factor to improve accuracy.', + Increase extension_factor to improve accuracy.", UserWarning, stacklevel=3, ) - if comp.width.value < SMALL_WIDTH_THRESHOLD * self._energy_grid.energy_dense_step: + if ( + comp.width.value + < SMALL_WIDTH_THRESHOLD * self._energy_grid.energy_dense_step + ): warnings.warn( f"The width of the {model_name} component '{comp.unique_name}' \ ({comp.width.value}) is small compared to the spacing of the input " - f'array ({self._energy_grid.energy_dense_step}). \ + f"array ({self._energy_grid.energy_dense_step}). \ This may lead to inaccuracies in the convolution. \ - Increase upsample_factor to improve accuracy.', + Increase upsample_factor to improve accuracy.", UserWarning, stacklevel=3, ) @@ -396,13 +408,13 @@ def __repr__(self) -> str: NumericalConvolutionBase. """ return ( - f'{self.__class__.__name__}(' - f'energy=array of shape {self.energy.values.shape},\n ' - f'sample_components={repr(self.sample_components)}, \n' - f'resolution_components={repr(self.resolution_components)},\n ' - f'energy_unit={self._energy_unit}, ' - f'upsample_factor={self.upsample_factor}, ' - f'extension_factor={self.extension_factor}, ' - f'temperature={self.temperature}, ' - f'normalize_detailed_balance={self.normalize_detailed_balance})' + f"{self.__class__.__name__}(" + f"energy=array of shape {self.energy.values.shape},\n " + f"sample_components={repr(self.sample_components)}, \n" + f"resolution_components={repr(self.resolution_components)},\n " + f"energy_unit={self._energy_unit}, " + f"upsample_factor={self.upsample_factor}, " + f"extension_factor={self.extension_factor}, " + f"temperature={self.temperature}, " + f"normalize_detailed_balance={self.normalize_detailed_balance})" ) diff --git a/src/easydynamics/experiment/experiment.py b/src/easydynamics/experiment/experiment.py index cfe7dcfb..0a5e8523 100644 --- a/src/easydynamics/experiment/experiment.py +++ b/src/easydynamics/experiment/experiment.py @@ -6,14 +6,14 @@ import numpy as np import plopp as pp import scipp as sc -from easyscience.base_classes.new_base import NewBase from scipp.io import load_hdf5 as sc_load_hdf5 from scipp.io import save_hdf5 as sc_save_hdf5 +from easydynamics.base_classes.easydynamics_base import EasyDynamicsBase from easydynamics.utils.utils import _in_notebook -class Experiment(NewBase): +class Experiment(EasyDynamicsBase): """Holds data from an experiment as a sc.DataArray along with metadata. diff --git a/src/easydynamics/sample_model/components/model_component.py b/src/easydynamics/sample_model/components/model_component.py index 7fc63880..ae5f310f 100644 --- a/src/easydynamics/sample_model/components/model_component.py +++ b/src/easydynamics/sample_model/components/model_component.py @@ -8,13 +8,13 @@ import numpy as np import scipp as sc -from easyscience.base_classes.model_base import ModelBase from scipp import UnitError +from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.utils.utils import Numeric -class ModelComponent(ModelBase): +class ModelComponent(EasyDynamicsModelBase): """Abstract base class for all model components.""" def __init__( @@ -32,35 +32,11 @@ def __init__( unique_name (str | None, default=None): A unique identifier for the component. """ - self.validate_unit(unit) - super().__init__(display_name=display_name, unique_name=unique_name) - self._unit = unit - - @property - def unit(self) -> str: - """Get the unit. - - Returns: - str: The unit of the model component. - """ - return str(self._unit) - - @unit.setter - def unit(self, unit_str: str) -> None: - """Unit is read-only. Use convert_unit to change the unit - between allowed types or create a new ModelComponent with the - desired unit. - - Args: - unit_str (str): The new unit to set. - - Raises: - AttributeError: Always raised since unit is read-only. - """ - raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' - f'or create a new {self.__class__.__name__} with the desired unit.' - ) # noqa: E501 + super().__init__( + unit=unit, + display_name=display_name, + unique_name=unique_name, + ) def fix_all_parameters(self) -> None: """Fix all parameters in the model component.""" @@ -146,21 +122,6 @@ def _prepare_x_for_evaluate( return np.sort(x_in) - @staticmethod - def validate_unit(unit: str | sc.Unit | None) -> None: - """Validate that the unit is either a string or a scipp Unit. - - Args: - unit (str | sc.Unit | None): The unit to validate. - - Raises: - TypeError: If unit is not a string or scipp Unit. - """ - if unit is not None and not isinstance(unit, (str, sc.Unit)): - raise TypeError( - f'unit must be None, a string, or a scipp Unit, got {type(unit).__name__}' - ) - def convert_unit(self, unit: str | sc.Unit) -> None: """Convert the unit of the Parameters in the component. diff --git a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py index 1721fdc8..2c45fe36 100644 --- a/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py +++ b/src/easydynamics/sample_model/diffusion_model/diffusion_model_base.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause import scipp as sc -from easyscience.base_classes.model_base import ModelBase from easyscience.variable import DescriptorNumber from easyscience.variable import Parameter from scipp import UnitError +from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.utils.utils import Numeric -class DiffusionModelBase(ModelBase): +class DiffusionModelBase(EasyDynamicsModelBase): """Base class for constructing diffusion models.""" def __init__( @@ -51,40 +51,13 @@ def __init__( scale = Parameter(name='scale', value=float(scale), fixed=False, min=0.0, unit=unit) - super().__init__(display_name=display_name, unique_name=unique_name) - self._unit = unit + super().__init__(display_name=display_name, unique_name=unique_name, unit=unit) self._scale = scale # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ - @property - def unit(self) -> str | sc.Unit | None: - """Get the unit of the energy axis of the DiffusionModel. - - Returns: - str | sc.Unit | None: Unit of the DiffusionModel. - """ - return str(self._unit) - - @unit.setter - def unit(self, unit_str: str) -> None: - """The unit of the energy axis is read-only. To change the unit, - use convert_unit or create a new DiffusionModel with the desired - unit. - - Args: - unit_str (str): The new unit to set (ignored) - - Raises: - AttributeError: Always, since the unit is read-only. - """ - raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' - f'or create a new {self.__class__.__name__} with the desired unit.' - ) # noqa: E501 - @property def scale(self) -> Parameter: """Get the scale parameter of the diffusion model. diff --git a/src/easydynamics/sample_model/model_base.py b/src/easydynamics/sample_model/model_base.py index 6c1392ea..1790cf15 100644 --- a/src/easydynamics/sample_model/model_base.py +++ b/src/easydynamics/sample_model/model_base.py @@ -5,18 +5,17 @@ import numpy as np import scipp as sc -from easyscience.base_classes.model_base import ModelBase as EasyScienceModelBase from easyscience.variable import Parameter +from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.utils.utils import Numeric from easydynamics.utils.utils import Q_type from easydynamics.utils.utils import _validate_and_convert_Q -from easydynamics.utils.utils import _validate_unit -class ModelBase(EasyScienceModelBase): +class ModelBase(EasyDynamicsModelBase): """Base class for Sample Models. Contains common functionality for models with components and Q @@ -49,10 +48,10 @@ def __init__( TypeError: If components is not a ModelComponent or ComponentCollection. """ super().__init__( + unit=unit, display_name=display_name, unique_name=unique_name, ) - self._unit = _validate_unit(unit) self._Q = _validate_and_convert_Q(Q) if components is not None and not isinstance( @@ -134,63 +133,6 @@ def clear_components(self) -> None: # Properties # ------------------------------------------------------------------ - @property - def unit(self) -> str | sc.Unit | None: - """Get the unit of the ComponentCollection. - - Returns: - str | sc.Unit | None: The unit of the ComponentCollection. - """ - - return self._unit - - @unit.setter - def unit(self, unit_str: str) -> None: - """Unit is read-only and cannot be set directly. - - Args: - unit_str (str): The new unit to set (ignored). - - Raises: - AttributeError: Always raised to indicate that the unit is - read-only. - """ - raise AttributeError( - f'Unit is read-only. Use convert_unit to change the unit between allowed types ' - f'or create a new {self.__class__.__name__} with the desired unit.' - ) # noqa: E501 - - def convert_unit(self, unit: str | sc.Unit) -> None: - """Convert the unit of the ComponentCollection and all its - components. - - Args: - unit (str | sc.Unit): The new unit to convert to. - - Raises: - TypeError: If the provided unit is not a string or sc.Unit. - Exception: If the provided unit is not compatible with the - current unit. - """ - - old_unit = self._unit - - if not isinstance(unit, (str, sc.Unit)): - raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}') - try: - for component in self.components: - component.convert_unit(unit) - self._unit = unit - except Exception as e: - # Attempt to rollback on failure - try: - for component in self.components: - component.convert_unit(old_unit) - except Exception: # noqa: S110 - pass # Best effort rollback - raise e - self._on_components_change() - @property def components(self) -> list[ModelComponent]: """Get the components of the SampleModel. @@ -279,6 +221,38 @@ def clear_Q(self, confirm: bool = False) -> None: # ------------------------------------------------------------------ # Other methods # ------------------------------------------------------------------ + + def convert_unit(self, unit: str | sc.Unit) -> None: + """Convert the unit of the ComponentCollection and all its + components. + + Args: + unit (str | sc.Unit): The new unit to convert to. + + Raises: + TypeError: If the provided unit is not a string or sc.Unit. + Exception: If the provided unit is not compatible with the + current unit. + """ + + old_unit = self._unit + + if not isinstance(unit, (str, sc.Unit)): + raise TypeError(f'Unit must be a string or sc.Unit, got {type(unit).__name__}') + try: + for component in self.components: + component.convert_unit(unit) + self._unit = unit + except Exception as e: + # Attempt to rollback on failure + try: + for component in self.components: + component.convert_unit(old_unit) + except Exception: # noqa: S110 + pass # Best effort rollback + raise e + self._on_components_change() + def fix_all_parameters(self) -> None: """Fix all Parameters in all ComponentCollections.""" for par in self.get_all_variables(): diff --git a/src/easydynamics/utils/utils.py b/src/easydynamics/utils/utils.py index 80a90712..be7f5506 100644 --- a/src/easydynamics/utils/utils.py +++ b/src/easydynamics/utils/utils.py @@ -12,8 +12,8 @@ Q_type = np.ndarray | Numeric | list | ArrayLike | sc.Variable energy_type = np.ndarray | Numeric | list | ArrayLike | sc.Variable -hbar = DescriptorNumber.from_scipp('hbar', scipp_hbar) -angstrom = DescriptorNumber('angstrom', 1e-10, unit='m') +hbar = DescriptorNumber.from_scipp("hbar", scipp_hbar) +angstrom = DescriptorNumber("angstrom", 1e-10, unit="m") def _validate_and_convert_Q( @@ -37,7 +37,7 @@ def _validate_and_convert_Q( if Q is None: return None if not isinstance(Q, (np.ndarray, Numeric, list, sc.Variable)): - raise TypeError('Q must be a number, list, numpy array, or scipp Variable.') + raise TypeError("Q must be a number, list, numpy array, or scipp Variable.") if isinstance(Q, Numeric): Q = np.array([Q]) @@ -45,14 +45,14 @@ def _validate_and_convert_Q( Q = np.array(Q) if isinstance(Q, np.ndarray): if Q.ndim > 1: - raise ValueError('Q must be a 1-dimensional array.') + raise ValueError("Q must be a 1-dimensional array.") - Q = sc.array(dims=['Q'], values=Q, unit='1/angstrom') + Q = sc.array(dims=["Q"], values=Q, unit="1/angstrom") if isinstance(Q, sc.Variable): - if Q.dims != ('Q',): + if Q.dims != ("Q",): raise ValueError("Q must have a single dimension named 'Q'.") - Q = Q.to(unit='1/angstrom') + Q = Q.to(unit="1/angstrom") return Q.values @@ -70,9 +70,14 @@ def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None: """ if unit is not None and not isinstance(unit, (str, sc.Unit)): - raise TypeError(f'unit must be None, a string, or a scipp Unit, got {type(unit).__name__}') - if isinstance(unit, str): - unit = sc.Unit(unit) + raise TypeError( + f"unit must be None, a string, or a scipp Unit, got {type(unit).__name__}" + ) + # if isinstance(unit, str): + # unit = sc.Unit(unit) + + if isinstance(unit, sc.Unit): + unit = str(unit) return unit @@ -86,9 +91,9 @@ def _in_notebook() -> bool: from IPython import get_ipython shell = get_ipython().__class__.__name__ - if shell == 'ZMQInteractiveShell': + if shell == "ZMQInteractiveShell": return True # Jupyter notebook or JupyterLab - elif shell == 'TerminalInteractiveShell': + elif shell == "TerminalInteractiveShell": return False # Terminal IPython else: return False diff --git a/tests/unit/easydynamics/utils/test_utils.py b/tests/unit/easydynamics/utils/test_utils.py index e5a2e9e9..ab24741a 100644 --- a/tests/unit/easydynamics/utils/test_utils.py +++ b/tests/unit/easydynamics/utils/test_utils.py @@ -12,7 +12,7 @@ class TestValidateAndConvertQ: @pytest.mark.parametrize( - 'Q_input, expected', + "Q_input, expected", [ (1.0, np.array([1.0])), (2, np.array([2])), @@ -30,7 +30,7 @@ def test_validate_and_convert_Q_numeric_and_array(self, Q_input, expected): def test_validate_and_convert_Q_scipp_variable(self): # WHEN - Q = sc.array(dims=['Q'], values=[1.0, 2.0], unit='1/angstrom') + Q = sc.array(dims=["Q"], values=[1.0, 2.0], unit="1/angstrom") # THEN result = _validate_and_convert_Q(Q) @@ -44,29 +44,29 @@ def test_validate_and_convert_Q_none(self): assert _validate_and_convert_Q(None) is None @pytest.mark.parametrize( - 'Q_input', + "Q_input", [ - 'invalid', - {'a': 1}, + "invalid", + {"a": 1}, (1, 2), object(), ], ) def test_validate_and_convert_Q_invalid_type(self, Q_input): # WHEN THEN EXPECT - with pytest.raises(TypeError, match='Q must be a number'): + with pytest.raises(TypeError, match="Q must be a number"): _validate_and_convert_Q(Q_input) def test_validate_and_convert_Q_ndarray_wrong_dim(self): # WHEN THEN Q = np.array([[1.0, 2.0]]) # EXPECT - with pytest.raises(ValueError, match='Q must be a 1-dimensional array'): + with pytest.raises(ValueError, match="Q must be a 1-dimensional array"): _validate_and_convert_Q(Q) def test_validate_and_convert_Q_scipp_wrong_dims(self): # WHEN THEN - Q = sc.array(dims=['x'], values=[1.0, 2.0], unit='1/angstrom') + Q = sc.array(dims=["x"], values=[1.0, 2.0], unit="1/angstrom") # EXPECT with pytest.raises(ValueError, match="single dimension named 'Q'"): @@ -78,12 +78,12 @@ def test_validate_and_convert_Q_scipp_wrong_dims(self): class TestValidateUnit: @pytest.mark.parametrize( - 'unit_input', + "unit_input", [ None, - '1/angstrom', - 'meV', - sc.Unit('meV'), + "1/angstrom", + "meV", + sc.Unit("meV"), ], ) def test_validate_unit_valid(self, unit_input): @@ -92,16 +92,16 @@ def test_validate_unit_valid(self, unit_input): if unit_input is None: assert unit is None else: - assert isinstance(unit, sc.Unit) + assert isinstance(unit, str) def test_validate_unit_string_conversion(self): - unit = _validate_unit('meV') + unit = _validate_unit(sc.Unit("meV")) - assert isinstance(unit, sc.Unit) - assert unit == sc.Unit('meV') + assert isinstance(unit, str) + assert unit == "meV" @pytest.mark.parametrize( - 'unit_input', + "unit_input", [ 123, 45.6, @@ -111,7 +111,9 @@ def test_validate_unit_string_conversion(self): ], ) def test_validate_unit_invalid_type(self, unit_input): - with pytest.raises(TypeError, match='unit must be None, a string, or a scipp Unit'): + with pytest.raises( + TypeError, match="unit must be None, a string, or a scipp Unit" + ): _validate_unit(unit_input) @@ -125,10 +127,10 @@ def test_in_notebook_returns_true_for_jupyter(self, monkeypatch): # WHEN class ZMQInteractiveShell: - __name__ = 'ZMQInteractiveShell' + __name__ = "ZMQInteractiveShell" # THEN - monkeypatch.setattr('IPython.get_ipython', lambda: ZMQInteractiveShell()) + monkeypatch.setattr("IPython.get_ipython", lambda: ZMQInteractiveShell()) # EXPECT assert _in_notebook() is True @@ -139,11 +141,11 @@ def test_in_notebook_returns_false_for_terminal_ipython(self, monkeypatch): # WHEN class TerminalInteractiveShell: - __name__ = 'TerminalInteractiveShell' + __name__ = "TerminalInteractiveShell" # THEN - monkeypatch.setattr('IPython.get_ipython', lambda: TerminalInteractiveShell()) + monkeypatch.setattr("IPython.get_ipython", lambda: TerminalInteractiveShell()) # EXPECT assert _in_notebook() is False @@ -154,10 +156,10 @@ def test_in_notebook_returns_false_for_unknown_shell(self, monkeypatch): # WHEN class UnknownShell: - __name__ = 'UnknownShell' + __name__ = "UnknownShell" # THEN - monkeypatch.setattr('IPython.get_ipython', lambda: UnknownShell()) + monkeypatch.setattr("IPython.get_ipython", lambda: UnknownShell()) # EXPECT assert _in_notebook() is False @@ -170,7 +172,7 @@ def raise_import_error(*args, **kwargs): raise ImportError # THEN - monkeypatch.setattr('builtins.__import__', raise_import_error) + monkeypatch.setattr("builtins.__import__", raise_import_error) # EXPECT assert _in_notebook() is False From 97efa67d8c693849309fd11103a13985f02ecf62 Mon Sep 17 00:00:00 2001 From: henrikjacobsenfys Date: Mon, 30 Mar 2026 14:45:37 +0200 Subject: [PATCH 2/3] use base classes everywhere --- .../base_classes/easydynamics_base.py | 5 +- .../base_classes/easydynamics_modelbase.py | 14 +-- .../convolution/analytical_convolution.py | 39 +++---- src/easydynamics/convolution/convolution.py | 6 +- .../convolution/convolution_base.py | 106 +++++++----------- .../convolution/numerical_convolution.py | 11 +- .../convolution/numerical_convolution_base.py | 80 ++++++------- src/easydynamics/utils/utils.py | 22 ++-- .../convolution/test_convolution.py | 4 +- .../convolution/test_convolution_base.py | 40 +++---- .../convolution/test_numerical_convolution.py | 2 +- .../test_numerical_convolution_base.py | 10 +- tests/unit/easydynamics/utils/test_utils.py | 48 ++++---- 13 files changed, 172 insertions(+), 215 deletions(-) diff --git a/src/easydynamics/base_classes/easydynamics_base.py b/src/easydynamics/base_classes/easydynamics_base.py index d70c4016..03eabe5b 100644 --- a/src/easydynamics/base_classes/easydynamics_base.py +++ b/src/easydynamics/base_classes/easydynamics_base.py @@ -17,7 +17,7 @@ def __init__( Args: name (str | None, default="MyEasyDynamicsModel"): Name of the model. - display_name (str, default="MyEasyDynamicsModel"): Display name of the model. + display_name (str | None, default="MyEasyDynamicsModel"): Display name of the model. unique_name (str | None, default=None): Unique name of the model. If None, a unique name will be generated. @@ -45,6 +45,9 @@ def name(self, name_str: str | None) -> None: Args: name_str (str | None): The new name to set. + + Raises: + TypeError: If name_str is not a string or None. """ if name_str is not None and not isinstance(name_str, str): diff --git a/src/easydynamics/base_classes/easydynamics_modelbase.py b/src/easydynamics/base_classes/easydynamics_modelbase.py index 6d334b88..dfd66820 100644 --- a/src/easydynamics/base_classes/easydynamics_modelbase.py +++ b/src/easydynamics/base_classes/easydynamics_modelbase.py @@ -12,9 +12,9 @@ class EasyDynamicsModelBase(ModelBase): def __init__( self, - unit: str | sc.Unit = "meV", - name: str | None = "MyEasyDynamicsModel", - display_name: str | None = "MyEasyDynamicsModel", + unit: str | sc.Unit = 'meV', + name: str | None = 'MyEasyDynamicsModel', + display_name: str | None = 'MyEasyDynamicsModel', unique_name: str | None = None, ) -> None: """Initialize the EasyDynamicsModelBase. @@ -33,7 +33,7 @@ def __init__( self._unit = _validate_unit(unit) if name is not None and not isinstance(name, str): - raise TypeError("Name must be a string or None.") + raise TypeError('Name must be a string or None.') self._name = name @property @@ -58,8 +58,8 @@ def unit(self, unit_str: str) -> None: read-only. """ raise AttributeError( - f"Unit is read-only. Use convert_unit to change the unit between allowed types " - f"or create a new {self.__class__.__name__} with the desired unit." + f'Unit is read-only. Use convert_unit to change the unit between allowed types ' + f'or create a new {self.__class__.__name__} with the desired unit.' ) # noqa: E501 @property @@ -83,5 +83,5 @@ def name(self, name_str: str) -> None: """ if name_str is not None and not isinstance(name_str, str): - raise TypeError("Name must be a string or None.") + raise TypeError('Name must be a string or None.') self._name = name_str diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py index 86e4cae7..e9b50765 100644 --- a/src/easydynamics/convolution/analytical_convolution.py +++ b/src/easydynamics/convolution/analytical_convolution.py @@ -27,18 +27,18 @@ class AnalyticalConvolution(ConvolutionBase): # Mapping of supported component type pairs to convolution methods. # Delta functions are handled separately. _CONVOLUTIONS = { - ("Gaussian", "Gaussian"): "_convolute_gaussian_gaussian", - ("Gaussian", "Lorentzian"): "_convolute_gaussian_lorentzian", - ("Gaussian", "Voigt"): "_convolute_gaussian_voigt", - ("Lorentzian", "Lorentzian"): "_convolute_lorentzian_lorentzian", - ("Lorentzian", "Voigt"): "_convolute_lorentzian_voigt", - ("Voigt", "Voigt"): "_convolute_voigt_voigt", + ('Gaussian', 'Gaussian'): '_convolute_gaussian_gaussian', + ('Gaussian', 'Lorentzian'): '_convolute_gaussian_lorentzian', + ('Gaussian', 'Voigt'): '_convolute_gaussian_voigt', + ('Lorentzian', 'Lorentzian'): '_convolute_lorentzian_lorentzian', + ('Lorentzian', 'Voigt'): '_convolute_lorentzian_voigt', + ('Voigt', 'Voigt'): '_convolute_voigt_voigt', } def __init__( self, energy: np.ndarray | sc.Variable, - energy_unit: str | sc.Unit = "meV", + unit: str | sc.Unit = 'meV', sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, energy_offset: Numeric | Parameter = 0.0, @@ -48,7 +48,7 @@ def __init__( Args: energy (np.ndarray | sc.Variable): 1D array of energy values where the convolution is evaluated. - energy_unit (str | sc.Unit, default='meV'): The unit of the + unit (str | sc.Unit, default='meV'): The unit of the energy. sample_components (ComponentCollection | ModelComponent | None, default=None): The sample model to be convolved. @@ -59,7 +59,7 @@ def __init__( """ super().__init__( energy=energy, - unit=energy_unit, + unit=unit, sample_components=sample_components, resolution_components=resolution_components, energy_offset=energy_offset, @@ -146,8 +146,8 @@ def _convolute_analytic_pair( if isinstance(resolution_component, DeltaFunction): raise ValueError( - "Analytical convolution with a delta function \ - in the resolution model is not supported." + 'Analytical convolution with a delta function \ + in the resolution model is not supported.' ) # Delta function + anything --> @@ -173,8 +173,8 @@ def _convolute_analytic_pair( if func_name is None: raise ValueError( - f"Analytical convolution not supported for component pair: " - f"{type(sample_component).__name__}, {type(resolution_component).__name__}" + f'Analytical convolution not supported for component pair: ' + f'{type(sample_component).__name__}, {type(resolution_component).__name__}' ) # Call the corresponding method @@ -225,9 +225,7 @@ def _convolute_gaussian_gaussian( np.ndarray: The evaluated convolution values at self.energy. """ - width = np.sqrt( - sample_component.width.value**2 + resolution_component.width.value**2 - ) + width = np.sqrt(sample_component.width.value**2 + resolution_component.width.value**2) area = sample_component.area.value * resolution_component.area.value @@ -286,8 +284,7 @@ def _convolute_gaussian_voigt( center = sample_component.center.value + resolution_component.center.value gaussian_width = np.sqrt( - sample_component.width.value**2 - + resolution_component.gaussian_width.value**2 + sample_component.width.value**2 + resolution_component.gaussian_width.value**2 ) lorentzian_width = resolution_component.lorentzian_width.value @@ -389,13 +386,11 @@ def _convolute_voigt_voigt( center = sample_component.center.value + resolution_component.center.value gaussian_width = np.sqrt( - sample_component.gaussian_width.value**2 - + resolution_component.gaussian_width.value**2 + sample_component.gaussian_width.value**2 + resolution_component.gaussian_width.value**2 ) lorentzian_width = ( - sample_component.lorentzian_width.value - + resolution_component.lorentzian_width.value + sample_component.lorentzian_width.value + resolution_component.lorentzian_width.value ) return self._voigt_eval( area=area, diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py index 3f3f1ef3..5a011526 100644 --- a/src/easydynamics/convolution/convolution.py +++ b/src/easydynamics/convolution/convolution.py @@ -58,7 +58,7 @@ def __init__( extension_factor: Numeric | None = 0.2, temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', - energy_unit: str | sc.Unit = 'meV', + unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, ) -> None: """Initialize the Convolution class. @@ -83,7 +83,7 @@ def __init__( Default is None. temperature_unit (str | sc.Unit, default='K'): The unit of the temperature parameter. Default is 'K'. - energy_unit (str | sc.Unit, default='meV'): The unit of the energy. + unit (str | sc.Unit, default='meV'): The unit of the energy. Default is 'meV'. normalize_detailed_balance (bool, default=True): Whether to normalize the detailed balance correction. Default is @@ -101,7 +101,7 @@ def __init__( extension_factor=extension_factor, temperature=temperature, temperature_unit=temperature_unit, - energy_unit=energy_unit, + unit=unit, normalize_detailed_balance=normalize_detailed_balance, ) diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py index 5ecfc346..36e253fc 100644 --- a/src/easydynamics/convolution/convolution_base.py +++ b/src/easydynamics/convolution/convolution_base.py @@ -5,13 +5,13 @@ import scipp as sc from easyscience.variable import Parameter -from easydynamics.base_classes import EasyDynamicsBase +from easydynamics.base_classes import EasyDynamicsModelBase from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.utils.utils import Numeric -class ConvolutionBase(EasyDynamicsBase): +class ConvolutionBase(EasyDynamicsModelBase): """Base class for convolutions of sample and resolution models. This base class has no convolution functionality. @@ -22,8 +22,10 @@ def __init__( energy: np.ndarray | sc.Variable, sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, - energy_unit: str | sc.Unit = "meV", + unit: str | sc.Unit = 'meV', energy_offset: Numeric | Parameter = 0.0, + display_name: str | None = 'MyConvolutionModel', + unique_name: str | None = None, ) -> None: """Initialize the ConvolutionBase. @@ -34,49 +36,51 @@ def __init__( The sample model to be convolved. resolution_components (ComponentCollection | ModelComponent | None, default=None): The resolution model to convolve with. - energy_unit (str | sc.Unit, default='meV'): The unit of the - energy. + unit (str | sc.Unit, default='meV'): The unit of the + energy axis. energy_offset (Numeric | Parameter, default=0.0): The energy offset applied to the convolution. Default is 0.0. + display_name (str | None, default='MyConvolutionModel'): + The display name of the model. + unique_name (str | None, default=None): The unique name of + the model. If None, a unique name will be generated. Raises: TypeError: If energy is not a numpy ndarray or a scipp - Variable or if energy_unit is not a string or scipp unit, or if + Variable or if energy_offset is not a number or a Parameter, or if sample_components is not a ComponentCollection or ModelComponent, or if resolution_components is not a ComponentCollection or ModelComponent. """ + super().__init__( + unit=unit, + display_name=display_name, + unique_name=unique_name, + ) + if isinstance(energy, Numeric): energy = np.array([float(energy)]) if not isinstance(energy, (np.ndarray, sc.Variable)): - raise TypeError( - f"Energy must be a numpy ndarray or a scipp Variable. Got {energy}" - ) - - if not isinstance(energy_unit, (str, sc.Unit)): - raise TypeError("Energy_unit must be a string or sc.Unit.") + raise TypeError(f'Energy must be a numpy ndarray or a scipp Variable. Got {energy}') if isinstance(energy, np.ndarray): - energy = sc.array(dims=["energy"], values=energy, unit=energy_unit) + energy = sc.array(dims=['energy'], values=energy, unit=unit) if isinstance(energy_offset, Numeric): - energy_offset = Parameter( - name="energy_offset", value=float(energy_offset), unit=energy_unit - ) + energy_offset = Parameter(name='energy_offset', value=float(energy_offset), unit=unit) if not isinstance(energy_offset, Parameter): - raise TypeError("Energy_offset must be a number or a Parameter.") + raise TypeError('Energy_offset must be a number or a Parameter.') self._energy = energy - self._energy_unit = energy_unit self._energy_offset = energy_offset if sample_components is not None and not ( isinstance(sample_components, (ComponentCollection, ModelComponent)) ): raise TypeError( - f"`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501 + f'`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 ) if isinstance(sample_components, ModelComponent): sample_components = ComponentCollection(components=[sample_components]) @@ -86,12 +90,10 @@ def __init__( isinstance(resolution_components, (ComponentCollection, ModelComponent)) ): raise TypeError( - f"`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501 + f'`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 ) if isinstance(resolution_components, ModelComponent): - resolution_components = ComponentCollection( - components=[resolution_components] - ) + resolution_components = ComponentCollection(components=[resolution_components]) self._resolution_components = resolution_components @property @@ -115,7 +117,7 @@ def energy_offset(self, energy_offset: Numeric | Parameter) -> None: TypeError: If energy_offset is not a number or a Parameter. """ if not isinstance(energy_offset, Parameter | Numeric): - raise TypeError("Energy_offset must be a number or a Parameter.") + raise TypeError('Energy_offset must be a number or a Parameter.') if isinstance(energy_offset, Numeric): self._energy_offset.value = float(energy_offset) @@ -147,7 +149,7 @@ def energy_with_offset(self, value: sc.Variable) -> None: read-only. """ raise AttributeError( - "Energy with offset is a read-only property derived from energy and energy_offset." + 'Energy with offset is a read-only property derived from energy and energy_offset.' ) @property @@ -178,64 +180,44 @@ def energy(self, energy: np.ndarray | sc.Variable) -> None: energy = np.array([float(energy)]) if not isinstance(energy, (np.ndarray, sc.Variable)): - raise TypeError( - "Energy must be a Number, a numpy ndarray or a scipp Variable." - ) + raise TypeError('Energy must be a Number, a numpy ndarray or a scipp Variable.') if isinstance(energy, np.ndarray): - self._energy = sc.array( - dims=["energy"], values=energy, unit=self._energy.unit - ) + self._energy = sc.array(dims=['energy'], values=energy, unit=self._energy.unit) if isinstance(energy, sc.Variable): self._energy = energy - self._energy_unit = energy.unit + self._unit = energy.unit - @property - def energy_unit(self) -> str: - """Get the energy unit. - - Returns: - str: The unit of the energy. - """ - return self._energy_unit - - @energy_unit.setter - def energy_unit(self, unit_str: str) -> None: - raise AttributeError( - f"Unit is read-only. Use convert_unit to change the unit between allowed types " - f"or create a new {self.__class__.__name__} with the desired unit." - ) # noqa: E501 - - def convert_energy_unit(self, energy_unit: str | sc.Unit) -> None: + def convert_unit(self, unit: str | sc.Unit) -> None: """Convert the energy and energy_offset to the specified unit. Args: - energy_unit (str | sc.Unit): The unit of the energy. + unit (str | sc.Unit): The unit of the energy. Raises: - TypeError: If energy_unit is not a string or scipp unit. + TypeError: If unit is not a string or scipp unit. Exception: If energy cannot be converted to the specified unit. """ - if not isinstance(energy_unit, (str, sc.Unit)): - raise TypeError("Energy unit must be a string or scipp unit.") + if not isinstance(unit, (str, sc.Unit)): + raise TypeError('Energy unit must be a string or scipp unit.') old_energy = self.energy.copy() try: - self.energy = sc.to_unit(self.energy, energy_unit) + self.energy = sc.to_unit(self.energy, unit) except Exception as e: self.energy = old_energy raise e old_energy_offset = self.energy_offset try: - self.energy_offset.convert_unit(energy_unit) + self.energy_offset.convert_unit(unit) except Exception as e: self.energy_offset = old_energy_offset raise e - self._energy_unit = energy_unit + self._unit = unit @property def sample_components(self) -> ComponentCollection | ModelComponent: @@ -248,9 +230,7 @@ def sample_components(self) -> ComponentCollection | ModelComponent: return self._sample_components @sample_components.setter - def sample_components( - self, sample_components: ComponentCollection | ModelComponent - ) -> None: + def sample_components(self, sample_components: ComponentCollection | ModelComponent) -> None: """Set the sample model. Args: @@ -263,7 +243,7 @@ def sample_components( """ if not isinstance(sample_components, (ComponentCollection, ModelComponent)): raise TypeError( - f"`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501 + f'`sample_components` is an instance of {type(sample_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 ) if isinstance(sample_components, ModelComponent): @@ -297,11 +277,9 @@ def resolution_components( """ if not isinstance(resolution_components, (ComponentCollection, ModelComponent)): raise TypeError( - f"`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent." # noqa: E501 + f'`resolution_components` is an instance of {type(resolution_components).__name__}, but must be a ComponentCollection or ModelComponent.' # noqa: E501 ) if isinstance(resolution_components, ModelComponent): - resolution_components = ComponentCollection( - components=[resolution_components] - ) + resolution_components = ComponentCollection(components=[resolution_components]) self._resolution_components = resolution_components diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py index 10a55c81..32c5bb3e 100644 --- a/src/easydynamics/convolution/numerical_convolution.py +++ b/src/easydynamics/convolution/numerical_convolution.py @@ -33,7 +33,7 @@ def __init__( extension_factor: Numeric | None = 0.2, temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', - energy_unit: str | sc.Unit = 'meV', + unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, ) -> None: """Initialize the NumericalConvolution object. @@ -53,14 +53,11 @@ def __init__( extend the input data range before convolution. temperature (Parameter | Numeric | None, default=None): The temperature to use for detailed balance correction. - Default is None. temperature_unit (str | sc.Unit, default='K'): The unit of the temperature parameter. - energy_unit (str | sc.Unit, default='meV'): The unit of the - energy. Default is 'meV'. + unit (str | sc.Unit, default='meV'): The unit of the energy. normalize_detailed_balance (bool, default=True): Whether to - normalize the detailed balance correction. Default is - True. + normalize the detailed balance correction. """ super().__init__( energy=energy, @@ -71,7 +68,7 @@ def __init__( extension_factor=extension_factor, temperature=temperature, temperature_unit=temperature_unit, - energy_unit=energy_unit, + unit=unit, normalize_detailed_balance=normalize_detailed_balance, ) diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py index ffbd3736..d580ede8 100644 --- a/src/easydynamics/convolution/numerical_convolution_base.py +++ b/src/easydynamics/convolution/numerical_convolution_base.py @@ -42,8 +42,8 @@ def __init__( upsample_factor: Numeric | None = 5, extension_factor: Numeric | None = 0.2, temperature: Parameter | Numeric | None = None, - temperature_unit: str | sc.Unit = "K", - energy_unit: str | sc.Unit = "meV", + temperature_unit: str | sc.Unit = 'K', + unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, ) -> None: """Initialize the NumericalConvolutionBase. @@ -65,7 +65,7 @@ def __init__( use for detailed balance correction. temperature_unit (str | sc.Unit, default='K'): The unit of the temperature parameter. - energy_unit (str | sc.Unit, default='meV'): The unit of the energy. + unit (str | sc.Unit, default='meV'): The unit of the energy. normalize_detailed_balance (bool, default=True): Whether to normalize the detailed balance correction. @@ -79,17 +79,15 @@ def __init__( energy=energy, sample_components=sample_components, resolution_components=resolution_components, - unit=energy_unit, + unit=unit, energy_offset=energy_offset, ) - if temperature is not None and not isinstance( - temperature, (Numeric, Parameter) - ): - raise TypeError("Temperature must be None, a number or a Parameter.") + if temperature is not None and not isinstance(temperature, (Numeric, Parameter)): + raise TypeError('Temperature must be None, a number or a Parameter.') if not isinstance(temperature_unit, (str, sc.Unit)): - raise TypeError("Temperature_unit must be a string or sc.Unit.") + raise TypeError('Temperature_unit must be a string or sc.Unit.') self._temperature_unit = temperature_unit self._temperature = None self.temperature = temperature @@ -142,10 +140,10 @@ def upsample_factor(self, factor: Numeric | None) -> None: return if not isinstance(factor, Numeric): - raise TypeError("Upsample factor must be a numerical value or None.") + raise TypeError('Upsample factor must be a numerical value or None.') factor = float(factor) if factor <= 1.0: - raise ValueError("Upsample factor must be greater than 1.") + raise ValueError('Upsample factor must be greater than 1.') self._upsample_factor = factor @@ -185,9 +183,9 @@ def extension_factor(self, factor: Numeric) -> None: """ if not isinstance(factor, Numeric): - raise TypeError("Extension factor must be a number.") + raise TypeError('Extension factor must be a number.') if factor < 0.0: - raise ValueError("Extension factor must be non-negative.") + raise ValueError('Extension factor must be non-negative.') self._extension_factor = float(factor) # Recreate dense grid when extension factor is updated @@ -227,7 +225,7 @@ def temperature(self, temp: Parameter | Numeric | None) -> None: self._temperature.value = float(temp) else: self._temperature = Parameter( - name="temperature", + name='temperature', value=float(temp), unit=self._temperature_unit, fixed=True, @@ -235,7 +233,7 @@ def temperature(self, temp: Parameter | Numeric | None) -> None: elif isinstance(temp, Parameter): self._temperature = temp else: - raise TypeError("Temperature must be None, a float or a Parameter.") + raise TypeError('Temperature must be None, a float or a Parameter.') @property def normalize_detailed_balance(self) -> bool: @@ -264,7 +262,7 @@ def normalize_detailed_balance(self, normalize: bool) -> None: """ if not isinstance(normalize, bool): - raise TypeError("normalize_detailed_balance must be True or False.") + raise TypeError('normalize_detailed_balance must be True or False.') self._normalize_detailed_balance = normalize @@ -292,7 +290,7 @@ def _create_energy_grid( is_uniform = np.allclose(energy_diff, energy_diff[0]) if not is_uniform: raise ValueError( - "Input array `energy` must be uniformly spaced if upsample_factor is not given." # noqa: E501 + 'Input array `energy` must be uniformly spaced if upsample_factor is not given.' # noqa: E501 ) energy_dense = self.energy.values @@ -309,7 +307,7 @@ def _create_energy_grid( energy_span_dense = extended_max - extended_min if len(energy_dense) < 2: - raise ValueError("Energy array must have at least two points.") + raise ValueError('Energy array must have at least two points.') energy_dense_step = energy_dense[1] - energy_dense[0] # Handle offset for even length of energy_dense in convolution. @@ -322,9 +320,7 @@ def _create_energy_grid( # select the 4 central points we either get # indices [2,3,4,5] or [1,2,3,4], both of which are offset by # 0.5*dx from the true center at index 3.5. - energy_even_length_offset = ( - -0.5 * energy_dense_step if len(energy_dense) % 2 == 0 else 0.0 - ) + energy_even_length_offset = -0.5 * energy_dense_step if len(energy_dense) % 2 == 0 else 0.0 # Handle the case when energy_dense is not symmetric around 0. # The resolution is still centered around zero (or close to it), @@ -366,35 +362,27 @@ def _check_width_thresholds( """ # Handle ComponentCollection or ModelComponent - components = ( - model.components if isinstance(model, ComponentCollection) else [model] - ) + components = model.components if isinstance(model, ComponentCollection) else [model] for comp in components: - if hasattr(comp, "width"): - if ( - comp.width.value - > LARGE_WIDTH_THRESHOLD * self._energy_grid.energy_span_dense - ): + if hasattr(comp, 'width'): + if comp.width.value > LARGE_WIDTH_THRESHOLD * self._energy_grid.energy_span_dense: warnings.warn( f"The width of the {model_name} component '{comp.unique_name}' \ ({comp.width.value}) is large compared to the span of the input " - f"array ({self._energy_grid.energy_span_dense}). \ + f'array ({self._energy_grid.energy_span_dense}). \ This may lead to inaccuracies in the convolution. \ - Increase extension_factor to improve accuracy.", + Increase extension_factor to improve accuracy.', UserWarning, stacklevel=3, ) - if ( - comp.width.value - < SMALL_WIDTH_THRESHOLD * self._energy_grid.energy_dense_step - ): + if comp.width.value < SMALL_WIDTH_THRESHOLD * self._energy_grid.energy_dense_step: warnings.warn( f"The width of the {model_name} component '{comp.unique_name}' \ ({comp.width.value}) is small compared to the spacing of the input " - f"array ({self._energy_grid.energy_dense_step}). \ + f'array ({self._energy_grid.energy_dense_step}). \ This may lead to inaccuracies in the convolution. \ - Increase upsample_factor to improve accuracy.", + Increase upsample_factor to improve accuracy.', UserWarning, stacklevel=3, ) @@ -408,13 +396,13 @@ def __repr__(self) -> str: NumericalConvolutionBase. """ return ( - f"{self.__class__.__name__}(" - f"energy=array of shape {self.energy.values.shape},\n " - f"sample_components={repr(self.sample_components)}, \n" - f"resolution_components={repr(self.resolution_components)},\n " - f"energy_unit={self._energy_unit}, " - f"upsample_factor={self.upsample_factor}, " - f"extension_factor={self.extension_factor}, " - f"temperature={self.temperature}, " - f"normalize_detailed_balance={self.normalize_detailed_balance})" + f'{self.__class__.__name__}(' + f'energy=array of shape {self.energy.values.shape},\n ' + f'sample_components={repr(self.sample_components)}, \n' + f'resolution_components={repr(self.resolution_components)},\n ' + f'unit={self.unit}, ' + f'upsample_factor={self.upsample_factor}, ' + f'extension_factor={self.extension_factor}, ' + f'temperature={self.temperature}, ' + f'normalize_detailed_balance={self.normalize_detailed_balance})' ) diff --git a/src/easydynamics/utils/utils.py b/src/easydynamics/utils/utils.py index be7f5506..6ab23d11 100644 --- a/src/easydynamics/utils/utils.py +++ b/src/easydynamics/utils/utils.py @@ -12,8 +12,8 @@ Q_type = np.ndarray | Numeric | list | ArrayLike | sc.Variable energy_type = np.ndarray | Numeric | list | ArrayLike | sc.Variable -hbar = DescriptorNumber.from_scipp("hbar", scipp_hbar) -angstrom = DescriptorNumber("angstrom", 1e-10, unit="m") +hbar = DescriptorNumber.from_scipp('hbar', scipp_hbar) +angstrom = DescriptorNumber('angstrom', 1e-10, unit='m') def _validate_and_convert_Q( @@ -37,7 +37,7 @@ def _validate_and_convert_Q( if Q is None: return None if not isinstance(Q, (np.ndarray, Numeric, list, sc.Variable)): - raise TypeError("Q must be a number, list, numpy array, or scipp Variable.") + raise TypeError('Q must be a number, list, numpy array, or scipp Variable.') if isinstance(Q, Numeric): Q = np.array([Q]) @@ -45,14 +45,14 @@ def _validate_and_convert_Q( Q = np.array(Q) if isinstance(Q, np.ndarray): if Q.ndim > 1: - raise ValueError("Q must be a 1-dimensional array.") + raise ValueError('Q must be a 1-dimensional array.') - Q = sc.array(dims=["Q"], values=Q, unit="1/angstrom") + Q = sc.array(dims=['Q'], values=Q, unit='1/angstrom') if isinstance(Q, sc.Variable): - if Q.dims != ("Q",): + if Q.dims != ('Q',): raise ValueError("Q must have a single dimension named 'Q'.") - Q = Q.to(unit="1/angstrom") + Q = Q.to(unit='1/angstrom') return Q.values @@ -70,9 +70,7 @@ def _validate_unit(unit: str | sc.Unit | None) -> sc.Unit | None: """ if unit is not None and not isinstance(unit, (str, sc.Unit)): - raise TypeError( - f"unit must be None, a string, or a scipp Unit, got {type(unit).__name__}" - ) + raise TypeError(f'unit must be None, a string, or a scipp Unit, got {type(unit).__name__}') # if isinstance(unit, str): # unit = sc.Unit(unit) @@ -91,9 +89,9 @@ def _in_notebook() -> bool: from IPython import get_ipython shell = get_ipython().__class__.__name__ - if shell == "ZMQInteractiveShell": + if shell == 'ZMQInteractiveShell': return True # Jupyter notebook or JupyterLab - elif shell == "TerminalInteractiveShell": + elif shell == 'TerminalInteractiveShell': return False # Terminal IPython else: return False diff --git a/tests/unit/easydynamics/convolution/test_convolution.py b/tests/unit/easydynamics/convolution/test_convolution.py index 9151409e..2983da4f 100644 --- a/tests/unit/easydynamics/convolution/test_convolution.py +++ b/tests/unit/easydynamics/convolution/test_convolution.py @@ -78,7 +78,7 @@ def test_init(self, default_convolution): assert default_convolution.upsample_factor == 5 assert default_convolution.extension_factor == 0.2 assert default_convolution.temperature is None - assert default_convolution.energy_unit == 'meV' + assert default_convolution.unit == 'meV' assert default_convolution.normalize_detailed_balance is True assert isinstance(default_convolution._energy_grid, EnergyGrid) @@ -112,7 +112,7 @@ def test_init_components(self, convolution_with_components): assert convolution_with_components.upsample_factor == 5 assert convolution_with_components.extension_factor == 0.2 assert convolution_with_components.temperature is None - assert convolution_with_components.energy_unit == 'meV' + assert convolution_with_components.unit == 'meV' assert convolution_with_components.normalize_detailed_balance is True assert isinstance(convolution_with_components._energy_grid, EnergyGrid) diff --git a/tests/unit/easydynamics/convolution/test_convolution_base.py b/tests/unit/easydynamics/convolution/test_convolution_base.py index 44ae133c..393c4af3 100644 --- a/tests/unit/easydynamics/convolution/test_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_convolution_base.py @@ -78,7 +78,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': 'invalid', 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'energy_unit': 'meV', + 'unit': 'meV', 'energy_offset': 0, }, 'Energy must be', @@ -88,7 +88,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': 'invalid', 'resolution_components': ComponentCollection(), - 'energy_unit': 'meV', + 'unit': 'meV', 'energy_offset': 0, }, ( @@ -101,7 +101,7 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': 'invalid', - 'energy_unit': 'meV', + 'unit': 'meV', 'energy_offset': 0, }, ( @@ -114,17 +114,17 @@ def test_init_energy_numerical_none_offset(self): 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'energy_unit': 123, + 'unit': 123, 'energy_offset': 0, }, - 'Energy_unit must be ', + 'unit must be ', ), ( { 'energy': np.linspace(-10, 10, 100), 'sample_components': ComponentCollection(), 'resolution_components': ComponentCollection(), - 'energy_unit': 'meV', + 'unit': 'meV', 'energy_offset': 'invalid', }, 'Energy_offset must be ', @@ -173,48 +173,48 @@ def test_energy_setter_invalid_type_raises(self, convolution_base): ): convolution_base.energy = 'invalid' - def test_energy_unit_property(self, convolution_base): + def test_unit_property(self, convolution_base): # WHEN THEN EXPECT assert convolution_base.energy.unit == 'meV' - def test_energy_unit_setter_raises(self, convolution_base): + def test_unit_setter_raises(self, convolution_base): # WHEN THEN EXPECT with pytest.raises( AttributeError, match='Use convert_unit to change the unit between allowed types ', ): - convolution_base.energy_unit = 'K' + convolution_base.unit = 'K' - def test_convert_energy_unit(self, convolution_base): + def test_convert_unit(self, convolution_base): # WHEN THEN - convolution_base.convert_energy_unit('eV') + convolution_base.convert_unit('eV') # EXPECT assert convolution_base.energy.unit == 'eV' - assert convolution_base.energy_unit == 'eV' + assert convolution_base.unit == 'eV' assert np.allclose(convolution_base.energy.values, np.linspace(-0.01, 0.01, 100)) - def test_convert_energy_unit_invalid_type_raises(self, convolution_base): + def test_convert_unit_invalid_type_raises(self, convolution_base): # WHEN THEN EXPECT with pytest.raises( TypeError, match='Energy unit must be a string or scipp unit.', ): - convolution_base.convert_energy_unit(123) + convolution_base.convert_unit(123) - def test_convert_energy_unit_invalid_unit_rollback(self, convolution_base): + def test_convert_unit_invalid_unit_rollback(self, convolution_base): # WHEN THEN with pytest.raises( UnitError, match='Conversion from `meV` to `s` is not valid.', ): - convolution_base.convert_energy_unit('s') + convolution_base.convert_unit('s') # EXPECT - assert convolution_base.energy_unit == 'meV' + assert convolution_base.unit == 'meV' assert np.allclose(convolution_base.energy.values, np.linspace(-10, 10, 100)) - def test_convert_energy_unit_invalid_offset_unit_rollback(self, convolution_base): + def test_convert_unit_invalid_offset_unit_rollback(self, convolution_base): # WHEN convolution_base.energy_offset = Parameter(name='energy_offset', value=5, unit='s') @@ -223,10 +223,10 @@ def test_convert_energy_unit_invalid_offset_unit_rollback(self, convolution_base UnitError, match='Conversion from `s` to `meV` is not valid.', ): - convolution_base.convert_energy_unit('meV') + convolution_base.convert_unit('meV') # EXPECT - assert convolution_base.energy_unit == 'meV' + assert convolution_base.unit == 'meV' assert convolution_base.energy_offset.unit == 's' def test_energy_offset_property(self, convolution_base): diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution.py b/tests/unit/easydynamics/convolution/test_numerical_convolution.py index 29e5dfd6..de28a2bc 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution.py @@ -48,7 +48,7 @@ def test_init(self, default_numerical_convolution): assert default_numerical_convolution.upsample_factor == 5 assert default_numerical_convolution.extension_factor == 0.2 assert default_numerical_convolution.temperature is None - assert default_numerical_convolution.energy_unit == 'meV' + assert default_numerical_convolution.unit == 'meV' assert default_numerical_convolution.normalize_detailed_balance is True assert isinstance(default_numerical_convolution._energy_grid, EnergyGrid) diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py index 3934eefb..1c462608 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py @@ -46,7 +46,7 @@ def test_init(self, default_numerical_convolution_base): assert default_numerical_convolution_base.upsample_factor == 5 assert default_numerical_convolution_base.extension_factor == 0.2 assert default_numerical_convolution_base.temperature is None - assert default_numerical_convolution_base.energy_unit == 'meV' + assert default_numerical_convolution_base.unit == 'meV' assert default_numerical_convolution_base.normalize_detailed_balance is True assert isinstance(default_numerical_convolution_base._energy_grid, EnergyGrid) @@ -63,7 +63,7 @@ def test_init_with_custom_parameters(self): extension_factor = 0.5 temperature = 300.0 temperature_unit = 'K' - energy_unit = 'meV' + unit = 'meV' normalize_detailed_balance = False # THEN @@ -75,7 +75,7 @@ def test_init_with_custom_parameters(self): extension_factor=extension_factor, temperature=temperature, temperature_unit=temperature_unit, - energy_unit=energy_unit, + unit=unit, normalize_detailed_balance=normalize_detailed_balance, ) @@ -84,7 +84,7 @@ def test_init_with_custom_parameters(self): assert numerical_convolution_base.extension_factor == extension_factor assert numerical_convolution_base.temperature.value == temperature assert numerical_convolution_base.temperature.unit == temperature_unit - assert numerical_convolution_base.energy_unit == energy_unit + assert numerical_convolution_base.unit == unit assert numerical_convolution_base.normalize_detailed_balance == normalize_detailed_balance assert isinstance(numerical_convolution_base._energy_grid, EnergyGrid) @@ -502,7 +502,7 @@ def test_repr(self, default_numerical_convolution_base): assert 'resolution_components=' in repr_str # Important parameters - assert 'energy_unit=meV' in repr_str + assert 'unit=meV' in repr_str assert 'upsample_factor=5' in repr_str assert 'extension_factor=0.2' in repr_str assert 'temperature=None' in repr_str diff --git a/tests/unit/easydynamics/utils/test_utils.py b/tests/unit/easydynamics/utils/test_utils.py index ab24741a..e4d950c3 100644 --- a/tests/unit/easydynamics/utils/test_utils.py +++ b/tests/unit/easydynamics/utils/test_utils.py @@ -12,7 +12,7 @@ class TestValidateAndConvertQ: @pytest.mark.parametrize( - "Q_input, expected", + 'Q_input, expected', [ (1.0, np.array([1.0])), (2, np.array([2])), @@ -30,7 +30,7 @@ def test_validate_and_convert_Q_numeric_and_array(self, Q_input, expected): def test_validate_and_convert_Q_scipp_variable(self): # WHEN - Q = sc.array(dims=["Q"], values=[1.0, 2.0], unit="1/angstrom") + Q = sc.array(dims=['Q'], values=[1.0, 2.0], unit='1/angstrom') # THEN result = _validate_and_convert_Q(Q) @@ -44,29 +44,29 @@ def test_validate_and_convert_Q_none(self): assert _validate_and_convert_Q(None) is None @pytest.mark.parametrize( - "Q_input", + 'Q_input', [ - "invalid", - {"a": 1}, + 'invalid', + {'a': 1}, (1, 2), object(), ], ) def test_validate_and_convert_Q_invalid_type(self, Q_input): # WHEN THEN EXPECT - with pytest.raises(TypeError, match="Q must be a number"): + with pytest.raises(TypeError, match='Q must be a number'): _validate_and_convert_Q(Q_input) def test_validate_and_convert_Q_ndarray_wrong_dim(self): # WHEN THEN Q = np.array([[1.0, 2.0]]) # EXPECT - with pytest.raises(ValueError, match="Q must be a 1-dimensional array"): + with pytest.raises(ValueError, match='Q must be a 1-dimensional array'): _validate_and_convert_Q(Q) def test_validate_and_convert_Q_scipp_wrong_dims(self): # WHEN THEN - Q = sc.array(dims=["x"], values=[1.0, 2.0], unit="1/angstrom") + Q = sc.array(dims=['x'], values=[1.0, 2.0], unit='1/angstrom') # EXPECT with pytest.raises(ValueError, match="single dimension named 'Q'"): @@ -78,12 +78,12 @@ def test_validate_and_convert_Q_scipp_wrong_dims(self): class TestValidateUnit: @pytest.mark.parametrize( - "unit_input", + 'unit_input', [ None, - "1/angstrom", - "meV", - sc.Unit("meV"), + '1/angstrom', + 'meV', + sc.Unit('meV'), ], ) def test_validate_unit_valid(self, unit_input): @@ -95,13 +95,13 @@ def test_validate_unit_valid(self, unit_input): assert isinstance(unit, str) def test_validate_unit_string_conversion(self): - unit = _validate_unit(sc.Unit("meV")) + unit = _validate_unit(sc.Unit('meV')) assert isinstance(unit, str) - assert unit == "meV" + assert unit == 'meV' @pytest.mark.parametrize( - "unit_input", + 'unit_input', [ 123, 45.6, @@ -111,9 +111,7 @@ def test_validate_unit_string_conversion(self): ], ) def test_validate_unit_invalid_type(self, unit_input): - with pytest.raises( - TypeError, match="unit must be None, a string, or a scipp Unit" - ): + with pytest.raises(TypeError, match='unit must be None, a string, or a scipp Unit'): _validate_unit(unit_input) @@ -127,10 +125,10 @@ def test_in_notebook_returns_true_for_jupyter(self, monkeypatch): # WHEN class ZMQInteractiveShell: - __name__ = "ZMQInteractiveShell" + __name__ = 'ZMQInteractiveShell' # THEN - monkeypatch.setattr("IPython.get_ipython", lambda: ZMQInteractiveShell()) + monkeypatch.setattr('IPython.get_ipython', lambda: ZMQInteractiveShell()) # EXPECT assert _in_notebook() is True @@ -141,11 +139,11 @@ def test_in_notebook_returns_false_for_terminal_ipython(self, monkeypatch): # WHEN class TerminalInteractiveShell: - __name__ = "TerminalInteractiveShell" + __name__ = 'TerminalInteractiveShell' # THEN - monkeypatch.setattr("IPython.get_ipython", lambda: TerminalInteractiveShell()) + monkeypatch.setattr('IPython.get_ipython', lambda: TerminalInteractiveShell()) # EXPECT assert _in_notebook() is False @@ -156,10 +154,10 @@ def test_in_notebook_returns_false_for_unknown_shell(self, monkeypatch): # WHEN class UnknownShell: - __name__ = "UnknownShell" + __name__ = 'UnknownShell' # THEN - monkeypatch.setattr("IPython.get_ipython", lambda: UnknownShell()) + monkeypatch.setattr('IPython.get_ipython', lambda: UnknownShell()) # EXPECT assert _in_notebook() is False @@ -172,7 +170,7 @@ def raise_import_error(*args, **kwargs): raise ImportError # THEN - monkeypatch.setattr("builtins.__import__", raise_import_error) + monkeypatch.setattr('builtins.__import__', raise_import_error) # EXPECT assert _in_notebook() is False From 3559c862b9bf391e105061c70d21d37a290691cf Mon Sep 17 00:00:00 2001 From: henrikjacobsenfys Date: Mon, 30 Mar 2026 20:09:47 +0200 Subject: [PATCH 3/3] add test --- .../convolution/analytical_convolution.py | 7 ++ src/easydynamics/convolution/convolution.py | 8 ++ .../convolution/convolution_base.py | 4 +- .../convolution/numerical_convolution.py | 8 ++ .../convolution/numerical_convolution_base.py | 7 ++ .../base_classes/test_easydynamics_base.py | 70 ++++++++++++++++ .../test_easydynamics_modelbase.py | 80 +++++++++++++++++++ .../components/test_model_component.py | 2 - 8 files changed, 182 insertions(+), 4 deletions(-) create mode 100644 tests/unit/easydynamics/base_classes/test_easydynamics_base.py create mode 100644 tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py index e9b50765..804ad50e 100644 --- a/src/easydynamics/convolution/analytical_convolution.py +++ b/src/easydynamics/convolution/analytical_convolution.py @@ -42,6 +42,8 @@ def __init__( sample_components: ComponentCollection | ModelComponent | None = None, resolution_components: ComponentCollection | ModelComponent | None = None, energy_offset: Numeric | Parameter = 0.0, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """Initialize an AnalyticalConvolution. @@ -56,6 +58,9 @@ def __init__( The resolution model to convolve with. energy_offset (Numeric | Parameter, default=0.0): An offset to shift the energy values by. + display_name (str | None, default='MyConvolution'): Display name of the model. + unique_name (str | None, default=None): Unique name of the model. If None, + a unique name will be generated. """ super().__init__( energy=energy, @@ -63,6 +68,8 @@ def __init__( sample_components=sample_components, resolution_components=resolution_components, energy_offset=energy_offset, + display_name=display_name, + unique_name=unique_name, ) def convolution( diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py index 5a011526..c347e8a2 100644 --- a/src/easydynamics/convolution/convolution.py +++ b/src/easydynamics/convolution/convolution.py @@ -60,6 +60,8 @@ def __init__( temperature_unit: str | sc.Unit = 'K', unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """Initialize the Convolution class. @@ -88,6 +90,10 @@ def __init__( normalize_detailed_balance (bool, default=True): Whether to normalize the detailed balance correction. Default is True. + display_name (str | None, default='MyConvolution'): The display + name of the convolution. + unique_name (str | None, default=None): The unique name of the + convolution. """ self._convolution_plan_is_valid = False @@ -103,6 +109,8 @@ def __init__( temperature_unit=temperature_unit, unit=unit, normalize_detailed_balance=normalize_detailed_balance, + display_name=display_name, + unique_name=unique_name, ) self._reactions_enabled = True diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py index 36e253fc..22a3046a 100644 --- a/src/easydynamics/convolution/convolution_base.py +++ b/src/easydynamics/convolution/convolution_base.py @@ -24,7 +24,7 @@ def __init__( resolution_components: ComponentCollection | ModelComponent | None = None, unit: str | sc.Unit = 'meV', energy_offset: Numeric | Parameter = 0.0, - display_name: str | None = 'MyConvolutionModel', + display_name: str | None = 'MyConvolution', unique_name: str | None = None, ) -> None: """Initialize the ConvolutionBase. @@ -40,7 +40,7 @@ def __init__( energy axis. energy_offset (Numeric | Parameter, default=0.0): The energy offset applied to the convolution. Default is 0.0. - display_name (str | None, default='MyConvolutionModel'): + display_name (str | None, default='MyConvolution'): The display name of the model. unique_name (str | None, default=None): The unique name of the model. If None, a unique name will be generated. diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py index 32c5bb3e..eafb797b 100644 --- a/src/easydynamics/convolution/numerical_convolution.py +++ b/src/easydynamics/convolution/numerical_convolution.py @@ -35,6 +35,8 @@ def __init__( temperature_unit: str | sc.Unit = 'K', unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """Initialize the NumericalConvolution object. @@ -58,6 +60,10 @@ def __init__( unit (str | sc.Unit, default='meV'): The unit of the energy. normalize_detailed_balance (bool, default=True): Whether to normalize the detailed balance correction. + display_name (str | None, default='MyConvolution'): The display + name of the convolution. + unique_name (str | None, default=None): The unique name of the + convolution. If None, a unique name will be generated. """ super().__init__( energy=energy, @@ -70,6 +76,8 @@ def __init__( temperature_unit=temperature_unit, unit=unit, normalize_detailed_balance=normalize_detailed_balance, + display_name=display_name, + unique_name=unique_name, ) def convolution( diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py index d580ede8..366f07d0 100644 --- a/src/easydynamics/convolution/numerical_convolution_base.py +++ b/src/easydynamics/convolution/numerical_convolution_base.py @@ -45,6 +45,8 @@ def __init__( temperature_unit: str | sc.Unit = 'K', unit: str | sc.Unit = 'meV', normalize_detailed_balance: bool = True, + display_name: str | None = 'MyConvolution', + unique_name: str | None = None, ) -> None: """Initialize the NumericalConvolutionBase. @@ -68,6 +70,9 @@ def __init__( unit (str | sc.Unit, default='meV'): The unit of the energy. normalize_detailed_balance (bool, default=True): Whether to normalize the detailed balance correction. + display_name (str | None, default='MyConvolution'): Display name of the model. + unique_name (str | None, default=None): Unique name of the model. If None, + a unique name will be generated. Raises: TypeError: If temperature is not None, a number, or a @@ -81,6 +86,8 @@ def __init__( resolution_components=resolution_components, unit=unit, energy_offset=energy_offset, + display_name=display_name, + unique_name=unique_name, ) if temperature is not None and not isinstance(temperature, (Numeric, Parameter)): diff --git a/tests/unit/easydynamics/base_classes/test_easydynamics_base.py b/tests/unit/easydynamics/base_classes/test_easydynamics_base.py new file mode 100644 index 00000000..a020fa97 --- /dev/null +++ b/tests/unit/easydynamics/base_classes/test_easydynamics_base.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from easydynamics.base_classes import EasyDynamicsBase + + +class TestEasyDynamicsBase: + """Tests for the EasyDynamicsBase class.""" + + @pytest.fixture + def easy_dynamics_base(self): + """Fixture for creating an instance of EasyDynamicsBase.""" + + return EasyDynamicsBase(name='TestModel') + + def test_initialization(self, easy_dynamics_base): + """Test that the EasyDynamicsBase is initialized correctly.""" + + # WHEN THEN EXPECT + assert easy_dynamics_base.name == 'TestModel' + assert easy_dynamics_base.display_name == 'MyEasyDynamicsModel' + assert easy_dynamics_base.unique_name is not None + + def test_init_raises_type_error_for_invalid_name(self): + """Test that initializing with an invalid name raises a TypeError.""" + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='Name must be a string or None.'): + EasyDynamicsBase(name=123) # Not a string + + def test_init_name_can_be_none(self): + """Test that initializing with name as None works correctly.""" + # WHEN THEN EXPECT + model = EasyDynamicsBase(name=None) + + # THEN EXPECT + assert model.name is None + + def test_name_setter_and_getter(self, easy_dynamics_base): + """Test that the name setter and getter work correctly.""" + # WHEN THEN EXPECT + assert easy_dynamics_base.name == 'TestModel' + + # THEN + easy_dynamics_base.name = 'NewName' + + # EXPECT + assert easy_dynamics_base.name == 'NewName' + + # THEN + easy_dynamics_base.name = None + + # EXPECT + assert easy_dynamics_base.name is None + + @pytest.mark.parametrize( + 'invalid_name', + [ + 123, # Not a string + [1, 2, 3], # Not a string + {'name': 'Test'}, # Not a string + ], + ids=['integer', 'list', 'dict'], + ) + def test_name_setter_invalid_type(self, easy_dynamics_base, invalid_name): + """Test that setting the name to an invalid type raises a TypeError.""" + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='Name must be a string or None.'): + easy_dynamics_base.name = invalid_name diff --git a/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py b/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py new file mode 100644 index 00000000..8d7612e6 --- /dev/null +++ b/tests/unit/easydynamics/base_classes/test_easydynamics_modelbase.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from easydynamics.base_classes import EasyDynamicsModelBase + + +class TestEasyDynamicsModelBase: + """Tests for the EasyDynamicsModelBase class.""" + + @pytest.fixture + def easy_dynamics_modelbase(self): + """Fixture for creating an instance of EasyDynamicsModelBase.""" + + return EasyDynamicsModelBase(name='TestModel', unit='meV') + + def test_initialization(self, easy_dynamics_modelbase): + """Test that the EasyDynamicsModelBase is initialized correctly.""" + + # WHEN THEN EXPECT + assert easy_dynamics_modelbase.name == 'TestModel' + assert easy_dynamics_modelbase.display_name == 'MyEasyDynamicsModel' + assert easy_dynamics_modelbase.unique_name is not None + + def test_init_raises_type_error_for_invalid_name(self): + """Test that initializing with an invalid name raises a TypeError.""" + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='Name must be a string or None.'): + EasyDynamicsModelBase(name=123) # Not a string + + def test_init_name_can_be_none(self): + """Test that initializing with name as None works correctly.""" + # WHEN THEN EXPECT + model = EasyDynamicsModelBase(name=None) + + # THEN EXPECT + assert model.name is None + + def test_name_setter_and_getter(self, easy_dynamics_modelbase): + """Test that the name setter and getter work correctly.""" + # WHEN THEN EXPECT + assert easy_dynamics_modelbase.name == 'TestModel' + + # THEN + easy_dynamics_modelbase.name = 'NewName' + + # EXPECT + assert easy_dynamics_modelbase.name == 'NewName' + + # THEN + easy_dynamics_modelbase.name = None + + # EXPECT + assert easy_dynamics_modelbase.name is None + + @pytest.mark.parametrize( + 'invalid_name', + [ + 123, # Not a string + [1, 2, 3], # Not a string + {'name': 'Test'}, # Not a string + ], + ids=['integer', 'list', 'dict'], + ) + def test_name_setter_invalid_type(self, easy_dynamics_modelbase, invalid_name): + """Test that setting the name to an invalid type raises a TypeError.""" + # WHEN THEN EXPECT + with pytest.raises(TypeError, match='Name must be a string or None.'): + easy_dynamics_modelbase.name = invalid_name + + def test_unit_property(self, easy_dynamics_modelbase): + # WHEN THEN EXPECT + + assert easy_dynamics_modelbase.unit == 'meV' + + def test_unit_setter_raises(self, easy_dynamics_modelbase): + # WHEN / THEN / EXPECT + with pytest.raises(AttributeError, match='Use convert_unit to change '): + easy_dynamics_modelbase.unit = 'K' diff --git a/tests/unit/easydynamics/sample_model/components/test_model_component.py b/tests/unit/easydynamics/sample_model/components/test_model_component.py index 93e42bc0..886a829c 100644 --- a/tests/unit/easydynamics/sample_model/components/test_model_component.py +++ b/tests/unit/easydynamics/sample_model/components/test_model_component.py @@ -8,8 +8,6 @@ from easydynamics.sample_model.components.model_component import ModelComponent -Numeric = float | int - class DummyComponent(ModelComponent): def __init__(self):