diff --git a/docs/docs/tutorials/sample_model.ipynb b/docs/docs/tutorials/sample_model.ipynb index 5371f7df..882048ce 100644 --- a/docs/docs/tutorials/sample_model.ipynb +++ b/docs/docs/tutorials/sample_model.ipynb @@ -134,7 +134,7 @@ ], "metadata": { "kernelspec": { - "display_name": "easydynamics_newbase", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -148,7 +148,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.14.4" } }, "nbformat": 4, diff --git a/docs/docs/tutorials/tutorial0_basics.ipynb b/docs/docs/tutorials/tutorial0_basics.ipynb index 227fd5c2..44442036 100644 --- a/docs/docs/tutorials/tutorial0_basics.ipynb +++ b/docs/docs/tutorials/tutorial0_basics.ipynb @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "7121f9c8", "metadata": {}, "outputs": [], diff --git a/src/easydynamics/__init__.py b/src/easydynamics/__init__.py index c83d92cb..54557144 100644 --- a/src/easydynamics/__init__.py +++ b/src/easydynamics/__init__.py @@ -3,8 +3,8 @@ """EasyDynamics library.""" from easydynamics.analysis import Analysis -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.experiment import Experiment +from easydynamics.settings.convolution_settings import ConvolutionSettings __all__ = [ 'Analysis', diff --git a/src/easydynamics/analysis/analysis.py b/src/easydynamics/analysis/analysis.py index 9b0966fd..b8493ec9 100644 --- a/src/easydynamics/analysis/analysis.py +++ b/src/easydynamics/analysis/analysis.py @@ -13,10 +13,11 @@ from easydynamics.analysis.analysis1d import Analysis1d from easydynamics.analysis.analysis_base import AnalysisBase -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.experiment import Experiment from easydynamics.sample_model import SampleModel from easydynamics.sample_model.instrument_model import InstrumentModel +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings from easydynamics.utils.utils import _in_notebook @@ -35,6 +36,7 @@ def __init__( sample_model: SampleModel | None = None, instrument_model: InstrumentModel | None = None, convolution_settings: ConvolutionSettings | None = None, + detailed_balance_settings: DetailedBalanceSettings | None = None, extra_parameters: Parameter | list[Parameter] | None = None, ) -> None: """ @@ -56,6 +58,8 @@ def __init__( is created. convolution_settings : ConvolutionSettings | None, default=None The settings for the convolution. If None, default settings will be used. + detailed_balance_settings : DetailedBalanceSettings | None, default=None + The settings for detailed balance. If None, default settings will be used. extra_parameters : Parameter | list[Parameter] | None, default=None Extra parameters to be included in the analysis for advanced users. If None, no extra parameters are added. @@ -71,6 +75,7 @@ def __init__( sample_model=sample_model, instrument_model=instrument_model, convolution_settings=convolution_settings, + detailed_balance_settings=detailed_balance_settings, extra_parameters=extra_parameters, ) @@ -538,6 +543,8 @@ def _create_analysis_list(self) -> None: experiment=self.experiment, sample_model=self.sample_model, instrument_model=self.instrument_model, + convolution_settings=self.convolution_settings, + detailed_balance_settings=self.detailed_balance_settings, extra_parameters=self._extra_parameters, Q_index=Q_index, ) diff --git a/src/easydynamics/analysis/analysis1d.py b/src/easydynamics/analysis/analysis1d.py index df698446..489f4381 100644 --- a/src/easydynamics/analysis/analysis1d.py +++ b/src/easydynamics/analysis/analysis1d.py @@ -13,12 +13,14 @@ from easydynamics.analysis.analysis_base import AnalysisBase from easydynamics.convolution.convolution import Convolution -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.experiment import Experiment from easydynamics.sample_model import InstrumentModel from easydynamics.sample_model import SampleModel from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings +from easydynamics.utils.detailed_balance import detailed_balance_factor class Analysis1d(AnalysisBase): @@ -37,6 +39,7 @@ def __init__( instrument_model: InstrumentModel | None = None, Q_index: int | None = None, convolution_settings: ConvolutionSettings | None = None, + detailed_balance_settings: DetailedBalanceSettings | None = None, extra_parameters: Parameter | list[Parameter] | None = None, ) -> None: """ @@ -61,6 +64,8 @@ def __init__( until a Q index is set. convolution_settings : ConvolutionSettings | None, default=None The settings for the convolution. If None, default settings will be used. + detailed_balance_settings : DetailedBalanceSettings | None, default=None + The settings for detailed balance. If None, default settings will be used. extra_parameters : Parameter | list[Parameter] | None, default=None Extra parameters to be included in the analysis for advanced users. If None, no extra parameters are added. @@ -72,6 +77,7 @@ def __init__( sample_model=sample_model, instrument_model=instrument_model, convolution_settings=convolution_settings, + detailed_balance_settings=detailed_balance_settings, extra_parameters=extra_parameters, ) @@ -452,6 +458,7 @@ def _evaluate_components( convolver: Convolution | None = None, convolve: bool = True, energy: sc.Variable | None = None, + apply_detailed_balance: bool = False, ) -> np.ndarray: """ Calculate the contribution of a set of components, optionally convolving with the @@ -475,6 +482,9 @@ def _evaluate_components( energy : sc.Variable | None, default=None Optional energy grid to use for evaluation. If None, the energy grid from the experiment is used. + apply_detailed_balance : bool, default=False + Whether to apply detailed balance correction. + Returns ------- @@ -496,24 +506,42 @@ def _evaluate_components( if isinstance(components, ComponentCollection) and components.is_empty: return np.zeros_like(energy.values) - # No convolution - if not convolve: - return components.evaluate(energy_with_offset) - - # If a convolver is provided, use it. This allows reusing the + # If a convolver is provided, we use it. This allows reusing the # same convolver for multiple evaluations during fitting for # performance reasons. if convolver is not None: return convolver.convolution() - # If no convolver is provided, create a new one. This is for - # evaluating individual components for plotting, where - # performance is not important. + # No convolution can happen for multiple reasons: + # Case 1: convolve=False, used for evaluating background components, where we don't want + # to convolve with the resolution. In this case, apply_detailed_balance is False, + # and we evaluate the components without DBF regardles of the settings + # Case 2: convolve=True but there is no resolution_model. In this case, + # apply_detailed_balance is True. We apply DBF if temperature is provided and + # the settings say to use detailed balance. - # We don't create a convolver if the resolution is empty. resolution = self.instrument_model.resolution_model.get_component_collection(Q_index) - if resolution.is_empty: - return components.evaluate(energy_with_offset) + if not convolve or resolution.is_empty: + result_no_convolution = components.evaluate(energy_with_offset) + if ( + apply_detailed_balance + and self.temperature is not None + and self.detailed_balance_settings.use_detailed_balance + ): + DBF = detailed_balance_factor( + energy=energy_with_offset, + temperature=self.temperature, + divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance, + energy_unit=self.unit, + ) + result_no_convolution *= DBF + return result_no_convolution + + # If no convolver is provided, we create a new one. This is for + # evaluating individual components for plotting, where + # performance is not important. We already handled the case of + # background components above, so we know that this is for sample components, + # where detailed balance settings should be applied. conv = Convolution( energy=energy, @@ -522,6 +550,7 @@ def _evaluate_components( energy_offset=energy_offset, convolution_settings=self.convolution_settings, temperature=self.temperature, + detailed_balance_settings=self.detailed_balance_settings, ) return conv.convolution() @@ -552,6 +581,7 @@ def _evaluate_sample( convolver=self._convolver, convolve=True, energy=energy, + apply_detailed_balance=True, ) def _evaluate_sample_component( @@ -580,6 +610,7 @@ def _evaluate_sample_component( convolver=None, convolve=True, energy=energy, + apply_detailed_balance=True, ) def _evaluate_background(self, energy: sc.Variable | None = None) -> np.ndarray: @@ -606,6 +637,7 @@ def _evaluate_background(self, energy: sc.Variable | None = None) -> np.ndarray: convolver=None, convolve=False, energy=energy, + apply_detailed_balance=False, ) def _evaluate_background_component( @@ -635,6 +667,7 @@ def _evaluate_background_component( convolver=None, convolve=False, energy=energy, + apply_detailed_balance=False, ) def _create_convolver( @@ -678,6 +711,7 @@ def _create_convolver( energy_offset=self.instrument_model.get_energy_offset(Q_index), convolution_settings=self.convolution_settings, temperature=self.temperature, + detailed_balance_settings=self.detailed_balance_settings, ) ############# diff --git a/src/easydynamics/analysis/analysis_base.py b/src/easydynamics/analysis/analysis_base.py index 0bd78cd7..01f0a20f 100644 --- a/src/easydynamics/analysis/analysis_base.py +++ b/src/easydynamics/analysis/analysis_base.py @@ -3,16 +3,17 @@ import numpy as np import scipp as sc -from easyscience.base_classes.model_base import ModelBase as EasyScienceModelBase from easyscience.variable import Parameter -from easydynamics.convolution.convolution_settings import ConvolutionSettings +from easydynamics.base_classes.easydynamics_modelbase import EasyDynamicsModelBase from easydynamics.experiment import Experiment from easydynamics.sample_model import InstrumentModel from easydynamics.sample_model import SampleModel +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings -class AnalysisBase(EasyScienceModelBase): +class AnalysisBase(EasyDynamicsModelBase): """ Base class for analysis in EasyDynamics. @@ -31,6 +32,7 @@ def __init__( sample_model: SampleModel | None = None, instrument_model: InstrumentModel | None = None, convolution_settings: ConvolutionSettings | None = None, + detailed_balance_settings: DetailedBalanceSettings | None = None, extra_parameters: Parameter | list[Parameter] | None = None, ) -> None: """ @@ -53,6 +55,8 @@ def __init__( is created. convolution_settings : ConvolutionSettings | None, default=None The settings for the convolution. If None, default settings will be used. + detailed_balance_settings : DetailedBalanceSettings | None, default=None + The settings for detailed balance. If None, default settings will be used. extra_parameters : Parameter | list[Parameter] | None, default=None Extra parameters to be included in the analysis for advanced users. If None, no extra parameters are added. @@ -61,8 +65,10 @@ def __init__( ------ TypeError If experiment is not an Experiment or None or if sample_model is not a SampleModel or - None or if instrument_model is not an InstrumentModel or None or if extra_parameters is - not a Parameter, a list of Parameters, or None. + None or if instrument_model is not an InstrumentModel or None or if + convolution_settings is not a ConvolutionSettings or None or if + detailed_balance_settings is not a DetailedBalanceSettings or None or if + extra_parameters is not a Parameter, a list of Parameters, or None. """ super().__init__(display_name=display_name, unique_name=unique_name) @@ -109,6 +115,15 @@ def __init__( else: self._extra_parameters = [] + if detailed_balance_settings is None: + self._detailed_balance_settings = DetailedBalanceSettings() + elif isinstance(detailed_balance_settings, DetailedBalanceSettings): + self._detailed_balance_settings = detailed_balance_settings + else: + raise TypeError( + 'detailed_balance_settings must be an instance of DetailedBalanceSettings or None.' + ) + self._on_experiment_changed() ############# @@ -338,6 +353,37 @@ def convolution_settings(self, value: ConvolutionSettings) -> None: self._convolution_settings = value self._on_convolution_settings_changed() + @property + def detailed_balance_settings(self) -> DetailedBalanceSettings: + """ + Get the DetailedBalanceSettings of the SampleModel. + + Returns + ------- + DetailedBalanceSettings + The DetailedBalanceSettings of the SampleModel. + """ + return self._detailed_balance_settings + + @detailed_balance_settings.setter + def detailed_balance_settings(self, value: DetailedBalanceSettings) -> None: + """ + Set the DetailedBalanceSettings of the SampleModel. + + Parameters + ---------- + value : DetailedBalanceSettings + The DetailedBalanceSettings to set. + + Raises + ------ + TypeError + If value is not a DetailedBalanceSettings. + """ + if not isinstance(value, DetailedBalanceSettings): + raise TypeError('detailed_balance_settings must be a DetailedBalanceSettings') + self._detailed_balance_settings = value + @property def extra_parameters(self) -> list[Parameter]: """ diff --git a/src/easydynamics/convolution/__init__.py b/src/easydynamics/convolution/__init__.py index fd5fe9ab..2cbfdcf5 100644 --- a/src/easydynamics/convolution/__init__.py +++ b/src/easydynamics/convolution/__init__.py @@ -2,6 +2,5 @@ # SPDX-License-Identifier: BSD-3-Clause from easydynamics.convolution.convolution import Convolution -from easydynamics.convolution.convolution_settings import ConvolutionSettings -__all__ = ['Convolution', 'ConvolutionSettings'] +__all__ = ['Convolution'] diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py index 77805e61..63bd66c6 100644 --- a/src/easydynamics/convolution/convolution.py +++ b/src/easydynamics/convolution/convolution.py @@ -8,7 +8,6 @@ from easyscience.variable import Parameter from easydynamics.convolution.analytical_convolution import AnalyticalConvolution -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.convolution.numerical_convolution import NumericalConvolution from easydynamics.convolution.numerical_convolution_base import NumericalConvolutionBase from easydynamics.sample_model import ComponentCollection @@ -17,6 +16,8 @@ from easydynamics.sample_model import Lorentzian from easydynamics.sample_model import Voigt from easydynamics.sample_model.components.model_component import ModelComponent +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings from easydynamics.utils.utils import Numeric @@ -45,6 +46,7 @@ class Convolution(NumericalConvolutionBase): '_temperature', '_energy_unit', '_normalize_detailed_balance', + '_detailed_balance_settings', } def __init__( @@ -56,6 +58,7 @@ def __init__( convolution_settings: ConvolutionSettings | None = None, temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', + detailed_balance_settings: DetailedBalanceSettings | None = None, unit: str | sc.Unit = 'meV', display_name: str | None = 'MyConvolution', unique_name: str | None = None, @@ -79,6 +82,8 @@ def __init__( The temperature to use for detailed balance correction. temperature_unit : str | sc.Unit, default='K' The unit of the temperature parameter. + detailed_balance_settings : DetailedBalanceSettings | None, default=None + The settings for detailed balance. If None, default settings will be used. unit : str | sc.Unit, default='meV' The unit of the energy. display_name : str | None, default='MyConvolution' @@ -96,6 +101,7 @@ def __init__( convolution_settings=convolution_settings, temperature=temperature, temperature_unit=temperature_unit, + detailed_balance_settings=detailed_balance_settings, unit=unit, display_name=display_name, unique_name=unique_name, @@ -224,7 +230,10 @@ def _build_convolution_plan(self) -> None: # If temperature is set, all other components go to # numerical sample model - if self.temperature is not None: + if ( + self.temperature is not None + and self.detailed_balance_settings.use_detailed_balance + ): numerical_sample_components.append_component(sample_component) continue @@ -278,6 +287,8 @@ def _set_convolvers(self) -> None: convolution_settings=self.convolution_settings, temperature=self.temperature, temperature_unit=self._temperature_unit, + detailed_balance_settings=self.detailed_balance_settings, + unit=self.unit, ) else: self._numerical_convolver = None diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py index 434b6dd3..7a990d37 100644 --- a/src/easydynamics/convolution/numerical_convolution.py +++ b/src/easydynamics/convolution/numerical_convolution.py @@ -6,10 +6,11 @@ from easyscience.variable import Parameter from scipy.signal import fftconvolve -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.convolution.numerical_convolution_base import NumericalConvolutionBase from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings from easydynamics.utils.detailed_balance import detailed_balance_factor from easydynamics.utils.utils import Numeric @@ -32,6 +33,7 @@ def __init__( convolution_settings: ConvolutionSettings | None = None, temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', + detailed_balance_settings: DetailedBalanceSettings | None = None, unit: str | sc.Unit = 'meV', display_name: str | None = 'MyConvolution', unique_name: str | None = None, @@ -55,6 +57,8 @@ def __init__( The temperature to use for detailed balance correction. temperature_unit : str | sc.Unit, default='K' The unit of the temperature parameter. + detailed_balance_settings : DetailedBalanceSettings | None, default=None + The settings for detailed balance. If None, default settings will be used. unit : str | sc.Unit, default='meV' The unit of the energy. display_name : str | None, default='MyConvolution' @@ -70,6 +74,7 @@ def __init__( convolution_settings=convolution_settings, temperature=temperature, temperature_unit=temperature_unit, + detailed_balance_settings=detailed_balance_settings, unit=unit, display_name=display_name, unique_name=unique_name, @@ -111,12 +116,12 @@ def convolution( ) # Detailed balance correction - if self.temperature is not None: + if self.temperature is not None and self.detailed_balance_settings.use_detailed_balance: detailed_balance_factor_correction = detailed_balance_factor( energy=self._energy_grid.energy_dense - self.energy_offset.value, temperature=self.temperature, energy_unit=self.energy.unit, - divide_by_temperature=self.normalize_detailed_balance, + divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance, ) sample_vals *= detailed_balance_factor_correction diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py index 5ba2a51b..7001e32b 100644 --- a/src/easydynamics/convolution/numerical_convolution_base.py +++ b/src/easydynamics/convolution/numerical_convolution_base.py @@ -8,10 +8,11 @@ from easyscience.variable import Parameter from easydynamics.convolution.convolution_base import ConvolutionBase -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.convolution.energy_grid import EnergyGrid from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings from easydynamics.utils.utils import Numeric # The thresholds are illustrated in @@ -41,7 +42,7 @@ def __init__( convolution_settings: ConvolutionSettings | None = None, temperature: Parameter | Numeric | None = None, temperature_unit: str | sc.Unit = 'K', - normalize_detailed_balance: bool = True, + detailed_balance_settings: DetailedBalanceSettings | None = None, unit: str | sc.Unit = 'meV', display_name: str | None = 'MyConvolution', unique_name: str | None = None, @@ -65,8 +66,8 @@ def __init__( The temperature to use for detailed balance correction. temperature_unit : str | sc.Unit, default='K' The unit of the temperature parameter. - normalize_detailed_balance : bool, default=True - Whether to normalize the detailed balance factor by temperature. + detailed_balance_settings : DetailedBalanceSettings | None, default=None + The settings for detailed balance. If None, default settings will be used. unit : str | sc.Unit, default='meV' The unit of the energy. display_name : str | None, default='MyConvolution' @@ -103,7 +104,13 @@ def __init__( convolution_settings = ConvolutionSettings() self._convolution_settings = convolution_settings - self._normalize_detailed_balance = normalize_detailed_balance + if detailed_balance_settings is None: + detailed_balance_settings = DetailedBalanceSettings() + if not isinstance(detailed_balance_settings, DetailedBalanceSettings): + raise TypeError( + 'detailed_balance_settings must be a DetailedBalanceSettings instance.' + ) + self._detailed_balance_settings = detailed_balance_settings # Create a dense grid to improve accuracy. # When upsample_factor>1, we evaluate on this grid and @@ -292,42 +299,35 @@ def temperature(self, temp: Parameter | Numeric | None) -> None: raise TypeError('Temperature must be None, a float or a Parameter.') @property - def normalize_detailed_balance(self) -> bool: + def detailed_balance_settings(self) -> DetailedBalanceSettings: """ - Get whether to normalize the detailed balance factor. - - If True, the detailed balance factor is divided by temperature. + Get the DetailedBalanceSettings of the Convolution. Returns ------- - bool - Whether to normalize the detailed balance factor. + DetailedBalanceSettings + The DetailedBalanceSettings of the Convolution. """ + return self._detailed_balance_settings - return self._normalize_detailed_balance - - @normalize_detailed_balance.setter - def normalize_detailed_balance(self, normalize: bool) -> None: + @detailed_balance_settings.setter + def detailed_balance_settings(self, value: DetailedBalanceSettings) -> None: """ - Set whether to normalize the detailed balance factor. - - If True, the detailed balance factor is divided by temperature. + Set the DetailedBalanceSettings of the Convolution. Parameters ---------- - normalize : bool - Whether to normalize the detailed balance factor. + value : DetailedBalanceSettings + The DetailedBalanceSettings to set. Raises ------ TypeError - If normalize is not a bool. + If value is not a DetailedBalanceSettings. """ - - if not isinstance(normalize, bool): - raise TypeError('normalize_detailed_balance must be True or False.') - - self._normalize_detailed_balance = normalize + if not isinstance(value, DetailedBalanceSettings): + raise TypeError('detailed_balance_settings must be a DetailedBalanceSettings') + self._detailed_balance_settings = value def _create_energy_grid( self, @@ -476,5 +476,5 @@ def __repr__(self) -> str: f'upsample_factor={self.upsample_factor}, ' f'extension_factor={self.extension_factor}, ' f'temperature={self.temperature}, ' - f'normalize_detailed_balance={self.normalize_detailed_balance})' + f'detailed_balance={self.detailed_balance_settings!r})' ) diff --git a/src/easydynamics/sample_model/sample_model.py b/src/easydynamics/sample_model/sample_model.py index b89829ad..fdf8eb09 100644 --- a/src/easydynamics/sample_model/sample_model.py +++ b/src/easydynamics/sample_model/sample_model.py @@ -11,6 +11,7 @@ from easydynamics.sample_model.components.model_component import ModelComponent from easydynamics.sample_model.diffusion_model.diffusion_model_base import DiffusionModelBase from easydynamics.sample_model.model_base import ModelBase +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings from easydynamics.utils import detailed_balance_factor from easydynamics.utils.utils import Numeric from easydynamics.utils.utils import Q_type @@ -35,7 +36,7 @@ def __init__( diffusion_models: DiffusionModelBase | list[DiffusionModelBase] | None = None, temperature: float | None = None, temperature_unit: str | sc.Unit = 'K', - divide_by_temperature: bool = True, + detailed_balance_settings: DetailedBalanceSettings | None = None, ) -> None: """ Initialize the SampleModel. @@ -60,14 +61,15 @@ def __init__( default, None. temperature_unit : str | sc.Unit, default='K' Unit of the temperature. - divide_by_temperature : bool, default=True - Whether to divide the detailed balance factor by temperature. + detailed_balance_settings : DetailedBalanceSettings | None, default=None + Settings for detailed balancing. Raises ------ TypeError If diffusion_models is not a DiffusionModelBase, a list of DiffusionModelBase, or None, - or if temperature is not a number or None, or if divide_by_temperature is not a bool. + or if temperature is not a number or None, or if detailed_balance_settings is not a + DetailedBalanceSettings instance. ValueError If temperature is negative. """ @@ -110,9 +112,12 @@ def __init__( ) self._temperature_unit = temperature_unit - if not isinstance(divide_by_temperature, bool): - raise TypeError('divide_by_temperature must be True or False') - self._divide_by_temperature = divide_by_temperature + if detailed_balance_settings is None: + self._detailed_balance_settings = DetailedBalanceSettings() + elif isinstance(detailed_balance_settings, DetailedBalanceSettings): + self._detailed_balance_settings = detailed_balance_settings + else: + raise TypeError('detailed_balance_settings must be a DetailedBalanceSettings or None') # ------------------------------------------------------------------ # Component management @@ -321,22 +326,22 @@ def convert_temperature_unit(self, unit: str | sc.Unit) -> None: If the provided unit is invalid or cannot be converted. """ - if self._temperature is None: + if self.temperature is None: raise ValueError('Temperature is not set, cannot convert unit.') - old_unit = self._temperature.unit + old_unit = self.temperature.unit try: - self._temperature.convert_unit(unit) + self.temperature.convert_unit(unit) self._temperature_unit = unit except Exception: # Attempt to rollback on failure with suppress(Exception): - self._temperature.convert_unit(old_unit) + self.temperature.convert_unit(old_unit) raise @property - def divide_by_temperature(self) -> bool: + def normalize_detailed_balance(self) -> bool: """ Get whether to divide the detailed balance factor by temperature. @@ -345,10 +350,10 @@ def divide_by_temperature(self) -> bool: bool True if the detailed balance factor is divided by temperature, False otherwise. """ - return self._divide_by_temperature + return self.detailed_balance_settings.normalize_detailed_balance - @divide_by_temperature.setter - def divide_by_temperature(self, value: bool) -> None: + @normalize_detailed_balance.setter + def normalize_detailed_balance(self, value: bool) -> None: """ Set whether to divide the detailed balance factor by temperature. @@ -363,8 +368,70 @@ def divide_by_temperature(self, value: bool) -> None: If value is not a bool. """ if not isinstance(value, bool): - raise TypeError('divide_by_temperature must be True or False') - self._divide_by_temperature = value + raise TypeError('normalize_detailed_balance must be True or False') + self.detailed_balance_settings.normalize_detailed_balance = value + + @property + def use_detailed_balance(self) -> bool: + """ + Get whether to apply detailed balance to the model. + + Returns + ------- + bool + True if detailed balance is applied, False otherwise. + """ + return self.detailed_balance_settings.use_detailed_balance + + @use_detailed_balance.setter + def use_detailed_balance(self, value: bool) -> None: + """ + Set whether to apply detailed balance to the model. + + Parameters + ---------- + value : bool + True to apply detailed balance, False otherwise. + + Raises + ------ + TypeError + If value is not a bool. + """ + if not isinstance(value, bool): + raise TypeError('use_detailed_balance must be True or False') + self.detailed_balance_settings.use_detailed_balance = value + + @property + def detailed_balance_settings(self) -> DetailedBalanceSettings: + """ + Get the DetailedBalanceSettings of the SampleModel. + + Returns + ------- + DetailedBalanceSettings + The DetailedBalanceSettings of the SampleModel. + """ + return self._detailed_balance_settings + + @detailed_balance_settings.setter + def detailed_balance_settings(self, value: DetailedBalanceSettings) -> None: + """ + Set the DetailedBalanceSettings of the SampleModel. + + Parameters + ---------- + value : DetailedBalanceSettings + The DetailedBalanceSettings to set. + + Raises + ------ + TypeError + If value is not a DetailedBalanceSettings. + """ + if not isinstance(value, DetailedBalanceSettings): + raise TypeError('detailed_balance_settings must be a DetailedBalanceSettings') + self._detailed_balance_settings = value # ------------------------------------------------------------------ # Other methods @@ -390,12 +457,12 @@ def evaluate( y = super().evaluate(x) - if self._temperature is not None: + if self.temperature is not None and self.detailed_balance_settings.use_detailed_balance: DBF = detailed_balance_factor( energy=x, - temperature=self._temperature, - divide_by_temperature=self._divide_by_temperature, - energy_unit=self._unit, + temperature=self.temperature, + divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance, + energy_unit=self.unit, ) y = [yi * DBF for yi in y] @@ -422,8 +489,8 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: """ all_vars = super().get_all_variables(Q_index=Q_index) - if self._temperature is not None: - all_vars.append(self._temperature) + if self.temperature is not None: + all_vars.append(self.temperature) for diffusion_model in self._diffusion_models: all_vars.extend(diffusion_model.get_all_variables()) @@ -472,9 +539,9 @@ def __repr__(self) -> str: """ return ( - f'{self.__class__.__name__}(unique_name={self.unique_name}, unit={self._unit}), ' - f'Q = {self._Q}, ' - f'components = {self._components}, diffusion_models = {self._diffusion_models}, ' - f'temperature = {self._temperature}, ' - f'divide_by_temperature = {self._divide_by_temperature}' + f'{self.__class__.__name__}(unique_name={self.unique_name}, unit={self.unit}), ' + f'Q = {self.Q}, ' + f'components = {self.components}, diffusion_models = {self.diffusion_models}, ' + f'temperature = {self.temperature}, ' + f'detailed_balance_settings = {self.detailed_balance_settings}' ) diff --git a/src/easydynamics/settings/__init__.py b/src/easydynamics/settings/__init__.py new file mode 100644 index 00000000..c401fbce --- /dev/null +++ b/src/easydynamics/settings/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings + +__all__ = [ + 'ConvolutionSettings', + 'DetailedBalanceSettings', +] diff --git a/src/easydynamics/convolution/convolution_settings.py b/src/easydynamics/settings/convolution_settings.py similarity index 100% rename from src/easydynamics/convolution/convolution_settings.py rename to src/easydynamics/settings/convolution_settings.py diff --git a/src/easydynamics/settings/detailed_balance_settings.py b/src/easydynamics/settings/detailed_balance_settings.py new file mode 100644 index 00000000..75cfce1f --- /dev/null +++ b/src/easydynamics/settings/detailed_balance_settings.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + + +from easydynamics.base_classes.easydynamics_base import EasyDynamicsBase + + +class DetailedBalanceSettings(EasyDynamicsBase): + """ + Class to manage detailed balance settings for a SampleModel or Analysis. + """ + + def __init__( + self, + use_detailed_balance: bool = True, + normalize_detailed_balance: bool = True, + display_name: str = 'DetailedBalanceSettings', + unique_name: str | None = None, + ) -> None: + """ + Initialize the DetailedBalanceSettings. + + Parameters + ---------- + use_detailed_balance : bool, default=True + Whether to apply detailed balance to the model. If False, no detailed balance is + applied. + normalize_detailed_balance : bool, default=True + Whether to normalize the detailed balance factor by dividing with temperature. + display_name : str, default='DetailedBalanceSettings' + Display name of the model. + unique_name : str | None, default=None + Unique name of the model. If None, a unique name will be generated. + + + Raises + ------ + TypeError + If use_detailed_balance or normalize_detailed_balance is not a bool. + """ + if not isinstance(use_detailed_balance, bool): + raise TypeError('use_detailed_balance must be True or False') + self._use_detailed_balance = use_detailed_balance + + if not isinstance(normalize_detailed_balance, bool): + raise TypeError('normalize_detailed_balance must be True or False') + self._normalize_detailed_balance = normalize_detailed_balance + + super().__init__( + display_name=display_name, + unique_name=unique_name, + ) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def use_detailed_balance(self) -> bool: + """ + Get whether to apply detailed balance to the model. + + Returns + ------- + bool + True if detailed balance is applied, False otherwise. + """ + return self._use_detailed_balance + + @use_detailed_balance.setter + def use_detailed_balance(self, value: bool) -> None: + """ + Set whether to apply detailed balance to the model. + + Parameters + ---------- + value : bool + True to apply detailed balance, False otherwise. + + Raises + ------ + TypeError + If value is not a bool. + """ + if not isinstance(value, bool): + raise TypeError('use_detailed_balance must be True or False') + self._use_detailed_balance = value + + @property + def normalize_detailed_balance(self) -> bool: + """ + Get whether to divide the detailed balance factor by temperature. + + Returns + ------- + bool + True if the detailed balance factor should be normalized by dividing with temperature, + False otherwise. + """ + return self._normalize_detailed_balance + + @normalize_detailed_balance.setter + def normalize_detailed_balance(self, value: bool) -> None: + """ + Set whether to normalize the detailed balance factor by dividing with temperature. + + Parameters + ---------- + value : bool + True to normalize the detailed balance factor by dividing with temperature, False + otherwise. + + Raises + ------ + TypeError + If value is not a bool. + """ + if not isinstance(value, bool): + raise TypeError('normalize_detailed_balance must be True or False') + self._normalize_detailed_balance = value + + def __repr__(self) -> str: + """ + Return a string representation of the DetailedBalanceSettings. + + Returns + ------- + str + A string representation of the DetailedBalanceSettings. + """ + return ( + f'DetailedBalanceSettings(use_detailed_balance={self.use_detailed_balance}, ' + f'normalize_detailed_balance={self.normalize_detailed_balance})' + ) diff --git a/tests/unit/easydynamics/analysis/test_analysis.py b/tests/unit/easydynamics/analysis/test_analysis.py index 56985b6c..91e8f795 100644 --- a/tests/unit/easydynamics/analysis/test_analysis.py +++ b/tests/unit/easydynamics/analysis/test_analysis.py @@ -10,11 +10,11 @@ import scipp as sc from easydynamics.analysis.analysis import Analysis -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.experiment import Experiment from easydynamics.sample_model import InstrumentModel from easydynamics.sample_model import SampleModel from easydynamics.sample_model.components.gaussian import Gaussian +from easydynamics.settings.convolution_settings import ConvolutionSettings class TestAnalysis: diff --git a/tests/unit/easydynamics/analysis/test_analysis1d.py b/tests/unit/easydynamics/analysis/test_analysis1d.py index 5f956c1a..234e166c 100644 --- a/tests/unit/easydynamics/analysis/test_analysis1d.py +++ b/tests/unit/easydynamics/analysis/test_analysis1d.py @@ -446,6 +446,38 @@ def test_evaluate_components_empty_resolution(self, analysis1d): components.evaluate.assert_called_once() assert np.array_equal(result, np.array([1.0, 2.0, 3.0])) + def test_evaluate_components_empty_resolution_DBF(self, analysis1d): + # WHEN + components = MagicMock() + components.evaluate = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) + + # Set temperature so DBF will be applied + analysis1d.sample_model.temperature = 10 + mock_dbf = np.array([10.0, 10.0, 10.0]) + + # The default analysis1d has no resolution model components, so + # no convolution should be applied even if convolve=True + + with patch( + 'easydynamics.analysis.analysis1d.detailed_balance_factor', + return_value=mock_dbf, + ) as dbf_mock: + # WHEN + result = analysis1d._evaluate_components( + components=components, + convolver=None, + convolve=True, + apply_detailed_balance=True, + ) + + # EXPECT + components.evaluate.assert_called_once() + dbf_mock.assert_called_once() + + # EXPECT multiplication applied + expected = np.array([1.0, 2.0, 3.0]) * mock_dbf + assert np.array_equal(result, expected) + def test_evaluate_with_resolution(self, analysis1d): # WHEN (set up the resolution model and create a component to # evaluate) @@ -511,6 +543,7 @@ def test_evaluate_sample(self, analysis1d): convolver=analysis1d._convolver, convolve=True, energy=None, + apply_detailed_balance=True, ) def test_evaluate_sample_component(self, analysis1d): @@ -530,6 +563,7 @@ def test_evaluate_sample_component(self, analysis1d): convolver=None, convolve=True, energy=None, + apply_detailed_balance=True, ) def test_evaluate_background(self, analysis1d): @@ -555,6 +589,7 @@ def test_evaluate_background(self, analysis1d): convolver=None, convolve=False, energy=None, + apply_detailed_balance=False, ) def test_evaluate_background_component(self, analysis1d): @@ -574,6 +609,7 @@ def test_evaluate_background_component(self, analysis1d): convolver=None, convolve=False, energy=None, + apply_detailed_balance=False, ) def test_create_convolver(self, analysis1d): diff --git a/tests/unit/easydynamics/analysis/test_analysis_base.py b/tests/unit/easydynamics/analysis/test_analysis_base.py index 33e019f2..a7241831 100644 --- a/tests/unit/easydynamics/analysis/test_analysis_base.py +++ b/tests/unit/easydynamics/analysis/test_analysis_base.py @@ -9,10 +9,11 @@ from easyscience.variable import Parameter from easydynamics.analysis.analysis_base import AnalysisBase -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.experiment import Experiment from easydynamics.sample_model import InstrumentModel from easydynamics.sample_model import SampleModel +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings class TestAnalysisBase: @@ -48,6 +49,19 @@ def test_init_convolution_settings(self): # EXPECT assert analysis.convolution_settings is convolution_settings + def test_init_detailed_balance_settings(self): + # WHEN + detailed_balance_settings = DetailedBalanceSettings( + use_detailed_balance=False, + normalize_detailed_balance=False, + ) + + # THEN + analysis = AnalysisBase(detailed_balance_settings=detailed_balance_settings) + + # EXPECT + assert analysis.detailed_balance_settings is detailed_balance_settings + def test_init_extra_parameter(self): extra_parameter = Parameter(name='param1', value=1.0) analysis = AnalysisBase(extra_parameters=extra_parameter) @@ -89,6 +103,11 @@ def test_init_calls_on_experiment_changed(self): TypeError, 'convolution_settings must be an instance of ConvolutionSettings', ), + ( + {'detailed_balance_settings': 'not detailed balance settings'}, + TypeError, + 'detailed_balance_settings must be an instance of DetailedBalanceSettings', + ), ( {'extra_parameters': 123}, TypeError, @@ -105,6 +124,7 @@ def test_init_calls_on_experiment_changed(self): 'invalid sample_model', 'invalid instrument_model', 'invalid convolution_settings', + 'invalid detailed_balance_settings', 'invalid extra_parameters', 'invalid extra_parameters list', ], @@ -272,6 +292,26 @@ def test_convolution_settings_calls_on_convolution_settings_changed(self, analys # EXPECT mock_on_convolution_settings_changed.assert_called_once() + def test_detailed_balance_settings_property(self, analysis_base): + # WHEN + new_settings = DetailedBalanceSettings( + use_detailed_balance=False, normalize_detailed_balance=False + ) + + # THEN + analysis_base.detailed_balance_settings = new_settings + + # EXPECT + assert analysis_base.detailed_balance_settings is new_settings + + def test_detailed_balance_settings_setter_invalid(self, analysis_base): + # WHEN / THEN / EXPECT + with pytest.raises( + TypeError, + match='detailed_balance_settings must be a DetailedBalanceSettings', + ): + analysis_base.detailed_balance_settings = 'invalid_settings' + @pytest.mark.parametrize( 'extra_parameters', [ diff --git a/tests/unit/easydynamics/convolution/test_convolution.py b/tests/unit/easydynamics/convolution/test_convolution.py index a34da86c..f3aff801 100644 --- a/tests/unit/easydynamics/convolution/test_convolution.py +++ b/tests/unit/easydynamics/convolution/test_convolution.py @@ -78,7 +78,7 @@ def test_init(self, default_convolution): assert default_convolution.extension_factor == pytest.approx(0.2) assert default_convolution.temperature is None assert default_convolution.unit == 'meV' - assert default_convolution.normalize_detailed_balance is True + assert default_convolution.detailed_balance_settings.normalize_detailed_balance is True assert isinstance(default_convolution._energy_grid, EnergyGrid) assert isinstance(default_convolution._analytical_sample_components, ComponentCollection) @@ -112,7 +112,10 @@ def test_init_components(self, convolution_with_components): assert convolution_with_components.extension_factor == pytest.approx(0.2) assert convolution_with_components.temperature is None assert convolution_with_components.unit == 'meV' - assert convolution_with_components.normalize_detailed_balance is True + assert ( + convolution_with_components.detailed_balance_settings.normalize_detailed_balance + is True + ) assert isinstance(convolution_with_components._energy_grid, EnergyGrid) assert isinstance( diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution.py b/tests/unit/easydynamics/convolution/test_numerical_convolution.py index a8ad6fb7..8b015355 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution.py @@ -49,7 +49,10 @@ def test_init(self, default_numerical_convolution): assert default_numerical_convolution.extension_factor == pytest.approx(0.2) assert default_numerical_convolution.temperature is None assert default_numerical_convolution.unit == 'meV' - assert default_numerical_convolution.normalize_detailed_balance is True + assert ( + default_numerical_convolution.detailed_balance_settings.normalize_detailed_balance + is True + ) assert isinstance(default_numerical_convolution._energy_grid, EnergyGrid) @pytest.mark.parametrize('upsample_factor', [None, 5]) diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py index 074b2c41..d349d591 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py @@ -6,11 +6,12 @@ import scipp as sc from easyscience.variable import Parameter -from easydynamics.convolution.convolution_settings import ConvolutionSettings from easydynamics.convolution.energy_grid import EnergyGrid from easydynamics.convolution.numerical_convolution_base import NumericalConvolutionBase from easydynamics.sample_model import Gaussian from easydynamics.sample_model.component_collection import ComponentCollection +from easydynamics.settings.convolution_settings import ConvolutionSettings +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings class TestNumericalConvolutionBase: @@ -48,7 +49,10 @@ def test_init(self, default_numerical_convolution_base): assert default_numerical_convolution_base.extension_factor == pytest.approx(0.2) assert default_numerical_convolution_base.temperature is None assert default_numerical_convolution_base.unit == 'meV' - assert default_numerical_convolution_base.normalize_detailed_balance is True + assert ( + default_numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance + is True + ) assert isinstance(default_numerical_convolution_base._energy_grid, EnergyGrid) def test_init_with_custom_parameters(self): @@ -63,7 +67,7 @@ def test_init_with_custom_parameters(self): resolution_settings = ConvolutionSettings(upsample_factor=10, extension_factor=0.5) temperature = 300.0 temperature_unit = 'K' - normalize_detailed_balance = False + detailed_balance_settings = DetailedBalanceSettings(normalize_detailed_balance=False) unit = 'meV' # THEN @@ -72,7 +76,7 @@ def test_init_with_custom_parameters(self): sample_components=sample_components, resolution_components=resolution_components, convolution_settings=resolution_settings, - normalize_detailed_balance=normalize_detailed_balance, + detailed_balance_settings=detailed_balance_settings, temperature=temperature, temperature_unit=temperature_unit, unit=unit, @@ -84,47 +88,58 @@ def test_init_with_custom_parameters(self): assert numerical_convolution_base.temperature.value == temperature assert numerical_convolution_base.temperature.unit == temperature_unit assert numerical_convolution_base.unit == unit - assert numerical_convolution_base.normalize_detailed_balance is False + assert ( + numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance + is False + ) + assert numerical_convolution_base.detailed_balance_settings is detailed_balance_settings assert isinstance(numerical_convolution_base._energy_grid, EnergyGrid) - def test_init_raises_type_error_for_invalid_temperature(self): - """ - Test that initialization raises TypeError for invalid - temperature. - """ - # WHEN - energy = np.linspace(-5, 5, 50) - sample_components = ComponentCollection(display_name='ComponentCollection') - resolution_components = ComponentCollection(display_name='ResolutionModel') - invalid_temperature = 'invalid_temperature' - - # THEN EXPECT - with pytest.raises(TypeError, match=r'Temperature must be None, a number or a Parameter.'): - NumericalConvolutionBase( - energy=energy, - sample_components=sample_components, - resolution_components=resolution_components, - temperature=invalid_temperature, - ) - - def test_init_raises_type_error_for_invalid_temperature_unit(self): + @pytest.mark.parametrize( + 'invalid_input, expected_exception, match', + [ + # temperature + ( + {'temperature': 'invalid_temperature'}, + TypeError, + r'Temperature must be None, a number or a Parameter.', + ), + # temperature_unit + ( + {'temperature_unit': 123}, + TypeError, + r'Temperature_unit must be a string or sc.Unit.', + ), + # detailed_balance_settings + ( + {'detailed_balance_settings': 'invalid_settings'}, + TypeError, + r'detailed_balance_settings must be a DetailedBalanceSettings instance.', + ), + ], + ids=[ + 'temperature_invalid_type', + 'temperature_unit_invalid_type', + 'detailed_balance_settings_invalid_type', + ], + ) + def test_init_raises_for_invalid_input(self, invalid_input, expected_exception, match): """ - Test that initialization raises TypeError for invalid - temperature_unit. + Test that initialization raises appropriate exceptions for + invalid input parameters. """ # WHEN energy = np.linspace(-5, 5, 50) sample_components = ComponentCollection(display_name='ComponentCollection') resolution_components = ComponentCollection(display_name='ResolutionModel') - invalid_temperature_unit = 123 # Not a string or sc.Unit # THEN EXPECT - with pytest.raises(TypeError, match=r'Temperature_unit must be a string or sc.Unit.'): + with pytest.raises(expected_exception, match=match): NumericalConvolutionBase( energy=energy, sample_components=sample_components, resolution_components=resolution_components, - temperature_unit=invalid_temperature_unit, + **invalid_input, ) #################### @@ -340,10 +355,15 @@ def test_normalize_detailed_balance_setter(self, default_numerical_convolution_b Test setting normalize_detailed_balance to False. """ # WHEN - default_numerical_convolution_base.normalize_detailed_balance = False + default_numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance = ( + False + ) # THEN EXPECT - assert default_numerical_convolution_base.normalize_detailed_balance is False + assert ( + default_numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance + is False + ) def test_normalize_detailed_balance_setter_raises(self, default_numerical_convolution_base): """ @@ -352,7 +372,29 @@ def test_normalize_detailed_balance_setter_raises(self, default_numerical_convol """ # WHEN THEN EXPECT with pytest.raises(TypeError, match='normalize_detailed_balance must be'): - default_numerical_convolution_base.normalize_detailed_balance = 'invalid' + default_numerical_convolution_base.detailed_balance_settings.normalize_detailed_balance = ( # noqa: E501 + 'invalid' + ) + + def test_detailed_balance_settings_property(self, default_numerical_convolution_base): + # WHEN + new_settings = DetailedBalanceSettings( + use_detailed_balance=False, normalize_detailed_balance=False + ) + + # THEN + default_numerical_convolution_base.detailed_balance_settings = new_settings + + # EXPECT + assert default_numerical_convolution_base.detailed_balance_settings is new_settings + + def test_detailed_balance_settings_setter_invalid(self, default_numerical_convolution_base): + # WHEN / THEN / EXPECT + with pytest.raises( + TypeError, + match='detailed_balance_settings must be a DetailedBalanceSettings', + ): + default_numerical_convolution_base.detailed_balance_settings = 'invalid_settings' def test_convolution_settings_setter_valid( self, diff --git a/tests/unit/easydynamics/sample_model/test_sample_model.py b/tests/unit/easydynamics/sample_model/test_sample_model.py index a093b1c1..712dfa8b 100644 --- a/tests/unit/easydynamics/sample_model/test_sample_model.py +++ b/tests/unit/easydynamics/sample_model/test_sample_model.py @@ -15,6 +15,7 @@ BrownianTranslationalDiffusion, ) from easydynamics.sample_model.sample_model import SampleModel +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings class TestSampleModel: @@ -66,51 +67,75 @@ def test_init(self, sample_model): assert len(model.diffusion_models) == 1 assert isinstance(model.diffusion_models[0], BrownianTranslationalDiffusion) assert model.temperature.value == pytest.approx(10.0) - assert model.divide_by_temperature is True + assert model.normalize_detailed_balance is True + assert model.use_detailed_balance is True + assert isinstance(model.detailed_balance_settings, DetailedBalanceSettings) np.testing.assert_array_equal(model.Q, np.array([1.0, 2.0, 3.0])) - def test_init_list_of_diffusion_model(self): + def test_init_custom_input(self): # WHEN THEN diffusion_model1 = BrownianTranslationalDiffusion() diffusion_model2 = BrownianTranslationalDiffusion() - sample_model = SampleModel(diffusion_models=[diffusion_model1, diffusion_model2]) + + detailed_balance_settings = DetailedBalanceSettings( + use_detailed_balance=False, + normalize_detailed_balance=False, + ) + sample_model = SampleModel( + diffusion_models=[diffusion_model1, diffusion_model2], + detailed_balance_settings=detailed_balance_settings, + ) # EXPECT assert len(sample_model.diffusion_models) == 2 assert sample_model.diffusion_models[0] is diffusion_model1 assert sample_model.diffusion_models[1] is diffusion_model2 + assert sample_model.use_detailed_balance is False + assert sample_model.normalize_detailed_balance is False + assert sample_model.detailed_balance_settings is detailed_balance_settings - def test_init_raises_with_invalid_diffusion_model(self): - # WHEN / THEN / EXPECT - with pytest.raises( - TypeError, - match='diffusion_models must be ', - ): - SampleModel(diffusion_models='invalid_diffusion_model') - - def test_init_raises_with_invalid_temperature(self): - # WHEN / THEN / EXPECT - with pytest.raises( - TypeError, - match='temperature must be a number or None', - ): - SampleModel(temperature='invalid_temperature') - - def test_init_raises_with_negative_temperature(self): - # WHEN / THEN / EXPECT - with pytest.raises( - ValueError, - match='temperature must be non-negative', - ): - SampleModel(temperature=-5.0) - - def test_init_raises_with_invalid_divide_by_temperature(self): - # WHEN / THEN / EXPECT - with pytest.raises( - TypeError, - match='divide_by_temperature must be True or False', - ): - SampleModel(divide_by_temperature='invalid_value') + @pytest.mark.parametrize( + 'invalid_input, expected_exception, match', + [ + # diffusion_models + ( + {'diffusion_models': 'invalid_diffusion_model'}, + TypeError, + 'diffusion_models must be a DiffusionModelBase', + ), + # temperature + ( + {'temperature': 'invalid_temperature'}, + TypeError, + 'temperature must be a number or None', + ), + ( + {'temperature': -5.0}, + ValueError, + 'temperature must be non-negative', + ), + # detailed_balance_settings + ( + {'detailed_balance_settings': 'invalid_settings'}, + TypeError, + 'detailed_balance_settings must be a DetailedBalanceSettings or None', + ), + ], + ids=[ + 'diffusion_models_invalid_type', + 'temperature_not_numeric', + 'temperature_negative', + 'detailed_balance_settings_invalid_type', + ], + ) + def test_init_raises_for_invalid_input(self, invalid_input, expected_exception, match): + """ + Test that initialization raises appropriate exceptions for + invalid input parameters. + """ + # WHEN THEN EXPECT + with pytest.raises(expected_exception, match=match): + SampleModel(**invalid_input) def test_append_and_remove_and_clear_diffusion_model(self, sample_model): # WHEN @@ -279,29 +304,75 @@ def test_convert_temperature_unit_raises_with_invalid_unit(self, sample_model): ): model.convert_temperature_unit('invalid_unit') - def test_divide_by_temperature_setter(self, sample_model): + def test_normalize_detailed_balance_setter(self, sample_model): # WHEN model = sample_model # THEN - model.divide_by_temperature = False + model.normalize_detailed_balance = False # EXPECT - assert model.divide_by_temperature is False + assert model.normalize_detailed_balance is False # THEN - model.divide_by_temperature = True + model.normalize_detailed_balance = True # EXPECT - assert model.divide_by_temperature is True + assert model.normalize_detailed_balance is True - def test_divide_by_temperature_setter_raises_with_invalid_type(self, sample_model): + def test_normalize_detailed_balance_setter_raises_with_invalid_type(self, sample_model): # WHEN / THEN / EXPECT with pytest.raises( TypeError, - match='divide_by_temperature must be True or False', + match='normalize_detailed_balance must be True or False', ): - sample_model.divide_by_temperature = 'invalid_value' + sample_model.normalize_detailed_balance = 'invalid_value' + + def test_use_detailed_balance_setter(self, sample_model): + # WHEN + model = sample_model + + # THEN + model.use_detailed_balance = False + + # EXPECT + assert model.use_detailed_balance is False + + # THEN + model.use_detailed_balance = True + + # EXPECT + assert model.use_detailed_balance is True + + def test_use_detailed_balance_setter_raises_with_invalid_type(self, sample_model): + # WHEN / THEN / EXPECT + with pytest.raises( + TypeError, + match='use_detailed_balance must be True or False', + ): + sample_model.use_detailed_balance = 'invalid_value' + + def test_detailed_balance_settings_property(self, sample_model): + # WHEN + new_settings = DetailedBalanceSettings( + use_detailed_balance=False, normalize_detailed_balance=False + ) + + # THEN + sample_model.detailed_balance_settings = new_settings + + # EXPECT + assert sample_model.detailed_balance_settings is new_settings + assert sample_model.use_detailed_balance is False + assert sample_model.normalize_detailed_balance is False + + def test_detailed_balance_settings_setter_invalid(self, sample_model): + # WHEN / THEN / EXPECT + with pytest.raises( + TypeError, + match='detailed_balance_settings must be a DetailedBalanceSettings', + ): + sample_model.detailed_balance_settings = 'invalid_settings' def test_evaluate_calls_dbf(self, sample_model): # WHEN @@ -325,7 +396,7 @@ def test_evaluate_calls_dbf(self, sample_model): mock_dbf.assert_called_once_with( energy=x, temperature=sample_model.temperature, - divide_by_temperature=sample_model.divide_by_temperature, + divide_by_temperature=sample_model.normalize_detailed_balance, energy_unit=sample_model.unit, ) @@ -337,6 +408,51 @@ def test_evaluate_calls_dbf(self, sample_model): np.testing.assert_allclose(result[0], np.array([1.0, 2.0, 3.0]) * 10.0) np.testing.assert_allclose(result[1], np.array([4.0, 5.0, 6.0]) * 10.0) + @pytest.mark.parametrize( + 'temperature, use_detailed_balance', + [ + (None, True), # DB disabled because temperature is None + (300.0, False), # DB disabled explicitly + ], + ids=[ + 'temperature_none', + 'use_detailed_balance_false', + ], + ) + def test_evaluate_doesnt_call_dbf_when_disabled( + self, sample_model, temperature, use_detailed_balance + ): + # WHEN + x = np.array([0.0, 1.0, 2.0]) + + collection1 = Mock() + collection2 = Mock() + + collection1.evaluate.return_value = np.array([1.0, 2.0, 3.0]) + collection2.evaluate.return_value = np.array([4.0, 5.0, 6.0]) + + sample_model._component_collections = [collection1, collection2] + + sample_model.temperature = temperature + sample_model.use_detailed_balance = use_detailed_balance + + with patch('easydynamics.sample_model.sample_model.detailed_balance_factor') as mock_dbf: + mock_dbf.return_value = np.array([10.0, 10.0, 10.0]) # simplified DBF + # THEN + result = sample_model.evaluate(x) + + # EXPECT + # Check that DBF was not called since detailed balance is disabled + mock_dbf.assert_not_called() + + # Check that evaluate was called on each component + collection1.evaluate.assert_called_once_with(x) + collection2.evaluate.assert_called_once_with(x) + + # Check that results were not modified by DBF + np.testing.assert_allclose(result[0], np.array([1.0, 2.0, 3.0])) + np.testing.assert_allclose(result[1], np.array([4.0, 5.0, 6.0])) + def test_generate_component_collections(self, sample_model): # WHEN THEN sample_model._generate_component_collections() @@ -389,4 +505,4 @@ def test_repr(self, sample_model): assert 'components' in repr_str assert 'diffusion_models' in repr_str assert 'temperature' in repr_str - assert 'divide_by_temperature' in repr_str + assert 'normalize_detailed_balance' in repr_str diff --git a/tests/unit/easydynamics/convolution/test_convolution_settings.py b/tests/unit/easydynamics/settings/test_convolution_settings.py similarity index 98% rename from tests/unit/easydynamics/convolution/test_convolution_settings.py rename to tests/unit/easydynamics/settings/test_convolution_settings.py index fa2f2a8b..60e70bdf 100644 --- a/tests/unit/easydynamics/convolution/test_convolution_settings.py +++ b/tests/unit/easydynamics/settings/test_convolution_settings.py @@ -3,7 +3,7 @@ import pytest -from easydynamics.convolution.convolution_settings import ConvolutionSettings +from easydynamics.settings.convolution_settings import ConvolutionSettings class TestConvolutionSettings: diff --git a/tests/unit/easydynamics/settings/test_detailed_balance_settings.py b/tests/unit/easydynamics/settings/test_detailed_balance_settings.py new file mode 100644 index 00000000..dba4d9d6 --- /dev/null +++ b/tests/unit/easydynamics/settings/test_detailed_balance_settings.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings + + +class TestDetailedBalanceSettings: + @pytest.fixture + def default_detailed_balance_settings(self): + return DetailedBalanceSettings() + + def test_init(self, default_detailed_balance_settings): + """ + Test initialization of DetailedBalanceSettings with default + parameters. + """ + # WHEN THEN EXPECT + assert isinstance(default_detailed_balance_settings, DetailedBalanceSettings) + assert default_detailed_balance_settings.use_detailed_balance is True + assert default_detailed_balance_settings.normalize_detailed_balance is True + assert default_detailed_balance_settings.display_name == 'DetailedBalanceSettings' + + def test_init_with_custom_parameters(self): + """ + Test initialization of DetailedBalanceSettings with custom + parameters. + """ + # WHEN + detailed_balance_settings = DetailedBalanceSettings( + use_detailed_balance=False, + normalize_detailed_balance=False, + ) + + # THEN EXPECT + assert detailed_balance_settings.use_detailed_balance is False + assert detailed_balance_settings.normalize_detailed_balance is False + assert detailed_balance_settings.display_name == 'DetailedBalanceSettings' + + @pytest.mark.parametrize( + 'invalid_input, expected_exception, match', + [ + ( + {'use_detailed_balance': 'not_a_boolean'}, + TypeError, + 'must be True or False', + ), + ( + {'normalize_detailed_balance': 'not_a_boolean'}, + TypeError, + 'must be True or False', + ), + ], + ids=[ + 'use_detailed_balance_not_boolean', + 'normalize_detailed_balance_not_boolean', + ], + ) + def test_init_raises_for_invalid_input(self, invalid_input, expected_exception, match): + """ + Test that initialization raises appropriate exceptions for + invalid input parameters. + """ + # WHEN THEN EXPECT + with pytest.raises(expected_exception, match=match): + DetailedBalanceSettings(**invalid_input) + + def test_setters_valid(self, default_detailed_balance_settings): + + # WHEN + # Ensure it's True first so we can test the reset + assert default_detailed_balance_settings.use_detailed_balance is True + assert default_detailed_balance_settings.normalize_detailed_balance is True + + # THEN + default_detailed_balance_settings.use_detailed_balance = False + default_detailed_balance_settings.normalize_detailed_balance = False + + # EXPECT + assert default_detailed_balance_settings.use_detailed_balance is False + assert default_detailed_balance_settings.normalize_detailed_balance is False + + @pytest.mark.parametrize( + 'value, expected_exception, match', + [ + ('5', TypeError, 'must be True or False'), + (1, TypeError, 'must be True or False'), + (None, TypeError, 'must be True or False'), + ], + ids=[ + 'string', + 'integer', + 'none', + ], + ) + def test_setters_invalid( + self, + default_detailed_balance_settings, + value, + expected_exception, + match, + ): + # WHEN THEN EXPECT + with pytest.raises(expected_exception, match=match): + default_detailed_balance_settings.use_detailed_balance = value + + # WHEN THEN EXPECT + with pytest.raises(expected_exception, match=match): + default_detailed_balance_settings.normalize_detailed_balance = value + + def test_repr_default(self, default_detailed_balance_settings): + # WHEN + repr_str = repr(default_detailed_balance_settings) + + # EXPECT + assert repr_str == ( + 'DetailedBalanceSettings(use_detailed_balance=True, normalize_detailed_balance=True)' + ) + + def test_repr_reflects_updated_values(self, default_detailed_balance_settings): + # WHEN + default_detailed_balance_settings.use_detailed_balance = False + default_detailed_balance_settings.normalize_detailed_balance = False + + repr_str = repr(default_detailed_balance_settings) + + # EXPECT + assert repr_str == ( + 'DetailedBalanceSettings(use_detailed_balance=False, normalize_detailed_balance=False)' + )