diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index f1c88e0f15..8b5b6de3b8 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -11,7 +11,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel class ClusteringBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index be1cf18fbf..eae0bf0e59 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -5,8 +5,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel from pathlib import Path @@ -33,7 +31,8 @@ def test_benchmark_clustering(create_cache_folder): # sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) # sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index") + extremum_channel_inds = gt_analyzer.get_main_channels(outputs="index", with_dict=True) + spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py index 22b519ec78..a529141323 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py @@ -28,7 +28,7 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_in_uV=False, **job_kwargs ) noise_levels = get_noise_levels(recording, **job_kwargs) - sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25) + sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index b9207caaa3..82a51e8292 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -6,7 +6,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel @pytest.mark.skip() @@ -30,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder): sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("templates", **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index c32d919bc7..7a683a3118 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -166,9 +166,12 @@ # template tools from .template_tools import ( get_template_amplitudes, - get_template_extremum_channel, - get_template_extremum_channel_peak_shift, - get_template_extremum_amplitude, + get_template_main_channel_peak_shift, + get_template_main_channel_amplitude, + # this is not needed anymore + get_template_extremum_channel, # keep for backward compatibility can be removed in 0.105 + get_template_extremum_channel_peak_shift, # keep for backward compatibility can be removed in 0.105 + get_template_extremum_amplitude, # keep for backward compatibility can be removed in 0.105 ) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 3eeede22b5..7f8d210209 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -381,7 +381,7 @@ class ComputeTemplates(AnalyzerExtension): extension_name = "templates" depend_on = ["random_spikes|waveforms"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = True need_backward_compatibility_on_load = True diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fc8373cdfb..46da558388 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -13,6 +13,10 @@ class BaseSorting(BaseExtractor): Abstract class representing several segment several units and relative spiketrains. """ + _main_properties = [ + "main_channel_index", + ] + def __init__(self, sampling_frequency: float, unit_ids: list): BaseExtractor.__init__(self, unit_ids) self._sampling_frequency = float(sampling_frequency) @@ -913,6 +917,7 @@ def _compute_and_cache_spike_vector(self) -> None: self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices + # TODO sam : change extremum_channel_inds to main_channel_index with vector def to_spike_vector( self, concatenated=True, @@ -933,7 +938,8 @@ def to_spike_vector( extremum_channel_inds : None or dict, default: None If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index". This can be convinient for computing spikes postion after sorter. - This dict can be computed with `get_template_extremum_channel(we, outputs="index")` + This dict can be given by analyzer.get_main_channels(outputs="index", with_dict=True) + use_cache : bool, default: True When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). This caching only occurs when extremum_channel_inds=None. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4fa68ebec0..95b1f37347 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2427,6 +2427,10 @@ def generate_ground_truth_recording( **generate_templates_kwargs, ) sorting.set_property("gt_unit_locations", unit_locations) + distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :2], axis=2) + main_channel_index = np.argmin(distances, axis=1) + sorting.set_property("main_channel_index", main_channel_index) + else: assert templates.shape[0] == num_units diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index afe2e1e4ca..4871568ead 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -145,6 +145,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea return (local_peaks,) +# TODO sam replace extremum_channels_indices by main_channel_index + + # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index b5885598fe..724c88dc50 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -44,6 +44,9 @@ from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open, _write_object_array from .node_pipeline import run_node_pipeline +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection + # high level function def create_sorting_analyzer( @@ -51,6 +54,10 @@ def create_sorting_analyzer( recording, format="memory", folder=None, + main_channel_index=None, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", + num_spikes_for_main_channel=100, sparse=True, sparsity=None, set_sparsity_by_dict_key=False, @@ -58,7 +65,9 @@ def create_sorting_analyzer( return_in_uV=True, overwrite=False, backend_options=None, - **sparsity_kwargs, + sparsity_kwargs=None, + seed=None, + **job_kwargs, ) -> "SortingAnalyzer": """ Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. @@ -68,6 +77,11 @@ def create_sorting_analyzer( This object will be also use used for plotting purpose. + The main_channel_index can be externally provided. If not then this is taken from + sorting property. If not then the main_channel_index is estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. + Parameters ---------- @@ -81,6 +95,17 @@ def create_sorting_analyzer( The mode to store analyzer. If "folder", the analyzer is stored on disk in the specified folder. The "folder" argument must be specified in case of mode "folder". If "memory" is used, the analyzer is stored in RAM. Use this option carefully! + main_channel_index : None | np.array + The main_channel_index can be externally provided + main_channel_peak_sign : "both" | "neg" + In case when the main_channel_index is estimated wich sign to consider "both" or "neg". + main_channel_peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + num_spikes_for_main_channel : int, default: 100 + How many spikes per units to compute the main channel. sparse : bool, default: True If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. @@ -106,8 +131,8 @@ def create_sorting_analyzer( * storage_options: dict | None (fsspec storage options) * saving_options: dict | None (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) - - sparsity_kwargs : keyword arguments + sparsity_kwargs : dict | None + Dict given to estimate the sparsity. Returns ------- @@ -143,6 +168,9 @@ def create_sorting_analyzer( sparsity off (or give external sparsity) like this. """ + if sparsity_kwargs is None: + sparsity_kwargs = dict() + if isinstance(sorting, dict) and isinstance(recording, dict): if sorting.keys() != recording.keys(): @@ -167,27 +195,55 @@ def create_sorting_analyzer( return_in_uV=return_in_uV, overwrite=overwrite, backend_options=backend_options, - **sparsity_kwargs, + sparsity_kwargs=sparsity_kwargs, + **job_kwargs, ) - if format != "memory" and not is_path_remote(folder): - folder = clean_zarr_folder_name(folder) if format == "zarr" else folder - if Path(folder).is_dir(): - if overwrite: - shutil.rmtree(folder) - else: - raise ValueError(f"Folder {folder} already exists! Use overwrite=True to overwrite it.") + if format != "memory": + if format == "zarr": + if not is_path_remote(folder): + folder = clean_zarr_folder_name(folder) + if not is_path_remote(folder): + if Path(folder).is_dir(): + if not overwrite: + raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") + else: + shutil.rmtree(folder) + + # retrieve or compute the main channel index per unit + if main_channel_index is None: + if "main_channel_index" in sorting.get_property_keys(): + main_channel_index = sorting.get_property("main_channel_index") + + if main_channel_index is None: + # this is weird but due to the cyclic import + from .template_tools import estimate_main_channel_from_recording + + main_channel_index = estimate_main_channel_from_recording( + recording, + sorting, + peak_sign=main_channel_peak_sign, + peak_mode=main_channel_peak_mode, + num_spikes_for_main_channel=num_spikes_for_main_channel, + seed=seed, + **job_kwargs, + ) # handle sparsity if sparsity is not None: # some checks assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object" - error_msg = "If external sparsity is given, unit_ids must match sorting" - assert np.array_equal(sorting.unit_ids, sparsity.unit_ids), error_msg - error_msg = "If external sparsity is given, channel_ids must match recording" - assert np.array_equal(recording.channel_ids, sparsity.channel_ids), error_msg + assert np.array_equal( + sorting.unit_ids, sparsity.unit_ids + ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" + assert np.array_equal( + recording.channel_ids, sparsity.channel_ids + ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" + assert all( + sparsity.mask[u, c] for u, c in enumerate(main_channel_index) + ), "sparsity si not constistentent with main_channel_index" elif sparse: - sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, main_channel_index=main_channel_index, **sparsity_kwargs) else: sparsity = None @@ -209,6 +265,9 @@ def create_sorting_analyzer( recording, format=format, folder=folder, + main_channel_index=main_channel_index, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, @@ -278,6 +337,8 @@ def __init__( format: str | None = None, sparsity: ChannelSparsity | None = None, return_in_uV: bool = True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options: dict | None = None, ): # very fast init because checks are done in load and create @@ -288,6 +349,9 @@ def __init__( self.format = format self.sparsity = sparsity self.return_in_uV = return_in_uV + self.main_channel_peak_sign = main_channel_peak_sign + self.main_channel_peak_mode = main_channel_peak_mode + self._main_channel_index = None # For backward compatibility self.return_scaled = return_in_uV @@ -341,12 +405,17 @@ def create( "zarr", ] = "memory", folder=None, + main_channel_index=None, sparsity=None, return_scaled=None, return_in_uV=True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options=None, ): assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" + assert main_channel_index is not None, "To create a SortingAnalyzer you need to specify the main_channel_index" + # some checks if sorting.sampling_frequency != recording.sampling_frequency: if math.isclose(sorting.sampling_frequency, recording.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5): @@ -376,8 +445,18 @@ def create( sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) + # This will ensure that the sorting saved always will have this main_channel + if format == "memory": - sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, rec_attributes=None) + sorting_analyzer = cls.create_memory( + sorting, + recording, + sparsity, + return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes=None, + ) elif format == "binary_folder": sorting_analyzer = cls.create_binary_folder( folder, @@ -385,6 +464,8 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) @@ -398,12 +479,16 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) else: raise ValueError("SortingAnalyzer.create: wrong format") + sorting_analyzer.set_sorting_property("main_channel_index", main_channel_index, save=True) + return sorting_analyzer @classmethod @@ -436,7 +521,9 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe return sorting_analyzer @classmethod - def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attributes): + def create_memory( + cls, sorting, recording, sparsity, return_in_uV, main_channel_peak_sign, main_channel_peak_mode, rec_attributes + ): # used by create and save_as if rec_attributes is None: @@ -457,11 +544,24 @@ def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attribute format="memory", sparsity=sparsity, return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, ) return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_binary_folder( + cls, + folder, + sorting, + recording, + sparsity, + return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes, + backend_options, + ): # used by create and save_as folder = Path(folder) @@ -525,16 +625,158 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV settings_file = folder / f"settings.json" settings = dict( return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, ) with open(settings_file, mode="w") as f: json.dump(check_json(settings), f, indent=4) return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) + @classmethod + def _handle_backward_compatibility_settings_pre_init(cls, settings, sorting, sparsity): + """ + backward compatibility before the __init__ to handle the settings: + * return_scaled > return_in_uV + * main_channel_peak_sign + * main_channel_peak_mode + + Note : + * see also _handle_backward_compatibility_settings_post_init + * there is also something at extension level to handle changes in paramaters with deferents mechanism + """ + + new_settings = dict() + new_settings.update(settings) + if "return_scaled" in settings: + new_settings["return_in_uV"] = new_settings.pop("return_scaled") + elif "return_in_uV" in settings: + pass + else: + # old version did not have settings at all + new_settings["return_in_uV"] = True + + if "main_channel_peak_sign" not in settings: + # before 0.104.0 was not in main_channel_peak_sign + # TODO make something more fancy that exlore the previous params of extension + new_settings["main_channel_peak_sign"] = "both" + new_settings["main_channel_peak_mode"] = "extremum" + + return new_settings + + @property + def main_channel_index(self): + if self._main_channel_index is None: + sorting_main_channel_index = self.get_sorting_property("main_channel_index") + if sorting_main_channel_index is not None: + self._main_channel_index = sorting_main_channel_index + else: + self._main_channel_index = self.compute_main_channel_backwards_compatibility() + + return self._main_channel_index + + def compute_main_channel_backwards_compatibility(self): + """ + Computes the `main_channel_index` for an old analyzer. Logic is: + + 1) If you have the `templates` extension: use this to find the main_channel_index, with + default settings and restricted to the sparsity of the analyzer. + + 2) If you do not have the `templates` extension but do have `waveforms`: compute + `templates`, then go to 1) + + 3) If you do not have `templates` or `waveforms`, but do have `sparsity = True`, we + will take the "average" channel as the `main_channel_index` + + 4) Failing that, if you have an attached `recording`, compute the `main_channel_index` + using the accumulator. + + 5) If you have a dense analyzer with no `templates`, `waveforms` or `recording`, your + analyzer is not compatible with newer versions of SpikeInterface. Raise an error + and ask the user to attach a recording. + """ + + warnings.warn( + "This sorting analyzer is from an an older version of spikeinterface. " + "For future compatibility we will compute the `main_channel_indices`." + ) + + main_channel_index = None + + templates_array = None + peak_sign = "both" + peak_mode = "extremum" + + # Case 1 + if self.has_extension("templates"): + templates = self.get_extension("templates") + for k in ("average", "median"): + if k in templates.data: + from .template_tools import _get_main_channel_from_template_array + + templates_array = templates.data[k] + break + else: + # Case 2 - from waveforms + if self.has_extension("waveforms") and self.has_extension("random_spikes"): + from spikeinterface.core.analyzer_extension_core import ComputeTemplates + + templates = ComputeTemplates(self) + templates_array = templates.data["average"] + + if templates_array is not None: + from .template_tools import _get_main_channel_from_template_array + + main_channel_index = _get_main_channel_from_template_array( + templates_array, peak_mode, peak_sign, templates.nbefore + ) + + # Case 3 + if self.is_sparse(): + channel_locations = self.get_channel_locations() + sparsity = self.sparsity + main_channel_index = [] + for channel_indices in sparsity.unit_id_to_channel_indices.values(): + unit_channel_locations = channel_locations[channel_indices] + average_unit_channel_location = np.average(unit_channel_locations, axis=0) + distance_from_average_channel = np.linalg.norm( + channel_locations - average_unit_channel_location, axis=1 + ) + closest_channel_index = np.argmin(distance_from_average_channel) + main_channel_index.append(closest_channel_index) + + return main_channel_index + + # Case 4 + if self.has_recording() or self.has_temporary_recording(): + from .template_tools import estimate_main_channel_from_recording + + main_channel_index = estimate_main_channel_from_recording( + self.recording, + self.sorting, + peak_sign=peak_sign, + peak_mode=peak_mode, + num_spikes_for_main_channel=100, + ) + + return main_channel_index + + raise ValueError( + "This analyzer is dense, and has no attached recording, waveforms or templates. Hence we cannot estimate the `main_channel_indices`, making the analyzer incompatible with newer versions of spikeinterface. Please attach a recording to continue, or re-create your analyzer from scratch." + ) + + # # this is only in memory + # self.sorting.set_property("main_channel_index", main_channel_index) + # # TODO @alessio @chris : we need to discuss this + # # this save also to disk but maybe there is no write for the analyzer... + # self.set_sorting_property("main_channel_index", main_channel_index, save=True) + @classmethod def load_from_binary_folder(cls, folder, recording=None, backend_options=None): from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -585,21 +827,27 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): if settings_file.exists(): with open(settings_file, "r") as f: settings = json.load(f) + need_to_create = False else: + need_to_create = True + settings = dict() + + settings = cls._handle_backward_compatibility_settings_pre_init(settings, sorting, sparsity) + + if need_to_create: warnings.warn("settings.json not found for this folder writing one with return_in_uV=True") - settings = dict(return_in_uV=True) with open(settings_file, "w") as f: json.dump(check_json(settings), f, indent=4) - return_in_uV = settings.get("return_in_uV", settings.get("return_scaled", True)) - sorting_analyzer = SortingAnalyzer( sorting=sorting, recording=recording, rec_attributes=rec_attributes, format="binary_folder", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV=settings["return_in_uV"], + main_channel_peak_sign=settings["main_channel_peak_sign"], + main_channel_peak_mode=settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -614,7 +862,18 @@ def _get_zarr_root(self, mode="r+"): return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_zarr( + cls, + folder, + sorting, + recording, + sparsity, + return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes, + backend_options, + ): # used by create and save_as import zarr from .zarrextractors import add_sorting_to_zarr_group @@ -637,7 +896,11 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) - settings = dict(return_in_uV=return_in_uV) + settings = dict( + return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) zarr_root.attrs["settings"] = check_json(settings) # the recording @@ -702,6 +965,8 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): import zarr from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) @@ -758,9 +1023,8 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): else: sparsity = None - return_in_uV = zarr_root.attrs["settings"].get( - "return_in_uV", zarr_root.attrs["settings"].get("return_scaled", True) - ) + settings = zarr_root.attrs["settings"] + settings = cls._handle_backward_compatibility_settings_pre_init(settings, sorting, sparsity) sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -768,7 +1032,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="zarr", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV=settings["return_in_uV"], + main_channel_peak_sign=settings["main_channel_peak_sign"], + main_channel_peak_mode=settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -871,6 +1137,21 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n """ return self.sorting.get_property(key, ids=ids) + def get_main_channels(self, outputs="index", with_dict=False): + """ """ + main_channel_index = self.main_channel_index + if outputs == "index": + main_chans = main_channel_index + elif outputs == "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + def are_units_mergeable( self, merge_unit_groups: list[str | int], @@ -1092,7 +1373,13 @@ def _save_or_select_or_merge_or_split( if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( - sorting_provenance, recording, sparsity, self.return_in_uV, self.rec_attributes + sorting_provenance, + recording, + sparsity, + self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes, ) elif format == "binary_folder": @@ -1105,6 +1392,8 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, self.rec_attributes, backend_options=backend_options, ) @@ -1118,6 +1407,8 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, self.rec_attributes, backend_options=backend_options, ) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index a203c2ff05..d5129e8c6d 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -28,8 +28,6 @@ In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. - peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels. num_channels : int Number of channels for "best_channels" method. radius_um : float @@ -82,14 +80,14 @@ class ChannelSparsity: Using the N best channels (largest template amplitude): - >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels) Using a neighborhood by radius: - >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um) Using a SNR threshold: - >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold) Using a template energy threshold: >>> sparsity = ChannelSparsity.from_energy(sorting_analyzer, threshold) @@ -323,7 +321,7 @@ def from_best_channels( mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode) + peak_values = get_template_amplitudes(templates_or_sorting_analyzer) for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] @@ -332,7 +330,7 @@ def from_best_channels( ## Some convinient function to compute sparsity from several strategy @classmethod - def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): + def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign=None): """ Construct sparsity from N closest channels Use the "num_channels" argument to specify the number of channels. @@ -368,15 +366,17 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels, peak return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): + def from_radius_and_main_channel(cls, unit_ids, channel_ids, main_channel_index, channel_locations, radius_um): """ - Construct sparsity from a radius around the best channel. + Construct sparsity from a radius around the main channel. Use the "radius_um" argument to specify the radius in um. Parameters ---------- - templates_or_sorting_analyzer : Templates | SortingAnalyzer - A Templates or a SortingAnalyzer object. + main_channel_index : np.array + Main channel index per units. + channel_locations : np.array + Channel locations of the recording. radius_um : float Radius in um for "radius" method. peak_sign : "neg" | "pos" | "both" @@ -387,27 +387,48 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): sparsity : ChannelSparsity The estimated sparsity. """ - from .template_tools import get_template_extremum_channel - - mask = np.zeros( - (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" - ) - channel_locations = templates_or_sorting_analyzer.get_channel_locations() + mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, outputs="index") - for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_ind = best_chan[unit_id] - (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) + for unit_ind, main_chan in enumerate(main_channel_index): + (chan_inds,) = np.nonzero(distances[main_chan, :] <= radius_um) mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) + return cls(mask, unit_ids, channel_ids) + + @classmethod + def from_radius(cls, templates_or_sorting_analyzer, radius_um): + """ + Construct sparsity from a radius around the main channel. + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + radius_um : float + Radius in um for "radius" method. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + main_channel_index = templates_or_sorting_analyzer.get_main_channels(outputs="index") + channel_locations = templates_or_sorting_analyzer.get_channel_locations() + return cls.from_radius_and_main_channel( + templates_or_sorting_analyzer.unit_ids, + templates_or_sorting_analyzer.channel_ids, + main_channel_index, + channel_locations, + radius_um, + ) @classmethod def from_snr( cls, templates_or_sorting_analyzer, threshold, - amplitude_mode="extremum", - peak_sign="neg", + amplitude_mode=None, + peak_sign=None, noise_levels=None, ): """ @@ -458,7 +479,7 @@ def from_snr( mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_in_uV=return_in_uV + templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=amplitude_mode ) for unit_ind, unit_id in enumerate(unit_ids): @@ -467,7 +488,7 @@ def from_snr( return cls(mask, unit_ids, channel_ids) @classmethod - def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): + def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode=None, peak_sign=None): """ Construct sparsity from a threshold based on template amplitude. The amplitude is computed with the specified amplitude mode and it is assumed @@ -513,7 +534,7 @@ def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_in_uV=True + templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=amplitude_mode ) for unit_ind, unit_id in enumerate(unit_ids): @@ -623,12 +644,10 @@ def compute_sparsity( method: Literal[ "radius", "best_channels", "closest_channels", "snr", "amplitude", "energy", "by_property" ] = "radius", - peak_sign: Literal["neg", "pos", "both"] = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, - amplitude_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", ) -> ChannelSparsity: """ Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. @@ -674,29 +693,25 @@ def compute_sparsity( if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" - sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels) elif method == "closest_channels": assert num_channels is not None, "For the 'closest_channels' method, 'num_channels' needs to be given" sparsity = ChannelSparsity.from_closest_channels(templates_or_sorting_analyzer, num_channels) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" - sparsity = ChannelSparsity.from_radius(templates_or_sorting_analyzer, radius_um, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_radius(templates_or_sorting_analyzer, radius_um) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_snr( templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, - peak_sign=peak_sign, - amplitude_mode=amplitude_mode, ) elif method == "amplitude": assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_amplitude( templates_or_sorting_analyzer, threshold, - amplitude_mode=amplitude_mode, - peak_sign=peak_sign, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -722,13 +737,14 @@ def estimate_sparsity( ms_before: float = 1.0, ms_after: float = 2.5, method: Literal["radius", "best_channels", "closest_channels", "amplitude", "snr", "by_property"] = "radius", - peak_sign: Literal["neg", "pos", "both"] = "neg", + peak_sign: Literal["neg", "pos", "both"] = "both", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, amplitude_mode: Literal["extremum", "peak_to_peak"] = "extremum", by_property: str | None = None, noise_levels: np.ndarray | list | None = None, + main_channel_index: np.ndarray | list | None = None, **job_kwargs, ): """ @@ -737,11 +753,10 @@ def estimate_sparsity( For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. These can be computed with the `get_noise_levels()` function. - Contrary to the previous implementation: - * all units are computed in one read of recording - * it doesn't require a folder - * it doesn't consume too much memory - * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + If main_channel_index is given and method="radius" then there is not need estimate + the templates otherwise the templates must be estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. @@ -760,6 +775,9 @@ def estimate_sparsity( noise_levels : np.ndarray | None, default: None Noise levels required for the "snr" and "energy" methods. You can use the `get_noise_levels()` function to compute them. + main_channel_index : np.array | None, default: None + Main channel indicies can be provided in case of method="radius", this avoid the + `estimate_templates_with_accumulator()` which is slow. {} Returns @@ -769,6 +787,7 @@ def estimate_sparsity( """ # Can't be done at module because this is a cyclic import, too bad from .template import Templates + from .template_tools import get_templates_array_from_recording_and_sorting assert method in ("radius", "best_channels", "closest_channels", "snr", "amplitude", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " @@ -784,32 +803,21 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - if method != "by_property": - nbefore = ms_to_samples(ms_before, recording.sampling_frequency) - nafter = ms_to_samples(ms_after, recording.sampling_frequency) - - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - random_spikes_indices = random_spikes_selection( - sorting, - num_samples, - method="uniform", - max_spikes_per_unit=num_spikes_for_sparsity, - margin_size=max(nbefore, nafter), - seed=2205, + if method == "radius" and main_channel_index is not None: + assert main_channel_index.size == sorting.unit_ids.size + chan_locs = recording.get_channel_locations() + sparsity = ChannelSparsity.from_radius_and_main_channel( + sorting.unit_ids, recording.channel_ids, main_channel_index, chan_locs, radius_um ) - spikes = sorting.to_spike_vector() - spikes = spikes[random_spikes_indices] - - templates_array = estimate_templates_with_accumulator( - recording, - spikes, - sorting.unit_ids, - nbefore, - nafter, - return_in_uV=False, - job_name="estimate_sparsity", - **job_kwargs, + + elif method != "by_property": + + templates_array = get_templates_array_from_recording_and_sorting( + recording, sorting, ms_before, ms_after, num_spikes_for_sparsity, 2205, **job_kwargs ) + + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + templates = Templates( templates_array=templates_array, sampling_frequency=recording.sampling_frequency, @@ -830,7 +838,7 @@ def estimate_sparsity( sparsity = ChannelSparsity.from_closest_channels(templates, num_channels) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" - sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_radius(templates, radius_um) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" assert noise_levels is not None, ( diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 67ba1179b0..6005fcec69 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np import json from dataclasses import dataclass, field, astuple, replace @@ -481,3 +483,27 @@ def get_channel_locations(self) -> np.ndarray: assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" channel_locations = self.probe.contact_positions return channel_locations + + def get_main_channels( + self, + peak_sign: Literal["neg", "both", "pos"] = "both", + peak_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + outputs="index", + with_dict=False, + ): + from .template_tools import _get_main_channel_from_template_array + + templates_array = self.get_dense_templates() + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, self.nbefore) + + if outputs == "index": + main_chans = main_channel_index + elif outputs == "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index b6a6b90bb2..bed53139ed 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -2,12 +2,17 @@ import numpy as np +from spikeinterface.core.core_tools import ms_to_samples from .template import Templates +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection from .sortinganalyzer import SortingAnalyzer +import warnings + def get_dense_templates_array( - one_object: Templates | SortingAnalyzer, return_in_uV: bool = True, operator="average" + one_object: Templates | SortingAnalyzer, return_in_uV: None | bool = None, operator="average" ) -> np.ndarray: """ Return dense templates as numpy array from either a Templates object or a SortingAnalyzer. @@ -28,13 +33,13 @@ def get_dense_templates_array( The dense templates (num_units, num_samples, num_channels) """ if isinstance(one_object, Templates): - if return_in_uV != one_object.is_in_uV: + if return_in_uV is not None and return_in_uV != one_object.is_in_uV: raise ValueError( f"get_dense_templates_array: return_in_uV={return_in_uV} is not possible Templates has the reverse" ) templates_array = one_object.get_dense_templates() elif isinstance(one_object, SortingAnalyzer): - if return_in_uV != one_object.return_in_uV: + if return_in_uV is not None and return_in_uV != one_object.return_in_uV: raise ValueError( f"get_dense_templates_array: return_in_uV={return_in_uV} is not possible SortingAnalyzer has the reverse" ) @@ -48,6 +53,33 @@ def get_dense_templates_array( return templates_array +def get_main_channel_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: None | bool = None): + """ + Return template on main channel for each unit, as numpy array from either a Templates object or a SortingAnalyzer. + + Parameters + ---------- + one_object : Templates | SortingAnalyzer + The Templates or SortingAnalyzer objects. If SortingAnalyzer, it needs the "templates" extension. + return_in_uV : bool, default: True + If True, templates are scaled. + + Returns + ------- + main_channel_templates : np.ndarray + The dense templates (num_units, num_samples) + """ + # TODO later: do not load the dense templates array if this is not necessary (when sprse internally) + main_channels = one_object.get_main_channels(outputs="index", with_dict=False) + templates_array = get_dense_templates_array(one_object, return_in_uV=return_in_uV) + num_units = templates_array.shape[0] + num_samples = templates_array.shape[1] + main_channel_templates = np.zeros((num_units, num_samples), dtype=templates_array.dtype) + for i in range(num_units): + main_channel_templates[i, :] = templates_array[i, :, main_channels[i]] + return main_channel_templates + + def _get_nbefore(one_object): if isinstance(one_object, Templates): return one_object.nbefore @@ -62,28 +94,28 @@ def _get_nbefore(one_object): def get_template_amplitudes( templates_or_sorting_analyzer, - peak_sign: Literal["neg", "pos", "both"] = "neg", - mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", - return_in_uV: bool = True, + peak_sign: None | Literal["neg", "pos", "both"] = None, + peak_mode: None | Literal["extremum", "at_index", "peak_to_peak"] = None, abs_value: bool = True, operator: str = "average", + override_peak_error: bool = False, ): """ - Get amplitude per channel for each unit. + Get amplitude on every channel for each unit. Parameters ---------- templates_or_sorting_analyzer : Templates | SortingAnalyzer A Templates or a SortingAnalyzer object - peak_sign : "neg" | "pos" | "both" + peak_sign : None | "neg" | "pos" | "both" + Used only when input is Templates. Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : None | "extremum" | "at_index" | "peak_to_peak", default: None + Used only when input is Templates. Where the amplitude is computed * "extremum" : take the peak value (max or min depending on `peak_sign`) * "at_index" : take value at `nbefore` index * "peak_to_peak" : take the peak-to-peak amplitude - return_in_uV : bool, default True - The amplitude is scaled or not. abs_value : bool = True Whether the extremum amplitude should be returned as an absolute value or not operator : str, default: "average" @@ -95,8 +127,35 @@ def get_template_amplitudes( peak_values : dict Dictionary with unit ids as keys and template amplitudes as values """ + + if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + if not override_peak_error: + assert ( + peak_sign is None + ), "get_template_amplitudes() peak_sign is now contained in SortingAnalyzer, should be None here" + assert ( + peak_mode is None + ), "get_template_amplitudes() peak_mode is now contained in SortingAnalyzer, should be None here" + if peak_sign is None: + peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign + if peak_mode is None: + peak_mode = templates_or_sorting_analyzer.main_channel_peak_mode + return_in_uV = templates_or_sorting_analyzer.return_in_uV + elif isinstance(templates_or_sorting_analyzer, Templates): + return_in_uV = templates_or_sorting_analyzer.is_in_uV + if peak_sign is None: + warnings.warn("get_template_amplitudes() with Templates should provide a peak_sign") + peak_sign = "both" + if peak_mode is None: + warnings.warn("get_template_amplitudes() with Templates should provide a peak_mode") + peak_mode = "extremum" + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" + assert peak_mode in ( + "extremum", + "at_index", + "peak_to_peak", + ), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" unit_ids = templates_or_sorting_analyzer.unit_ids before = _get_nbefore(templates_or_sorting_analyzer) @@ -110,19 +169,19 @@ def get_template_amplitudes( for unit_ind, unit_id in enumerate(unit_ids): template = templates_array[unit_ind, :, :] - if mode == "extremum": + if peak_mode == "extremum": if peak_sign == "both": values = np.max(np.abs(template), axis=0) elif peak_sign == "neg": values = np.min(template, axis=0) elif peak_sign == "pos": values = np.max(template, axis=0) - elif mode == "at_index": + elif peak_mode == "at_index": if peak_sign == "both": values = np.abs(template[before, :]) elif peak_sign in ["neg", "pos"]: values = template[before, :] - elif mode == "peak_to_peak": + elif peak_mode == "peak_to_peak": values = np.ptp(template, axis=0) if abs_value: @@ -133,14 +192,112 @@ def get_template_amplitudes( return peak_values +def _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore): + """ + Get the main channel for each template in a `templates_array`, which has dimensions + (num_units) x (num time samples) x (num channels) + """ + # Step1 : max on time axis + if peak_mode == "extremum": + if peak_sign == "both": + values = np.max(np.abs(templates_array), axis=1) + elif peak_sign == "neg": + values = -np.min(templates_array, axis=1) + elif peak_sign == "pos": + values = np.max(templates_array, axis=1) + elif peak_mode == "at_index": + if peak_sign == "both": + values = np.abs(templates_array[:, nbefore, :]) + elif peak_sign in ["neg", "pos"]: + values = templates_array[:, nbefore, :] + elif peak_mode == "peak_to_peak": + values = np.ptp(templates_array, axis=1) + + # Step2: max on channel axis + main_channel_index = np.argmax(values, axis=1) + + return main_channel_index + + +def estimate_main_channel_from_recording( + recording, + sorting, + peak_sign: Literal["neg", "both", "pos"] = "both", + peak_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + num_spikes_for_main_channel=100, + ms_before=1.0, + ms_after=2.5, + seed=None, + **job_kwargs, +): + """ + Estimate the main channel from recording using `estimate_templates_with_accumulator()` + + """ + + if peak_sign == "pos": + warnings.warn( + "estimate_main_channel_from_recording() with peak_sign='pos' is a strange case maybe you " + "should revert the traces instead" + ) + + templates_array = get_templates_array_from_recording_and_sorting( + recording, sorting, ms_before, ms_after, num_spikes_for_main_channel, seed + ) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore) + + return main_channel_index + + +def get_templates_array_from_recording_and_sorting( + recording, sorting, ms_before, ms_after, num_spikes_for_main_channel, seed, **job_kwargs +): + + sampling_frequency = recording.sampling_frequency + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_main_channel, + margin_size=max(nbefore, nafter), + seed=seed, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_in_uV=False, + job_name="estimate_main_channel", + **job_kwargs, + ) + + return templates_array + + +# TODO remove this in 0.106.0 def get_template_extremum_channel( templates_or_sorting_analyzer, peak_sign: Literal["neg", "pos", "both"] = "neg", - mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + peak_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", outputs: Literal["id", "index"] = "id", operator: str = "average", ): """ + Deprecated will be removed in 0.106.0. + Use analyzer.get_main_channels() or tempates.get_main_channels(peak_sign=...) instead. + + Compute the channel with the extremum peak for each unit. Parameters @@ -149,7 +306,7 @@ def get_template_extremum_channel( A Templates or a SortingAnalyzer object peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" Where the amplitude is computed * "extremum" : take the peak value (max or min depending on `peak_sign`) * "at_index" : take value at `nbefore` index @@ -163,49 +320,63 @@ def get_template_extremum_channel( Returns ------- - extremum_channels : dict + main_channels : dict Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ - assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" - assert outputs in ("id", "index"), "`outputs` must be either `id` or `index`" - - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids + warnings.warn( + "get_template_extremum_channel() is deprecated and will be removed in version 0.106.0 use analyzer.get_main_channels() instead", + category=DeprecationWarning, + ) - # if SortingAnalyzer need to use global SortingAnalyzer return_scaled otherwise - # we use the Templates is_in_uV if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV + assert ( + peak_sign is None + ), "get_template_extremum_channel() peak_sign is now contained in SortingAnalyzer, should be None here" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs=outputs, with_dict=True) + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_extremum_channel() with Templates should provide a peak_sign") + peak_sign = "both" + main_channels = templates_or_sorting_analyzer.get_main_channels( + outputs=outputs, peak_sign=peak_sign, with_dict=True + ) + + return main_channels + + +# TODO remove this in 0.106.0 +def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None): + """ + Depracted will be removed in 0.106.0. + Use get_template_main_channel_peak_shift() instead. - peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_in_uV=return_in_uV, operator=operator - ) - extremum_channels_id = {} - extremum_channels_index = {} - for unit_id in unit_ids: - max_ind = np.argmax(np.abs(peak_values[unit_id])) - extremum_channels_id[unit_id] = channel_ids[max_ind] - extremum_channels_index[unit_id] = max_ind + In some situations spike sorters could return a spike index with a small shift related to the waveform peak. + This function estimates and return these alignment shifts for the mean template. + This function is internally used by `compute_spike_amplitudes()` to accurately retrieve the spike amplitudes. - if outputs == "id": - return extremum_channels_id - elif outputs == "index": - return extremum_channels_index + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object + peak_sign : None + Sign of the template to find extremum channels + Returns + ------- + shifts : dict + Dictionary with unit ids as keys and shifts as values + """ -def get_template_extremum_channel_peak_shift( - templates_or_sorting_analyzer, - peak_sign: Literal["neg", "pos", "both"] = "neg", - operator: Literal["average", "median"] = "average", -): + warnings.warn( + "get_template_extremum_channel_peak_shift() is deprecated use get_template_main_channel_peak_shift() instead" + "Will be removed in 0.106.0", + category=DeprecationWarning, + ) + return get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None, with_dict=True) + + +def get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None, with_dict=True): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. @@ -215,7 +386,7 @@ def get_template_extremum_channel_peak_shift( ---------- templates_or_sorting_analyzer : Templates | SortingAnalyzer A Templates or a SortingAnalyzer object - peak_sign : "neg" | "pos" | "both" + peak_sign : None | "neg" | "pos" | "both" Sign of the template to find extremum channels operator : str, default: "average" If the "templates" extension of the SortingAnalyzer contains several operators (e.g., "average" and "median"), @@ -226,55 +397,52 @@ def get_template_extremum_channel_peak_shift( shifts : dict Dictionary with unit ids as keys and shifts as values """ - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids - nbefore = _get_nbefore(templates_or_sorting_analyzer) - - extremum_channels_ids = get_template_extremum_channel( - templates_or_sorting_analyzer, peak_sign=peak_sign, operator=operator - ) - shifts = {} - - # We need to use the SortingAnalyzer return_scaled - # We need to use the Templates is_in_uV if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV + assert peak_sign is None + peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") + peak_sign = "both" - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_in_uV=return_in_uV) - - for unit_ind, unit_id in enumerate(unit_ids): - template = templates_array[unit_ind, :, :] + unit_ids = templates_or_sorting_analyzer.unit_ids + nbefore = _get_nbefore(templates_or_sorting_analyzer) - chan_id = extremum_channels_ids[unit_id] - chan_ind = list(channel_ids).index(chan_id) + main_channel_templates = get_main_channel_templates_array(templates_or_sorting_analyzer) + shifts = [] + for unit_ind, unit_id in enumerate(unit_ids): + template = main_channel_templates[unit_ind] if peak_sign == "both": - peak_pos = np.argmax(np.abs(template[:, chan_ind])) + peak_pos = np.argmax(np.abs(template)) elif peak_sign == "neg": - peak_pos = np.argmin(template[:, chan_ind]) + peak_pos = np.argmin(template) elif peak_sign == "pos": - peak_pos = np.argmax(template[:, chan_ind]) + peak_pos = np.argmax(template) shift = peak_pos - nbefore - shifts[unit_id] = shift + shifts.append(shift) + + if with_dict: + shifts = dict(zip(unit_ids, shifts)) + else: + shifts = np.array(shifts) return shifts +# TODO remove this in 0.105.0 def get_template_extremum_amplitude( templates_or_sorting_analyzer, peak_sign: Literal["neg", "pos", "both"] = "neg", - mode: Literal["extremum", "at_index", "peak_to_peak"] = "at_index", + peak_mode: Literal["extremum", "at_index", "peak_to_peak"] = "at_index", abs_value: bool = True, operator: Literal["average", "median"] = "average", ): """ + Depracted will be removed in 0.105.0. + Use get_template_main_channel_amplitude() instead. + Computes amplitudes on the best channel. Parameters @@ -283,7 +451,7 @@ def get_template_extremum_amplitude( A Templates or a SortingAnalyzer object peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" Where the amplitude is computed * "extremum": take the peak value (max or min depending on `peak_sign`) * "at_index": take value at `nbefore` index @@ -300,35 +468,72 @@ def get_template_extremum_amplitude( amplitudes : dict Dictionary with unit ids as keys and amplitudes as values """ - assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode) + warnings.warn( + "get_template_extremum_amplitude() is deprecated use get_template_main_channel_amplitude() instead" + "Will be removed in 0.106.0", + category=DeprecationWarning, + ) + return get_template_main_channel_amplitude( + templates_or_sorting_analyzer, + peak_sign=peak_sign, + peak_mode=peak_mode, + abs_value=abs_value, + ) + +def get_template_main_channel_amplitude( + templates_or_sorting_analyzer, + peak_sign: None | Literal["neg", "pos", "both"] = None, + peak_mode: None | Literal["extremum", "at_index", "peak_to_peak"] = None, + abs_value: bool = True, + with_dict=True, +): + """ + Computes amplitudes on the best channel. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum": take the peak value (max or min depending on `peak_sign`) + * "at_index": take value at `nbefore` index + * "peak_to_peak": take the peak-to-peak amplitude + abs_value : bool = True + Whether the extremum amplitude should be returned as an absolute value or not + + + Returns + ------- + amplitudes : dict + Dictionary with unit ids as keys and amplitudes as values + """ if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=False) + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") + peak_sign = "both" + main_channels = templates_or_sorting_analyzer.get_main_channels( + outputs="index", peak_sign=peak_sign, with_dict=False + ) extremum_amplitudes = get_template_amplitudes( - templates_or_sorting_analyzer, - peak_sign=peak_sign, - mode=mode, - return_in_uV=return_in_uV, - abs_value=abs_value, - operator=operator, + templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=peak_mode, abs_value=abs_value ) - unit_amplitudes = {} - for unit_id in unit_ids: - channel_id = extremum_channels_ids[unit_id] - best_channel = list(channel_ids).index(channel_id) - unit_amplitudes[unit_id] = extremum_amplitudes[unit_id][best_channel] + unit_ids = templates_or_sorting_analyzer.unit_ids + unit_amplitudes = [] + for unit_ind, unit_id in enumerate(unit_ids): + chan_ind = main_channels[unit_ind] + unit_amplitudes.append(extremum_amplitudes[unit_id][chan_ind]) + + if with_dict: + unit_amplitudes_dict = dict(zip(unit_ids, unit_amplitudes)) + return unit_amplitudes_dict return unit_amplitudes diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index c9d6e888f9..979113a2aa 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -228,3 +228,7 @@ def test_remote_analyzer(): "quality_metrics", ]: assert ext in analyzer.get_saved_extension_names() + + +if __name__ == "__main__": + test_remote_analyzer() diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 61c2fda873..d3c3b719c5 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface import create_sorting_analyzer, generate_ground_truth_recording from spikeinterface.core.base import spike_peak_dtype from spikeinterface.core.job_tools import divide_time_series_into_chunks @@ -80,7 +80,7 @@ def test_run_node_pipeline(cache_folder_creation): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) @@ -202,7 +202,7 @@ def test_skip_after_n_peaks_and_recording_slices(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..7a30ac87b2 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -52,6 +52,12 @@ def dataset(): def test_SortingAnalyzer_memory(tmp_path, dataset): recording, sorting = dataset + + # Note the sorting contain already main_channel_index + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + assert np.array_equal(sorting_analyzer.get_main_channels(), sorting.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -75,6 +81,16 @@ def test_SortingAnalyzer_memory(tmp_path, dataset): assert "quality" in sorting_analyzer.sorting.get_property_keys() assert "number" in sorting_analyzer.sorting.get_property_keys() + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + + # Create when main_channel_index is not given : this is estimated + sorting2 = sorting.clone() + sorting2._properties.pop("main_channel_index") + print(sorting2.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting2, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting2, cache_folder=tmp_path) + def test_SortingAnalyzer_binary_folder(tmp_path, dataset): recording, sorting = dataset @@ -349,7 +365,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): assert ext is None assert sorting_analyzer.has_recording() - # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": @@ -615,12 +630,11 @@ def _set_params(self, param0=5.5): return params def _get_pipeline_nodes(self): - from spikeinterface.core.template_tools import get_template_extremum_channel recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") + extremum_channel_inds = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) @@ -718,9 +732,9 @@ def test_runtime_dependencies(dataset): tmp_path = Path("test_SortingAnalyzer") dataset = get_dataset() test_SortingAnalyzer_memory(tmp_path, dataset) - test_SortingAnalyzer_binary_folder(tmp_path, dataset) - test_SortingAnalyzer_zarr(tmp_path, dataset) - test_SortingAnalyzer_tmp_recording(dataset) - test_extension() - test_extension_params() - test_runtime_dependencies() + # test_SortingAnalyzer_binary_folder(tmp_path, dataset) + # test_SortingAnalyzer_zarr(tmp_path, dataset) + # test_SortingAnalyzer_tmp_recording(dataset) + # test_extension() + # test_extension_params() + # test_runtime_dependencies(dataset) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 6e85221621..372cc2094d 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -303,28 +303,25 @@ def test_compute_sparsity(): # this is needed for method="energy" # using object SortingAnalyzer - sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2) + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0) sparsity = compute_sparsity(sorting_analyzer, method="closest_channels", num_channels=2) - sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") - sparsity = compute_sparsity( - sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" - ) - sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5) + sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() - sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") - sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") - sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + sparsity = compute_sparsity(templates, method="best_channels", num_channels=2) + sparsity = compute_sparsity(templates, method="radius", radius_um=50.0) + sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5) + sparsity = compute_sparsity(templates, method="amplitude", threshold=5) sparsity = compute_sparsity(templates, method="closest_channels", num_channels=2) if __name__ == "__main__": - # test_ChannelSparsity() - # test_estimate_sparsity() + test_ChannelSparsity() + test_estimate_sparsity() test_compute_sparsity() diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index a28680612a..e78bed0bd0 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -7,9 +7,8 @@ from spikeinterface import Templates from spikeinterface.core import ( get_template_amplitudes, - get_template_extremum_channel, - get_template_extremum_channel_peak_shift, - get_template_extremum_amplitude, + get_template_main_channel_peak_shift, + get_template_main_channel_amplitude, ) @@ -56,24 +55,16 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer): def test_get_template_amplitudes(sorting_analyzer): peak_values = get_template_amplitudes(sorting_analyzer) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - peak_values = get_template_amplitudes(templates, abs_value=True) - peak_to_peak_values = get_template_amplitudes(templates, mode="peak_to_peak") + peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="extremum", abs_value=True) + peak_to_peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="peak_to_peak") assert np.all(ptp > p for ptp, p in zip(peak_to_peak_values.values(), peak_values.values())) -def test_get_template_extremum_channel(sorting_analyzer): - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign="both") - print(extremum_channels_ids) - templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - extremum_channels_ids = get_template_extremum_channel(templates, peak_sign="both") - print(extremum_channels_ids) - - -def test_get_template_extremum_channel_peak_shift(sorting_analyzer): - shifts = get_template_extremum_channel_peak_shift(sorting_analyzer, peak_sign="neg") +def test_get_template_main_channel_peak_shift(sorting_analyzer): + shifts = get_template_main_channel_peak_shift(sorting_analyzer) print(shifts) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - shifts = get_template_extremum_channel_peak_shift(templates, peak_sign="neg") + shifts = get_template_main_channel_peak_shift(templates, peak_sign="both") # DEBUG # import matplotlib.pyplot as plt @@ -91,13 +82,14 @@ def test_get_template_extremum_channel_peak_shift(sorting_analyzer): # plt.show() -def test_get_template_extremum_amplitude(sorting_analyzer): +def test_get_template_main_channel_amplitude(sorting_analyzer): - extremum_channels_ids = get_template_extremum_amplitude(sorting_analyzer, peak_sign="both") + extremum_channels_ids = get_template_main_channel_amplitude(sorting_analyzer) print(extremum_channels_ids) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - extremum_channels_ids = get_template_extremum_amplitude(templates, peak_sign="both") + + extremum_channels_ids = get_template_main_channel_amplitude(templates, peak_sign="both", peak_mode="extremum") if __name__ == "__main__": @@ -107,6 +99,5 @@ def test_get_template_extremum_amplitude(sorting_analyzer): print(sorting_analyzer) test_get_template_amplitudes(sorting_analyzer) - test_get_template_extremum_channel(sorting_analyzer) - test_get_template_extremum_channel_peak_shift(sorting_analyzer) - test_get_template_extremum_amplitude(sorting_analyzer) + test_get_template_main_channel_peak_shift(sorting_analyzer) + test_get_template_main_channel_amplitude(sorting_analyzer) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index b1510fb27d..43ed16fbd8 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -493,7 +493,13 @@ def _read_old_waveforms_extractor_binary(folder, sorting): sorting = load(folder / "sorting.pickle", base_folder=folder) sorting_analyzer = SortingAnalyzer.create_memory( - sorting, recording, sparsity=sparsity, return_in_uV=return_in_uV, rec_attributes=rec_attributes + sorting, + recording, + sparsity=sparsity, + main_channel_peak_mode="extremum", + main_channel_peak_sign="neg", + return_in_uV=return_in_uV, + rec_attributes=rec_attributes, ) # waveforms diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index aa7987b7fe..f998bcb8d4 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -3,7 +3,7 @@ from spikeinterface import SortingAnalyzer -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift, get_template_amplitudes +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift, get_template_amplitudes from spikeinterface.postprocessing import align_sorting _remove_strategies = ("minimum_shift", "highest_amplitude", "max_spikes") @@ -17,7 +17,6 @@ def remove_redundant_units( agreement_threshold: float = 0.2, duplicate_threshold: float = 0.8, remove_strategy: str = "minimum_shift", - peak_sign: str = "neg", extra_outputs: bool = False, ) -> BaseSorting | tuple[BaseSorting, list[tuple[int, int]]]: """ @@ -52,9 +51,6 @@ def remove_redundant_units( If shifts are equal then the "highest_amplitude" is used * "highest_amplitude" : keep the unit with the best amplitude on unshifted max. * "max_spikes" : keep the unit with more spikes - - peak_sign : "neg" | "pos" | "both", default: "neg" - Used when remove_strategy="highest_amplitude" extra_outputs : bool, default: False If True, will return the redundant pairs. unit_peak_shifts : dict @@ -79,7 +75,8 @@ def remove_redundant_units( if align and unit_peak_shifts is None: assert sorting_analyzer is not None, "For align=True must give a SortingAnalyzer or explicit unit_peak_shifts" - unit_peak_shifts = get_template_extremum_channel_peak_shift(sorting_analyzer, peak_sign=peak_sign) + + unit_peak_shifts = get_template_main_channel_peak_shift(sorting_analyzer, with_dict=True) if align: sorting_aligned = align_sorting(sorting, unit_peak_shifts) @@ -97,7 +94,7 @@ def remove_redundant_units( if remove_strategy in ("minimum_shift", "highest_amplitude"): # this is the values at spike index ! - peak_values = get_template_amplitudes(sorting_analyzer, peak_sign=peak_sign, mode="at_index") + peak_values = get_template_amplitudes(sorting_analyzer, peak_mode="at_index", override_peak_error=True) peak_values = {unit_id: np.max(np.abs(values)) for unit_id, values in peak_values.items()} if remove_strategy == "minimum_shift": diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 4761660f68..97a7ce7678 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -4,7 +4,7 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs import spikeinterface.widgets as sw -from spikeinterface.core import get_template_extremum_channel, get_template_extremum_amplitude +from spikeinterface.core import get_template_main_channel_amplitude from spikeinterface.postprocessing import compute_correlograms @@ -99,10 +99,11 @@ def export_report( # unit list units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" - units["max_on_channel_id"] = pd.Series( - get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="id") - ) - units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign)) + # max_on_channel_id is kept (oold name) + units["max_on_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) + units["main_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) + + units["amplitude"] = pd.Series(get_template_main_channel_amplitude(sorting_analyzer)) units.to_csv(output_folder / "unit list.csv", sep="\t") unit_colors = sw.get_unit_colors(sorting) diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index d86df0c6c8..c9c0879331 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -23,14 +23,21 @@ def make_sorting_analyzer(sparse=True, with_group=False): ) if with_group: + + # this is a bit painful. We need to figure out which units belong to which group + initial_sorting_analyzer = create_sorting_analyzer( + sorting=sorting, recording=recording, format="memory", sparse=True + ) + main_channels = initial_sorting_analyzer.main_channel_index + sorting_groups = [0 if main_channel < 4 else 1 for main_channel in main_channels] recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1]) - sorting.set_property("group", [0, 0, 1, 1]) + sorting.set_property("group", sorting_groups) + # Now we know that, we can make an analyzer with a working sparsity_group sorting_analyzer_unused = create_sorting_analyzer( sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=None ) sparsity_group = compute_sparsity(sorting_analyzer_unused, method="by_property", by_property="group") - sorting_analyzer = create_sorting_analyzer( sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=sparsity_group ) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 928fc999ab..7d52d4f500 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -100,7 +100,7 @@ def export_to_ibl_gui( output_folder.mkdir(parents=True, exist_ok=True) ### Save spikes info ### - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_indices) # spikes.clusters @@ -135,7 +135,8 @@ def export_to_ibl_gui( np.save(output_folder / "clusters.waveforms.npy", templates) # cluster channels - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + cluster_channels = np.array(list(extremum_channel_indices.values()), dtype="int32") np.save(output_folder / "clusters.channels.npy", cluster_channels) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 244575aff2..1776f60ec8 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -12,7 +12,6 @@ _ensure_seed, synthesize_amplitude_factor, ) -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.motion import Motion @@ -127,8 +126,7 @@ def select_templates( min_amplitude is not None or max_amplitude is not None or min_depth is not None or max_depth is not None ), "At least one of min_amplitude, max_amplitude, min_depth, max_depth should be provided" # get template amplitudes and depth - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) mask = np.ones(templates.num_units, dtype=bool) if min_amplitude is not None or max_amplitude is not None: @@ -142,7 +140,7 @@ def select_templates( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) if min_amplitude is not None: mask &= amplitudes >= min_amplitude if max_amplitude is not None: @@ -151,7 +149,7 @@ def select_templates( assert templates.probe is not None, "Templates should have a probe to filter based on depth" depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] if min_depth is not None: mask &= unit_depths >= min_depth if max_depth is not None: @@ -190,8 +188,7 @@ def scale_template_to_range( Templates The scaled templates. """ - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) # get amplitudes if amplitude_function == "ptp": @@ -203,7 +200,7 @@ def scale_template_to_range( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) # scale templates to meet min_amplitude and max_amplitude range min_scale = np.min(amplitudes) / min_amplitude @@ -264,11 +261,10 @@ def relocate_templates( """ seed = _ensure_seed(seed) - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] assert margin >= 0, "margin should be positive" top_margin = np.max(channel_depths) + margin diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index 1f404ea3f7..03d7cd5ef8 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -107,9 +107,9 @@ def split_sorting_by_amplitudes( rng = np.random.default_rng(seed) fs = sorting_analyzer.sampling_frequency - from spikeinterface.core.template_tools import get_template_extremum_channel - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) new_spikes = spikes[0].copy() amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 3d6911b383..e633b8429a 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -17,8 +17,7 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core import SortingAnalyzer, NumpySorting from spikeinterface.core.template_tools import ( - get_template_extremum_channel, - get_template_extremum_amplitude, + get_template_main_channel_amplitude, get_dense_templates_array, ) from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate @@ -143,9 +142,6 @@ class PresenceRatio(BaseMetric): def compute_snrs( sorting_analyzer, unit_ids=None, - peak_sign: str = "both", - peak_mode: str = "extremum", - operator: str = "median", ): """ Compute signal to noise ratio. @@ -156,12 +152,14 @@ def compute_snrs( A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the SNR. If None, all units are used. - peak_sign : "neg" | "pos" | "both", default: "neg" - The sign of the template to compute best channels. - peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" - How to compute the amplitude. - Extremum takes the maxima/minima - At_index takes the value at t=sorting_analyzer.nbefore. + peak_sign_for_signal : "neg" | "pos" | "both", default: "neg" + Which extremum of the template to use for the "signal". Use either the minimum ("neg"), + maximum ("pos") or absmax ("both") + peak_mode_for_signal : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + How to compute the amplitude: + "extremum" takes the maxima/minima + "at_index" takes the value at t=sorting_analyzer.nbefore. + "peak_to_peak" takes the difference between the maximum and minimum operator : "median" | "average", default: "median" The operator to apply to retrieve templates and amplitudes. @@ -177,35 +175,15 @@ def compute_snrs( noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() - assert peak_sign in ("neg", "pos", "both") - assert peak_mode in ("extremum", "at_index", "peak_to_peak") - - channel_ids = sorting_analyzer.channel_ids - - if operator not in ("median", "average"): - raise ValueError(f"Invalid operator: {operator}. Expected 'median' or 'average'.") - if operator == "median" and not sorting_analyzer.has_extension("waveforms"): - warnings.warn( - "Operator 'median' requires 'waveforms' extension. Falling back to 'average'. " - "To use 'median', please compute the 'waveforms' extension first with: analyzer.compute('waveforms')" - ) - operator = "average" - - extremum_channels_ids = get_template_extremum_channel( - sorting_analyzer, peak_sign=peak_sign, mode=peak_mode, operator=operator - ) - unit_amplitudes = get_template_extremum_amplitude( - sorting_analyzer, peak_sign=peak_sign, mode=peak_mode, operator=operator - ) - - # make a dict to access by chan_id - noise_levels = dict(zip(channel_ids, noise_levels)) + main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=False) + unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer, with_dict=False) snrs = {} for unit_id in unit_ids: - chan_id = extremum_channels_ids[unit_id] - noise = noise_levels[chan_id] - amplitude = unit_amplitudes[unit_id] + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) + chan_ind = main_channel_index[unit_index] + noise = noise_levels[chan_ind] + amplitude = unit_amplitudes[unit_index] snrs[unit_id] = np.abs(amplitude) / noise return snrs @@ -214,7 +192,7 @@ def compute_snrs( class SNR(BaseMetric): metric_name = "snr" metric_function = compute_snrs - metric_params = {"peak_sign": "both", "peak_mode": "extremum"} + metric_params = {} metric_columns = {"snr": float} metric_descriptions = {"snr": "Signal to noise ratio for each unit."} depend_on = ["noise_levels", "templates"] @@ -1430,7 +1408,10 @@ def compute_sd_ratio( noise_levels = get_noise_levels( sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs ) - best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", peak_sign=peak_sign) + + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + + n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) if correct_for_template_itself: n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) @@ -1466,7 +1447,7 @@ def compute_sd_ratio( else: unit_std = np.std(spk_amp) - best_channel = best_channels[unit_id] + best_channel = main_channels[unit_id] std_noise = noise_levels[best_channel] if correct_for_template_itself: diff --git a/src/spikeinterface/metrics/quality/pca_metrics.py b/src/spikeinterface/metrics/quality/pca_metrics.py index bede198016..26bed55ac5 100644 --- a/src/spikeinterface/metrics/quality/pca_metrics.py +++ b/src/spikeinterface/metrics/quality/pca_metrics.py @@ -653,7 +653,7 @@ def nearest_neighbors_isolation( if sorting_analyzer.is_sparse(): sparsity = sorting_analyzer.sparsity else: - sparsity = compute_sparsity(sorting_analyzer, method="radius", peak_sign=peak_sign, radius_um=radius_um) + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=radius_um) closest_chans_target_unit = sparsity.unit_id_to_channel_indices[this_unit_id] n_channels_target_unit = len(closest_chans_target_unit) # select other units that have a minimum spatial overlap with target unit @@ -867,7 +867,7 @@ def nearest_neighbors_noise_overlap( if sorting_analyzer.is_sparse(): sparsity = sorting_analyzer.sparsity else: - sparsity = compute_sparsity(sorting_analyzer, method="radius", peak_sign=peak_sign, radius_um=radius_um) + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=radius_um) noise_cluster = noise_cluster[:, :, sparsity.unit_id_to_channel_indices[this_unit_id]] # compute weighted noise snippet (Z) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index a72e3d3775..04b7e12e5d 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -3,7 +3,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension from spikeinterface.core.sorting_tools import cast_periods_to_unit_period_dtype @@ -162,7 +161,8 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): all_labels = sorting_analyzer.sorting.unit_ids[spike_unit_indices] # Get extremum channels for neighbor selection in sparse mode - extremum_channels = get_template_extremum_channel(sorting_analyzer, peak_sign=self.params["peak_sign"]) + + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) # Pre-compute spike counts and firing rates if advanced NN metrics are requested advanced_nn_metrics = ["nn_advanced"] # Our grouped advanced NN metric @@ -177,7 +177,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): if sorting_analyzer.is_sparse(): neighbor_channel_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ - other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids + other_unit for other_unit in unit_ids if main_channels[other_unit] in neighbor_channel_ids ] neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) else: diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index e267b176ce..8a288d21fe 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -639,15 +639,20 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): ) # check that, when parameters are changed, the data and metadata are updated - old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + old_presence_ratio_data = deepcopy(quality_metric_extension.get_data()["presence_ratio"].values) small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + { + "quality_metrics": { + "metric_names": ["presence_ratio"], + "metric_params": {"presence_ratio": {"bin_duration_s": "10"}}, + } + } ) new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") - new_snr_data = new_quality_metric_extension.get_data()["snr"].values + new_presence_ratio_data = new_quality_metric_extension.get_data()["presence_ratio"].values - assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" + assert np.all(old_presence_ratio_data != new_presence_ratio_data) + assert new_quality_metric_extension.params["metric_params"]["presence_ratio"]["bin_duration_s"] == "10" def test_metric_names_in_same_order(small_sorting_analyzer): diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index e9ac6838e2..37f00d4376 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -3,7 +3,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array +from spikeinterface.core.template_tools import get_dense_templates_array from .metrics import ( get_trough_and_peak_idx, @@ -238,7 +238,6 @@ def _prepare_data(self, sorting_analyzer, unit_ids): if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] min_thresh_detect_peaks_troughs = self.params.get("min_thresh_detect_peaks_troughs", 0.3) edge_exclusion_ms = self.params.get("edge_exclusion_ms", 0.09) @@ -255,10 +254,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) - operator = self.params.get("template_operator", "average") - extremum_channel_indices = get_template_extremum_channel( - sorting_analyzer, peak_sign=peak_sign, outputs="index", operator=operator - ) + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + + operator = self.params["template_operator"] all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True, operator=operator) analyzer_channel_locations = sorting_analyzer.get_channel_locations() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 67a46bff8a..948878028e 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -1,8 +1,7 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.core_tools import ms_to_samples -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore +from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -103,10 +102,7 @@ def _get_pipeline_nodes(self): else: cut_out_after = nafter - peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) # collisions handle_collisions = self.params["handle_collisions"] diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 7975097629..4d07eb8a5b 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -6,7 +6,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity -from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -99,7 +99,8 @@ def compute_monopolar_triangulation( chan_inds = sparsity.unit_id_to_channel_indices[unit_id] neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = get_template_extremum_channel(sorting_analyzer_or_templates, outputs="index") + + best_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) unit_location = np.zeros((unit_ids.size, 4), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -161,9 +162,7 @@ def compute_center_of_mass( assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" if sorting_analyzer_or_templates.sparsity is None: - sparsity = compute_sparsity( - sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um - ) + sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) else: sparsity = sorting_analyzer_or_templates.sparsity @@ -205,7 +204,7 @@ def compute_center_of_mass( def compute_grid_convolution( sorting_analyzer_or_templates: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: str = "neg", + invert_prototype_waveform: bool = True, radius_um: float = 40.0, upsampling_um: float = 5, sigma_ms: float = 0.25, @@ -223,8 +222,8 @@ def compute_grid_convolution( A SortingAnalyzer or Templates object unit_ids: str | int | None A list of unit_id to restrci the computation - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute best channels + invert_prototype_waveform: bool, default: True + If True, will multiply the prototype waveform by -1 radius_um : float, default: 40.0 Radius to consider for the fake templates upsampling_um : float, default: 5 @@ -269,7 +268,7 @@ def compute_grid_convolution( time_axis = np.arange(-nbefore, nafter) * 1000 / fs if prototype is None: prototype = np.exp(-(time_axis**2) / (2 * (sigma_ms**2))) - if peak_sign == "neg": + if invert_prototype_waveform: prototype *= -1 prototype = prototype[:, np.newaxis] @@ -278,7 +277,7 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - peak_channels = get_template_extremum_channel(sorting_analyzer_or_templates, peak_sign, outputs="index") + main_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) weights_sparsity_mask = weights > 0 @@ -286,7 +285,7 @@ def compute_grid_convolution( unit_location = np.zeros((len(unit_ids), 3), dtype="float64") for i, unit_id in enumerate(unit_ids): - main_chan = peak_channels[unit_id] + main_chan = main_channels[unit_id] wf = templates[i, :, :] nearest_mask = nearest_template_mask[main_chan, :] channel_mask = np.sum(weights_sparsity_mask[:, :, nearest_mask], axis=(0, 2)) > 0 @@ -661,15 +660,10 @@ def get_convolution_weights( def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: Literal["neg", "pos", "both"] = "neg", - mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", - operator: Literal["average", "median"] = "average", ) -> np.ndarray: """ Localize a unit using max channel. - This uses internally `get_template_extremum_channel()` - Parameters ---------- @@ -677,30 +671,22 @@ def compute_location_max_channel( A SortingAnalyzer or Templates object unit_ids: list[str] | list[int] | None A list of unit_id to restrict the computation - peak_sign : "neg" | "pos" | "both" - Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" - Where the amplitude is computed - * "extremum" : take the peak value (max or min depending on `peak_sign`) - * "at_index" : take value at `nbefore` index - * "peak_to_peak" : take the peak-to-peak amplitude Returns ------- unit_locations: np.ndarray 2d """ - extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index", operator=operator - ) - contact_locations = templates_or_sorting_analyzer.get_channel_locations()[:, :2] + extremum_channels_index = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + + contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: unit_ids = templates_or_sorting_analyzer.unit_ids else: unit_ids = np.asarray(unit_ids) unit_locations = np.zeros((unit_ids.size, 2), dtype="float32") for i, unit_id in enumerate(unit_ids): - unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]] + unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id], :2] return unit_locations diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 1e28b751d1..eed0314816 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -2,7 +2,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type @@ -12,31 +12,24 @@ class ComputeSpikeAmplitudes(BaseSpikeVectorExtension): Needs "templates" to be computed first. Computes spike amplitudes from the template's peak channel for every spike. - - Parameters - ---------- - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute extremum channel used to retrieve spike amplitudes. """ extension_name = "spike_amplitudes" depend_on = ["templates"] nodepipeline_variables = ["amplitudes"] - def _set_params(self, peak_sign="neg"): - return super()._set_params(peak_sign=peak_sign) + def _set_params(self): + return super()._set_params() def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["peak_sign"] return_in_uV = self.sorting_analyzer.return_in_uV - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) - peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + + peak_shifts = get_template_main_channel_peak_shift(self.sorting_analyzer, with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d11d4eb2c7..5fb15a6ea1 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -1,5 +1,4 @@ from spikeinterface.core.sortinganalyzer import register_result_extension -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.node_pipeline import SpikeRetriever from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -14,8 +13,6 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - peak_sign : "neg" | "pos" | "both", default: "neg" - The peak sign to use when looking for the template extremum channel. spike_retriever_kwargs : dict Arguments to control the spike retriever behavior. See `spikeinterface.sortingcomponents.peak_localization.SpikeRetriever`. @@ -38,14 +35,12 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): def _handle_backward_compatibility_on_load(self): # For backwards compatibility - this renames spike_retriver_kwargs to spike_retriever_kwargs if "spike_retriver_kwargs" in self.params: - self.params["peak_sign"] = self.params["spike_retriver_kwargs"].get("peak_sign", "neg") self.params["spike_retriever_kwargs"] = self.params.pop("spike_retriver_kwargs") def _set_params( self, ms_before=0.5, ms_after=0.5, - peak_sign="neg", spike_retriever_kwargs=None, method="center_of_mass", method_kwargs={}, @@ -55,7 +50,6 @@ def _set_params( return super()._set_params( ms_before=ms_before, ms_after=ms_after, - peak_sign=peak_sign, spike_retriever_kwargs=spike_retriever_kwargs, method=method, method_kwargs=method_kwargs, @@ -66,10 +60,8 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["peak_sign"] - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) retriever_kwargs = { "channel_from_template": True, diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index e3c45fe8ef..df80a2a34c 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -7,7 +7,7 @@ create_sorting_analyzer, generate_ground_truth_recording, set_global_job_kwargs, - get_template_extremum_amplitude, + get_template_main_channel_amplitude, ) from spikeinterface.core.generate import inject_some_split_units @@ -85,7 +85,7 @@ def get_dataset_to_merge(): analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) analyzer_raw.compute(["random_spikes", "templates"]) # select 3 largest templates to split - sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + sort_by_amp = np.argsort(get_template_main_channel_amplitude(analyzer_raw, with_dict=False))[::-1] split_ids = sorting.unit_ids[sort_by_amp][:3] sorting_with_splits, split_unit_ids = inject_some_split_units( @@ -116,7 +116,7 @@ def get_dataset_to_split(): analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) analyzer_raw.compute(["random_spikes", "templates"]) # select 3 largest templates to split - sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + sort_by_amp = np.argsort(get_template_main_channel_amplitude(analyzer_raw, with_dict=False))[::-1] large_units = sorting.unit_ids[sort_by_amp][:2] return recording, sorting, large_units diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index f457bd9250..645314237a 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -8,7 +8,7 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( "params", [ - dict(method="center_of_mass", peak_sign="both"), + dict(method="center_of_mass"), dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=True)), dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=False)), dict(method="center_of_mass"), diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 85e0da7220..f164facf1f 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -376,7 +376,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sparsity_snr = compute_sparsity( dense_templates, method="snr", - amplitude_mode="peak_to_peak", noise_levels=noise_levels, threshold=threshold, ) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 3ed0c7ee61..f3b29fa541 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,13 +53,11 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - from spikeinterface.core.template_tools import get_template_extremum_channel + main_channels = self.templates.get_main_channels(peak_sign=self.peak_sign, outputs="index", with_dict=False) - best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") - best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( - channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], axis=2 + channel_locations[:, None] - channel_locations[main_channels][np.newaxis, :], axis=2 ) self.neighborhood_mask = template_distances <= neighborhood_radius_um else: @@ -70,7 +68,9 @@ def __init__( from spikeinterface.core.sparsity import compute_sparsity sparsity = compute_sparsity( - templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign + templates, + method="radius", + radius_um=sparsity_radius_um, ) else: sparsity = templates.sparsity diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index e228d70953..f96ee035e9 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -3,7 +3,6 @@ import numpy as np from spikeinterface.core import ( get_channel_distances, - get_template_extremum_channel, ) from spikeinterface.core.core_tools import ms_to_samples @@ -222,12 +221,11 @@ def __init__( self.sparse_templates_array_static = templates.templates_array self.dtype = self.sparse_templates_array_static.dtype - extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") # as numpy vector - self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") + self.main_channels = templates.get_main_channels(peak_sign=peak_sign, outputs="index", with_dict=False) channel_locations = templates.probe.contact_positions - unit_locations = channel_locations[self.extremum_channel] + unit_locations = channel_locations[self.main_channels] self.channel_locations = channel_locations # distance between units diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 3bb3378775..26a28b67f1 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -13,7 +13,7 @@ from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.sorting_tools import get_numba_vector_to_list_of_spiketrain from spikeinterface.core.core_tools import ms_to_samples @@ -602,7 +602,6 @@ def clean_templates( sparsity = compute_sparsity( templates, method="snr", - amplitude_mode="peak_to_peak", noise_levels=noise_levels, threshold=sparsify_threshold, ) @@ -620,7 +619,7 @@ def clean_templates( if max_jitter_ms is not None: max_jitter = int(max_jitter_ms * templates.sampling_frequency / 1000.0) n_before = len(templates.unit_ids) - shifts = get_template_extremum_channel_peak_shift(templates) + shifts = get_template_main_channel_peak_shift(templates, with_dict=True) to_select = [] for unit_id in templates.unit_ids: if np.abs(shifts[unit_id]) <= max_jitter: @@ -637,7 +636,6 @@ def clean_templates( sparsity = compute_sparsity( templates.to_dense(), method="snr", - amplitude_mode="peak_to_peak", noise_levels=noise_levels, threshold=min_snr, ) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 8505a706e9..952165fd24 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -120,9 +120,9 @@ def __init__( sparsity = sorting_analyzer.sparsity else: if sparsity is None: - # in this case, we construct a sparsity dictionary only with the best channel - extremum_channel_ids = get_template_extremum_channel(sorting_analyzer) - unit_id_to_channel_ids = {u: [ch] for u, ch in extremum_channel_ids.items()} + # in this case, we construct a sparsity dictionary only with the main channel + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) + unit_id_to_channel_ids = {u: [ch] for u, ch in main_channels.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_analyzer.unit_ids, diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index ba2f939b80..6551f6e7b0 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -4,7 +4,7 @@ from .utils import get_unit_colors -from spikeinterface.core.template_tools import get_template_extremum_amplitude +from spikeinterface.core.template_tools import get_template_main_channel_amplitude class UnitDepthsWidget(BaseWidget): @@ -20,13 +20,9 @@ class UnitDepthsWidget(BaseWidget): by matplotlib. If None, default colors are chosen using the `get_some_colors` function. depth_axis : int, default: 1 The dimension of unit_locations that is depth - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of peak for amplitudes """ - def __init__( - self, sorting_analyzer, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs - ): + def __init__(self, sorting_analyzer, unit_colors=None, depth_axis=1, backend=None, **backend_kwargs): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) @@ -42,7 +38,7 @@ def __init__( unit_locations = ulc.get_data(outputs="numpy") unit_depths = unit_locations[:, depth_axis] - unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign) + unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer) unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids]) num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="array") diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 6a212b1d0e..7d88a523d4 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -2,7 +2,6 @@ import numpy as np from probeinterface import ProbeGroup -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from .base import BaseWidget, to_attr @@ -84,10 +83,10 @@ def __init__( if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])): warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) for unit_id in unit_ids: if np.any(np.isnan(unit_locations[unit_id])): - unit_locations[unit_id] = channel_locations[extremum_channel_indices[unit_id]] + unit_locations[unit_id] = channel_locations[main_channels[unit_id]] data_plot = dict( all_unit_ids=sorting.unit_ids, diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index ec781b5470..0ab6800462 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -3,7 +3,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -135,12 +134,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] if np.isnan(x) or np.isnan(y): warnings.warn(f"Unit {unit_id} location contains NaN values. Replacing NaN extremum channel location.") - x, y = sorting_analyzer.get_channel_locations()[extremum_channel_indices[unit_id]] + x, y = sorting_analyzer.get_channel_locations()[main_channels[unit_id]] ax_unit_locations.set_xlim(x - 80, x + 80) ax_unit_locations.set_ylim(y - 250, y + 250) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 24994bb570..fb9050a9b7 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -3,7 +3,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core import ChannelSparsity, get_template_extremum_channel +from spikeinterface.core import ChannelSparsity class UnitWaveformDensityMapWidget(BaseWidget): @@ -23,8 +23,6 @@ class UnitWaveformDensityMapWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored use_max_channel : bool, default: False Use only the max channel - peak_sign : "neg" | "pos" | "both", default: "neg" - Used to detect max channel only when use_max_channel=True unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. @@ -41,7 +39,6 @@ def __init__( sparsity=None, same_axis=False, use_max_channel=False, - peak_sign="neg", unit_colors=None, backend=None, **backend_kwargs, @@ -59,9 +56,7 @@ def __init__( if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = get_template_extremum_channel( - sorting_analyzer, mode="extremum", peak_sign=peak_sign, outputs="index" - ) + max_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all if sorting_analyzer.is_sparse():