Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def create_sorting_analyzer(
return sorting_analyzer


def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_options=None) -> "SortingAnalyzer":
def load_sorting_analyzer(
folder, load_extensions=True, format="auto", backend_options=None, lazy=False
) -> "SortingAnalyzer":
"""
Load a SortingAnalyzer object from disk.

Expand Down Expand Up @@ -245,7 +247,9 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_o
The loaded SortingAnalyzer

"""
return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, backend_options=backend_options)
return SortingAnalyzer.load(
folder, load_extensions=load_extensions, format=format, backend_options=backend_options, lazy=lazy
)


class SortingAnalyzer:
Expand Down Expand Up @@ -279,6 +283,7 @@ def __init__(
sparsity: ChannelSparsity | None = None,
return_in_uV: bool = True,
backend_options: dict | None = None,
lazy: bool = False,
):
# very fast init because checks are done in load and create
self.sorting = sorting
Expand All @@ -304,6 +309,9 @@ def __init__(
# (additional saving options for creating and saving datasets, e.g. compression/filters for zarr)
self._backend_options = {} if backend_options is None else backend_options

# the lazy flag is used to load the extensions in a lazy way (only when needed)
self._lazy = lazy

# extensions are not loaded at init
self.extensions = dict()

Expand Down Expand Up @@ -407,7 +415,7 @@ def create(
return sorting_analyzer

@classmethod
def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None):
def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None, lazy=False):
"""
Load folder or zarr.
The recording can be given if the recording location has changed.
Expand All @@ -422,14 +430,14 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe

if format == "binary_folder":
sorting_analyzer = SortingAnalyzer.load_from_binary_folder(
folder, recording=recording, backend_options=backend_options
folder, recording=recording, backend_options=backend_options, lazy=lazy
)
elif format == "zarr":
sorting_analyzer = SortingAnalyzer.load_from_zarr(
folder, recording=recording, backend_options=backend_options
folder, recording=recording, backend_options=backend_options, lazy=lazy
)

if not is_path_remote(str(folder)):
if not is_path_remote(str(folder)) and not lazy:
if load_extensions:
sorting_analyzer.load_all_saved_extension()

Expand Down Expand Up @@ -532,15 +540,24 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV
return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options)

@classmethod
def load_from_binary_folder(cls, folder, recording=None, backend_options=None):
def load_from_binary_folder(cls, folder, recording=None, backend_options=None, lazy=False):
from .loading import load

folder = Path(folder)
assert folder.is_dir(), f"This folder does not exists {folder}"

# load internal sorting copy in memory
if lazy:
numpy_folder_kwargs = dict(mmap_mode="r")
copy_spike_vector = False
else:
numpy_folder_kwargs = dict()
copy_spike_vector = True

sorting = NumpySorting.from_sorting(
NumpyFolderSorting(folder / "sorting"), with_metadata=True, copy_spike_vector=True
NumpyFolderSorting(folder / "sorting", **numpy_folder_kwargs),
with_metadata=True,
copy_spike_vector=copy_spike_vector,
)

# Try to load the recording if not provided
Expand Down Expand Up @@ -601,6 +618,7 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None):
sparsity=sparsity,
return_in_uV=return_in_uV,
backend_options=backend_options,
lazy=lazy,
)
sorting_analyzer.folder = folder

Expand Down Expand Up @@ -698,7 +716,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options)

@classmethod
def load_from_zarr(cls, folder, recording=None, backend_options=None):
def load_from_zarr(cls, folder, recording=None, backend_options=None, lazy=False):
import zarr
from .loading import load

Expand All @@ -721,11 +739,22 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None):
"Please consider re-generating the SortingAnalyzer object."
)

# load internal sorting in memory
if lazy:
copy_spike_vector = False
lazy_spike_vector = True
else:
copy_spike_vector = True
lazy_spike_vector = False

sorting = NumpySorting.from_sorting(
ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options),
ZarrSortingExtractor(
folder,
zarr_group="sorting",
storage_options=storage_options,
lazy_spike_vector=lazy_spike_vector,
),
with_metadata=True,
copy_spike_vector=True,
copy_spike_vector=copy_spike_vector,
)

# load recording if possible
Expand Down Expand Up @@ -770,6 +799,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None):
sparsity=sparsity,
return_in_uV=return_in_uV,
backend_options=backend_options,
lazy=lazy,
)
sorting_analyzer.folder = folder

Expand Down Expand Up @@ -988,6 +1018,11 @@ def _save_or_select_or_merge_or_split(
new_sorting_analyzer : SortingAnalyzer
The newly created SortingAnalyzer object.
"""
if self._lazy:
raise ValueError(
"Cannot save, select, merge or split units when the SortingAnalyzer is lazy. "
"Please load the SortingAnalyzer with lazy=False."
)
if self.has_recording():
recording = self._recording
elif self.has_temporary_recording():
Expand Down Expand Up @@ -1936,7 +1971,7 @@ def load_extension(self, extension_name: str):
if extension_class is None:
return None

extension_instance = extension_class.load(self)
extension_instance = extension_class.load(self, lazy=self._lazy)

self.extensions[extension_name] = extension_instance

Expand Down Expand Up @@ -2414,20 +2449,20 @@ def _get_zarr_extension_group(self, mode="r+"):
return extension_group

