Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions ncrf/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,94 @@ def _copy_from_data(self, data):
self.tstop = data.tstop
self.gaussian_fwhm = data.gaussian_fwhm

def reduce_data(self):
"""
Reduce the model size for storage.

This method removes the cached training data stored in 'self._data' to reduce
the size of the model when saving to disk (for example with pickle).
A small amount of metadata needed for later reconstruction is stored in
'self._reducemeta' including meg_shape and nlevel.
"""
if self._data is None:
raise RuntimeError("Model is already reduced (self._data is None).")

self._reducemeta = {}
meglen = len(self._data.meg)
meg_shapes = []
for i in range(meglen):
meg_i = self._data.meg[i]
meg_shapes.append(tuple(meg_i.shape))

self._reducemeta["meg_shapes"] = meg_shapes
self._reducemeta["nlevel"] = self._data.nlevel

self._data = None

def reconstruct_data(
self,
meg: Sequence[object],
stim: Sequence[object],
attach: bool = False,
):
"""
Reconstruct RegressionData
meg : sequence of meg same as the one used for fitting
stim : sequence of stim same as the one used for fitting
attach=True : reconstructed data will be added to model._data

Returns data : RegressionData

"""
if len(meg) != len(stim):
raise ValueError("Meg size and stim size do not match.")

data = RegressionData(
tstart=self.tstart,
tstop=self.tstop,
nlevel=self._reducemeta["nlevel"],
baseline=self._stim_baseline,
scaling=self._stim_scaling,
stim_is_single=self._stim_is_single,
gaussian_fwhm=self.gaussian_fwhm,
)

for meg_i, stim_i in zip(meg, stim):
meg_i = meg_i.copy()
if isinstance(stim_i, (list, tuple)):
stim_i = [s.copy() for s in stim_i]
else:
stim_i = stim_i.copy()
data.add_data(meg_i, stim_i)

data.post_normalization()
data._prewhiten(self._whitening_filter)
data._precompute()

# Validate meg size and shapes
expected_shapes = self._reducemeta["meg_shapes"]
expected_meglen = len(expected_shapes)

if len(data.meg) != expected_meglen:
raise ValueError(" Input MEG List-size mismatches with metadata.")

for i, meg_proc in enumerate(data.meg):
got = tuple(meg_proc.shape)
exp = tuple(expected_shapes[i])
if got != exp:
raise ValueError("MEG shape (channel count or timepoints) mismatches with metadata.")

# Validate explained variance
ev_orig = float(self.explained_var)
ev_recon = float(self.compute_explained_variance(data))
if not np.allclose(ev_recon, ev_orig, rtol=1e-10, atol=1e-8):
raise ValueError("Explained variance mismatch! ")

if attach:
self._data = data

return data

def _construct_f(self, data):
"""creates instances of objective function and its gradient to be passes to the FASTA algorithm

Expand Down
9 changes: 9 additions & 0 deletions ncrf/tests/test_ncrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ def test_ncrf():
# 1 stimulus
model = fit_ncrf(meg, stim, fwd, emptyroom, tstop=0.2, normalize='l1', mu=0.0019444, n_iter=3, n_iterc=3,
n_iterf=10, do_post_normalization=False)
# check reduce
model.reduce_data()
assert model._data is None
assert model._reducemeta is not None
# check reconstruct
model.reconstruct_data(meg, stim, attach=True)
assert model._data is not None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also check that an error is raised when attempting to reconstruct with the wrong data?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you address this?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Maryamvalian reminder to let us know when when you address a comment (I noticed you pushed a related commit) – what I meant here is, try to deliberately reconstruct the data with the wrong data, something like this:

meg_wrong = meg.sub(time=(0, 4))
with pytest.raises(ValueError):
    model.reconstruct_data(meg_wrong, stim, attach=True)

assert len(model._reducemeta["meg_shapes"]) == len(meg), "Attempting to reconstruct with wrong meg data."

# check residual and explained var
np.testing.assert_allclose(model.explained_var, 0.00641890144769941, rtol=0.001)
np.testing.assert_allclose(model.voxelwise_explained_variance.sum(), 0.08261162457414245, rtol=0.001)
Expand Down
Loading