Skip to content
Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ CMakeLists.txt.user*
.cache/
*.log
*.zip
.claude/
200 changes: 95 additions & 105 deletions src/easydynamics/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@ def __init__(
parameters are added.
"""

# Avoid triggering updates before the object is fully
# initialized
self._call_updaters = False
self._analysis_list: list[Analysis1d] = []
self._analysis_list_is_dirty = True
super().__init__(
display_name=display_name,
unique_name=unique_name,
Expand All @@ -80,13 +79,6 @@ def __init__(
extra_parameters=extra_parameters,
)

self._analysis_list = []
if self.Q is not None:
self._create_analysis_list()

# Now we can allow updates to trigger recalculations
self._call_updaters = True

#############
# Properties
#############
Expand All @@ -101,6 +93,7 @@ def analysis_list(self) -> list[Analysis1d]:
list[Analysis1d]
A list of Analysis1d objects, one for each Q index.
"""
self._ensure_analysis_list_current()
return self._analysis_list

@analysis_list.setter
Expand Down Expand Up @@ -131,6 +124,58 @@ def analysis_list(self, _value: list[Analysis1d]) -> None:
#############
# Other methods
#############
def rebin(
self,
dimensions: dict[str, int | sc.Variable],
confirm: bool = False,
) -> None:
"""
Rebin the experiment data along specified dimensions and update the analysis.

If Q values change (in count or magnitude), ``confirm=True`` is required. This clears Q
from ``sample_model`` and ``instrument_model`` (including resolution and background
sub-models) so they can accept the new Q values when the analysis list is next rebuilt.

Parameters
----------
dimensions : dict[str, int | sc.Variable]
A dictionary mapping dimension names to number of bins (int) or bin edges
(sc.Variable).
confirm : bool, default=False
Must be ``True`` when rebinning changes the Q values (count or magnitude), since this
clears Q from all models. Raises ``ValueError`` otherwise.

Raises
------
ValueError
If rebinning changes Q and ``confirm`` is not ``True``.
"""
old_Q = np.asarray(self.Q.values) if self.Q is not None else None
old_binned_data = self.experiment._binned_data # noqa: SLF001

self.experiment.rebin(dimensions)
new_Q = np.asarray(self.Q.values) if self.Q is not None else None

q_changed = (
old_Q is not None
and new_Q is not None
and (len(old_Q) != len(new_Q) or not np.allclose(old_Q, new_Q))
)

if q_changed and not confirm:
self.experiment._binned_data = old_binned_data # noqa: SLF001
raise ValueError(
'Rebinning changed Q values, which requires clearing Q from sample_model and '
'instrument_model (including resolution and background sub-models). '
'Pass confirm=True to proceed.'
)

if q_changed:
self.sample_model.clear_Q(confirm=True)
self.instrument_model.clear_Q(confirm=True)

self._analysis_list_is_dirty = True

def calculate(
self,
Q_index: int | None = None,
Expand Down Expand Up @@ -249,8 +294,6 @@ def plot_data_and_model(
values available for plotting.
RuntimeError
If not in a Jupyter notebook environment.
TypeError
If plot_components or add_background is not True or False.

Returns
-------
Expand Down Expand Up @@ -279,14 +322,9 @@ def plot_data_and_model(
'No Q values available for plotting. Please check the experiment data.'
)

if not isinstance(plot_components, bool):
raise TypeError('plot_components must be True or False.')

if not isinstance(add_background, bool):
raise TypeError('add_background must be True or False.')

if not isinstance(plot_residuals, bool):
raise TypeError('plot_residuals must be True or False.')
self._verify_bool(plot_components, 'plot_components')
self._verify_bool(add_background, 'add_background')
self._verify_bool(plot_residuals, 'plot_residuals')

if energy is None:
energy = self.energy
Expand All @@ -300,39 +338,8 @@ def plot_data_and_model(
include_residuals=plot_residuals,
)

plot_kwargs_defaults = {
'title': self.display_name,
'linestyle': {},
'marker': {},
'color': {},
'markerfacecolor': {},
'keep': 'energy',
}

for key in data_and_model:
if key == 'Data':
plot_kwargs_defaults['linestyle'][key] = 'none'
plot_kwargs_defaults['marker'][key] = 'o'
plot_kwargs_defaults['color'][key] = 'black'
plot_kwargs_defaults['markerfacecolor'][key] = 'none'

elif key == 'Model':
plot_kwargs_defaults['linestyle'][key] = '-'
plot_kwargs_defaults['marker'][key] = None
plot_kwargs_defaults['color'][key] = 'red'
plot_kwargs_defaults['markerfacecolor'][key] = 'none'

elif key == 'Residuals':
plot_kwargs_defaults['linestyle'][key] = 'none'
plot_kwargs_defaults['marker'][key] = 'o'
plot_kwargs_defaults['color'][key] = 'blue'
plot_kwargs_defaults['markerfacecolor'][key] = 'none'

else:
plot_kwargs_defaults['linestyle'][key] = '--'
plot_kwargs_defaults['marker'][key] = None

# Overwrite defaults with any user-provided kwargs
plot_kwargs_defaults = self._build_plot_style_defaults(data_and_model)
plot_kwargs_defaults['keep'] = 'energy'
plot_kwargs_defaults.update(kwargs)

if plot_residuals:
Expand Down Expand Up @@ -384,9 +391,6 @@ def data_and_model_to_datagroup(
If there is no data to include in the DataGroup, or if there are no Q values available
for creating the DataGroup.

TypeError
If add_background is not True or False. If include_components is not True or False.

Returns
-------
sc.DataGroup
Expand All @@ -401,14 +405,9 @@ def data_and_model_to_datagroup(
'No Q values available for creating DataGroup. Please check the experiment data.'
)

if not isinstance(add_background, bool):
raise TypeError('add_background must be True or False.')

if not isinstance(include_components, bool):
raise TypeError('include_components must be True or False.')

if not isinstance(include_residuals, bool):
raise TypeError('include_residuals must be True or False.')
self._verify_bool(add_background, 'add_background')
self._verify_bool(include_components, 'include_components')
self._verify_bool(include_residuals, 'include_residuals')

energy = self._verify_energy(energy)

Expand Down Expand Up @@ -442,6 +441,8 @@ def parameters_to_dataset(self) -> sc.Dataset:
------
UnitError
If there are inconsistent units for the same parameter across different Q values.
ValueError
If duplicate parameter names exist for the same Q index.

Returns
-------
Expand All @@ -452,20 +453,28 @@ def parameters_to_dataset(self) -> sc.Dataset:

ds = sc.Dataset(coords={'Q': self.Q})

# Collect all parameter names
all_names = {
# Collect all parameter names in first-seen order
all_names = dict.fromkeys(
param.name
for analysis in self.analysis_list
for param in analysis.get_all_parameters()
}
)

# Storage
values = {name: [] for name in all_names}
variances = {name: [] for name in all_names}
units = {}

for analysis in self.analysis_list:
pars = {p.name: p for p in analysis.get_all_parameters()}
all_params = analysis.get_all_parameters()
param_names = [p.name for p in all_params]
if len(param_names) != len(set(param_names)):
dups = sorted({n for n in param_names if param_names.count(n) > 1})
raise ValueError(
f'Duplicate parameter names at Q_index {analysis.Q_index}: {dups}. '
'Rename components so all parameters have unique names.'
)
pars = {p.name: p for p in all_params}

for name in all_names:
if name in pars:
Expand Down Expand Up @@ -603,45 +612,37 @@ def free_energy_offset(self, Q_index: int | None = None) -> None:
def _on_experiment_changed(self) -> None:
"""
Update the Q values in the sample and instrument models when the experiment changes.

Also update all the Analysis1d objects with the new experiment.
"""
if self._call_updaters:
super()._on_experiment_changed()
for analysis in self.analysis_list:
analysis.experiment = self.experiment
super()._on_experiment_changed()
self._analysis_list_is_dirty = True

def _on_sample_model_changed(self) -> None:
"""
Update the Q values in the sample model when the sample model changes.

Also update all the Analysis1d objects with the new sample model.
"""
if self._call_updaters:
super()._on_sample_model_changed()
for analysis in self.analysis_list:
analysis.sample_model = self.sample_model
super()._on_sample_model_changed()
self._analysis_list_is_dirty = True

def _on_instrument_model_changed(self) -> None:
"""
Update the Q values in the instrument model when the instrument model changes.

Also update all the Analysis1d objects with the new instrument model.
"""
if self._call_updaters:
super()._on_instrument_model_changed()
for analysis in self.analysis_list:
analysis.instrument_model = self.instrument_model
super()._on_instrument_model_changed()
self._analysis_list_is_dirty = True

def _on_convolution_settings_changed(self) -> None:
"""
Update the convolution settings in all Analysis1d objects when the convolution settings
change.
Update the convolution settings when they change.
"""
if self._call_updaters:
super()._on_convolution_settings_changed()
for analysis1d in self.analysis_list:
analysis1d.convolution_settings = self.convolution_settings
super()._on_convolution_settings_changed()
self._analysis_list_is_dirty = True

def _ensure_analysis_list_current(self) -> None:
"""Rebuild the analysis list if any dependency has changed since it was last built."""
if self._analysis_list_is_dirty:
if self.Q is not None:
self._create_analysis_list()
self._analysis_list_is_dirty = False

def _create_analysis_list(self) -> None:
"""
Expand All @@ -652,7 +653,6 @@ def _create_analysis_list(self) -> None:
for Q_index in range(len(self.Q)):
analysis = Analysis1d(
display_name=f'{self.display_name}_Q{Q_index}',
unique_name=(f'{self.unique_name}_Q{Q_index}'),
experiment=self.experiment,
sample_model=self.sample_model,
instrument_model=self.instrument_model,
Expand Down Expand Up @@ -682,8 +682,6 @@ def _fit_single_Q(self, Q_index: int) -> FitResults:
The results of the fit for the specified Q index.
"""

Q_index = self._verify_Q_index(Q_index)

return self.analysis_list[Q_index].fit()

def _fit_all_Q_independently(self) -> list[FitResults]:
Expand Down Expand Up @@ -720,9 +718,7 @@ def _fit_all_Q_simultaneously(self) -> FitResults:
ws.append(weight)

# Make sure the convolver is up to date for this Q index
analysis1d._convolver = analysis1d._create_convolver( # noqa: SLF001
energy=x
)
analysis1d.refresh_convolver(energy=x)

mf = MultiFitter(
fit_objects=self.analysis_list,
Expand Down Expand Up @@ -799,18 +795,12 @@ def _create_components_dataset(
The energy values to use for calculating the components. If None, uses the energy from
the experiment.

Raises
------
TypeError
If add_background is not True or False.

Returns
-------
sc.Dataset
A scipp Dataset where each entry is a component of the model, with dimensions "Q".
"""
if not isinstance(add_background, bool):
raise TypeError('add_background must be True or False.')
self._verify_bool(add_background, 'add_background')

if energy is None:
energy = self.energy
Expand Down
Loading