Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions docs/docs/tutorials/analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@
"vanadium_analysis.parameters_to_dataset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "536d0cad",
"metadata": {},
"outputs": [],
"source": [
"vanadium_analysis.get_parameters_near_bounds()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -338,7 +348,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "easydynamics_newbase",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -352,7 +362,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
"version": "3.14.4"
}
},
"nbformat": 4,
Expand Down
58 changes: 58 additions & 0 deletions src/easydynamics/analysis/analysis_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,64 @@ def normalize_resolution(self) -> None:
"""
self.instrument_model.normalize_resolution()

def get_parameters_near_bounds(
self,
rtol: float = 1e-5,
atol: float = 1e-8,
) -> list[Parameter]:
"""
Get a list of parameters that are near their bounds.

Parameters
----------
rtol : float, default=1e-5
Relative tolerance for determining if a parameter is near its bound.
atol : float, default=1e-8
Absolute tolerance for determining if a parameter is near its bound.

Returns
-------
list[Parameter]
A list of parameters that are near their bounds.

Raises
------
TypeError
If rtol or atol is not a float.
ValueError
If rtol or atol is negative.
"""

if not isinstance(rtol, (int, float)):
raise TypeError(f'rtol must be a float. Got {type(rtol)}.')

if rtol < 0:
raise ValueError(f'rtol must be non-negative. Got {rtol}.')

if not isinstance(atol, (int, float)):
raise TypeError(f'atol must be a float. Got {type(atol)}.')

if atol < 0:
raise ValueError(f'atol must be non-negative. Got {atol}.')

parameters = self.get_all_parameters()
at_bounds = []

for p in parameters:
value = p.value
if not np.isfinite(value):
at_bounds.append(p)
continue

Comment on lines +479 to +483

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NaN values are silently ignored. If p.value is NaN, np.isclose will returns False, so a broken parameter isn't reported.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would use a test for it

at_min = not np.isneginf(p.min) and np.isclose(value, p.min, rtol=rtol, atol=atol)

at_max = not np.isposinf(p.max) and np.isclose(value, p.max, rtol=rtol, atol=atol)

if at_min or at_max:
at_bounds.append(p)

return at_bounds

#############
# Private methods
#############
Expand Down
17 changes: 9 additions & 8 deletions src/easydynamics/convolution/numerical_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def convolution(
self._energy_grid = self._create_energy_grid()

# Give warnings if peaks are very wide or very narrow
self._check_width_thresholds(
model=self.sample_components,
model_name='sample model',
)
self._check_width_thresholds(
model=self.resolution_components,
model_name='resolution model',
)
if not self.convolution_settings.suppress_warnings:
self._check_width_thresholds(
model=self.sample_components,
model_name='sample model',
)
self._check_width_thresholds(
model=self.resolution_components,
model_name='resolution model',
)

# Evaluate sample model. If called via the Convolution class,
# delta functions are already filtered out.
Expand Down
43 changes: 41 additions & 2 deletions src/easydynamics/settings/convolution_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
self,
upsample_factor: Numeric | None = 5,
extension_factor: Numeric | None = 0.2,
suppress_warnings: bool = False,
display_name: str | None = 'MyConvolutionSettings',
unique_name: str | None = None,
) -> None:
Expand All @@ -27,6 +28,8 @@ def __init__(
The factor by which to upsample the input data before convolution.
extension_factor : Numeric | None, default=0.2
The factor by which to extend the input data range before convolution.
suppress_warnings : bool, default=False
Whether to suppress warnings about wide or narrow peaks in the models.
display_name : str | None, default='MyConvolutionSettings'
Display name of the model.
unique_name : str | None, default=None
Expand All @@ -36,7 +39,7 @@ def __init__(
------
TypeError
If upsample_factor is not a number or None. If extension_factor is not a number or
None.
None. If suppress_warnings is not a boolean.
ValueError
If upsample_factor is not greater than 1. If extension_factor is negative.
"""
Expand All @@ -61,6 +64,10 @@ def __init__(
raise ValueError('Upsample factor must be greater than 1.')
self._upsample_factor = upsample_factor

if not isinstance(suppress_warnings, bool):
raise TypeError('suppress_warnings must be True or False.')
self._suppress_warnings = suppress_warnings

self._convolution_plan_is_valid = False

@property
Expand Down Expand Up @@ -184,6 +191,37 @@ def convolution_plan_is_valid(self, is_valid: bool) -> None:
raise TypeError('convolution_plan_is_valid must be True or False.')
self._convolution_plan_is_valid = is_valid

@property
def suppress_warnings(self) -> bool:
"""
Get whether to suppress warnings.

Returns
-------
bool
Whether to suppress warnings.
"""
return self._suppress_warnings

@suppress_warnings.setter
def suppress_warnings(self, suppress: bool) -> None:
"""
Set whether to suppress warnings.

Parameters
----------
suppress : bool
Whether to suppress warnings.

Raises
------
TypeError
If suppress is not a bool.
"""
if not isinstance(suppress, bool):
raise TypeError('suppress_warnings must be True or False.')
self._suppress_warnings = suppress

def __repr__(self) -> str:
"""
Return a string representation of the ConvolutionSettings.
Expand All @@ -196,5 +234,6 @@ def __repr__(self) -> str:
return (
f'{self.__class__.__name__}('
f'upsample_factor={self.upsample_factor}, '
f'extension_factor={self.extension_factor})'
f'extension_factor={self.extension_factor}, '
f'suppress_warnings={self.suppress_warnings})'
)
118 changes: 118 additions & 0 deletions tests/unit/easydynamics/analysis/test_analysis_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import numpy as np
import pytest
import scipp as sc
from easyscience.variable import Parameter

from easydynamics.analysis.analysis_base import AnalysisBase
from easydynamics.experiment import Experiment
from easydynamics.sample_model import InstrumentModel
from easydynamics.sample_model import SampleModel
from easydynamics.sample_model.components.gaussian import Gaussian
from easydynamics.settings.convolution_settings import ConvolutionSettings
from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings

Expand All @@ -29,6 +31,33 @@ def analysis_base(self):
instrument_model=instrument_model,
)

@pytest.fixture
def analysis_base_with_components(self):
Q = sc.array(dims=['Q'], values=[1, 2], unit='1/Angstrom')
energy = sc.array(dims=['energy'], values=[10.0, 20.0, 30.0], unit='meV')
data = sc.array(
dims=['Q', 'energy'],
values=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
variances=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
)

data_array = sc.DataArray(data=data, coords={'Q': Q, 'energy': energy})

experiment = Experiment(data=data_array)

comp1 = Gaussian(area=1, width=2, center=3)
comp2 = Gaussian(area=4, width=5, center=6)
sample_model = SampleModel()
sample_model.append_component(comp1)
sample_model.append_component(comp2)
instrument_model = InstrumentModel()
return AnalysisBase(
display_name='TestAnalysis',
experiment=experiment,
sample_model=sample_model,
instrument_model=instrument_model,
)

def test_init(self, analysis_base):
# WHEN THEN

Expand Down Expand Up @@ -369,6 +398,95 @@ def test_normalize_resolution_calls_instrument_model(self, analysis_base):
analysis_base.normalize_resolution()
mock_normalize_resolution.assert_called_once()

def test_get_parameters_near_bounds_no_bounds(self, analysis_base_with_components):
# WHEN THEN
near_bounds = analysis_base_with_components.get_parameters_near_bounds()

# EXPECT
assert isinstance(near_bounds, list)
assert len(near_bounds) == 0

def test_get_parameters_near_bounds_at_bounds(self, analysis_base_with_components):
# WHEN
components = analysis_base_with_components.sample_model.get_component_collection(
Q_index=0
).components
components[0].area.min = 1.0
components[1].center.max = 6.0

# THEN
near_bounds = analysis_base_with_components.get_parameters_near_bounds()

# EXPECT
assert isinstance(near_bounds, list)
assert len(near_bounds) == 2
assert components[0].area in near_bounds
assert components[1].center in near_bounds

def test_get_parameters_near_bounds_with_tolerances(self, analysis_base_with_components):
# WHEN
components = analysis_base_with_components.sample_model.get_component_collection(
Q_index=0
).components
components[0].area.min = 0.99999
components[1].center.max = 6.00001

# THEN
near_bounds = analysis_base_with_components.get_parameters_near_bounds()

# EXPECT
assert isinstance(near_bounds, list)
assert len(near_bounds) == 2
assert components[0].area in near_bounds
assert components[1].center in near_bounds

@pytest.mark.parametrize(
'rtol, atol, expected_param, expected_error',
[
('1e-5', 1e-8, 'rtol', TypeError), # str
(None, 1e-8, 'rtol', TypeError), # None
(1e-5, '1e-8', 'atol', TypeError), # str
(1e-5, None, 'atol', TypeError), # None
(-1e-5, 1e-8, 'rtol', ValueError), # negative rtol
(1e-5, -1e-8, 'atol', ValueError), # negative atol
],
ids=[
'rtol as string',
'rtol as None',
'atol as string',
'atol as None',
'rtol negative',
'atol negative',
],
)
def test_get_parameters_near_bounds_errors(
self, analysis_base_with_components, rtol, atol, expected_param, expected_error
):
with pytest.raises(expected_error) as exc:
analysis_base_with_components.get_parameters_near_bounds(
rtol=rtol,
atol=atol,
)

assert expected_param in str(exc.value)

def test_not_finite_parameters(self, analysis_base_with_components):
# WHEN
components = analysis_base_with_components.sample_model.get_component_collection(
Q_index=0
).components
components[0].area.value = np.inf
components[1].center.value = np.nan

# THEN
near_bounds = analysis_base_with_components.get_parameters_near_bounds()

# EXPECT
assert isinstance(near_bounds, list)
assert len(near_bounds) == 2
assert components[0].area in near_bounds
assert components[1].center in near_bounds

#############
# Private methods
#############
Expand Down
Loading
Loading