From 3385cf26aab1208b3fb832c1d7c7a8bde412b645 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 15 Jun 2026 17:35:18 +0200 Subject: [PATCH 1/6] wip: lazy load (analyzer + extensions) --- src/spikeinterface/core/sortinganalyzer.py | 56 +++++++++++++++------- src/spikeinterface/core/sortingfolder.py | 6 +-- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index b5885598fe..38ef7266f8 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -217,7 +217,9 @@ def create_sorting_analyzer( return sorting_analyzer -def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_options=None) -> "SortingAnalyzer": +def load_sorting_analyzer( + folder, load_extensions=True, format="auto", backend_options=None, lazy=False +) -> "SortingAnalyzer": """ Load a SortingAnalyzer object from disk. @@ -245,7 +247,9 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_o The loaded SortingAnalyzer """ - return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, backend_options=backend_options) + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options, lazy=lazy + ) class SortingAnalyzer: @@ -407,7 +411,7 @@ def create( return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None): + def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None, lazy=False): """ Load folder or zarr. The recording can be given if the recording location has changed. @@ -422,11 +426,11 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe if format == "binary_folder": sorting_analyzer = SortingAnalyzer.load_from_binary_folder( - folder, recording=recording, backend_options=backend_options + folder, recording=recording, backend_options=backend_options, lazy=lazy ) elif format == "zarr": sorting_analyzer = SortingAnalyzer.load_from_zarr( - folder, recording=recording, backend_options=backend_options + folder, recording=recording, backend_options=backend_options, lazy=lazy ) if not is_path_remote(str(folder)): @@ -532,15 +536,23 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) @classmethod - def load_from_binary_folder(cls, folder, recording=None, backend_options=None): + def load_from_binary_folder(cls, folder, recording=None, backend_options=None, lazy=False): from .loading import load folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" # load internal sorting copy in memory + if lazy: + numpy_folder_kwargs = dict(mmap_mode="r") + copy_spike_vector = False + else: + numpy_folder_kwargs = dict() + copy_spike_vector = True sorting = NumpySorting.from_sorting( - NumpyFolderSorting(folder / "sorting"), with_metadata=True, copy_spike_vector=True + NumpyFolderSorting(folder / "sorting", **numpy_folder_kwargs), + with_metadata=True, + copy_spike_vector=copy_spike_vector, ) # Try to load the recording if not provided @@ -698,7 +710,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options) @classmethod - def load_from_zarr(cls, folder, recording=None, backend_options=None): + def load_from_zarr(cls, folder, recording=None, backend_options=None, lazy=False): import zarr from .loading import load @@ -722,6 +734,8 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): ) # load internal sorting in memory + if lazy: + copy_spike_vector = False sorting = NumpySorting.from_sorting( ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), with_metadata=True, @@ -1894,7 +1908,7 @@ def get_saved_extension_names(self): return saved_extension_names - def get_extension(self, extension_name: str): + def get_extension(self, extension_name: str, lazy: bool = False): """ Get a AnalyzerExtension. If not loaded then load is automatic. @@ -1906,13 +1920,13 @@ def get_extension(self, extension_name: str): return self.extensions[extension_name] elif self.format != "memory" and self.has_extension(extension_name): - self.load_extension(extension_name) + self.load_extension(extension_name, lazy=lazy) return self.extensions[extension_name] else: return None - def load_extension(self, extension_name: str): + def load_extension(self, extension_name: str, lazy: bool = False): """ Load an extension from a folder or zarr into the `ResultSorting.extensions` dict. @@ -1920,6 +1934,8 @@ def load_extension(self, extension_name: str): ---------- extension_name : str The extension name. + lazy : bool, default: False + If True, array data are not loaded in memory, but kept as memmap/zarr arrays Returns ------- @@ -1936,7 +1952,7 @@ def load_extension(self, extension_name: str): if extension_class is None: return None - extension_instance = extension_class.load(self) + extension_instance = extension_class.load(self, lazy=lazy) self.extensions[extension_name] = extension_instance @@ -2414,20 +2430,20 @@ def _get_zarr_extension_group(self, mode="r+"): return extension_group @classmethod - def load(cls, sorting_analyzer): + def load(cls, sorting_analyzer, lazy=False): ext = cls(sorting_analyzer) ext.load_params() ext.load_run_info() if ext.run_info is not None: if ext.run_info["run_completed"]: - ext.load_data() + ext.load_data(lazy=lazy) if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() if len(ext.data) > 0: return ext else: # this is for back-compatibility of old analyzers - ext.load_data() + ext.load_data(lazy=lazy) if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() if len(ext.data) > 0: @@ -2527,7 +2543,7 @@ def load_params(self): self.params = params - def load_data(self): + def load_data(self, lazy=False): ext_data = None if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() @@ -2550,7 +2566,8 @@ def load_data(self): # and have a link to the old buffer on windows then it fails # ext_data = np.load(ext_data_file, mmap_mode="r") # so we go back to full loading - ext_data = np.load(ext_data_file) + kwargs = dict(mmap_mode="r") if lazy else dict() + ext_data = np.load(ext_data_file, **kwargs) elif ext_data_file.suffix == ".csv": import pandas as pd @@ -2587,7 +2604,10 @@ def load_data(self): ext_data = ext_data_[0] else: # this load in memmory - ext_data = np.array(ext_data_) + if lazy: + ext_data = ext_data_ + else: + ext_data = np.array(ext_data_) self.set_data(ext_data_name, ext_data) if len(self.data) == 0: diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py index c0d66393d2..2dba9d4465 100644 --- a/src/spikeinterface/core/sortingfolder.py +++ b/src/spikeinterface/core/sortingfolder.py @@ -24,7 +24,7 @@ class NumpyFolderSorting(BaseSorting): mode = "folder" name = "NumpyFolder" - def __init__(self, folder_path): + def __init__(self, folder_path, mmap_mode=None): folder_path = Path(folder_path) with open(folder_path / "numpysorting_info.json", "r") as f: @@ -36,7 +36,7 @@ def __init__(self, folder_path): BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.spikes = np.load(folder_path / "spikes.npy") + self.spikes = np.load(folder_path / "spikes.npy", mmap_mode=mmap_mode) for segment_index in range(num_segments): self.add_sorting_segment(SpikeVectorSortingSegment(self.spikes, segment_index, unit_ids)) @@ -47,7 +47,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._kwargs = dict(folder_path=str(folder_path.absolute()), mmap_mode=mmap_mode) @staticmethod def write_sorting(sorting, save_path): From 121a0559a91bc9f3003dbf5a36b422c0f1473c35 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 18 Jun 2026 19:09:51 -0600 Subject: [PATCH 2/6] todo --- src/spikeinterface/core/zarrextractors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index bbc797c693..7e8dd72a47 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -289,6 +289,8 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, BaseSorting.__init__(self, sampling_frequency, unit_ids) + # TODO: make a virtual memmap view of the spike vector or override to_spike_vector to behave like + # a memmap spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype) spikes["sample_index"] = spikes_group["sample_index"][:] spikes["unit_index"] = spikes_group["unit_index"][:] From 6fdd1947a1c27eb7aa69658c14e66a208e632324 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 24 Jun 2026 15:46:24 +0200 Subject: [PATCH 3/6] feat: move lazy logic to load only --- src/spikeinterface/core/sortinganalyzer.py | 35 +++++++++++----------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 38ef7266f8..8814c30a36 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -283,6 +283,7 @@ def __init__( sparsity: ChannelSparsity | None = None, return_in_uV: bool = True, backend_options: dict | None = None, + lazy: bool = False, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -308,6 +309,9 @@ def __init__( # (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) self._backend_options = {} if backend_options is None else backend_options + # the lazy flag is used to load the extensions in a lazy way (only when needed) + self._lazy = lazy + # extensions are not loaded at init self.extensions = dict() @@ -549,6 +553,7 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None, l else: numpy_folder_kwargs = dict() copy_spike_vector = True + sorting = NumpySorting.from_sorting( NumpyFolderSorting(folder / "sorting", **numpy_folder_kwargs), with_metadata=True, @@ -613,6 +618,7 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None, l sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, + lazy=lazy, ) sorting_analyzer.folder = folder @@ -733,13 +739,12 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None, lazy=False "Please consider re-generating the SortingAnalyzer object." ) - # load internal sorting in memory - if lazy: - copy_spike_vector = False + # TODO: make a Virtual memmap of ZarrSorting spike vector + copy_spike_vector = False if lazy else True sorting = NumpySorting.from_sorting( ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), with_metadata=True, - copy_spike_vector=True, + copy_spike_vector=copy_spike_vector, ) # load recording if possible @@ -784,6 +789,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None, lazy=False sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, + lazy=lazy, ) sorting_analyzer.folder = folder @@ -1908,7 +1914,7 @@ def get_saved_extension_names(self): return saved_extension_names - def get_extension(self, extension_name: str, lazy: bool = False): + def get_extension(self, extension_name: str): """ Get a AnalyzerExtension. If not loaded then load is automatic. @@ -1920,13 +1926,13 @@ def get_extension(self, extension_name: str, lazy: bool = False): return self.extensions[extension_name] elif self.format != "memory" and self.has_extension(extension_name): - self.load_extension(extension_name, lazy=lazy) + self.load_extension(extension_name) return self.extensions[extension_name] else: return None - def load_extension(self, extension_name: str, lazy: bool = False): + def load_extension(self, extension_name: str): """ Load an extension from a folder or zarr into the `ResultSorting.extensions` dict. @@ -1934,8 +1940,6 @@ def load_extension(self, extension_name: str, lazy: bool = False): ---------- extension_name : str The extension name. - lazy : bool, default: False - If True, array data are not loaded in memory, but kept as memmap/zarr arrays Returns ------- @@ -1952,7 +1956,7 @@ def load_extension(self, extension_name: str, lazy: bool = False): if extension_class is None: return None - extension_instance = extension_class.load(self, lazy=lazy) + extension_instance = extension_class.load(self, lazy=self._lazy) self.extensions[extension_name] = extension_instance @@ -2563,9 +2567,8 @@ def load_data(self, lazy=False): ext_data = json.load(f) elif ext_data_file.suffix == ".npy": # The lazy loading of an extension is complicated because if we compute again - # and have a link to the old buffer on windows then it fails - # ext_data = np.load(ext_data_file, mmap_mode="r") - # so we go back to full loading + # and have a link to the old buffer on windows then it fails. + # So, by default, we use full loading, but lazy can be requested on demand. kwargs = dict(mmap_mode="r") if lazy else dict() ext_data = np.load(ext_data_file, **kwargs) elif ext_data_file.suffix == ".csv": @@ -2603,11 +2606,7 @@ def load_data(self, lazy=False): elif "object" in ext_data_.attrs: ext_data = ext_data_[0] else: - # this load in memmory - if lazy: - ext_data = ext_data_ - else: - ext_data = np.array(ext_data_) + ext_data = ext_data_ if lazy else np.array(ext_data_[:]) self.set_data(ext_data_name, ext_data) if len(self.data) == 0: From 10dfd9f3060e0343461b26cb837ca6a1b6fe6482 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 24 Jun 2026 16:59:38 +0200 Subject: [PATCH 4/6] test: add tests for lazy mode --- src/spikeinterface/core/sortinganalyzer.py | 7 ++- .../core/tests/test_sortinganalyzer.py | 58 ++++++++++++++++++- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8814c30a36..3c8dae4188 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -437,7 +437,7 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe folder, recording=recording, backend_options=backend_options, lazy=lazy ) - if not is_path_remote(str(folder)): + if not is_path_remote(str(folder)) and not lazy: if load_extensions: sorting_analyzer.load_all_saved_extension() @@ -1008,6 +1008,11 @@ def _save_or_select_or_merge_or_split( new_sorting_analyzer : SortingAnalyzer The newly created SortingAnalyzer object. """ + if self._lazy: + raise ValueError( + "Cannot save, select, merge or split units when the SortingAnalyzer is lazy. " + "Please load the SortingAnalyzer with lazy=False." + ) if self.has_recording(): recording = self._recording elif self.has_temporary_recording(): diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..43d4b153fb 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -119,7 +119,7 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset): assert "number" in sorting_analyzer.sorting.get_property_keys() sorting_analyzer_reloded = load_sorting_analyzer(folder, format="auto") assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys() - assert "number" in sorting_analyzer.sorting.get_property_keys() + assert "number" in sorting_analyzer_reloded.sorting.get_property_keys() def test_SortingAnalyzer_zarr(tmp_path, dataset): @@ -201,7 +201,7 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): assert "number" in sorting_analyzer.sorting.get_property_keys() sorting_analyzer_reloded = load_sorting_analyzer(sorting_analyzer.folder, format="auto") assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys() - assert "number" in sorting_analyzer.sorting.get_property_keys() + assert "number" in sorting_analyzer_reloded.sorting.get_property_keys() def test_create_by_dict(): @@ -325,6 +325,60 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset): assert np.array_equal(recording.get_channel_locations(), sorting_analyzer.get_channel_locations()) +def test_load_in_lazy_mode_binary(tmp_path, dataset): + recording, sorting = dataset + + folder = tmp_path / "test_SortingAnalyzer_binary_folder" + if folder.exists(): + shutil.rmtree(folder) + + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None + ) + + sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) + # load in lazy mode and check that extension data are memmap + sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True) + template_ext = sorting_analyzer_lazy.get_extension("templates") + template_data = template_ext.data + for key, value in template_data.items(): + if isinstance(value, np.ndarray): + assert isinstance(value, np.memmap) + spike_amplitudes_ext = sorting_analyzer_lazy.get_extension("spike_amplitudes") + spike_amplitudes_data = spike_amplitudes_ext.data + for key, value in spike_amplitudes_data.items(): + if isinstance(value, np.ndarray): + assert isinstance(value, np.memmap) + + +def test_load_in_lazy_mode_zarr(tmp_path, dataset): + import zarr + + recording, sorting = dataset + + folder = tmp_path / "test_SortingAnalyzer_zarr_folder.zarr" + if folder.exists(): + shutil.rmtree(folder) + + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + ) + + sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) + # load in lazy mode and check that extension data are zarr arrays + sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True) + template_ext = sorting_analyzer_lazy.get_extension("templates") + template_data = template_ext.data + for key, value in template_data.items(): + if isinstance(value, np.ndarray): + assert isinstance(value, zarr.Array) + spike_amplitudes_ext = sorting_analyzer_lazy.get_extension("spike_amplitudes") + spike_amplitudes_data = spike_amplitudes_ext.data + for key, value in spike_amplitudes_data.items(): + if isinstance(value, np.ndarray): + assert isinstance(value, zarr.Array) + + def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): register_result_extension(DummyAnalyzerExtension) From 21a6a596f0d999e7c05dd521603b450a1e45f839 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 24 Jun 2026 17:18:42 +0200 Subject: [PATCH 5/6] feat: implement ZarrSpikeVector - a memmap like lazy spike vector for zarr --- src/spikeinterface/core/sortinganalyzer.py | 16 +- .../core/tests/test_sortinganalyzer.py | 11 +- src/spikeinterface/core/zarrextractors.py | 140 ++++++++++++++++-- 3 files changed, 153 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3c8dae4188..1588de85bb 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -739,10 +739,20 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None, lazy=False "Please consider re-generating the SortingAnalyzer object." ) - # TODO: make a Virtual memmap of ZarrSorting spike vector - copy_spike_vector = False if lazy else True + if lazy: + copy_spike_vector = False + lazy_spike_vector = True + else: + copy_spike_vector = True + lazy_spike_vector = False + sorting = NumpySorting.from_sorting( - ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), + ZarrSortingExtractor( + folder, + zarr_group="sorting", + storage_options=storage_options, + lazy_spike_vector=lazy_spike_vector, + ), with_metadata=True, copy_spike_vector=copy_spike_vector, ) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 43d4b153fb..e0411bc9cd 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -337,8 +337,11 @@ def test_load_in_lazy_mode_binary(tmp_path, dataset): ) sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) - # load in lazy mode and check that extension data are memmap + # load in lazy mode and check that spike vector and extension data are memmap sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True) + + assert isinstance(sorting_analyzer_lazy.sorting.to_spike_vector(), np.memmap) + template_ext = sorting_analyzer_lazy.get_extension("templates") template_data = template_ext.data for key, value in template_data.items(): @@ -353,6 +356,7 @@ def test_load_in_lazy_mode_binary(tmp_path, dataset): def test_load_in_lazy_mode_zarr(tmp_path, dataset): import zarr + from spikeinterface.core.zarrextractors import ZarrSpikeVector recording, sorting = dataset @@ -365,8 +369,11 @@ def test_load_in_lazy_mode_zarr(tmp_path, dataset): ) sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) - # load in lazy mode and check that extension data are zarr arrays + # load in lazy mode and check that spikevector is ZarrSpikeVector andextension data are zarr arrays sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True) + + assert isinstance(sorting_analyzer_lazy.sorting.to_spike_vector(), ZarrSpikeVector) + template_ext = sorting_analyzer_lazy.get_extension("templates") template_data = template_ext.data for key, value in template_data.items(): diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 7e8dd72a47..e2f1b9cd7b 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -241,6 +241,112 @@ def get_traces( return traces +class _ZarrSegmentIndex: + """Lazy segment_index array derived from segment_slices stored in zarr.""" + + def __init__(self, segment_slices: np.ndarray, n: int): + self._segment_slices = segment_slices + self._n = n + + def __len__(self) -> int: + return self._n + + def __array__(self, dtype=None): + arr = np.empty(self._n, dtype="int64") + for seg_idx, (s0, s1) in enumerate(self._segment_slices): + arr[s0:s1] = seg_idx + return arr if dtype is None else arr.astype(dtype) + + def __getitem__(self, key): + return np.asarray(self)[key] + + def __eq__(self, other): + return np.asarray(self) == other + + +class ZarrSpikeVector: + """ + Virtual structured spike vector backed by zarr arrays. + + Mimics a memmap-backed numpy structured array with fields + (sample_index, unit_index, segment_index) without loading any data + at construction time. Data is read from zarr lazily: + + * Field access (``spikes["sample_index"]``) returns the zarr array + (or a lazy segment-index object). + * Slice access (``spikes[s0:s1]``) materialises only that slice. + * ``np.asarray(spikes)`` materialises the full array. + + The zarr arrays are assumed to be stored in sorted order + (segment_index ASC, sample_index ASC, unit_index ASC), which is the + ordering guaranteed by :func:`add_sorting_to_zarr_group`. + """ + + def __init__(self, spikes_group, segment_slices: np.ndarray): + self._sample_index = spikes_group["sample_index"] + self._unit_index = spikes_group["unit_index"] + self._segment_slices = np.asarray(segment_slices, dtype="int64") + self._n = len(self._sample_index) + self.dtype = np.dtype(minimum_spike_dtype) + + @property + def size(self) -> int: + return self._n + + def __len__(self) -> int: + return self._n + + def __getitem__(self, key): + if isinstance(key, str): + if key == "sample_index": + return self._sample_index + elif key == "unit_index": + return self._unit_index + elif key == "segment_index": + return _ZarrSegmentIndex(self._segment_slices, self._n) + else: + raise KeyError(f"ZarrSpikeVector has no field {key!r}") + + if isinstance(key, (int, np.integer)): + idx = int(key) + if idx < 0: + idx += self._n + result = np.empty(1, dtype=self.dtype) + result["sample_index"][0] = self._sample_index[idx] + result["unit_index"][0] = self._unit_index[idx] + result["segment_index"][0] = int(np.searchsorted(self._segment_slices[:, 0], idx, side="right")) - 1 + return result[0] + + if isinstance(key, slice): + start, stop, step = key.indices(self._n) + n = len(range(start, stop, step)) + result = np.empty(n, dtype=self.dtype) + result["sample_index"] = self._sample_index[start:stop:step] + result["unit_index"] = self._unit_index[start:stop:step] + if step == 1: + seg_index = np.empty(n, dtype="int64") + for seg_idx, (s0, s1) in enumerate(self._segment_slices): + lo = max(start, int(s0)) - start + hi = min(stop, int(s1)) - start + if hi > lo: + seg_index[lo:hi] = seg_idx + result["segment_index"] = seg_index + else: + result["segment_index"] = _ZarrSegmentIndex(self._segment_slices, self._n)[start:stop:step] + return result + + # fallback for fancy/boolean indexing: materialise then index + return np.asarray(self)[key] + + def __array__(self, dtype=None): + arr = np.empty(self._n, dtype=self.dtype) + arr["sample_index"] = self._sample_index[:] + arr["unit_index"] = self._unit_index[:] + for seg_idx, (s0, s1) in enumerate(self._segment_slices): + arr["segment_index"][s0:s1] = seg_idx + return arr if dtype is None else arr.astype(dtype) + + class ZarrSortingExtractor(BaseSorting): """ SortingExtractor for a zarr format @@ -257,13 +363,23 @@ class ZarrSortingExtractor(BaseSorting): Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. zarr_group : str or None, default: None Optional zarr group path to load the sorting from. This can be used when the sorting is not stored at the root, but in sub group. + lazy_spike_vector : bool, default: False + If True, the spike vector is loaded lazily. This can be useful for large sortings with many spikes. + If False, the spike vector is loaded in memory. Default: False + Returns ------- sorting : ZarrSortingExtractor The sorting Extractor """ - def __init__(self, folder_path: Path | str, storage_options: dict | None = None, zarr_group: str | None = None): + def __init__( + self, + folder_path: Path | str, + storage_options: dict | None = None, + zarr_group: str | None = None, + lazy_spike_vector: bool = False, + ): folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) @@ -289,15 +405,21 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, BaseSorting.__init__(self, sampling_frequency, unit_ids) - # TODO: make a virtual memmap view of the spike vector or override to_spike_vector to behave like - # a memmap - spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype) - spikes["sample_index"] = spikes_group["sample_index"][:] - spikes["unit_index"] = spikes_group["unit_index"][:] - for i, (start, end) in enumerate(segment_slices_list): - spikes["segment_index"][start:end] = i - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + if lazy_spike_vector: + spikes = ZarrSpikeVector(spikes_group, segment_slices_list) + else: + # Materialize the spike vector in memory and sort it by (segment_index, sample_index, unit_index) + spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype) + spikes["sample_index"] = spikes_group["sample_index"][:] + spikes["unit_index"] = spikes_group["unit_index"][:] + for i, (start, end) in enumerate(segment_slices_list): + spikes["segment_index"][start:end] = i + spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + self._cached_spike_vector = spikes + # pre-populate segment slices so _get_spike_vector_segment_slices() never + # needs to materialise the full segment_index array + self._cached_spike_vector_segment_slices = np.asarray(segment_slices_list, dtype="int64") for segment_index in range(num_segments): soring_segment = SpikeVectorSortingSegment(spikes, segment_index, unit_ids) From e1a1e75a149264acfe93e01129b7ba459387e66a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 24 Jun 2026 17:19:26 +0200 Subject: [PATCH 6/6] fix: add lazy spike vector as kwarg --- src/spikeinterface/core/zarrextractors.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index e2f1b9cd7b..a832129fe1 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -437,7 +437,12 @@ def __init__( if annotations is not None: self.annotate(**annotations) - self._kwargs = {"folder_path": folder_path_kwarg, "storage_options": storage_options, "zarr_group": zarr_group} + self._kwargs = { + "folder_path": folder_path_kwarg, + "storage_options": storage_options, + "zarr_group": zarr_group, + "lazy_spike_vector": lazy_spike_vector, + } @staticmethod def write_sorting(sorting: BaseSorting, folder_path: str | Path, storage_options: dict | None = None, **kwargs):