diff --git a/ncrf/_model.py b/ncrf/_model.py index 0093b95..7533595 100644 --- a/ncrf/_model.py +++ b/ncrf/_model.py @@ -907,6 +907,94 @@ def _copy_from_data(self, data): self.tstop = data.tstop self.gaussian_fwhm = data.gaussian_fwhm + def reduce_data(self): + """ + Reduce the model size for storage. + + This method removes the cached training data stored in 'self._data' to reduce + the size of the model when saving to disk (for example with pickle). + A small amount of metadata needed for later reconstruction is stored in + 'self._reducemeta' including meg_shape and nlevel. + """ + if self._data is None: + raise RuntimeError("Model is already reduced (self._data is None).") + + self._reducemeta = {} + meglen = len(self._data.meg) + meg_shapes = [] + for i in range(meglen): + meg_i = self._data.meg[i] + meg_shapes.append(tuple(meg_i.shape)) + + self._reducemeta["meg_shapes"] = meg_shapes + self._reducemeta["nlevel"] = self._data.nlevel + + self._data = None + + def reconstruct_data( + self, + meg: Sequence[object], + stim: Sequence[object], + attach: bool = False, + ): + """ + Reconstruct RegressionData + meg : sequence of meg same as the one used for fitting + stim : sequence of stim same as the one used for fitting + attach=True : reconstructed data will be added to model._data + + Returns data : RegressionData + + """ + if len(meg) != len(stim): + raise ValueError("Meg size and stim size do not match.") + + data = RegressionData( + tstart=self.tstart, + tstop=self.tstop, + nlevel=self._reducemeta["nlevel"], + baseline=self._stim_baseline, + scaling=self._stim_scaling, + stim_is_single=self._stim_is_single, + gaussian_fwhm=self.gaussian_fwhm, + ) + + for meg_i, stim_i in zip(meg, stim): + meg_i = meg_i.copy() + if isinstance(stim_i, (list, tuple)): + stim_i = [s.copy() for s in stim_i] + else: + stim_i = stim_i.copy() + data.add_data(meg_i, stim_i) + + data.post_normalization() + data._prewhiten(self._whitening_filter) + data._precompute() + + # Validate meg size and shapes + expected_shapes = self._reducemeta["meg_shapes"] + expected_meglen = len(expected_shapes) + + if len(data.meg) != expected_meglen: + raise ValueError(" Input MEG List-size mismatches with metadata.") + + for i, meg_proc in enumerate(data.meg): + got = tuple(meg_proc.shape) + exp = tuple(expected_shapes[i]) + if got != exp: + raise ValueError("MEG shape (channel count or timepoints) mismatches with metadata.") + + # Validate explained variance + ev_orig = float(self.explained_var) + ev_recon = float(self.compute_explained_variance(data)) + if not np.allclose(ev_recon, ev_orig, rtol=1e-10, atol=1e-8): + raise ValueError("Explained variance mismatch! ") + + if attach: + self._data = data + + return data + def _construct_f(self, data): """creates instances of objective function and its gradient to be passes to the FASTA algorithm diff --git a/ncrf/tests/test_ncrf.py b/ncrf/tests/test_ncrf.py index fd96c63..8589659 100644 --- a/ncrf/tests/test_ncrf.py +++ b/ncrf/tests/test_ncrf.py @@ -19,6 +19,15 @@ def test_ncrf(): # 1 stimulus model = fit_ncrf(meg, stim, fwd, emptyroom, tstop=0.2, normalize='l1', mu=0.0019444, n_iter=3, n_iterc=3, n_iterf=10, do_post_normalization=False) + # check reduce + model.reduce_data() + assert model._data is None + assert model._reducemeta is not None + # check reconstruct + model.reconstruct_data(meg, stim, attach=True) + assert model._data is not None + assert len(model._reducemeta["meg_shapes"]) == len(meg), "Attempting to reconstruct with wrong meg data." + # check residual and explained var np.testing.assert_allclose(model.explained_var, 0.00641890144769941, rtol=0.001) np.testing.assert_allclose(model.voxelwise_explained_variance.sum(), 0.08261162457414245, rtol=0.001)