From 5e05b805ff2fe40210c0810f4d1abf1de1e8bde2 Mon Sep 17 00:00:00 2001 From: Maryam Valian Date: Thu, 29 Jan 2026 14:29:45 -0500 Subject: [PATCH 1/4] ENH: add NCRF.reduce_data() + recon_data(); add tests (closes #43) --- ncrf/_model.py | 122 ++++++++++++++++++++++++++++++++++++++++ ncrf/tests/test_ncrf.py | 8 +++ 2 files changed, 130 insertions(+) diff --git a/ncrf/_model.py b/ncrf/_model.py index 0093b95..23f41fb 100644 --- a/ncrf/_model.py +++ b/ncrf/_model.py @@ -907,6 +907,128 @@ def _copy_from_data(self, data): self.tstop = data.tstop self.gaussian_fwhm = data.gaussian_fwhm + def reduce_data(self): + """ + Reduce model size and save meta data in model.megmeta. + """ + if getattr(self, "_data", None) is None: + raise ValueError("self._data is None.") + + if not hasattr(self._data, "meg"): + raise ValueError("self._data.meg is missing.") + + if not hasattr(self, "megmeta") or self.megmeta is None: + self.megmeta = {} + + meglen = len(self._data.meg) + meg_shapes = [] + for i in range(meglen): + meg_i = self._data.meg[i] + if not hasattr(meg_i, "shape"): + raise TypeError(f"self._data.meg[{i}] has no .shape attribute.") + meg_shapes.append(tuple(meg_i.shape)) + + self.megmeta["meglen"] = meglen + self.megmeta["meg_shapes"] = meg_shapes + + self._stim_normalization = { + "s_normalization": self._stim_normalization, + "nlevel": self._data.nlevel, + } + + self._data = None + print("Data removed from model successfully!") + return + + def reconstruct_data( + self, + meg_all, + stim_all, + in_place=False, + do_post_normalization=True, + attach=False, + ): + """ + Reconstruct RegressionData + in_place = False : meg and stim data will not change. + + """ + if len(meg_all) != len(stim_all): + raise ValueError("Meg size and stim size do not match.") + + whitening_filter = getattr(self, "_whitening_filter", None) + if whitening_filter is None: + raise ValueError("self._whitening_filter is None.") + + stim_norm = getattr(self, "_stim_normalization", None) + if not isinstance(stim_norm, dict) or "nlevel" not in stim_norm: + raise ValueError("self._stim_normalization invalid") + nlevel = stim_norm["nlevel"] + + data = RegressionData( + tstart=self.tstart, + tstop=self.tstop, + nlevel=nlevel, + baseline=getattr(self, "_stim_baseline", None), + scaling=getattr(self, "_stim_scaling", None), + stim_is_single=getattr(self, "_stim_is_single", None), + gaussian_fwhm=self.gaussian_fwhm, + ) + + for meg, stim in zip(meg_all, stim_all): + if not in_place: + if isinstance(stim, (list, tuple)): + stim = [s.copy() for s in stim] + else: + stim = stim.copy() + data.add_data(meg, stim) + + if do_post_normalization: + data.post_normalization() + + data._prewhiten(whitening_filter) + data._precompute() + + # Validate meg size and shapes + if not hasattr(self, "megmeta") or self.megmeta is None: + raise ValueError("self.megmeta is missing.") + if "meglen" not in self.megmeta or "meg_shapes" not in self.megmeta: + raise ValueError('self.megmeta must contain "meglen" and "meg_shapes".') + + expected_meglen = self.megmeta["meglen"] + expected_shapes = self.megmeta["meg_shapes"] + + if len(data.meg) != expected_meglen: + raise ValueError("Processed MEG size mismatch") + + if len(expected_shapes) != expected_meglen: + raise ValueError("self.megmeta['meg_shapes'] length mismatch.") + + for i, meg_proc in enumerate(data.meg): + got = tuple(meg_proc.shape) + exp = tuple(expected_shapes[i]) + if got != exp: + raise ValueError(f"Processed MEG shape mismatch at run {i}") + + # Validate explained variance + if not hasattr(self, "explained_var"): + raise ValueError("self.explained_var is missing.") + + ev_orig = float(self.explained_var) + ev_recon = float(self.compute_explained_variance(data)) + + if np.allclose(ev_recon, ev_orig, rtol=1e-10, atol=1e-8): + print("Explained Variance Checked Successfully!") + else: + diff = abs(ev_recon - ev_orig) + raise ValueError(f"Explained variance mismatch! diff={diff}") + + if attach: + self._data = data + print("Reconstructed Data attached successfully!") + + return data + def _construct_f(self, data): """creates instances of objective function and its gradient to be passes to the FASTA algorithm diff --git a/ncrf/tests/test_ncrf.py b/ncrf/tests/test_ncrf.py index fd96c63..5a38bff 100644 --- a/ncrf/tests/test_ncrf.py +++ b/ncrf/tests/test_ncrf.py @@ -19,6 +19,14 @@ def test_ncrf(): # 1 stimulus model = fit_ncrf(meg, stim, fwd, emptyroom, tstop=0.2, normalize='l1', mu=0.0019444, n_iter=3, n_iterc=3, n_iterf=10, do_post_normalization=False) + # check reduce + model.reduce_data() + assert model._data is None + assert getattr(model, "megmeta", None) is not None + # check reconstruct + model.reconstruct_data(meg, stim, attach=True) + assert model._data is not None + # check residual and explained var np.testing.assert_allclose(model.explained_var, 0.00641890144769941, rtol=0.001) np.testing.assert_allclose(model.voxelwise_explained_variance.sum(), 0.08261162457414245, rtol=0.001) From a7bee22b5ccd694c57135dc42aba26f320be9d38 Mon Sep 17 00:00:00 2001 From: Maryam Valian Date: Thu, 5 Feb 2026 16:11:40 -0500 Subject: [PATCH 2/4] Fix based on PR feedback --- ncrf/_model.py | 112 ++++++++++++++-------------------------- ncrf/tests/test_ncrf.py | 2 +- 2 files changed, 40 insertions(+), 74 deletions(-) diff --git a/ncrf/_model.py b/ncrf/_model.py index 23f41fb..d10f820 100644 --- a/ncrf/_model.py +++ b/ncrf/_model.py @@ -909,123 +909,89 @@ def _copy_from_data(self, data): def reduce_data(self): """ - Reduce model size and save meta data in model.megmeta. - """ - if getattr(self, "_data", None) is None: - raise ValueError("self._data is None.") - - if not hasattr(self._data, "meg"): - raise ValueError("self._data.meg is missing.") + Reduce the model size for storage. - if not hasattr(self, "megmeta") or self.megmeta is None: - self.megmeta = {} + This method removes the cached training data stored in 'self._data' to reduce + the size of the model when saving to disk (for example with pickle). + A small amount of metadata needed for later reconstruction is stored in + 'self._reducemeta' including meg_shape and nlevel. + """ + if self._data is None: + raise RuntimeError("Model is already reduced (self._data is None).") + self._reducemeta = {} meglen = len(self._data.meg) meg_shapes = [] for i in range(meglen): meg_i = self._data.meg[i] - if not hasattr(meg_i, "shape"): - raise TypeError(f"self._data.meg[{i}] has no .shape attribute.") meg_shapes.append(tuple(meg_i.shape)) - self.megmeta["meglen"] = meglen - self.megmeta["meg_shapes"] = meg_shapes - - self._stim_normalization = { - "s_normalization": self._stim_normalization, - "nlevel": self._data.nlevel, - } + self._reducemeta["meg_shapes"] = meg_shapes + self._reducemeta["nlevel"] = self._data.nlevel self._data = None - print("Data removed from model successfully!") - return def reconstruct_data( self, - meg_all, - stim_all, - in_place=False, - do_post_normalization=True, - attach=False, + meg: Sequence[object], + stim: Sequence[object], + attach: bool=False, ): """ Reconstruct RegressionData - in_place = False : meg and stim data will not change. + meg : sequence of meg same as the one used for fitting + stim : sequence of stim same as the one used for fitting + attach=True : reconstructed data will be added to model._data + + Returns data : RegressionData """ - if len(meg_all) != len(stim_all): + if len(meg) != len(stim): raise ValueError("Meg size and stim size do not match.") - whitening_filter = getattr(self, "_whitening_filter", None) - if whitening_filter is None: - raise ValueError("self._whitening_filter is None.") - - stim_norm = getattr(self, "_stim_normalization", None) - if not isinstance(stim_norm, dict) or "nlevel" not in stim_norm: - raise ValueError("self._stim_normalization invalid") - nlevel = stim_norm["nlevel"] - data = RegressionData( tstart=self.tstart, tstop=self.tstop, - nlevel=nlevel, - baseline=getattr(self, "_stim_baseline", None), - scaling=getattr(self, "_stim_scaling", None), - stim_is_single=getattr(self, "_stim_is_single", None), + nlevel=self._reducemeta["nlevel"], + baseline=self._stim_baseline, + scaling=self._stim_scaling, + stim_is_single=self._stim_is_single, gaussian_fwhm=self.gaussian_fwhm, ) - for meg, stim in zip(meg_all, stim_all): - if not in_place: - if isinstance(stim, (list, tuple)): - stim = [s.copy() for s in stim] - else: - stim = stim.copy() - data.add_data(meg, stim) - - if do_post_normalization: - data.post_normalization() + for meg_i, stim_i in zip(meg, stim): + meg_i = meg_i.copy() + if isinstance(stim_i, (list, tuple)): + stim_i = [s.copy() for s in stim_i] + else: + stim_i = stim_i.copy() + data.add_data(meg_i, stim_i) - data._prewhiten(whitening_filter) + data.post_normalization() + data._prewhiten(self._whitening_filter) data._precompute() # Validate meg size and shapes - if not hasattr(self, "megmeta") or self.megmeta is None: - raise ValueError("self.megmeta is missing.") - if "meglen" not in self.megmeta or "meg_shapes" not in self.megmeta: - raise ValueError('self.megmeta must contain "meglen" and "meg_shapes".') - - expected_meglen = self.megmeta["meglen"] - expected_shapes = self.megmeta["meg_shapes"] + expected_shapes = self._reducemeta["meg_shapes"] + expected_meglen = len(expected_shapes) if len(data.meg) != expected_meglen: - raise ValueError("Processed MEG size mismatch") - - if len(expected_shapes) != expected_meglen: - raise ValueError("self.megmeta['meg_shapes'] length mismatch.") + raise ValueError(" Input MEG List-size mismatches with metadata.") for i, meg_proc in enumerate(data.meg): got = tuple(meg_proc.shape) exp = tuple(expected_shapes[i]) if got != exp: - raise ValueError(f"Processed MEG shape mismatch at run {i}") + raise ValueError(f"MEG shape (channel count or timepoints) mismatches with metadata.") # Validate explained variance - if not hasattr(self, "explained_var"): - raise ValueError("self.explained_var is missing.") - ev_orig = float(self.explained_var) ev_recon = float(self.compute_explained_variance(data)) - - if np.allclose(ev_recon, ev_orig, rtol=1e-10, atol=1e-8): - print("Explained Variance Checked Successfully!") - else: - diff = abs(ev_recon - ev_orig) - raise ValueError(f"Explained variance mismatch! diff={diff}") + if not np.allclose(ev_recon, ev_orig, rtol=1e-10, atol=1e-8): + raise ValueError(f"Explained variance mismatch! ") if attach: self._data = data - print("Reconstructed Data attached successfully!") return data diff --git a/ncrf/tests/test_ncrf.py b/ncrf/tests/test_ncrf.py index 5a38bff..9b6ffb5 100644 --- a/ncrf/tests/test_ncrf.py +++ b/ncrf/tests/test_ncrf.py @@ -22,7 +22,7 @@ def test_ncrf(): # check reduce model.reduce_data() assert model._data is None - assert getattr(model, "megmeta", None) is not None + assert model._reducemeta is not None # check reconstruct model.reconstruct_data(meg, stim, attach=True) assert model._data is not None From ff918e71e6212f0cb0e0b366c6f118652ec962ed Mon Sep 17 00:00:00 2001 From: Maryam Valian Date: Thu, 5 Feb 2026 16:40:24 -0500 Subject: [PATCH 3/4] Fix flake8 formatting --- ncrf/_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ncrf/_model.py b/ncrf/_model.py index d10f820..7533595 100644 --- a/ncrf/_model.py +++ b/ncrf/_model.py @@ -935,7 +935,7 @@ def reconstruct_data( self, meg: Sequence[object], stim: Sequence[object], - attach: bool=False, + attach: bool = False, ): """ Reconstruct RegressionData @@ -982,13 +982,13 @@ def reconstruct_data( got = tuple(meg_proc.shape) exp = tuple(expected_shapes[i]) if got != exp: - raise ValueError(f"MEG shape (channel count or timepoints) mismatches with metadata.") + raise ValueError("MEG shape (channel count or timepoints) mismatches with metadata.") # Validate explained variance ev_orig = float(self.explained_var) ev_recon = float(self.compute_explained_variance(data)) if not np.allclose(ev_recon, ev_orig, rtol=1e-10, atol=1e-8): - raise ValueError(f"Explained variance mismatch! ") + raise ValueError("Explained variance mismatch! ") if attach: self._data = data From 282bcc2586132235470f2462677d5c1f93371652 Mon Sep 17 00:00:00 2001 From: Maryam Valian Date: Mon, 9 Feb 2026 14:20:28 -0500 Subject: [PATCH 4/4] Add test for wrong data --- ncrf/tests/test_ncrf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ncrf/tests/test_ncrf.py b/ncrf/tests/test_ncrf.py index 9b6ffb5..8589659 100644 --- a/ncrf/tests/test_ncrf.py +++ b/ncrf/tests/test_ncrf.py @@ -26,6 +26,7 @@ def test_ncrf(): # check reconstruct model.reconstruct_data(meg, stim, attach=True) assert model._data is not None + assert len(model._reducemeta["meg_shapes"]) == len(meg), "Attempting to reconstruct with wrong meg data." # check residual and explained var np.testing.assert_allclose(model.explained_var, 0.00641890144769941, rtol=0.001)