diff --git a/docs/docs/tutorials/analysis.ipynb b/docs/docs/tutorials/analysis.ipynb index 0d30fd4b..2e3bb3d0 100644 --- a/docs/docs/tutorials/analysis.ipynb +++ b/docs/docs/tutorials/analysis.ipynb @@ -130,6 +130,16 @@ "vanadium_analysis.parameters_to_dataset()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "536d0cad", + "metadata": {}, + "outputs": [], + "source": [ + "vanadium_analysis.get_parameters_near_bounds()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -338,7 +348,7 @@ ], "metadata": { "kernelspec": { - "display_name": "easydynamics_newbase", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -352,7 +362,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.14.4" } }, "nbformat": 4, diff --git a/src/easydynamics/analysis/analysis_base.py b/src/easydynamics/analysis/analysis_base.py index 01f0a20f..4b5b4cb8 100644 --- a/src/easydynamics/analysis/analysis_base.py +++ b/src/easydynamics/analysis/analysis_base.py @@ -432,6 +432,64 @@ def normalize_resolution(self) -> None: """ self.instrument_model.normalize_resolution() + def get_parameters_near_bounds( + self, + rtol: float = 1e-5, + atol: float = 1e-8, + ) -> list[Parameter]: + """ + Get a list of parameters that are near their bounds. + + Parameters + ---------- + rtol : float, default=1e-5 + Relative tolerance for determining if a parameter is near its bound. + atol : float, default=1e-8 + Absolute tolerance for determining if a parameter is near its bound. + + Returns + ------- + list[Parameter] + A list of parameters that are near their bounds. + + Raises + ------ + TypeError + If rtol or atol is not a float. + ValueError + If rtol or atol is negative. + """ + + if not isinstance(rtol, (int, float)): + raise TypeError(f'rtol must be a float. Got {type(rtol)}.') + + if rtol < 0: + raise ValueError(f'rtol must be non-negative. Got {rtol}.') + + if not isinstance(atol, (int, float)): + raise TypeError(f'atol must be a float. Got {type(atol)}.') + + if atol < 0: + raise ValueError(f'atol must be non-negative. Got {atol}.') + + parameters = self.get_all_parameters() + at_bounds = [] + + for p in parameters: + value = p.value + if not np.isfinite(value): + at_bounds.append(p) + continue + + at_min = not np.isneginf(p.min) and np.isclose(value, p.min, rtol=rtol, atol=atol) + + at_max = not np.isposinf(p.max) and np.isclose(value, p.max, rtol=rtol, atol=atol) + + if at_min or at_max: + at_bounds.append(p) + + return at_bounds + ############# # Private methods ############# diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py index 7a990d37..d39550cb 100644 --- a/src/easydynamics/convolution/numerical_convolution.py +++ b/src/easydynamics/convolution/numerical_convolution.py @@ -98,14 +98,15 @@ def convolution( self._energy_grid = self._create_energy_grid() # Give warnings if peaks are very wide or very narrow - self._check_width_thresholds( - model=self.sample_components, - model_name='sample model', - ) - self._check_width_thresholds( - model=self.resolution_components, - model_name='resolution model', - ) + if not self.convolution_settings.suppress_warnings: + self._check_width_thresholds( + model=self.sample_components, + model_name='sample model', + ) + self._check_width_thresholds( + model=self.resolution_components, + model_name='resolution model', + ) # Evaluate sample model. If called via the Convolution class, # delta functions are already filtered out. diff --git a/src/easydynamics/settings/convolution_settings.py b/src/easydynamics/settings/convolution_settings.py index a845fc91..3a4995c5 100644 --- a/src/easydynamics/settings/convolution_settings.py +++ b/src/easydynamics/settings/convolution_settings.py @@ -15,6 +15,7 @@ def __init__( self, upsample_factor: Numeric | None = 5, extension_factor: Numeric | None = 0.2, + suppress_warnings: bool = False, display_name: str | None = 'MyConvolutionSettings', unique_name: str | None = None, ) -> None: @@ -27,6 +28,8 @@ def __init__( The factor by which to upsample the input data before convolution. extension_factor : Numeric | None, default=0.2 The factor by which to extend the input data range before convolution. + suppress_warnings : bool, default=False + Whether to suppress warnings about wide or narrow peaks in the models. display_name : str | None, default='MyConvolutionSettings' Display name of the model. unique_name : str | None, default=None @@ -36,7 +39,7 @@ def __init__( ------ TypeError If upsample_factor is not a number or None. If extension_factor is not a number or - None. + None. If suppress_warnings is not a boolean. ValueError If upsample_factor is not greater than 1. If extension_factor is negative. """ @@ -61,6 +64,10 @@ def __init__( raise ValueError('Upsample factor must be greater than 1.') self._upsample_factor = upsample_factor + if not isinstance(suppress_warnings, bool): + raise TypeError('suppress_warnings must be True or False.') + self._suppress_warnings = suppress_warnings + self._convolution_plan_is_valid = False @property @@ -184,6 +191,37 @@ def convolution_plan_is_valid(self, is_valid: bool) -> None: raise TypeError('convolution_plan_is_valid must be True or False.') self._convolution_plan_is_valid = is_valid + @property + def suppress_warnings(self) -> bool: + """ + Get whether to suppress warnings. + + Returns + ------- + bool + Whether to suppress warnings. + """ + return self._suppress_warnings + + @suppress_warnings.setter + def suppress_warnings(self, suppress: bool) -> None: + """ + Set whether to suppress warnings. + + Parameters + ---------- + suppress : bool + Whether to suppress warnings. + + Raises + ------ + TypeError + If suppress is not a bool. + """ + if not isinstance(suppress, bool): + raise TypeError('suppress_warnings must be True or False.') + self._suppress_warnings = suppress + def __repr__(self) -> str: """ Return a string representation of the ConvolutionSettings. @@ -196,5 +234,6 @@ def __repr__(self) -> str: return ( f'{self.__class__.__name__}(' f'upsample_factor={self.upsample_factor}, ' - f'extension_factor={self.extension_factor})' + f'extension_factor={self.extension_factor}, ' + f'suppress_warnings={self.suppress_warnings})' ) diff --git a/tests/unit/easydynamics/analysis/test_analysis_base.py b/tests/unit/easydynamics/analysis/test_analysis_base.py index a7241831..079404e4 100644 --- a/tests/unit/easydynamics/analysis/test_analysis_base.py +++ b/tests/unit/easydynamics/analysis/test_analysis_base.py @@ -6,12 +6,14 @@ import numpy as np import pytest +import scipp as sc from easyscience.variable import Parameter from easydynamics.analysis.analysis_base import AnalysisBase from easydynamics.experiment import Experiment from easydynamics.sample_model import InstrumentModel from easydynamics.sample_model import SampleModel +from easydynamics.sample_model.components.gaussian import Gaussian from easydynamics.settings.convolution_settings import ConvolutionSettings from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings @@ -29,6 +31,33 @@ def analysis_base(self): instrument_model=instrument_model, ) + @pytest.fixture + def analysis_base_with_components(self): + Q = sc.array(dims=['Q'], values=[1, 2], unit='1/Angstrom') + energy = sc.array(dims=['energy'], values=[10.0, 20.0, 30.0], unit='meV') + data = sc.array( + dims=['Q', 'energy'], + values=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + variances=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + ) + + data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy}) + + experiment = Experiment(data=data_array) + + comp1 = Gaussian(area=1, width=2, center=3) + comp2 = Gaussian(area=4, width=5, center=6) + sample_model = SampleModel() + sample_model.append_component(comp1) + sample_model.append_component(comp2) + instrument_model = InstrumentModel() + return AnalysisBase( + display_name='TestAnalysis', + experiment=experiment, + sample_model=sample_model, + instrument_model=instrument_model, + ) + def test_init(self, analysis_base): # WHEN THEN @@ -369,6 +398,95 @@ def test_normalize_resolution_calls_instrument_model(self, analysis_base): analysis_base.normalize_resolution() mock_normalize_resolution.assert_called_once() + def test_get_parameters_near_bounds_no_bounds(self, analysis_base_with_components): + # WHEN THEN + near_bounds = analysis_base_with_components.get_parameters_near_bounds() + + # EXPECT + assert isinstance(near_bounds, list) + assert len(near_bounds) == 0 + + def test_get_parameters_near_bounds_at_bounds(self, analysis_base_with_components): + # WHEN + components = analysis_base_with_components.sample_model.get_component_collection( + Q_index=0 + ).components + components[0].area.min = 1.0 + components[1].center.max = 6.0 + + # THEN + near_bounds = analysis_base_with_components.get_parameters_near_bounds() + + # EXPECT + assert isinstance(near_bounds, list) + assert len(near_bounds) == 2 + assert components[0].area in near_bounds + assert components[1].center in near_bounds + + def test_get_parameters_near_bounds_with_tolerances(self, analysis_base_with_components): + # WHEN + components = analysis_base_with_components.sample_model.get_component_collection( + Q_index=0 + ).components + components[0].area.min = 0.99999 + components[1].center.max = 6.00001 + + # THEN + near_bounds = analysis_base_with_components.get_parameters_near_bounds() + + # EXPECT + assert isinstance(near_bounds, list) + assert len(near_bounds) == 2 + assert components[0].area in near_bounds + assert components[1].center in near_bounds + + @pytest.mark.parametrize( + 'rtol, atol, expected_param, expected_error', + [ + ('1e-5', 1e-8, 'rtol', TypeError), # str + (None, 1e-8, 'rtol', TypeError), # None + (1e-5, '1e-8', 'atol', TypeError), # str + (1e-5, None, 'atol', TypeError), # None + (-1e-5, 1e-8, 'rtol', ValueError), # negative rtol + (1e-5, -1e-8, 'atol', ValueError), # negative atol + ], + ids=[ + 'rtol as string', + 'rtol as None', + 'atol as string', + 'atol as None', + 'rtol negative', + 'atol negative', + ], + ) + def test_get_parameters_near_bounds_errors( + self, analysis_base_with_components, rtol, atol, expected_param, expected_error + ): + with pytest.raises(expected_error) as exc: + analysis_base_with_components.get_parameters_near_bounds( + rtol=rtol, + atol=atol, + ) + + assert expected_param in str(exc.value) + + def test_not_finite_parameters(self, analysis_base_with_components): + # WHEN + components = analysis_base_with_components.sample_model.get_component_collection( + Q_index=0 + ).components + components[0].area.value = np.inf + components[1].center.value = np.nan + + # THEN + near_bounds = analysis_base_with_components.get_parameters_near_bounds() + + # EXPECT + assert isinstance(near_bounds, list) + assert len(near_bounds) == 2 + assert components[0].area in near_bounds + assert components[1].center in near_bounds + ############# # Private methods ############# diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution.py b/tests/unit/easydynamics/convolution/test_numerical_convolution.py index 8b015355..7b6ada94 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution.py @@ -109,3 +109,127 @@ def test_convolution_with_temperature( ) assert np.allclose(result, expected_result, rtol=1e-4) + + @pytest.mark.parametrize( + 'plan_valid, suppress_warnings, use_db, upsample', + [ + (True, True, False, None), + (False, True, False, None), + (True, False, False, None), + (True, False, True, None), + (True, False, True, 10), + (False, False, True, 10), + ], + ids=[ + 'plan_valid=True, suppress_warnings=True, use_db=False, upsample=None', + 'plan_valid=False, suppress_warnings=True, use_db=False, upsample=None', + 'plan_valid=True, suppress_warnings=False, use_db=False, upsample=None', + 'plan_valid=True, suppress_warnings=False, use_db=True, upsample=None', + 'plan_valid=True, suppress_warnings=False, use_db=True, upsample=10', + 'plan_valid=False, suppress_warnings=False, use_db=True, upsample=10', + ], + ) + def test_convolution_branches( + self, + default_numerical_convolution, + monkeypatch, + plan_valid, + suppress_warnings, + use_db, + upsample, + ): + "Test that convolution branches are executed as expected based on settings." + # WHEN + conv = default_numerical_convolution + + # --- Configure branches --- + conv.convolution_settings.suppress_warnings = suppress_warnings + + conv.detailed_balance_settings.use_detailed_balance = use_db + conv.temperature = 10.0 if use_db else None + + conv.upsample_factor = upsample + + conv.convolution_settings.convolution_plan_is_valid = plan_valid + + # --- Track calls --- + create_grid_called = False + check_width_calls = [] + + # Mock the methods that would be called in the branches to track + # whether they were called or not. + def fake_create_energy_grid(): + nonlocal create_grid_called + create_grid_called = True + return conv._energy_grid + + def fake_check_width_thresholds(*args, **kwargs): + check_width_calls.append((args, kwargs)) + + monkeypatch.setattr(conv, '_create_energy_grid', fake_create_energy_grid) + monkeypatch.setattr(conv, '_check_width_thresholds', fake_check_width_thresholds) + + # --- Simplify numerics --- + dense = conv._energy_grid.energy_dense + + monkeypatch.setattr( + conv.sample_components, + 'evaluate', + lambda x: np.ones_like(dense), # noqa: ARG005 + ) + monkeypatch.setattr( + conv.resolution_components, + 'evaluate', + lambda x: np.ones_like(dense), # noqa: ARG005 + ) + + db_called = False + + def fake_db(*args, **kwargs): # noqa: ARG001 + nonlocal db_called + db_called = True + return np.ones_like(dense) + + monkeypatch.setattr( + 'easydynamics.convolution.numerical_convolution.detailed_balance_factor', + fake_db, + ) + + monkeypatch.setattr( + 'easydynamics.convolution.numerical_convolution.fftconvolve', + lambda a, b, mode: np.ones_like(dense), # noqa: ARG005 + ) + + interp_called = False + + def fake_interp(*args, **kwargs): # noqa: ARG001 + nonlocal interp_called + interp_called = True + return np.ones_like(conv.energy.values) + + monkeypatch.setattr(np, 'interp', fake_interp) + + # THEN + result = conv.convolution() + + # EXPECT + # Branch 1: energy grid recreation + assert create_grid_called is (not plan_valid) + + # Branch 2: warnings + if suppress_warnings: + assert len(check_width_calls) == 0 + else: + assert len(check_width_calls) == 2 + + # Branch 3: detailed balance + assert db_called is (use_db and conv.temperature is not None) + + # Branch 4: interpolation + assert interp_called is (upsample is not None) + + # Sanity: result shape + if upsample is not None: + assert result.shape == conv.energy.values.shape + else: + assert result.shape == dense.shape diff --git a/tests/unit/easydynamics/settings/test_convolution_settings.py b/tests/unit/easydynamics/settings/test_convolution_settings.py index 60e70bdf..202f5bae 100644 --- a/tests/unit/easydynamics/settings/test_convolution_settings.py +++ b/tests/unit/easydynamics/settings/test_convolution_settings.py @@ -31,12 +31,14 @@ def test_init_with_custom_parameters(self): convolution_settings = ConvolutionSettings( upsample_factor=10, extension_factor=0.5, + suppress_warnings=True, ) # THEN EXPECT assert convolution_settings.upsample_factor == 10 assert convolution_settings.extension_factor == pytest.approx(0.5) assert convolution_settings.convolution_plan_is_valid is False + assert convolution_settings.suppress_warnings is True def test_init_with_None(self): """ @@ -53,6 +55,7 @@ def test_init_with_None(self): assert convolution_settings.upsample_factor is None assert convolution_settings.extension_factor is None assert convolution_settings.convolution_plan_is_valid is False + assert convolution_settings.suppress_warnings is False @pytest.mark.parametrize( 'invalid_input, expected_exception, match', @@ -62,6 +65,7 @@ def test_init_with_None(self): ({'upsample_factor': '5'}, TypeError, 'must be a numerical value or None'), ({'upsample_factor': 1.0}, ValueError, 'must be greater than 1'), ({'upsample_factor': 0.5}, ValueError, 'must be greater than 1'), + ({'suppress_warnings': 'yes'}, TypeError, r'must be True or False'), ], ids=[ 'extension_factor_not_numeric', @@ -69,6 +73,7 @@ def test_init_with_None(self): 'upsample_factor_not_numeric', 'upsample_factor_equal_1', 'upsample_factor_less_than_1', + 'suppress_warnings_not_boolean', ], ) def test_init_raises_for_invalid_input(self, invalid_input, expected_exception, match): @@ -219,12 +224,28 @@ def test_convolution_plan_is_valid_setter_invalid( with pytest.raises(expected_exception, match=match): default_convolution_settings.convolution_plan_is_valid = value + def test_suppress_warnings_setter_valid(self, default_convolution_settings): + # WHEN + default_convolution_settings.suppress_warnings = True + + # THEN EXPECT + assert default_convolution_settings.suppress_warnings is True + + def test_suppress_warnings_setter_invalid(self, default_convolution_settings): + # WHEN / THEN / EXPECT + with pytest.raises(TypeError, match=r'must be True or False'): + default_convolution_settings.suppress_warnings = 'True' + def test_repr_default(self, default_convolution_settings): # WHEN repr_str = repr(default_convolution_settings) # EXPECT - assert repr_str == ('ConvolutionSettings(upsample_factor=5.0, extension_factor=0.2)') + assert repr_str == ( + f'ConvolutionSettings(upsample_factor={default_convolution_settings.upsample_factor}, ' + f'extension_factor={default_convolution_settings.extension_factor}, ' + f'suppress_warnings={default_convolution_settings.suppress_warnings})' + ) def test_repr_reflects_updated_values(self, default_convolution_settings): # WHEN @@ -234,4 +255,8 @@ def test_repr_reflects_updated_values(self, default_convolution_settings): repr_str = repr(default_convolution_settings) # EXPECT - assert repr_str == ('ConvolutionSettings(upsample_factor=3.0, extension_factor=0.5)') + assert repr_str == ( + f'ConvolutionSettings(upsample_factor={default_convolution_settings.upsample_factor}, ' + f'extension_factor={default_convolution_settings.extension_factor}, ' + f'suppress_warnings={default_convolution_settings.suppress_warnings})' + )