Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 25 additions & 71 deletions src/easydynamics/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,6 @@ def plot_data_and_model(
values available for plotting.
RuntimeError
If not in a Jupyter notebook environment.
TypeError
If plot_components or add_background is not True or False.

Returns
-------
Expand Down Expand Up @@ -279,14 +277,9 @@ def plot_data_and_model(
'No Q values available for plotting. Please check the experiment data.'
)

if not isinstance(plot_components, bool):
raise TypeError('plot_components must be True or False.')

if not isinstance(add_background, bool):
raise TypeError('add_background must be True or False.')

if not isinstance(plot_residuals, bool):
raise TypeError('plot_residuals must be True or False.')
self._verify_bool(plot_components, 'plot_components')
self._verify_bool(add_background, 'add_background')
self._verify_bool(plot_residuals, 'plot_residuals')

if energy is None:
energy = self.energy
Expand All @@ -300,39 +293,8 @@ def plot_data_and_model(
include_residuals=plot_residuals,
)

plot_kwargs_defaults = {
'title': self.display_name,
'linestyle': {},
'marker': {},
'color': {},
'markerfacecolor': {},
'keep': 'energy',
}

for key in data_and_model:
if key == 'Data':
plot_kwargs_defaults['linestyle'][key] = 'none'
plot_kwargs_defaults['marker'][key] = 'o'
plot_kwargs_defaults['color'][key] = 'black'
plot_kwargs_defaults['markerfacecolor'][key] = 'none'

elif key == 'Model':
plot_kwargs_defaults['linestyle'][key] = '-'
plot_kwargs_defaults['marker'][key] = None
plot_kwargs_defaults['color'][key] = 'red'
plot_kwargs_defaults['markerfacecolor'][key] = 'none'

elif key == 'Residuals':
plot_kwargs_defaults['linestyle'][key] = 'none'
plot_kwargs_defaults['marker'][key] = 'o'
plot_kwargs_defaults['color'][key] = 'blue'
plot_kwargs_defaults['markerfacecolor'][key] = 'none'

else:
plot_kwargs_defaults['linestyle'][key] = '--'
plot_kwargs_defaults['marker'][key] = None

# Overwrite defaults with any user-provided kwargs
plot_kwargs_defaults = self._build_plot_style_defaults(data_and_model)
plot_kwargs_defaults['keep'] = 'energy'
plot_kwargs_defaults.update(kwargs)

if plot_residuals:
Expand Down Expand Up @@ -384,9 +346,6 @@ def data_and_model_to_datagroup(
If there is no data to include in the DataGroup, or if there are no Q values available
for creating the DataGroup.

TypeError
If add_background is not True or False. If include_components is not True or False.

Returns
-------
sc.DataGroup
Expand All @@ -401,14 +360,9 @@ def data_and_model_to_datagroup(
'No Q values available for creating DataGroup. Please check the experiment data.'
)

if not isinstance(add_background, bool):
raise TypeError('add_background must be True or False.')

if not isinstance(include_components, bool):
raise TypeError('include_components must be True or False.')

if not isinstance(include_residuals, bool):
raise TypeError('include_residuals must be True or False.')
self._verify_bool(add_background, 'add_background')
self._verify_bool(include_components, 'include_components')
self._verify_bool(include_residuals, 'include_residuals')

energy = self._verify_energy(energy)

Expand Down Expand Up @@ -442,6 +396,8 @@ def parameters_to_dataset(self) -> sc.Dataset:
------
UnitError
If there are inconsistent units for the same parameter across different Q values.
ValueError
If duplicate parameter names exist for the same Q index.

Returns
-------
Expand All @@ -452,20 +408,28 @@ def parameters_to_dataset(self) -> sc.Dataset:

ds = sc.Dataset(coords={'Q': self.Q})

# Collect all parameter names
all_names = {
# Collect all parameter names in first-seen order
all_names = dict.fromkeys(
param.name
for analysis in self.analysis_list
for param in analysis.get_all_parameters()
}
)

# Storage
values = {name: [] for name in all_names}
variances = {name: [] for name in all_names}
units = {}

for analysis in self.analysis_list:
pars = {p.name: p for p in analysis.get_all_parameters()}
all_params = analysis.get_all_parameters()
param_names = [p.name for p in all_params]
if len(param_names) != len(set(param_names)):
dups = sorted({n for n in param_names if param_names.count(n) > 1})
raise ValueError(
f'Duplicate parameter names at Q_index {analysis.Q_index}: {dups}. '
'Rename components so all parameters have unique names.'
)
pars = {p.name: p for p in all_params}

for name in all_names:
if name in pars:
Expand Down Expand Up @@ -652,7 +616,7 @@ def _create_analysis_list(self) -> None:
for Q_index in range(len(self.Q)):
analysis = Analysis1d(
display_name=f'{self.display_name}_Q{Q_index}',
unique_name=(f'{self.unique_name}_Q{Q_index}'),
unique_name=f'{self.unique_name}_Q{Q_index}',
experiment=self.experiment,
sample_model=self.sample_model,
instrument_model=self.instrument_model,
Expand Down Expand Up @@ -682,8 +646,6 @@ def _fit_single_Q(self, Q_index: int) -> FitResults:
The results of the fit for the specified Q index.
"""

Q_index = self._verify_Q_index(Q_index)

return self.analysis_list[Q_index].fit()

def _fit_all_Q_independently(self) -> list[FitResults]:
Expand Down Expand Up @@ -720,9 +682,7 @@ def _fit_all_Q_simultaneously(self) -> FitResults:
ws.append(weight)

# Make sure the convolver is up to date for this Q index
analysis1d._convolver = analysis1d._create_convolver( # noqa: SLF001
energy=x
)
analysis1d.refresh_convolver(energy=x)

mf = MultiFitter(
fit_objects=self.analysis_list,
Expand Down Expand Up @@ -799,18 +759,12 @@ def _create_components_dataset(
The energy values to use for calculating the components. If None, uses the energy from
the experiment.

Raises
------
TypeError
If add_background is not True or False.

Returns
-------
sc.Dataset
A scipp Dataset where each entry is a component of the model, with dimensions "Q".
"""
if not isinstance(add_background, bool):
raise TypeError('add_background must be True or False.')
self._verify_bool(add_background, 'add_background')

if energy is None:
energy = self.energy
Expand Down
Loading
Loading