@classmethod
def load(cls, sorting_analyzer):
def load(cls, sorting_analyzer, lazy=False):
ext = cls(sorting_analyzer)
ext.load_params()
ext.load_run_info()
if ext.run_info is not None:
if ext.run_info["run_completed"]:
ext.load_data()
ext.load_data(lazy=lazy)
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
return ext
else:
# this is for back-compatibility of old analyzers
ext.load_data()
ext.load_data(lazy=lazy)
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
Expand Down Expand Up @@ -2527,7 +2562,7 @@ def load_params(self):

self.params = params

def load_data(self):
def load_data(self, lazy=False):
ext_data = None
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
Expand All @@ -2547,10 +2582,10 @@ def load_data(self):
ext_data = json.load(f)
elif ext_data_file.suffix == ".npy":
# The lazy loading of an extension is complicated because if we compute again
# and have a link to the old buffer on windows then it fails
# ext_data = np.load(ext_data_file, mmap_mode="r")
# so we go back to full loading
ext_data = np.load(ext_data_file)
# and have a link to the old buffer on windows then it fails.
# So, by default, we use full loading, but lazy can be requested on demand.
kwargs = dict(mmap_mode="r") if lazy else dict()
ext_data = np.load(ext_data_file, **kwargs)
elif ext_data_file.suffix == ".csv":
import pandas as pd

Expand Down Expand Up @@ -2586,8 +2621,7 @@ def load_data(self):
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
# this load in memmory
ext_data = np.array(ext_data_)
ext_data = ext_data_ if lazy else np.array(ext_data_[:])
self.set_data(ext_data_name, ext_data)

if len(self.data) == 0:
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/sortingfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class NumpyFolderSorting(BaseSorting):
mode = "folder"
name = "NumpyFolder"

def __init__(self, folder_path):
def __init__(self, folder_path, mmap_mode=None):
folder_path = Path(folder_path)

with open(folder_path / "numpysorting_info.json", "r") as f:
Expand All @@ -36,7 +36,7 @@ def __init__(self, folder_path):

BaseSorting.__init__(self, sampling_frequency, unit_ids)

self.spikes = np.load(folder_path / "spikes.npy")
self.spikes = np.load(folder_path / "spikes.npy", mmap_mode=mmap_mode)

for segment_index in range(num_segments):
self.add_sorting_segment(SpikeVectorSortingSegment(self.spikes, segment_index, unit_ids))
Expand All @@ -47,7 +47,7 @@ def __init__(self, folder_path):
folder_metadata = folder_path
self.load_metadata_from_folder(folder_metadata)

self._kwargs = dict(folder_path=str(folder_path.absolute()))
self._kwargs = dict(folder_path=str(folder_path.absolute()), mmap_mode=mmap_mode)
Comment thread
alejoe91 marked this conversation as resolved.

@staticmethod
def write_sorting(sorting, save_path):
Expand Down
65 changes: 63 additions & 2 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset):
assert "number" in sorting_analyzer.sorting.get_property_keys()
sorting_analyzer_reloded = load_sorting_analyzer(folder, format="auto")
assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()
assert "number" in sorting_analyzer_reloded.sorting.get_property_keys()


def test_SortingAnalyzer_zarr(tmp_path, dataset):
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
assert "number" in sorting_analyzer.sorting.get_property_keys()
sorting_analyzer_reloded = load_sorting_analyzer(sorting_analyzer.folder, format="auto")
assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()
assert "number" in sorting_analyzer_reloded.sorting.get_property_keys()


def test_create_by_dict():
Expand Down Expand Up @@ -325,6 +325,67 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset):
assert np.array_equal(recording.get_channel_locations(), sorting_analyzer.get_channel_locations())


def test_load_in_lazy_mode_binary(tmp_path, dataset):
recording, sorting = dataset

folder = tmp_path / "test_SortingAnalyzer_binary_folder"
if folder.exists():
shutil.rmtree(folder)

sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None
)

sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"])
# load in lazy mode and check that spike vector and extension data are memmap
sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True)

assert isinstance(sorting_analyzer_lazy.sorting.to_spike_vector(), np.memmap)

template_ext = sorting_analyzer_lazy.get_extension("templates")
template_data = template_ext.data
for key, value in template_data.items():
if isinstance(value, np.ndarray):
assert isinstance(value, np.memmap)
spike_amplitudes_ext = sorting_analyzer_lazy.get_extension("spike_amplitudes")
spike_amplitudes_data = spike_amplitudes_ext.data
for key, value in spike_amplitudes_data.items():
if isinstance(value, np.ndarray):
assert isinstance(value, np.memmap)


def test_load_in_lazy_mode_zarr(tmp_path, dataset):
import zarr
from spikeinterface.core.zarrextractors import ZarrSpikeVector

recording, sorting = dataset

folder = tmp_path / "test_SortingAnalyzer_zarr_folder.zarr"
if folder.exists():
shutil.rmtree(folder)

sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None
)

sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"])
# load in lazy mode and check that spikevector is ZarrSpikeVector andextension data are zarr arrays
sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True)

assert isinstance(sorting_analyzer_lazy.sorting.to_spike_vector(), ZarrSpikeVector)

template_ext = sorting_analyzer_lazy.get_extension("templates")
template_data = template_ext.data
for key, value in template_data.items():
if isinstance(value, np.ndarray):
assert isinstance(value, zarr.Array)
spike_amplitudes_ext = sorting_analyzer_lazy.get_extension("spike_amplitudes")
spike_amplitudes_data = spike_amplitudes_ext.data
for key, value in spike_amplitudes_data.items():
if isinstance(value, np.ndarray):
assert isinstance(value, zarr.Array)


def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

register_result_extension(DummyAnalyzerExtension)
Expand Down
Loading
Loading