diff --git a/.gitignore b/.gitignore index 6dc595c7..761f512c 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,4 @@ CMakeLists.txt.user* .cache/ *.log *.zip +.claude/ \ No newline at end of file diff --git a/src/easydynamics/analysis/analysis.py b/src/easydynamics/analysis/analysis.py index c7d01ef8..665f28f8 100644 --- a/src/easydynamics/analysis/analysis.py +++ b/src/easydynamics/analysis/analysis.py @@ -66,9 +66,8 @@ def __init__( parameters are added. """ - # Avoid triggering updates before the object is fully - # initialized - self._call_updaters = False + self._analysis_list: list[Analysis1d] = [] + self._analysis_list_is_dirty = True super().__init__( display_name=display_name, unique_name=unique_name, @@ -80,13 +79,6 @@ def __init__( extra_parameters=extra_parameters, ) - self._analysis_list = [] - if self.Q is not None: - self._create_analysis_list() - - # Now we can allow updates to trigger recalculations - self._call_updaters = True - ############# # Properties ############# @@ -101,6 +93,7 @@ def analysis_list(self) -> list[Analysis1d]: list[Analysis1d] A list of Analysis1d objects, one for each Q index. """ + self._ensure_analysis_list_current() return self._analysis_list @analysis_list.setter @@ -131,6 +124,58 @@ def analysis_list(self, _value: list[Analysis1d]) -> None: ############# # Other methods ############# + def rebin( + self, + dimensions: dict[str, int | sc.Variable], + confirm: bool = False, + ) -> None: + """ + Rebin the experiment data along specified dimensions and update the analysis. + + If Q values change (in count or magnitude), ``confirm=True`` is required. This clears Q + from ``sample_model`` and ``instrument_model`` (including resolution and background + sub-models) so they can accept the new Q values when the analysis list is next rebuilt. + + Parameters + ---------- + dimensions : dict[str, int | sc.Variable] + A dictionary mapping dimension names to number of bins (int) or bin edges + (sc.Variable). + confirm : bool, default=False + Must be ``True`` when rebinning changes the Q values (count or magnitude), since this + clears Q from all models. Raises ``ValueError`` otherwise. + + Raises + ------ + ValueError + If rebinning changes Q and ``confirm`` is not ``True``. + """ + old_Q = np.asarray(self.Q.values) if self.Q is not None else None + old_binned_data = self.experiment._binned_data # noqa: SLF001 + + self.experiment.rebin(dimensions) + new_Q = np.asarray(self.Q.values) if self.Q is not None else None + + q_changed = ( + old_Q is not None + and new_Q is not None + and (len(old_Q) != len(new_Q) or not np.allclose(old_Q, new_Q)) + ) + + if q_changed and not confirm: + self.experiment._binned_data = old_binned_data # noqa: SLF001 + raise ValueError( + 'Rebinning changed Q values, which requires clearing Q from sample_model and ' + 'instrument_model (including resolution and background sub-models). ' + 'Pass confirm=True to proceed.' + ) + + if q_changed: + self.sample_model.clear_Q(confirm=True) + self.instrument_model.clear_Q(confirm=True) + + self._analysis_list_is_dirty = True + def calculate( self, Q_index: int | None = None, @@ -249,8 +294,6 @@ def plot_data_and_model( values available for plotting. RuntimeError If not in a Jupyter notebook environment. - TypeError - If plot_components or add_background is not True or False. Returns ------- @@ -279,14 +322,9 @@ def plot_data_and_model( 'No Q values available for plotting. Please check the experiment data.' ) - if not isinstance(plot_components, bool): - raise TypeError('plot_components must be True or False.') - - if not isinstance(add_background, bool): - raise TypeError('add_background must be True or False.') - - if not isinstance(plot_residuals, bool): - raise TypeError('plot_residuals must be True or False.') + self._verify_bool(plot_components, 'plot_components') + self._verify_bool(add_background, 'add_background') + self._verify_bool(plot_residuals, 'plot_residuals') if energy is None: energy = self.energy @@ -300,39 +338,8 @@ def plot_data_and_model( include_residuals=plot_residuals, ) - plot_kwargs_defaults = { - 'title': self.display_name, - 'linestyle': {}, - 'marker': {}, - 'color': {}, - 'markerfacecolor': {}, - 'keep': 'energy', - } - - for key in data_and_model: - if key == 'Data': - plot_kwargs_defaults['linestyle'][key] = 'none' - plot_kwargs_defaults['marker'][key] = 'o' - plot_kwargs_defaults['color'][key] = 'black' - plot_kwargs_defaults['markerfacecolor'][key] = 'none' - - elif key == 'Model': - plot_kwargs_defaults['linestyle'][key] = '-' - plot_kwargs_defaults['marker'][key] = None - plot_kwargs_defaults['color'][key] = 'red' - plot_kwargs_defaults['markerfacecolor'][key] = 'none' - - elif key == 'Residuals': - plot_kwargs_defaults['linestyle'][key] = 'none' - plot_kwargs_defaults['marker'][key] = 'o' - plot_kwargs_defaults['color'][key] = 'blue' - plot_kwargs_defaults['markerfacecolor'][key] = 'none' - - else: - plot_kwargs_defaults['linestyle'][key] = '--' - plot_kwargs_defaults['marker'][key] = None - - # Overwrite defaults with any user-provided kwargs + plot_kwargs_defaults = self._build_plot_style_defaults(data_and_model) + plot_kwargs_defaults['keep'] = 'energy' plot_kwargs_defaults.update(kwargs) if plot_residuals: @@ -384,9 +391,6 @@ def data_and_model_to_datagroup( If there is no data to include in the DataGroup, or if there are no Q values available for creating the DataGroup. - TypeError - If add_background is not True or False. If include_components is not True or False. - Returns ------- sc.DataGroup @@ -401,14 +405,9 @@ def data_and_model_to_datagroup( 'No Q values available for creating DataGroup. Please check the experiment data.' ) - if not isinstance(add_background, bool): - raise TypeError('add_background must be True or False.') - - if not isinstance(include_components, bool): - raise TypeError('include_components must be True or False.') - - if not isinstance(include_residuals, bool): - raise TypeError('include_residuals must be True or False.') + self._verify_bool(add_background, 'add_background') + self._verify_bool(include_components, 'include_components') + self._verify_bool(include_residuals, 'include_residuals') energy = self._verify_energy(energy) @@ -442,6 +441,8 @@ def parameters_to_dataset(self) -> sc.Dataset: ------ UnitError If there are inconsistent units for the same parameter across different Q values. + ValueError + If duplicate parameter names exist for the same Q index. Returns ------- @@ -452,12 +453,12 @@ def parameters_to_dataset(self) -> sc.Dataset: ds = sc.Dataset(coords={'Q': self.Q}) - # Collect all parameter names - all_names = { + # Collect all parameter names in first-seen order + all_names = dict.fromkeys( param.name for analysis in self.analysis_list for param in analysis.get_all_parameters() - } + ) # Storage values = {name: [] for name in all_names} @@ -465,7 +466,15 @@ def parameters_to_dataset(self) -> sc.Dataset: units = {} for analysis in self.analysis_list: - pars = {p.name: p for p in analysis.get_all_parameters()} + all_params = analysis.get_all_parameters() + param_names = [p.name for p in all_params] + if len(param_names) != len(set(param_names)): + dups = sorted({n for n in param_names if param_names.count(n) > 1}) + raise ValueError( + f'Duplicate parameter names at Q_index {analysis.Q_index}: {dups}. ' + 'Rename components so all parameters have unique names.' + ) + pars = {p.name: p for p in all_params} for name in all_names: if name in pars: @@ -603,45 +612,37 @@ def free_energy_offset(self, Q_index: int | None = None) -> None: def _on_experiment_changed(self) -> None: """ Update the Q values in the sample and instrument models when the experiment changes. - - Also update all the Analysis1d objects with the new experiment. """ - if self._call_updaters: - super()._on_experiment_changed() - for analysis in self.analysis_list: - analysis.experiment = self.experiment + super()._on_experiment_changed() + self._analysis_list_is_dirty = True def _on_sample_model_changed(self) -> None: """ Update the Q values in the sample model when the sample model changes. - - Also update all the Analysis1d objects with the new sample model. """ - if self._call_updaters: - super()._on_sample_model_changed() - for analysis in self.analysis_list: - analysis.sample_model = self.sample_model + super()._on_sample_model_changed() + self._analysis_list_is_dirty = True def _on_instrument_model_changed(self) -> None: """ Update the Q values in the instrument model when the instrument model changes. - - Also update all the Analysis1d objects with the new instrument model. """ - if self._call_updaters: - super()._on_instrument_model_changed() - for analysis in self.analysis_list: - analysis.instrument_model = self.instrument_model + super()._on_instrument_model_changed() + self._analysis_list_is_dirty = True def _on_convolution_settings_changed(self) -> None: """ - Update the convolution settings in all Analysis1d objects when the convolution settings - change. + Update the convolution settings when they change. """ - if self._call_updaters: - super()._on_convolution_settings_changed() - for analysis1d in self.analysis_list: - analysis1d.convolution_settings = self.convolution_settings + super()._on_convolution_settings_changed() + self._analysis_list_is_dirty = True + + def _ensure_analysis_list_current(self) -> None: + """Rebuild the analysis list if any dependency has changed since it was last built.""" + if self._analysis_list_is_dirty: + if self.Q is not None: + self._create_analysis_list() + self._analysis_list_is_dirty = False def _create_analysis_list(self) -> None: """ @@ -652,7 +653,6 @@ def _create_analysis_list(self) -> None: for Q_index in range(len(self.Q)): analysis = Analysis1d( display_name=f'{self.display_name}_Q{Q_index}', - unique_name=(f'{self.unique_name}_Q{Q_index}'), experiment=self.experiment, sample_model=self.sample_model, instrument_model=self.instrument_model, @@ -682,8 +682,6 @@ def _fit_single_Q(self, Q_index: int) -> FitResults: The results of the fit for the specified Q index. """ - Q_index = self._verify_Q_index(Q_index) - return self.analysis_list[Q_index].fit() def _fit_all_Q_independently(self) -> list[FitResults]: @@ -720,9 +718,7 @@ def _fit_all_Q_simultaneously(self) -> FitResults: ws.append(weight) # Make sure the convolver is up to date for this Q index - analysis1d._convolver = analysis1d._create_convolver( # noqa: SLF001 - energy=x - ) + analysis1d.refresh_convolver(energy=x) mf = MultiFitter( fit_objects=self.analysis_list, @@ -799,18 +795,12 @@ def _create_components_dataset( The energy values to use for calculating the components. If None, uses the energy from the experiment. - Raises - ------ - TypeError - If add_background is not True or False. - Returns ------- sc.Dataset A scipp Dataset where each entry is a component of the model, with dimensions "Q". """ - if not isinstance(add_background, bool): - raise TypeError('add_background must be True or False.') + self._verify_bool(add_background, 'add_background') if energy is None: energy = self.energy diff --git a/src/easydynamics/analysis/analysis1d.py b/src/easydynamics/analysis/analysis1d.py index 480b5fd0..3739bc68 100644 --- a/src/easydynamics/analysis/analysis1d.py +++ b/src/easydynamics/analysis/analysis1d.py @@ -91,10 +91,8 @@ def __init__( self._masked_energy = None self._fit_result = None - if self._Q_index is not None: - self._convolver = self._create_convolver() - else: - self._convolver = None + self._convolver = None + self._convolver_is_dirty = True ############# # Properties @@ -150,6 +148,9 @@ def calculate(self, energy: sc.Variable | None = None) -> np.ndarray: """ energy = self._verify_energy(energy) self._convolver = self._create_convolver(energy=energy) + # Mark dirty so the next fit() call rebuilds the convolver with the standard + # (unmasked) energy grid rather than reusing this plot-path grid. + self._convolver_is_dirty = True return self._calculate(energy=energy) @@ -171,11 +172,17 @@ def _calculate(self, energy: sc.Variable | None = None) -> np.ndarray: The calculated model prediction. """ - sample_intensity = self._evaluate_sample(energy=energy) - - background_intensity = self._evaluate_background(energy=energy) - - return sample_intensity + background_intensity + Q_index = self._require_Q_index() + sample = self._evaluate_with_convolution( + self.sample_model.get_component_collection(Q_index), + energy, + convolver=self._convolver, + ) + background = self._evaluate_direct( + self.instrument_model.background_model.get_component_collection(Q_index), + energy, + ) + return sample + background def fit(self) -> FitResults: """ @@ -197,8 +204,13 @@ def fit(self) -> FitResults: if self._experiment is None: raise ValueError('No experiment is associated with this Analysis.') - # Create convolver once to reuse during fitting - self._convolver = self._create_convolver() + if ( + self.sample_model.component_collections_is_dirty + or self.instrument_model.resolution_model.component_collections_is_dirty + ): + self._convolver_is_dirty = True + + self._ensure_convolver_current() fitter = EasyScienceFitter( fit_object=self, @@ -309,38 +321,7 @@ def plot_data_and_model( include_residuals=plot_residuals, ) - plot_kwargs_defaults = { - 'title': self.display_name, - 'linestyle': {}, - 'marker': {}, - 'color': {}, - 'markerfacecolor': {}, - } - - for key in data_and_model: - if key == 'Data': - plot_kwargs_defaults['linestyle'][key] = 'none' - plot_kwargs_defaults['marker'][key] = 'o' - plot_kwargs_defaults['color'][key] = 'black' - plot_kwargs_defaults['markerfacecolor'][key] = 'none' - - elif key == 'Model': - plot_kwargs_defaults['linestyle'][key] = '-' - plot_kwargs_defaults['marker'][key] = None - plot_kwargs_defaults['color'][key] = 'red' - plot_kwargs_defaults['markerfacecolor'][key] = 'none' - - elif key == 'Residuals': - plot_kwargs_defaults['linestyle'][key] = 'none' - plot_kwargs_defaults['marker'][key] = 'o' - plot_kwargs_defaults['color'][key] = 'blue' - plot_kwargs_defaults['markerfacecolor'][key] = 'none' - - else: - plot_kwargs_defaults['linestyle'][key] = '--' - plot_kwargs_defaults['marker'][key] = None - - # Overwrite defaults with any user-provided kwargs + plot_kwargs_defaults = self._build_plot_style_defaults(data_and_model) plot_kwargs_defaults.update(kwargs) if plot_residuals: @@ -391,8 +372,6 @@ def data_and_model_to_datagroup( If no data is available in the experiment to include in the DataGroup. If no Q values are available in the experiment to create the DataGroup. If Q_index is not set to create the DataGroup. - TypeError - If add_background is not a boolean. If include_components is not a boolean. Returns ------- @@ -409,14 +388,9 @@ def data_and_model_to_datagroup( 'No Q values available for creating DataGroup. Please check the experiment data.' ) - if not isinstance(add_background, bool): - raise TypeError('add_background must be True or False.') - - if not isinstance(include_components, bool): - raise TypeError('include_components must be True or False.') - - if not isinstance(include_residuals, bool): - raise TypeError('include_residuals must be True or False.') + self._verify_bool(add_background, 'add_background') + self._verify_bool(include_components, 'include_components') + self._verify_bool(include_residuals, 'include_residuals') if self.Q_index is None: raise ValueError('Q_index must be set to create DataGroup.') @@ -454,6 +428,26 @@ def free_energy_offset(self) -> None: """Free the energy offset parameter for the current Q index.""" self.instrument_model.free_energy_offset(Q_index=self._require_Q_index()) + def rebin(self, dimensions: dict[str, int | sc.Variable]) -> None: + """ + Rebin the experiment data along specified dimensions and update the analysis. + + Parameters + ---------- + dimensions : dict[str, int | sc.Variable] + A dictionary mapping dimension names to number of bins (int) or bin edges + (sc.Variable). + """ + self.experiment.rebin(dimensions) + if self._Q_index is not None and self.experiment is not None: + self._masked_energy = self.experiment.get_masked_energy(Q_index=self._Q_index) + self._convolver_is_dirty = True + + def refresh_convolver(self, energy: sc.Variable | None = None) -> None: + """Refresh the pre-built Convolution object for the current Q index.""" + self._convolver = self._create_convolver(energy=energy) + self._convolver_is_dirty = False + ############# # Private methods: small utilities ############# @@ -482,12 +476,38 @@ def _on_Q_index_changed(self) -> None: """ Handle changes to the Q index. - This method is called whenever the Q index is changed. It updates the Convolution object - for the new Q index and the masked energy from the experiment for the new Q index. + This method is called whenever the Q index is changed. It updates the masked energy from + the experiment for the new Q index and marks the convolver as dirty. """ masked_energy = self.experiment.get_masked_energy(Q_index=self._Q_index) self._masked_energy = masked_energy - self._convolver = self._create_convolver() + self._convolver_is_dirty = True + + def _on_experiment_changed(self) -> None: + """Mark the convolver as dirty when the experiment changes.""" + super()._on_experiment_changed() + self._convolver_is_dirty = True + + def _on_sample_model_changed(self) -> None: + """Mark the convolver as dirty when the sample model changes.""" + super()._on_sample_model_changed() + self._convolver_is_dirty = True + + def _on_instrument_model_changed(self) -> None: + """Mark the convolver as dirty when the instrument model changes.""" + super()._on_instrument_model_changed() + self._convolver_is_dirty = True + + def _on_convolution_settings_changed(self) -> None: + """Mark the convolver as dirty when the convolution settings change.""" + super()._on_convolution_settings_changed() + self._convolver_is_dirty = True + + def _ensure_convolver_current(self) -> None: + """Rebuild the convolver if any dependency has changed since it was last built.""" + if self._convolver_is_dirty: + self._convolver = self._create_convolver() + self._convolver_is_dirty = False def _calculate_energy_with_offset( self, @@ -532,98 +552,62 @@ def _calculate_energy_with_offset( # Private methods: evaluation ############# - def _evaluate_components( + def _evaluate_with_convolution( self, components: ComponentCollection | ModelComponent, + energy: sc.Variable | None, convolver: Convolution | None = None, - convolve: bool = True, - energy: sc.Variable | None = None, - apply_detailed_balance: bool = False, ) -> np.ndarray: """ - Calculate the contribution of a set of components, optionally convolving with the - resolution. + Evaluate sample components, applying convolution and detailed balance as appropriate. - If convolve is True and a Convolution object is provided (for full model evaluation), we - use it to perform the convolution of the components with the resolution. If convolve is - True but no Convolution object is provided, create a new Convolution object for the given - components (for individual components). If convolve is False, evaluate the components - directly without convolution (for background). + Uses the pre-built convolver when provided (fit path, for performance). If no convolver is + given, creates a temporary one per call (plot path for individual components). Falls back + to direct evaluation with detailed balance if there is no resolution model. Parameters ---------- components : ComponentCollection | ModelComponent - The components to evaluate. + The sample components to evaluate. + energy : sc.Variable | None + Energy grid to use. If None, uses the masked energy from the experiment. convolver : Convolution | None, default=None - An optional Convolution object to use for convolution. If None, a new Convolution - object will be created if convolve is True. - convolve : bool, default=True - Whether to perform convolution with the resolution. - energy : sc.Variable | None, default=None - Optional energy grid to use for evaluation. If None, the energy grid from the - experiment is used. - apply_detailed_balance : bool, default=False - Whether to apply detailed balance correction. - + Pre-built Convolution to reuse. If None, a new one is created if needed. Returns ------- np.ndarray - The evaluated contribution of the components. + The evaluated sample contribution. """ - Q_index = self._require_Q_index() if energy is None: energy = self._masked_energy - energy_offset = self.instrument_model.get_energy_offset(Q_index) - energy_with_offset = self._calculate_energy_with_offset( - energy=energy, - energy_offset=energy_offset, - ) - - # If there are no components, return zero if isinstance(components, ComponentCollection) and components.is_empty: return np.zeros_like(energy.values) - # If a convolver is provided, we use it. This allows reusing the - # same convolver for multiple evaluations during fitting for - # performance reasons. if convolver is not None: return convolver.convolution() - # No convolution can happen for multiple reasons: - # Case 1: convolve=False, used for evaluating background components, where we don't want - # to convolve with the resolution. In this case, apply_detailed_balance is False, - # and we evaluate the components without DBF regardles of the settings - # Case 2: convolve=True but there is no resolution_model. In this case, - # apply_detailed_balance is True. We apply DBF if temperature is provided and - # the settings say to use detailed balance. - + energy_offset = self.instrument_model.get_energy_offset(Q_index) + energy_with_offset = self._calculate_energy_with_offset(energy, energy_offset) resolution = self.instrument_model.resolution_model.get_component_collection(Q_index) - if not convolve or resolution.is_empty: - result_no_convolution = components.evaluate(energy_with_offset) + + if resolution.is_empty: + result = components.evaluate(energy_with_offset) if ( - apply_detailed_balance - and self.temperature is not None + self.temperature is not None and self.detailed_balance_settings.use_detailed_balance ): - DBF = detailed_balance_factor( + result *= detailed_balance_factor( energy=energy_with_offset, temperature=self.temperature, divide_by_temperature=self.detailed_balance_settings.normalize_detailed_balance, energy_unit=self.unit, ) - result_no_convolution *= DBF - return result_no_convolution + return result - # If no convolver is provided, we create a new one. This is for - # evaluating individual components for plotting, where - # performance is not important. We already handled the case of - # background components above, so we know that this is for sample components, - # where detailed balance settings should be applied. - - conv = Convolution( + return Convolution( energy=energy, sample_components=components, resolution_components=resolution, @@ -631,77 +615,22 @@ def _evaluate_components( convolution_settings=self.convolution_settings, temperature=self.temperature, detailed_balance_settings=self.detailed_balance_settings, - ) - return conv.convolution() - - def _evaluate_sample( - self, - energy: sc.Variable | None = None, - ) -> np.ndarray: - """ - Evaluate the sample contribution for a given Q index. - - Assumes that self._convolver is up to date. - - Parameters - ---------- - energy : sc.Variable | None, default=None - Optional energy grid to use for evaluation. If None, the energy grid from the - experiment is used. - - Returns - ------- - np.ndarray - The evaluated sample contribution. - """ - Q_index = self._require_Q_index() - components = self.sample_model.get_component_collection(Q_index=Q_index) - return self._evaluate_components( - components=components, - convolver=self._convolver, - convolve=True, - energy=energy, - apply_detailed_balance=True, - ) + ).convolution() - def _evaluate_sample_component( + def _evaluate_direct( self, - component: ModelComponent, - energy: sc.Variable | None = None, + components: ComponentCollection | ModelComponent, + energy: sc.Variable | None, ) -> np.ndarray: """ - Evaluate a single sample component for the chosen Q index. - - Parameters - ---------- - component : ModelComponent - The sample component to evaluate. - energy : sc.Variable | None, default=None - Optional energy grid to use for evaluation. If None, the energy grid from the - experiment is used. - - Returns - ------- - np.ndarray - The evaluated sample component contribution. - """ - return self._evaluate_components( - components=component, - convolver=None, - convolve=True, - energy=energy, - apply_detailed_balance=True, - ) - - def _evaluate_background(self, energy: sc.Variable | None = None) -> np.ndarray: - """ - Evaluate the background contribution for the chosen Q index. + Evaluate background components directly — no convolution, no detailed balance factor. Parameters ---------- - energy : sc.Variable | None, default=None - Optional energy grid to use for evaluation. If None, the energy grid from the - experiment is used. + components : ComponentCollection | ModelComponent + The background components to evaluate. + energy : sc.Variable | None + Energy grid to use. If None, uses the masked energy from the experiment. Returns ------- @@ -709,46 +638,15 @@ def _evaluate_background(self, energy: sc.Variable | None = None) -> np.ndarray: The evaluated background contribution. """ Q_index = self._require_Q_index() - background_components = self.instrument_model.background_model.get_component_collection( - Q_index=Q_index - ) - return self._evaluate_components( - components=background_components, - convolver=None, - convolve=False, - energy=energy, - apply_detailed_balance=False, - ) - - def _evaluate_background_component( - self, - component: ModelComponent, - energy: sc.Variable | None = None, - ) -> np.ndarray: - """ - Evaluate a single background component for the chosen Q index. - - Parameters - ---------- - component : ModelComponent - The background component to evaluate. - energy : sc.Variable | None, default=None - Optional energy grid to use for evaluation. If None, the energy grid from the - experiment is used. + if energy is None: + energy = self._masked_energy - Returns - ------- - np.ndarray - The evaluated background component contribution. - """ + if isinstance(components, ComponentCollection) and components.is_empty: + return np.zeros_like(energy.values) - return self._evaluate_components( - components=component, - convolver=None, - convolve=False, - energy=energy, - apply_detailed_balance=False, - ) + energy_offset = self.instrument_model.get_energy_offset(Q_index) + energy_with_offset = self._calculate_energy_with_offset(energy, energy_offset) + return components.evaluate(energy_with_offset) def _create_convolver( self, @@ -798,66 +696,6 @@ def _create_convolver( # Private methods: create scipp arrays for plotting ############# - def _create_component_scipp_array( - self, - component: ModelComponent, - background: np.ndarray | None = None, - energy: sc.Variable | None = None, - ) -> sc.DataArray: - """ - Create a scipp DataArray for a single component. - - Adds the background if it is not None. - - Parameters - ---------- - component : ModelComponent - The component to evaluate. - background : np.ndarray | None, default=None - Optional background to add to the component. - energy : sc.Variable | None, default=None - Optional energy grid to use for evaluation. If None, the energy grid from the - experiment is used. - - Returns - ------- - sc.DataArray - The model calculation of the component. - """ - - values = self._evaluate_sample_component(component=component, energy=energy) - if background is not None: - values += background - return self._to_scipp_array(values=values, energy=energy) - - def _create_background_component_scipp_array( - self, - component: ModelComponent, - energy: sc.Variable | None = None, - ) -> sc.DataArray: - """ - Create a scipp DataArray for a single background component. - - Parameters - ---------- - component : ModelComponent - The component to evaluate. - energy : sc.Variable | None, default=None - Optional energy grid to use for evaluation. If None, the energy grid from the - experiment is used. - - Returns - ------- - sc.DataArray - The model calculation of the component. - """ - - values = self._evaluate_background_component( - component=component, - energy=energy, - ) - return self._to_scipp_array(values=values, energy=energy) - def _create_model_array(self, energy: sc.Variable | None = None) -> sc.DataArray: """ Create a scipp DataArray for the full sample model including background. @@ -902,44 +740,47 @@ def _create_components_dataset_single_Q( self, add_background: bool = True, energy: sc.Variable | None = None, - ) -> dict[str, sc.DataArray]: + ) -> sc.Dataset: """ Create sc.DataArrays for all sample and background components. Parameters ---------- add_background : bool, default=True - Whether to add background components. + Whether to add the background to each sample component. energy : sc.Variable | None, default=None Optional energy grid to use for evaluation. If None, the energy grid from the experiment is used. Returns ------- - dict[str, sc.DataArray] - A dictionary of component names to their corresponding sc.DataArrays. + sc.Dataset + A Dataset of component names to their corresponding sc.DataArrays. """ - scipp_arrays = {} - sample_components = self.sample_model.get_component_collection(Q_index=self.Q_index) + Q_index = self.Q_index + if energy is None: + energy = self._masked_energy background_components = self.instrument_model.background_model.get_component_collection( - Q_index=self.Q_index + Q_index=Q_index + ) + background_values = ( + self._evaluate_direct(background_components, energy) if add_background else None ) - if energy is None: - energy = self._masked_energy - - background = self._evaluate_background(energy=energy) if add_background else None + result: dict[str, sc.DataArray] = {} + for component in self.sample_model.get_component_collection(Q_index=Q_index): + values = self._evaluate_with_convolution(component, energy) + if background_values is not None: + values = values + background_values + result[component.display_name] = self._to_scipp_array(values, energy) - for component in sample_components: - scipp_arrays[component.display_name] = self._create_component_scipp_array( - component=component, background=background, energy=energy - ) for component in background_components: - scipp_arrays[component.display_name] = self._create_background_component_scipp_array( - component=component, energy=energy + result[component.display_name] = self._to_scipp_array( + self._evaluate_direct(component, energy), energy ) - return sc.Dataset(scipp_arrays) + + return sc.Dataset(result) def _to_scipp_array( self, diff --git a/src/easydynamics/analysis/analysis_base.py b/src/easydynamics/analysis/analysis_base.py index ac30a9ea..8e19a90d 100644 --- a/src/easydynamics/analysis/analysis_base.py +++ b/src/easydynamics/analysis/analysis_base.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Iterable + import numpy as np import scipp as sc from easyscience.variable import Parameter @@ -103,17 +105,7 @@ def __init__( 'convolution_settings must be an instance of ConvolutionSettings or None.' ) - if extra_parameters is not None: - if isinstance(extra_parameters, Parameter): - self._extra_parameters = [extra_parameters] - elif isinstance(extra_parameters, list) and all( - isinstance(p, Parameter) for p in extra_parameters - ): - self._extra_parameters = extra_parameters - else: - raise TypeError('extra_parameters must be a Parameter or a list of Parameters.') - else: - self._extra_parameters = [] + self.extra_parameters = extra_parameters if detailed_balance_settings is None: self._detailed_balance_settings = DetailedBalanceSettings() @@ -266,7 +258,7 @@ def energy(self) -> sc.Variable | None: Returns ------- sc.Variable | None - The energy values from the associated. + The energy values from the associated Experiment, if available, and None if not. """ return self.experiment.energy @@ -451,26 +443,10 @@ def get_parameters_near_bounds( ------- list[Parameter] A list of parameters that are near their bounds. - - Raises - ------ - TypeError - If rtol or atol is not a float. - ValueError - If rtol or atol is negative. """ - if not isinstance(rtol, (int, float)): - raise TypeError(f'rtol must be a float. Got {type(rtol)}.') - - if rtol < 0: - raise ValueError(f'rtol must be non-negative. Got {rtol}.') - - if not isinstance(atol, (int, float)): - raise TypeError(f'atol must be a float. Got {type(atol)}.') - - if atol < 0: - raise ValueError(f'atol must be non-negative. Got {atol}.') + self._verify_nonneg_float(rtol, 'rtol') + self._verify_nonneg_float(atol, 'atol') parameters = self.get_all_parameters() at_bounds = [] @@ -573,6 +549,97 @@ def _verify_energy(self, energy: sc.Variable | None) -> sc.Variable | None: raise TypeError(f'Energy must be a sc.Variable or None. Got {type(energy)}.') return energy + @staticmethod + def _verify_bool(value: object, name: str) -> None: + """ + Raise TypeError if value is not a bool. + + Parameters + ---------- + value : object + The object to verify. + name : str + The name of the object for use in the error message. + + Raises + ------ + TypeError + If value is not a bool. + """ + if not isinstance(value, bool): + raise TypeError(f'{name} must be True or False.') + + @staticmethod + def _verify_nonneg_float(value: object, name: str) -> None: + """ + Raise TypeError or ValueError if value is not a non-negative number. + + Parameters + ---------- + value : object + The object to verify. + name : str + The name of the object for use in the error message. + + Raises + ------ + TypeError + If value is not an int or float. + ValueError + If value is negative. + """ + if not isinstance(value, (int, float)): + raise TypeError(f'{name} must be a float. Got {type(value)}.') + if value < 0: + raise ValueError(f'{name} must be non-negative. Got {value}.') + + def _build_plot_style_defaults(self, keys: Iterable[str]) -> dict: + """ + Build default plot style kwargs for the given DataGroup keys. + + Parameters + ---------- + keys : Iterable[str] + The DataGroup keys to build plot style defaults for. Recognized values are ``"Data"``, + ``"Model"``, and ``"Residuals"``; any other key gets a dashed line style. + + Returns + ------- + dict + A dict of plot style kwargs including ``title``, ``linestyle``, ``marker``, ``color``, + and ``markerfacecolor``. + """ + linestyle: dict = {} + marker: dict = {} + color: dict = {} + markerfacecolor: dict = {} + for key in keys: + if key == 'Data': + linestyle[key] = 'none' + marker[key] = 'o' + color[key] = 'black' + markerfacecolor[key] = 'none' + elif key == 'Model': + linestyle[key] = '-' + marker[key] = None + color[key] = 'red' + markerfacecolor[key] = 'none' + elif key == 'Residuals': + linestyle[key] = 'none' + marker[key] = 'o' + color[key] = 'blue' + markerfacecolor[key] = 'none' + else: + linestyle[key] = '--' + marker[key] = None + return { + 'title': self.display_name, + 'linestyle': linestyle, + 'marker': marker, + 'color': color, + 'markerfacecolor': markerfacecolor, + } + ############# # Dunder methods ############# @@ -587,6 +654,6 @@ def __repr__(self) -> str: A string representation of the Analysis. """ return ( - f' {self.__class__.__name__} (display_name={self.display_name}, ' + f'{self.__class__.__name__} (display_name={self.display_name}, ' f'unique_name={self.unique_name})' ) diff --git a/src/easydynamics/analysis/parameter_analysis.py b/src/easydynamics/analysis/parameter_analysis.py index eb3e3b4f..2283ba83 100644 --- a/src/easydynamics/analysis/parameter_analysis.py +++ b/src/easydynamics/analysis/parameter_analysis.py @@ -529,6 +529,14 @@ def _get_xyweight_from_dataset( # Dunder methods ############# def __repr__(self) -> str: + """ + Return a string representation of the ParameterAnalysis. + + Returns + ------- + str + A string representation of the ParameterAnalysis. + """ cls = self.__class__.__name__ n_params = len(self._parameters) if isinstance(self._parameters, sc.Dataset) else 0 diff --git a/src/easydynamics/base_classes/easydynamics_list.py b/src/easydynamics/base_classes/easydynamics_list.py index b45248b3..fbadba2a 100644 --- a/src/easydynamics/base_classes/easydynamics_list.py +++ b/src/easydynamics/base_classes/easydynamics_list.py @@ -91,6 +91,7 @@ def insert(self, index: int, value: ProtectedType_) -> None: def append(self, value: ProtectedType_) -> None: """ Append an item to the end of the list. + Parameters ---------- value : ProtectedType_ diff --git a/src/easydynamics/convolution/analytical_convolution.py b/src/easydynamics/convolution/analytical_convolution.py index 5ab305cf..dc07a26f 100644 --- a/src/easydynamics/convolution/analytical_convolution.py +++ b/src/easydynamics/convolution/analytical_convolution.py @@ -4,8 +4,6 @@ from typing import ClassVar import numpy as np -import scipp as sc -from easyscience.variable import Parameter from scipy.special import voigt_profile from easydynamics.convolution.convolution_base import ConvolutionBase @@ -15,7 +13,6 @@ from easydynamics.sample_model import Voigt from easydynamics.sample_model.component_collection import ComponentCollection from easydynamics.sample_model.components.model_component import ModelComponent -from easydynamics.utils.utils import Numeric class AnalyticalConvolution(ConvolutionBase): @@ -37,46 +34,6 @@ class AnalyticalConvolution(ConvolutionBase): ('Voigt', 'Voigt'): '_convolute_voigt_voigt', } - def __init__( - self, - energy: np.ndarray | sc.Variable, - unit: str | sc.Unit = 'meV', - sample_components: ComponentCollection | ModelComponent | None = None, - resolution_components: ComponentCollection | ModelComponent | None = None, - energy_offset: Numeric | Parameter = 0.0, - display_name: str | None = 'MyConvolution', - unique_name: str | None = None, - ) -> None: - """ - Initialize an AnalyticalConvolution. - - Parameters - ---------- - energy : np.ndarray | sc.Variable - 1D array of energy values where the convolution is evaluated. - unit : str | sc.Unit, default='meV' - The unit of the energy. - sample_components : ComponentCollection | ModelComponent | None, default=None - The sample model to be convolved. - resolution_components : ComponentCollection | ModelComponent | None, default=None - The resolution model to convolve with. - energy_offset : Numeric | Parameter, default=0.0 - An offset to shift the energy values by. - display_name : str | None, default='MyConvolution' - Display name of the model. - unique_name : str | None, default=None - Unique name of the model. If None, a unique name will be generated. - """ - super().__init__( - energy=energy, - unit=unit, - sample_components=sample_components, - resolution_components=resolution_components, - energy_offset=energy_offset, - display_name=display_name, - unique_name=unique_name, - ) - def convolution( self, ) -> np.ndarray: @@ -209,7 +166,7 @@ def _convolute_delta_any( sample_component : DeltaFunction The sample component to be convolved. resolution_components : ComponentCollection | ModelComponent - : The resolution model to convolve with. + The resolution model to convolve with. Returns ------- diff --git a/src/easydynamics/convolution/convolution.py b/src/easydynamics/convolution/convolution.py index 8a11e05f..d2151e42 100644 --- a/src/easydynamics/convolution/convolution.py +++ b/src/easydynamics/convolution/convolution.py @@ -36,9 +36,10 @@ class Convolution(NumericalConvolutionBase): """ # When these attributes are changed, the convolution plan - # needs to be rebuilt + # needs to be rebuilt. + # Note: the public 'energy' property setter always writes to '_energy', so '_energy' alone + # is sufficient — listing 'energy' separately would cause a double invalidation. _invalidate_plan_on_change: ClassVar[dict[str, object]] = { - 'energy', '_energy', '_energy_grid', '_sample_components', @@ -180,8 +181,7 @@ def _check_if_pair_is_analytic( Raises ------ TypeError - If either component is not a ModelComponent, or if the resolution component is a - DeltaFunction. + If the resolution component is a DeltaFunction. Returns ------- @@ -189,18 +189,6 @@ def _check_if_pair_is_analytic( True if the component pair can be handled analytically, False otherwise. """ - if not isinstance(sample_component, ModelComponent): - raise TypeError( - f'`sample_component` is an instance of {type(sample_component).__name__}, \ - but must be a ModelComponent.' - ) - - if not isinstance(resolution_component, ModelComponent): - raise TypeError( - f'`resolution_component` is an instance of {type(resolution_component).__name__}, \ - but must be a ModelComponent.' - ) - if isinstance(resolution_component, DeltaFunction): raise TypeError( 'resolution components contains delta functions. This is not supported.' diff --git a/src/easydynamics/convolution/convolution_base.py b/src/easydynamics/convolution/convolution_base.py index dd98d5e9..c8516cb0 100644 --- a/src/easydynamics/convolution/convolution_base.py +++ b/src/easydynamics/convolution/convolution_base.py @@ -151,25 +151,6 @@ def energy_with_offset(self) -> sc.Variable: energy_with_offset.values = self.energy.values - self.energy_offset.value return energy_with_offset - @energy_with_offset.setter - def energy_with_offset(self, _value: sc.Variable) -> None: - """ - Energy with offset is a read-only property derived from energy and energy_offset. - - Parameters - ---------- - _value : sc.Variable - The value to set (ignored). - - Raises - ------ - AttributeError - Always raised since energy_with_offset is read-only. - """ - raise AttributeError( - 'Energy with offset is a read-only property derived from energy and energy_offset.' - ) - @property def energy(self) -> sc.Variable: """ diff --git a/src/easydynamics/convolution/numerical_convolution.py b/src/easydynamics/convolution/numerical_convolution.py index d39550cb..7deac048 100644 --- a/src/easydynamics/convolution/numerical_convolution.py +++ b/src/easydynamics/convolution/numerical_convolution.py @@ -2,17 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause import numpy as np -import scipp as sc -from easyscience.variable import Parameter from scipy.signal import fftconvolve from easydynamics.convolution.numerical_convolution_base import NumericalConvolutionBase -from easydynamics.sample_model.component_collection import ComponentCollection -from easydynamics.sample_model.components.model_component import ModelComponent -from easydynamics.settings.convolution_settings import ConvolutionSettings -from easydynamics.settings.detailed_balance_settings import DetailedBalanceSettings from easydynamics.utils.detailed_balance import detailed_balance_factor -from easydynamics.utils.utils import Numeric class NumericalConvolution(NumericalConvolutionBase): @@ -24,62 +17,6 @@ class NumericalConvolution(NumericalConvolutionBase): applied to the sample model. """ - def __init__( - self, - energy: np.ndarray | sc.Variable, - sample_components: ComponentCollection | ModelComponent, - resolution_components: ComponentCollection | ModelComponent, - energy_offset: Numeric | Parameter = 0.0, - convolution_settings: ConvolutionSettings | None = None, - temperature: Parameter | Numeric | None = None, - temperature_unit: str | sc.Unit = 'K', - detailed_balance_settings: DetailedBalanceSettings | None = None, - unit: str | sc.Unit = 'meV', - display_name: str | None = 'MyConvolution', - unique_name: str | None = None, - ) -> None: - """ - Initialize the NumericalConvolution object. - - Parameters - ---------- - energy : np.ndarray | sc.Variable - 1D array of energy values where the convolution is evaluated. - sample_components : ComponentCollection | ModelComponent - The sample model to be convolved. - resolution_components : ComponentCollection | ModelComponent - The resolution model to convolve with. - energy_offset : Numeric | Parameter, default=0.0 - An energy offset to apply to the energy values before convolution. - convolution_settings : ConvolutionSettings | None, default=None - The settings for the convolution. - temperature : Parameter | Numeric | None, default=None - The temperature to use for detailed balance correction. - temperature_unit : str | sc.Unit, default='K' - The unit of the temperature parameter. - detailed_balance_settings : DetailedBalanceSettings | None, default=None - The settings for detailed balance. If None, default settings will be used. - unit : str | sc.Unit, default='meV' - The unit of the energy. - display_name : str | None, default='MyConvolution' - Display name of the model. - unique_name : str | None, default=None - Unique name of the model. If None, a unique name will be generated. - """ - super().__init__( - energy=energy, - sample_components=sample_components, - resolution_components=resolution_components, - energy_offset=energy_offset, - convolution_settings=convolution_settings, - temperature=temperature, - temperature_unit=temperature_unit, - detailed_balance_settings=detailed_balance_settings, - unit=unit, - display_name=display_name, - unique_name=unique_name, - ) - def convolution( self, ) -> np.ndarray: @@ -96,6 +33,7 @@ def convolution( # settings before convolution. if not self.convolution_settings.convolution_plan_is_valid: self._energy_grid = self._create_energy_grid() + self.convolution_settings.convolution_plan_is_valid = True # Give warnings if peaks are very wide or very narrow if not self.convolution_settings.suppress_warnings: diff --git a/src/easydynamics/convolution/numerical_convolution_base.py b/src/easydynamics/convolution/numerical_convolution_base.py index 9a6e80cd..da17cae8 100644 --- a/src/easydynamics/convolution/numerical_convolution_base.py +++ b/src/easydynamics/convolution/numerical_convolution_base.py @@ -116,6 +116,7 @@ def __init__( # When upsample_factor>1, we evaluate on this grid and # interpolate back to the original values at the end self._energy_grid = self._create_energy_grid() + self._convolution_settings.convolution_plan_is_valid = True @property def convolution_settings(self) -> ConvolutionSettings: @@ -179,30 +180,13 @@ def upsample_factor(self) -> Numeric | None: @upsample_factor.setter def upsample_factor(self, factor: Numeric | None) -> None: """ - Set the upsample factor and recreate the dense grid. + Set the upsample factor. Parameters ---------- factor : Numeric | None The new upsample factor. - - Raises - ------ - TypeError - If factor is not a number or None. - ValueError - If factor is not greater than 1. """ - if factor is None: - self.convolution_settings.upsample_factor = factor - return - - if not isinstance(factor, Numeric): - raise TypeError('Upsample factor must be a numerical value or None.') - factor = float(factor) - if factor <= 1.0: - raise ValueError('Upsample factor must be greater than 1.') - self.convolution_settings.upsample_factor = factor @property @@ -224,7 +208,7 @@ def extension_factor(self) -> float: @extension_factor.setter def extension_factor(self, factor: Numeric) -> None: """ - Set the extension factor and recreate the dense grid. + Set the extension factor. The extension factor determines how much the energy range is extended on both sides before convolution. 0.2 means extending by 20% of the original energy span on each side. @@ -233,21 +217,8 @@ def extension_factor(self, factor: Numeric) -> None: ---------- factor : Numeric The new extension factor. - - Raises - ------ - TypeError - If factor is not a number. - ValueError - If factor is negative. """ - - if not isinstance(factor, Numeric): - raise TypeError('Extension factor must be a number.') - if factor < 0.0: - raise ValueError('Extension factor must be non-negative.') - - self.convolution_settings.extension_factor = float(factor) + self.convolution_settings.extension_factor = factor @property def temperature(self) -> Parameter | None: @@ -397,16 +368,13 @@ def _create_energy_grid( else: energy_dense_centered = energy_dense - energy_grid = EnergyGrid( + return EnergyGrid( energy_dense=energy_dense, energy_dense_centered=energy_dense_centered, energy_dense_step=energy_dense_step, energy_span_dense=energy_span_dense, energy_even_length_offset=energy_even_length_offset, ) - self._energy_grid = energy_grid - self.convolution_settings.convolution_plan_is_valid = True - return energy_grid def _check_width_thresholds( self, diff --git a/src/easydynamics/sample_model/component_collection.py b/src/easydynamics/sample_model/component_collection.py index f0bbd2d7..b556103a 100644 --- a/src/easydynamics/sample_model/component_collection.py +++ b/src/easydynamics/sample_model/component_collection.py @@ -106,6 +106,8 @@ def __init__( unique_name=unique_name, ) + self._warn_if_duplicate_names() + # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ @@ -196,6 +198,7 @@ def append_component(self, component: ModelComponent | ComponentCollection) -> N self.extend(component) else: self.append(component) + self._warn_if_duplicate_names() def list_component_names(self) -> list[str]: """ @@ -336,6 +339,27 @@ def free_all_parameters(self) -> None: for param in self.get_fittable_parameters(): param.fixed = False + # ------------------------------------------------------------------ + # Private methods + # ------------------------------------------------------------------ + + def _warn_if_duplicate_names(self) -> None: + """Warn if any two components share the same name.""" + names = [c.name for c in self] + seen: set[str] = set() + dups: set[str] = set() + for name in names: + if name in seen: + dups.add(name) + seen.add(name) + if dups: + warnings.warn( + f'Duplicate component names in ComponentCollection: {sorted(dups)}. ' + 'Components with the same name will produce duplicate parameter names.', + UserWarning, + stacklevel=3, + ) + # ------------------------------------------------------------------ # Dunder methods # ------------------------------------------------------------------ @@ -357,6 +381,14 @@ def __repr__(self) -> str: ) def to_dict(self) -> dict: + """ + Serialise the ComponentCollection to a dictionary. + + Returns + ------- + dict + Dictionary representation of the ComponentCollection. + """ return { '@module': self.__class__.__module__, '@class': self.__class__.__name__, @@ -368,6 +400,19 @@ def to_dict(self) -> dict: @classmethod def from_dict(cls, obj_dict: dict) -> ComponentCollection: + """ + Deserialise a ComponentCollection from its dictionary representation. + + Parameters + ---------- + obj_dict : dict + Dictionary representation of the ComponentCollection, as produced by to_dict(). + + Returns + ------- + ComponentCollection + The deserialised ComponentCollection. + """ def deserialise_component(d: dict) -> ModelComponent: """ diff --git a/src/easydynamics/sample_model/components/gaussian.py b/src/easydynamics/sample_model/components/gaussian.py index cab505b1..8b7f824a 100644 --- a/src/easydynamics/sample_model/components/gaussian.py +++ b/src/easydynamics/sample_model/components/gaussian.py @@ -49,7 +49,7 @@ def __init__( area : Numeric, default=1.0 Area of the Gaussian. center : Numeric | None, default=None - Center of the Gaussian. If None. + Center of the Gaussian. If None, defaults to 0 and is fixed. width : Numeric, default=1.0 Standard deviation. unit : str | sc.Unit, default='meV' diff --git a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py index 0d04d352..a8632272 100644 --- a/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py +++ b/src/easydynamics/sample_model/diffusion_model/jump_translational_diffusion.py @@ -297,8 +297,6 @@ def create_component_collections( """ Create ComponentCollection components for the diffusion model at given Q values. - TypeError If component_display_name is not a string. If component_name is not a string. - Returns ------- list[ComponentCollection] diff --git a/src/easydynamics/sample_model/instrument_model.py b/src/easydynamics/sample_model/instrument_model.py index fd0ae458..e906d14a 100644 --- a/src/easydynamics/sample_model/instrument_model.py +++ b/src/easydynamics/sample_model/instrument_model.py @@ -106,6 +106,8 @@ def __init__( unit=self.unit, fixed=False, ) + self._energy_offsets: list = [] + self._energy_offsets_is_dirty = True self._Q = _validate_and_convert_Q(Q) self._on_Q_change() @@ -350,6 +352,7 @@ def convert_unit(self, unit_str: str | sc.Unit) -> None: self._background_model.convert_unit(unit) self._resolution_model.convert_unit(unit) self._energy_offset.convert_unit(unit) + self._ensure_energy_offsets_current() for offset in self._energy_offsets: offset.convert_unit(unit) @@ -382,6 +385,7 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: if self._Q is None: return [] + self._ensure_energy_offsets_current() if Q_index is None: variables = [self._energy_offsets[i] for i in range(len(self._Q))] else: @@ -441,6 +445,7 @@ def get_energy_offset( if self._Q is None: raise ValueError('No Q values are set in the InstrumentModel.') + self._ensure_energy_offsets_current() if Q_index is None: return self._energy_offsets @@ -508,6 +513,7 @@ def _fix_or_free_energy_offset(self, Q_index: int | None = None, fixed: bool = T If Q_index is out of bounds for the Q values in the InstrumentModel. """ + self._ensure_energy_offsets_current() if Q_index is None: for offset in self._energy_offsets: offset.fixed = fixed @@ -521,6 +527,12 @@ def _fix_or_free_energy_offset(self, Q_index: int | None = None, fixed: bool = T ) self._energy_offsets[Q_index].fixed = fixed + def _ensure_energy_offsets_current(self) -> None: + """Rebuild energy offset Parameters if Q has changed since they were last built.""" + if self._energy_offsets_is_dirty: + self._generate_energy_offsets() + self._energy_offsets_is_dirty = False + def _generate_energy_offsets(self) -> None: """Generate energy offset Parameters for each Q value.""" if self._Q is None: @@ -531,12 +543,13 @@ def _generate_energy_offsets(self) -> None: def _on_Q_change(self) -> None: """Handle changes to the Q values.""" - self._generate_energy_offsets() + self._energy_offsets_is_dirty = True self.resolution_model.Q = self.Q self.background_model.Q = self.Q def _on_energy_offset_change(self) -> None: """Handle changes to the energy offset.""" + self._ensure_energy_offsets_current() for offset in self._energy_offsets: offset.value = self._energy_offset.value diff --git a/src/easydynamics/sample_model/model_base.py b/src/easydynamics/sample_model/model_base.py index aa76b384..c6c773d2 100644 --- a/src/easydynamics/sample_model/model_base.py +++ b/src/easydynamics/sample_model/model_base.py @@ -68,11 +68,11 @@ def __init__( ) self._components = ComponentCollection() + self._component_collections: list[ComponentCollection] = [] + self._component_collections_is_dirty = True if isinstance(components, (ModelComponent, ComponentCollection)): self.append_component(components) - self._generate_component_collections() - def evaluate( self, x: Numeric | list | np.ndarray | sc.Variable | sc.DataArray ) -> list[np.ndarray]: @@ -98,11 +98,9 @@ def evaluate( the list will match the number of Q values in the model. """ + self._ensure_component_collections_current() if not self._component_collections: - raise ValueError( - 'No components in the model to evaluate. ' - 'Run generate_component_collections() first' - ) + raise ValueError('No components in the model to evaluate.') return [collection.evaluate(x) for collection in self._component_collections] # ------------------------------------------------------------------ @@ -175,6 +173,18 @@ def components(self, value: ModelComponent | ComponentCollection | None) -> None if value is not None: self.append_component(value) + @property + def component_collections_is_dirty(self) -> bool: + """ + Return whether component collections need to be rebuilt before use. + + Returns + ------- + bool + ``True`` if component collections have not been built yet or are stale. + """ + return self._component_collections_is_dirty + @property def Q(self) -> np.ndarray | None: """ @@ -295,8 +305,8 @@ def free_all_parameters(self) -> None: def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: """ Get all Parameters and Descriptors from all ComponentCollections in the ModelBase. - Parameters Ignores the Parameters and Descriptors in self._components as these are just - templates. + + Ignores the Parameters and Descriptors in self._components as these are just templates. Parameters ---------- @@ -318,6 +328,7 @@ def get_all_variables(self, Q_index: int | None = None) -> list[Parameter]: ModelBase. """ + self._ensure_component_collections_current() if Q_index is None: all_vars = [ var @@ -354,8 +365,9 @@ def get_component_collection(self, Q_index: int) -> ComponentCollection: Returns ------- ComponentCollection - The ComponentCollection at the. + The ComponentCollection at the given Q index. """ + self._ensure_component_collections_current() if not isinstance(Q_index, int): raise TypeError(f'Q_index must be an int, got {type(Q_index).__name__}') if Q_index < 0 or Q_index >= len(self._component_collections): @@ -367,6 +379,7 @@ def get_component_collection(self, Q_index: int) -> ComponentCollection: def normalize_area(self) -> None: """Normalize the area of the model across all Q values.""" + self._ensure_component_collections_current() for collection in self._component_collections: collection.normalize_area() @@ -374,6 +387,14 @@ def normalize_area(self) -> None: # Private methods # ------------------------------------------------------------------ + def _ensure_component_collections_current(self) -> None: + """ + Rebuild component collections if any dependency has changed since they were last built. + """ + if self._component_collections_is_dirty: + self._generate_component_collections() + self._component_collections_is_dirty = False + def _generate_component_collections(self) -> None: """Generate ComponentCollections for each Q value.""" @@ -387,11 +408,11 @@ def _generate_component_collections(self) -> None: def _on_Q_change(self) -> None: """Handle changes to the Q values.""" - self._generate_component_collections() + self._component_collections_is_dirty = True def _on_components_change(self) -> None: """Handle changes to the components.""" - self._generate_component_collections() + self._component_collections_is_dirty = True # ------------------------------------------------------------------ # dunder methods diff --git a/src/easydynamics/sample_model/resolution_model.py b/src/easydynamics/sample_model/resolution_model.py index 6531d0b7..7ded2568 100644 --- a/src/easydynamics/sample_model/resolution_model.py +++ b/src/easydynamics/sample_model/resolution_model.py @@ -130,6 +130,7 @@ def from_sample_model( ) if sample_model.Q is not None: + resolution_model._ensure_component_collections_current() for index in range(len(sample_model.Q)): resolution_model._component_collections[index] = copy( sample_model.get_component_collection(Q_index=index) diff --git a/src/easydynamics/sample_model/sample_model.py b/src/easydynamics/sample_model/sample_model.py index c391ece7..a6203071 100644 --- a/src/easydynamics/sample_model/sample_model.py +++ b/src/easydynamics/sample_model/sample_model.py @@ -150,7 +150,7 @@ def append_diffusion_model(self, diffusion_model: DiffusionModelBase) -> None: ) diffusion_model.Q = self.Q self._diffusion_models.append(diffusion_model) - self._generate_component_collections() + self._component_collections_is_dirty = True def remove_diffusion_model(self, name: str) -> None: """ @@ -169,7 +169,7 @@ def remove_diffusion_model(self, name: str) -> None: for i, dm in enumerate(self.diffusion_models): if dm.name == name: del self.diffusion_models[i] - self._generate_component_collections() + self._component_collections_is_dirty = True return raise ValueError( f'No DiffusionModel with name {name} found. \n' @@ -179,7 +179,7 @@ def remove_diffusion_model(self, name: str) -> None: def clear_diffusion_models(self) -> None: """Clear all DiffusionModels from the SampleModel.""" self.diffusion_models = [] - self._generate_component_collections() + self._component_collections_is_dirty = True # ------------------------------------------------------------------ # Properties @@ -538,17 +538,14 @@ def _on_diffusion_models_change(self) -> None: """Handle changes to the diffusion models.""" for diffusion_model in self.diffusion_models: diffusion_model.Q = self.Q - - self._generate_component_collections() + self._component_collections_is_dirty = True def _on_Q_change(self) -> None: """Handle changes to the Q values.""" - for diffusion_model in self.diffusion_models: - # This may be too aggressive diffusion_model.clear_Q(confirm=True) diffusion_model.Q = self.Q - self._generate_component_collections() + self._component_collections_is_dirty = True # ------------------------------------------------------------------ # dunder methods diff --git a/tests/unit/easydynamics/analysis/test_analysis.py b/tests/unit/easydynamics/analysis/test_analysis.py index 7e76bef2..516b718e 100644 --- a/tests/unit/easydynamics/analysis/test_analysis.py +++ b/tests/unit/easydynamics/analysis/test_analysis.py @@ -532,6 +532,15 @@ def test_parameters_to_dataset_different_units(self, analysis): assert parameter_name in parameters_dataset assert 'Q' in parameters_dataset[parameter_name].dims + def test_parameters_to_dataset_raises_on_duplicate_names(self, analysis): + # Add a second Gaussian with the same parameter names as the first + analysis.sample_model.append_component( + Gaussian(name='GaussianName', display_name='Gaussian2', area=0.5) + ) + + with pytest.raises(ValueError, match='Duplicate parameter names'): + analysis.parameters_to_dataset() + @pytest.mark.parametrize( 'parameter_names', [ @@ -750,7 +759,7 @@ def test_fit_single_Q_invalid_Q_index(self, analysis): IndexError, match='must be a valid index', ): - analysis._fit_single_Q(Q_index=3) + analysis.fit(Q_index=3) def test_fit_all_Q_independently(self, analysis): # WHEN @@ -913,3 +922,120 @@ def test_create_components_dataset_single_Q(self, analysis_single_Q): assert components_dataset.coords['Q'].dims == ('Q',) assert components_dataset.sizes['Q'] == 1 assert components_dataset.coords['Q'].ndim == 1 + + def test_ensure_analysis_list_current_clears_dirty_when_Q_is_none(self): + # An Analysis with no experiment has Q=None; _ensure_analysis_list_current should still + # clear the dirty flag without attempting to build the list. + + # WHEN + analysis = Analysis(display_name='NoQ') + assert analysis._analysis_list_is_dirty is True + assert analysis.Q is None + + # THEN + result = analysis.analysis_list + + # EXPECT - dirty flag cleared, list stays empty + assert analysis._analysis_list_is_dirty is False + assert result == [] + + def test_rebin_marks_analysis_list_dirty(self, analysis): + # WHEN - force build of analysis_list so it is no longer dirty + _ = analysis.analysis_list + assert analysis._analysis_list_is_dirty is False + + # THEN - energy rebin leaves Q unchanged, so no confirm required + with patch.object(analysis.experiment, 'rebin'): + analysis.rebin({'energy': 2}) + + # EXPECT + assert analysis._analysis_list_is_dirty is True + + def test_rebin_rebuilds_analysis_list(self, analysis): + # WHEN - 3 Q values → 3 Analysis1d objects; rebin to 1 Q value (confirm required) + assert len(analysis.analysis_list) == 3 + analysis.rebin({'Q': 1}, confirm=True) + + # THEN + result = analysis.analysis_list + + # EXPECT - list rebuilt with 1 Analysis1d for the single remaining Q + assert len(result) == 1 + assert result[0].Q_index == 0 + + def test_rebin_raises_without_confirm_when_Q_count_changes(self, analysis): + # WHEN - rebin Q from 3 to 1 (count changes) without confirm + # THEN / EXPECT + with pytest.raises(ValueError, match='confirm=True'): + analysis.rebin({'Q': 1}) + + def test_rebin_raises_without_confirm_when_Q_values_change_same_count(self, analysis): + # WHEN - simulate a rebin that keeps count but shifts Q values + # (e.g. Q=[1,2,3] → Q=[2,3,4] via non-uniform binning) + old_Q = analysis.Q + new_Q = sc.array(dims=['Q'], values=[2.0, 3.0, 4.0], unit='1/Angstrom') + + def fake_rebin(_dims: dict) -> None: + analysis.experiment._binned_data = analysis.experiment._binned_data.assign_coords( + Q=new_Q + ) + + # THEN / EXPECT - raises without confirm and rolls back + with ( + patch.object(analysis.experiment, 'rebin', side_effect=fake_rebin), + pytest.raises(ValueError, match='confirm=True'), + ): + analysis.rebin({'Q': 3}) + + # EXPECT - experiment Q was rolled back to original + assert sc.allclose(analysis.Q, old_Q) + + def test_rebin_rolls_back_experiment_on_failed_confirm(self, analysis): + # WHEN - rebin Q without confirm (would change count) + old_Q = analysis.Q + + with pytest.raises(ValueError, match='confirm=True'): + analysis.rebin({'Q': 1}) + + # EXPECT - experiment Q was rolled back; analysis is unchanged + assert sc.allclose(analysis.Q, old_Q) + assert len(analysis.analysis_list) == 3 + + def test_rebin_without_Q_change_does_not_require_confirm(self, analysis): + # WHEN - energy rebin leaves Q unchanged, so no confirm required + # THEN / EXPECT - no error raised + with patch.object(analysis.experiment, 'rebin'): + analysis.rebin({'energy': 2}) + + def test_rebin_clears_Q_from_models_when_Q_count_changes(self, analysis): + # WHEN + assert analysis.sample_model.Q is not None + assert analysis.instrument_model.Q is not None + + # THEN + analysis.rebin({'Q': 1}, confirm=True) + + # EXPECT - Q has been propagated back to models with the new single-Q dimension + # (cleared then repopulated when analysis_list is rebuilt) + # At this point (before accessing analysis_list), models have Q=None + assert analysis.sample_model.Q is None + assert analysis.instrument_model.Q is None + + # After rebuild, models get the new Q + _ = analysis.analysis_list + assert len(analysis.sample_model.Q) == 1 + assert len(analysis.instrument_model.Q) == 1 + + def test_direct_experiment_rebin_does_not_update_analysis_list(self, analysis): + # This test documents a known limitation: calling experiment.rebin() directly bypasses + # Analysis and leaves the analysis list stale. Always use Analysis.rebin() instead. + + # WHEN - force build so the list is clean + _ = analysis.analysis_list + assert analysis._analysis_list_is_dirty is False + + # THEN - rebinning via experiment directly (bypasses Analysis) + analysis.experiment.rebin({'Q': 1}) + + # EXPECT - analysis_list is NOT marked dirty (callers must use Analysis.rebin()) + assert analysis._analysis_list_is_dirty is False diff --git a/tests/unit/easydynamics/analysis/test_analysis1d.py b/tests/unit/easydynamics/analysis/test_analysis1d.py index 9df8815e..5bd57163 100644 --- a/tests/unit/easydynamics/analysis/test_analysis1d.py +++ b/tests/unit/easydynamics/analysis/test_analysis1d.py @@ -121,15 +121,15 @@ def test__calculate_adds_sample_and_background(self, analysis1d): sample = np.array([1.0, 2.0, 3.0]) background = np.array([0.5, 0.5, 0.5]) - analysis1d._evaluate_sample = MagicMock(return_value=sample) - analysis1d._evaluate_background = MagicMock(return_value=background) + analysis1d._evaluate_with_convolution = MagicMock(return_value=sample) + analysis1d._evaluate_direct = MagicMock(return_value=background) result = analysis1d._calculate() np.testing.assert_array_equal(result, sample + background) - analysis1d._evaluate_sample.assert_called_once() - analysis1d._evaluate_background.assert_called_once() + analysis1d._evaluate_with_convolution.assert_called_once() + analysis1d._evaluate_direct.assert_called_once() def test_fit_raises_if_no_experiment(self, analysis1d): # WHEN THEN @@ -455,13 +455,13 @@ def test_require_Q_index_raises_if_no_Q_index(self, analysis1d): def test_on_Q_index_changed(self, analysis1d): # WHEN - analysis1d._create_convolver = MagicMock() + analysis1d._convolver_is_dirty = False # THEN analysis1d._on_Q_index_changed() # EXPECT - analysis1d._create_convolver.assert_called_once() + assert analysis1d._convolver_is_dirty is True def test_verify_energy(self, analysis1d): # WHEN @@ -533,61 +533,59 @@ def test_calculate_energy_with_offset_raises_if_incompatible_units(self, analysi # Private methods: evaluation ############# - def test_evaluate_components_no_components(self, analysis1d): + def test_evaluate_direct_no_components(self, analysis1d): # WHEN components = ComponentCollection() # THEN - result = analysis1d._evaluate_components(components=components) + result = analysis1d._evaluate_direct(components=components, energy=None) # EXPECT assert isinstance(result, np.ndarray) assert result.shape == (len(analysis1d.experiment.energy),) assert np.all(result == pytest.approx(0.0)) - def test_evaluate_components_no_convolution(self, analysis1d): + def test_evaluate_direct(self, analysis1d): # WHEN components = Polynomial(coefficients=[1.0]) + # THEN - result = analysis1d._evaluate_components( - components=components, convolver=None, convolve=False - ) + result = analysis1d._evaluate_direct(components=components, energy=None) + # EXPECT assert np.array_equal(result, np.array([1.0, 1.0, 1.0])) - def test_evaluate_components_convolution(self, analysis1d): + def test_evaluate_with_convolution_uses_convolver(self, analysis1d): # WHEN components = Gaussian() convolver = MagicMock() convolver.convolution = MagicMock(return_value=np.array([1, 2, 3])) # THEN - result = analysis1d._evaluate_components( - components=components, convolver=convolver, convolve=True + result = analysis1d._evaluate_with_convolution( + components=components, energy=None, convolver=convolver ) # EXPECT convolver.convolution.assert_called_once() assert result is convolver.convolution.return_value - def test_evaluate_components_empty_resolution(self, analysis1d): + def test_evaluate_with_convolution_no_resolution(self, analysis1d): # WHEN components = MagicMock() components.evaluate = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) # The default analysis1d has no resolution model components, so - # no convolution should be applied even if convolve=True + # evaluate is called directly even though we want convolution. # THEN - result = analysis1d._evaluate_components( - components=components, convolver=None, convolve=True - ) + result = analysis1d._evaluate_with_convolution(components=components, energy=None) # EXPECT components.evaluate.assert_called_once() assert np.array_equal(result, np.array([1.0, 2.0, 3.0])) - def test_evaluate_components_empty_resolution_DBF(self, analysis1d): + def test_evaluate_with_convolution_no_resolution_DBF(self, analysis1d): # WHEN components = MagicMock() components.evaluate = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) @@ -596,19 +594,16 @@ def test_evaluate_components_empty_resolution_DBF(self, analysis1d): analysis1d.sample_model.temperature = 10 mock_dbf = np.array([10.0, 10.0, 10.0]) - # The default analysis1d has no resolution model components, so - # no convolution should be applied even if convolve=True + # The default analysis1d has no resolution model components with patch( 'easydynamics.analysis.analysis1d.detailed_balance_factor', return_value=mock_dbf, ) as dbf_mock: - # WHEN - result = analysis1d._evaluate_components( + # THEN + result = analysis1d._evaluate_with_convolution( components=components, - convolver=None, - convolve=True, - apply_detailed_balance=True, + energy=None, ) # EXPECT @@ -620,18 +615,13 @@ def test_evaluate_components_empty_resolution_DBF(self, analysis1d): assert np.array_equal(result, expected) def test_evaluate_with_resolution(self, analysis1d): - # WHEN (set up the resolution model and create a component to - # evaluate) + # WHEN (set up the resolution model and create a component to evaluate) analysis1d.instrument_model.resolution_model.components = Gaussian() components = Gaussian() with patch('easydynamics.analysis.analysis1d.Convolution') as MockConvolution: # THEN - analysis1d._evaluate_components( - components=components, - convolver=None, - convolve=True, - ) + analysis1d._evaluate_with_convolution(components=components, energy=None) # EXPECT # Ensure constructor called once @@ -661,98 +651,6 @@ def test_evaluate_with_resolution(self, analysis1d): # and check that convolution() was called MockConvolution.return_value.convolution.assert_called_once_with() - def test_evaluate_sample(self, analysis1d): - # WHEN - analysis1d.sample_model.get_component_collection = MagicMock() - analysis1d._evaluate_components = MagicMock() - - # THEN - analysis1d._evaluate_sample() - - # EXPECT - - # The correct component collection is requested with the correct - # Q_index - analysis1d.sample_model.get_component_collection.assert_called_once_with( - Q_index=analysis1d.Q_index - ) - - # The components are evaluated with the correct convolver and - # convolve=True - analysis1d._evaluate_components.assert_called_once_with( - components=analysis1d.sample_model.get_component_collection(), - convolver=analysis1d._convolver, - convolve=True, - energy=None, - apply_detailed_balance=True, - ) - - def test_evaluate_sample_component(self, analysis1d): - # WHEN - analysis1d._evaluate_components = MagicMock() - component = object() - - # THEN - analysis1d._evaluate_sample_component(component=component) - - # EXPECT - - # The components are evaluated with the correct convolver and - # convolve=True - analysis1d._evaluate_components.assert_called_once_with( - components=component, - convolver=None, - convolve=True, - energy=None, - apply_detailed_balance=True, - ) - - def test_evaluate_background(self, analysis1d): - # WHEN - analysis1d.instrument_model.background_model.get_component_collection = MagicMock() - analysis1d._evaluate_components = MagicMock() - - # THEN - analysis1d._evaluate_background() - - # EXPECT - - # The correct component collection is requested with the correct - # Q_index - analysis1d.instrument_model.background_model.get_component_collection.assert_called_once_with( - Q_index=analysis1d.Q_index - ) - - # The components are evaluated with the correct convolver and - # convolve=True - analysis1d._evaluate_components.assert_called_once_with( - components=analysis1d.instrument_model.background_model.get_component_collection(), - convolver=None, - convolve=False, - energy=None, - apply_detailed_balance=False, - ) - - def test_evaluate_background_component(self, analysis1d): - # WHEN - analysis1d._evaluate_components = MagicMock() - component = object() - - # THEN - analysis1d._evaluate_background_component(component=component) - - # EXPECT - - # The components are evaluated with the correct convolver and - # convolve=True - analysis1d._evaluate_components.assert_called_once_with( - components=component, - convolver=None, - convolve=False, - energy=None, - apply_detailed_balance=False, - ) - def test_create_convolver(self, analysis1d): # WHEN # Mock sample components @@ -843,88 +741,6 @@ def test_create_residuals_array_no_Q_index_raises(self, analysis1d): with pytest.raises(ValueError, match='Q_index must be set'): analysis1d._create_residuals_array() - @pytest.mark.parametrize( - 'background', - [ - None, - np.array([0.5, 0.5, 0.5]), - ], - ids=[ - 'No background', - 'With background', - ], - ) - def test_create_component_scipp_array(self, analysis1d, background): - """ - Test that _create_component_scipp_array correctly evaluates - the component, adds the background and calls _to_scipp_array - with the correct values. - """ - # WHEN - - # Mock the functions that will be called. - analysis1d._evaluate_sample_component = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) - - analysis1d._to_scipp_array = MagicMock() - - component = object() - - # THEN - analysis1d._create_component_scipp_array(component=component, background=background) - - # EXPECT - analysis1d._evaluate_sample_component.assert_called_once_with( - component=component, energy=None - ) - - expected_values = np.array([1.0, 2.0, 3.0]) - if background is not None: - expected_values += background - - analysis1d._to_scipp_array.assert_called_once() - - # Extract the actual call - _, kwargs = analysis1d._to_scipp_array.call_args - - np.testing.assert_array_equal( - kwargs['values'], - expected_values, - ) - - def test_create_background_component_scipp_array(self, analysis1d): - """Test that _create_background_component_scipp_array correctly - evaluates the component, adds the background and calls - _to_scipp_array with the correct values.""" - - # WHEN - - # Mock the functions that will be called. - analysis1d._evaluate_background_component = MagicMock( - return_value=np.array([1.0, 2.0, 3.0]) - ) - analysis1d._to_scipp_array = MagicMock() - - component = object() - - # THEN - analysis1d._create_background_component_scipp_array(component=component) - - # EXPECT - analysis1d._evaluate_background_component.assert_called_once_with( - component=component, - energy=None, - ) - - analysis1d._to_scipp_array.assert_called_once() - - # Extract the actual call - _, kwargs = analysis1d._to_scipp_array.call_args - - np.testing.assert_array_equal( - kwargs['values'], - np.array([1.0, 2.0, 3.0]), - ) - def test_create_model_array(self, analysis1d): """Test that _create_model_array correctly evaluates the full model and calls _to_scipp_array with the @@ -996,65 +812,44 @@ def test_create_components_dataset_single_Q( return_value=background_collection ) - # ---- Background evaluation ---- + # ---- Evaluation mocks ---- background_value = np.array([11.0, 21.0, 31.0]) - analysis1d._evaluate_background = MagicMock(return_value=background_value) + sample_value = np.array([1.0, 2.0, 3.0]) - # ---- Return scipp DataArrays ---- - fake_sample_da = sc.DataArray(data=sc.array(dims=['energy'], values=[1.0, 2.0, 3.0])) - - analysis1d._create_component_scipp_array = MagicMock(return_value=fake_sample_da) - - fake_background_da = sc.DataArray(data=sc.array(dims=['energy'], values=[4.0, 5.0, 6.0])) + analysis1d._evaluate_direct = MagicMock(return_value=background_value) + analysis1d._evaluate_with_convolution = MagicMock(return_value=sample_value) - analysis1d._create_background_component_scipp_array = MagicMock( - return_value=fake_background_da - ) + # ---- Return scipp DataArrays ---- + fake_da = sc.DataArray(data=sc.array(dims=['energy'], values=[1.0, 2.0, 3.0])) + analysis1d._to_scipp_array = MagicMock(return_value=fake_da) # THEN dataset = analysis1d._create_components_dataset_single_Q(add_background=add_background) # EXPECT - # The correct component collections are requested with the - # correct Q_index + # The correct component collections are requested with the correct Q_index analysis1d.sample_model.get_component_collection.assert_called_once_with( Q_index=analysis1d.Q_index ) - analysis1d.instrument_model.background_model.get_component_collection.assert_called_once_with( Q_index=analysis1d.Q_index ) - # Background is evaluated if add_background=True, and not - # evaluated if False + # _evaluate_direct is called once for the background collection (for add_background), + # and once for the background component — if add_background=False only the component call. if add_background: - analysis1d._evaluate_background.assert_called_once() - expected_background = background_value - else: - analysis1d._evaluate_background.assert_not_called() - expected_background = None - - # The sample component scipp array is created with the correct - # component and background - analysis1d._create_component_scipp_array.assert_called_once() - _, kwargs = analysis1d._create_component_scipp_array.call_args - - assert kwargs['component'] is sample_component - - if expected_background is None: - assert kwargs['background'] is None + # First call: background_collection to get total background value + # Second call: background_component for its individual array + assert analysis1d._evaluate_direct.call_count == 2 else: - np.testing.assert_array_equal( - kwargs['background'], - expected_background, - ) + # Only called for the background component + assert analysis1d._evaluate_direct.call_count == 1 - # Background component creation - analysis1d._create_background_component_scipp_array.assert_called_once() - _, kwargs = analysis1d._create_background_component_scipp_array.call_args - assert kwargs['component'] is background_component - assert sc.identical(kwargs['energy'], analysis1d.energy) + # _evaluate_with_convolution called once for the sample component + analysis1d._evaluate_with_convolution.assert_called_once_with( + sample_component, analysis1d._masked_energy + ) # Dataset content assert isinstance(dataset, sc.Dataset) @@ -1080,3 +875,170 @@ def test_to_scipp_array(self, analysis1d): scipp_array.coords['Q'].values, analysis1d.experiment.Q[analysis1d.Q_index].values, ) + + def test_rebin_marks_convolver_dirty(self, analysis1d): + # WHEN + analysis1d._convolver_is_dirty = False + + # THEN + analysis1d.rebin({'Q': 1}) + + # EXPECT + assert analysis1d._convolver_is_dirty is True + + def test_rebin_refreshes_masked_energy(self, analysis1d): + # WHEN - capture the current masked_energy object + energy_before = analysis1d._masked_energy + + # THEN - patch get_masked_energy to verify it is called during rebin + with patch.object( + analysis1d.experiment, + 'get_masked_energy', + wraps=analysis1d.experiment.get_masked_energy, + ) as mock_get_energy: + analysis1d.rebin({'Q': 1}) + + # EXPECT - get_masked_energy was called to refresh _masked_energy + mock_get_energy.assert_called_once_with(Q_index=analysis1d.Q_index) + # And _masked_energy is a different object now (re-fetched from experiment) + assert analysis1d._masked_energy is not energy_before + + def test_fit_marks_convolver_dirty_when_sample_model_components_change(self, analysis1d): + """Issue #68: fit() should detect in-place component changes and rebuild the convolver.""" + # WHEN - simulate state after a previous fit (convolver built, not dirty) + analysis1d._create_convolver = MagicMock(return_value=None) + analysis1d._convolver_is_dirty = False + analysis1d.sample_model._component_collections_is_dirty = False + + # THEN - append a component in-place (doesn't go through Analysis1d setters) + analysis1d.sample_model.append_component(Gaussian(name='NewGaussian')) + assert analysis1d.sample_model._component_collections_is_dirty is True + assert analysis1d._convolver_is_dirty is False # not yet propagated + + # WHEN - fit() should propagate the dirty flag and rebuild the convolver + with patch( + 'easydynamics.analysis.analysis1d.EasyScienceFitter', + return_value=MagicMock(fit=MagicMock(return_value=MagicMock())), + ): + analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock( + return_value=( + np.array([1.0, 2.0, 3.0]), + np.array([1.0, 2.0, 3.0]), + np.array([1.0, 1.0, 1.0]), + np.array([True, True, True]), + ) + ) + analysis1d.fit() + + # EXPECT - convolver was rebuilt (_ensure_convolver_current called _create_convolver) + analysis1d._create_convolver.assert_called_once() + + def test_fit_does_not_rebuild_convolver_when_nothing_changed(self, analysis1d): + """fit() should not call _create_convolver if nothing has changed since last fit.""" + # WHEN - build convolver and clear all dirty flags + analysis1d._create_convolver = MagicMock(return_value=None) + analysis1d._convolver_is_dirty = False + analysis1d.sample_model._component_collections_is_dirty = False + analysis1d.instrument_model.resolution_model._component_collections_is_dirty = False + + # THEN - call fit() with nothing changed + with patch( + 'easydynamics.analysis.analysis1d.EasyScienceFitter', + return_value=MagicMock(fit=MagicMock(return_value=MagicMock())), + ): + analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock( + return_value=( + np.array([1.0, 2.0, 3.0]), + np.array([1.0, 2.0, 3.0]), + np.array([1.0, 1.0, 1.0]), + np.array([True, True, True]), + ) + ) + analysis1d.fit() + + # EXPECT - _create_convolver was NOT called (convolver reused) + analysis1d._create_convolver.assert_not_called() + + def test_rebin_rebins_experiment(self, analysis1d): + """rebin() should delegate to experiment.rebin().""" + # WHEN + with patch.object( + analysis1d.experiment, 'rebin', wraps=analysis1d.experiment.rebin + ) as mock_rebin: + analysis1d.rebin({'Q': 1}) + + # EXPECT - experiment.rebin was called with the correct dimensions + mock_rebin.assert_called_once_with({'Q': 1}) + + def test_rebin_without_Q_index_does_not_crash(self): + """rebin() with no Q_index set should not try to refresh masked_energy.""" + # WHEN + experiment = Experiment() + analysis1d = Analysis1d(experiment=experiment) + assert analysis1d.Q_index is None + + # THEN / EXPECT - no error, and _masked_energy stays None + # (no data loaded, so rebin would raise from experiment, just test the branch) + assert analysis1d._masked_energy is None + + def test_on_sample_model_changed_marks_convolver_dirty(self, analysis1d): + # WHEN - clear the dirty flag first + analysis1d._convolver_is_dirty = False + + # THEN - replace the sample model via the public setter + analysis1d.sample_model = SampleModel(components=Gaussian(name='NewGaussian')) + + # EXPECT + assert analysis1d._convolver_is_dirty is True + + def test_on_instrument_model_changed_marks_convolver_dirty(self, analysis1d): + # WHEN - clear the dirty flag first + analysis1d._convolver_is_dirty = False + + # THEN - replace the instrument model via the public setter + analysis1d.instrument_model = InstrumentModel() + + # EXPECT + assert analysis1d._convolver_is_dirty is True + + def test_evaluate_with_convolution_returns_zeros_for_empty_collection(self, analysis1d): + # WHEN + empty_collection = ComponentCollection() + energy = analysis1d._masked_energy + + # THEN + result = analysis1d._evaluate_with_convolution(empty_collection, energy) + + # EXPECT + assert result.shape == energy.values.shape + np.testing.assert_array_equal(result, 0.0) + + def test_fit_marks_convolver_dirty_when_resolution_model_components_change(self, analysis1d): + """Issue #68: fit() should detect resolution_model component changes.""" + # WHEN - simulate state after a previous fit + analysis1d._create_convolver = MagicMock(return_value=None) + analysis1d._convolver_is_dirty = False + analysis1d.sample_model._component_collections_is_dirty = False + analysis1d.instrument_model.resolution_model._component_collections_is_dirty = False + + # THEN - mark resolution_model dirty in-place (doesn't go through Analysis1d setters) + analysis1d.instrument_model.resolution_model._component_collections_is_dirty = True + assert analysis1d._convolver_is_dirty is False + + # WHEN - fit() should propagate and rebuild + with patch( + 'easydynamics.analysis.analysis1d.EasyScienceFitter', + return_value=MagicMock(fit=MagicMock(return_value=MagicMock())), + ): + analysis1d.experiment._extract_x_y_weights_only_finite = MagicMock( + return_value=( + np.array([1.0, 2.0, 3.0]), + np.array([1.0, 2.0, 3.0]), + np.array([1.0, 1.0, 1.0]), + np.array([True, True, True]), + ) + ) + analysis1d.fit() + + # EXPECT + analysis1d._create_convolver.assert_called_once() diff --git a/tests/unit/easydynamics/analysis/test_analysis_base.py b/tests/unit/easydynamics/analysis/test_analysis_base.py index 120e0af0..e9f0426a 100644 --- a/tests/unit/easydynamics/analysis/test_analysis_base.py +++ b/tests/unit/easydynamics/analysis/test_analysis_base.py @@ -140,12 +140,12 @@ def test_init_calls_on_experiment_changed(self): ( {'extra_parameters': 123}, TypeError, - 'extra_parameters must be a Parameter or a list of Parameters.', + 'extra_parameters must be a Parameter, a list of Parameters, or None.', ), ( {'extra_parameters': [123]}, TypeError, - 'extra_parameters must be a Parameter or a list of Parameters.', + 'extra_parameters must be a Parameter, a list of Parameters, or None.', ), ], ids=[ diff --git a/tests/unit/easydynamics/convolution/test_convolution.py b/tests/unit/easydynamics/convolution/test_convolution.py index 59191f7c..dbb4290d 100644 --- a/tests/unit/easydynamics/convolution/test_convolution.py +++ b/tests/unit/easydynamics/convolution/test_convolution.py @@ -391,40 +391,6 @@ def test_check_if_pair_is_analytic_raises_with_delta_in_resolution(self, default resolution_component=resolution_component, ) - @pytest.mark.parametrize( - 'sample_component,resolution_component', - [ - ( - 'NotAModelComponent', - Gaussian(name='G', area=1.0, center=0.0, width=0.1), - ), - ( - Gaussian(name='G', area=1.0, center=0.0, width=0.1), - 'NotAModelComponent', - ), - ], - ids=['invalid_sample_component', 'invalid_resolution_component'], - ) - def test_check_if_pair_is_analytic_raises_with_invalid_types( - self, default_convolution, sample_component, resolution_component - ): - """ - Test that _check_if_pair_is_analytic raises TypeError when given - invalid component types. - """ - # WHEN - conv = default_convolution - - # THEN EXPECT - with pytest.raises( - TypeError, - match='must be a ModelComponent', - ): - conv._check_if_pair_is_analytic( - sample_component=sample_component, - resolution_component=resolution_component, - ) - @pytest.mark.parametrize( 'analytical_component', [True, False], diff --git a/tests/unit/easydynamics/convolution/test_convolution_base.py b/tests/unit/easydynamics/convolution/test_convolution_base.py index a300890b..fe8ceef6 100644 --- a/tests/unit/easydynamics/convolution/test_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_convolution_base.py @@ -252,10 +252,7 @@ def test_energy_offset_setter_invalid_type_raises(self, convolution_base): def test_energy_with_offset_setter_raises(self, convolution_base): # WHEN THEN EXPECT - with pytest.raises( - AttributeError, - match='is a read-only property', - ): + with pytest.raises(AttributeError): convolution_base.energy_with_offset = 5 def test_sample_components_property(self, convolution_base): diff --git a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py index a084b898..a40e4388 100644 --- a/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py +++ b/tests/unit/easydynamics/convolution/test_numerical_convolution_base.py @@ -169,7 +169,9 @@ def test_energy_setter(self, default_numerical_convolution_base): # THEN # Force regeneration of energy grid - default_numerical_convolution_base._create_energy_grid() + default_numerical_convolution_base._energy_grid = ( + default_numerical_convolution_base._create_energy_grid() + ) # EXPECT assert default_numerical_convolution_base._energy_grid.energy_dense.shape[0] == round( @@ -204,7 +206,9 @@ def test_upsample_factor_setter( ) # Force regeneration of energy grid - default_numerical_convolution_base._create_energy_grid() + default_numerical_convolution_base._energy_grid = ( + default_numerical_convolution_base._create_energy_grid() + ) # EXPECT: correct factor + grid size assert default_numerical_convolution_base.upsample_factor == new_upsample_factor @@ -257,7 +261,9 @@ def test_extension_factor_setter(self, default_numerical_convolution_base): # THEN # Force regeneration of energy grid - default_numerical_convolution_base._create_energy_grid() + default_numerical_convolution_base._energy_grid = ( + default_numerical_convolution_base._create_energy_grid() + ) # EXPECT assert default_numerical_convolution_base.extension_factor == new_extension_factor diff --git a/tests/unit/easydynamics/sample_model/test_component_collection.py b/tests/unit/easydynamics/sample_model/test_component_collection.py index 14835ec3..7ae15ed6 100644 --- a/tests/unit/easydynamics/sample_model/test_component_collection.py +++ b/tests/unit/easydynamics/sample_model/test_component_collection.py @@ -465,3 +465,25 @@ def test_copy(self, component_collection): assert param_copy.min == param_orig.min assert param_copy.max == param_orig.max assert param_copy.fixed == param_orig.fixed + + def test_warns_on_duplicate_names_at_init(self): + g1 = Gaussian(name='SameName', display_name='Display1', area=1.0) + g2 = Gaussian(name='SameName', display_name='Display2', area=2.0) + + with pytest.warns(UserWarning, match='Duplicate component names'): + ComponentCollection(components=[g1, g2]) + + def test_warns_on_duplicate_names_on_append(self): + g1 = Gaussian(name='SameName', display_name='Display1', area=1.0) + g2 = Gaussian(name='SameName', display_name='Display2', area=2.0) + collection = ComponentCollection(components=[g1]) + + with pytest.warns(UserWarning, match='Duplicate component names'): + collection.append_component(g2) + + def test_no_warning_with_unique_names(self, recwarn): + g1 = Gaussian(name='Name1', display_name='Display1', area=1.0) + g2 = Gaussian(name='Name2', display_name='Display2', area=2.0) + ComponentCollection(components=[g1, g2]) + user_warnings = [w for w in recwarn.list if issubclass(w.category, UserWarning)] + assert not user_warnings diff --git a/tests/unit/easydynamics/sample_model/test_instrument_model.py b/tests/unit/easydynamics/sample_model/test_instrument_model.py index a61acff1..cb716fc4 100644 --- a/tests/unit/easydynamics/sample_model/test_instrument_model.py +++ b/tests/unit/easydynamics/sample_model/test_instrument_model.py @@ -277,6 +277,8 @@ def test_convert_unit_calls_all_children(self, instrument_model): new_unit = 'eV' # THEN + # Ensure energy offsets are built before mocking + instrument_model._ensure_energy_offsets_current() # Mock downstream convert_unit calls instrument_model._background_model.convert_unit = MagicMock() instrument_model._resolution_model.convert_unit = MagicMock() @@ -432,14 +434,13 @@ def test_generate_energy_offsets(self, instrument_model): def test_Q_setter(self, instrument_model_without_Q): # WHEN - instrument_model_without_Q._generate_energy_offsets = MagicMock() first_new_Q = np.array([1.0, 2.0, 3.0]) # THEN instrument_model_without_Q.Q = first_new_Q # EXPECT - instrument_model_without_Q._generate_energy_offsets.assert_called_once() + assert instrument_model_without_Q._energy_offsets_is_dirty is True np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q) np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q) @@ -455,17 +456,16 @@ def test_Q_setter(self, instrument_model_without_Q): instrument_model_without_Q.Q = new_Q # EXPECT - # No new calls to _generate_energy_offsets, and Q values remain unchanged - instrument_model_without_Q._generate_energy_offsets.assert_called_once() + # Q values remain unchanged np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q) np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q) - # THEN + # THEN - set Q to an equivalent scipp Variable; values match so should be accepted new_Q = sc.Variable(dims=['Q'], values=[1.0, 2.0, 3.0], unit='1/angstrom') + instrument_model_without_Q.Q = new_Q - # EXPECT - # No new calls to _generate_energy_offsets, and Q values remain unchanged - instrument_model_without_Q._generate_energy_offsets.assert_called_once() + # EXPECT - Q propagated to child models, offsets marked dirty again + assert instrument_model_without_Q._energy_offsets_is_dirty is True np.testing.assert_array_equal(instrument_model_without_Q.background_model.Q, first_new_Q) np.testing.assert_array_equal(instrument_model_without_Q.resolution_model.Q, first_new_Q) diff --git a/tests/unit/easydynamics/sample_model/test_model_base.py b/tests/unit/easydynamics/sample_model/test_model_base.py index 4bd60457..dfcbd1ad 100644 --- a/tests/unit/easydynamics/sample_model/test_model_base.py +++ b/tests/unit/easydynamics/sample_model/test_model_base.py @@ -72,6 +72,7 @@ def test_evaluate_calls_all_component_collections(self, model_base): collection2.evaluate.return_value = np.array([4.0, 5.0, 6.0]) model_base._component_collections = [collection1, collection2] + model_base._component_collections_is_dirty = False # THEN result = model_base.evaluate(x) @@ -88,6 +89,7 @@ def test_evaluate_no_component_collections_raises(self, model_base): x = np.array([0.0, 1.0, 2.0]) model_base._component_collections = [] + model_base._component_collections_is_dirty = False # THEN / EXPECT with pytest.raises(ValueError, match='No components'): diff --git a/tests/unit/easydynamics/sample_model/test_sample_model.py b/tests/unit/easydynamics/sample_model/test_sample_model.py index 75514fc6..bbbb30b9 100644 --- a/tests/unit/easydynamics/sample_model/test_sample_model.py +++ b/tests/unit/easydynamics/sample_model/test_sample_model.py @@ -389,6 +389,7 @@ def test_evaluate_calls_dbf(self, sample_model): collection2.evaluate.return_value = np.array([4.0, 5.0, 6.0]) sample_model._component_collections = [collection1, collection2] + sample_model._component_collections_is_dirty = False with patch('easydynamics.sample_model.sample_model.detailed_balance_factor') as mock_dbf: mock_dbf.return_value = np.array([10.0, 10.0, 10.0]) # simplified DBF @@ -436,6 +437,7 @@ def test_evaluate_doesnt_call_dbf_when_disabled( collection2.evaluate.return_value = np.array([4.0, 5.0, 6.0]) sample_model._component_collections = [collection1, collection2] + sample_model._component_collections_is_dirty = False sample_model.temperature = temperature sample_model.use_detailed_balance = use_detailed_balance