From 6dd726997b86310640378550e6fc8da27c6b659a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Feb 2026 16:59:55 +0100 Subject: [PATCH 01/20] Try to remove get_template_extremum_channel() --- .../benchmark/benchmark_clustering.py | 1 - .../tests/test_benchmark_clustering.py | 5 +- .../tests/test_benchmark_peak_detection.py | 3 +- src/spikeinterface/core/basesorting.py | 7 +- src/spikeinterface/core/generate.py | 4 + src/spikeinterface/core/node_pipeline.py | 2 + src/spikeinterface/core/sortinganalyzer.py | 89 +++++++++++++++++-- src/spikeinterface/core/sparsity.py | 62 ++++++++++--- src/spikeinterface/core/template.py | 24 +++++ src/spikeinterface/core/template_tools.py | 88 ++++++++++++++++++ .../core/tests/test_node_pipeline.py | 6 +- .../core/tests/test_sortinganalyzer.py | 31 +++++-- src/spikeinterface/exporters/report.py | 9 +- src/spikeinterface/exporters/to_ibl.py | 6 +- src/spikeinterface/generation/hybrid_tools.py | 19 ++-- .../generation/splitting_tools.py | 4 +- .../metrics/quality/misc_metrics.py | 15 ++-- .../metrics/quality/quality_metrics.py | 6 +- .../metrics/template/template_metrics.py | 5 +- .../postprocessing/amplitude_scalings.py | 7 +- .../postprocessing/localization_tools.py | 19 ++-- .../postprocessing/spike_amplitudes.py | 7 +- .../postprocessing/spike_locations.py | 5 +- .../sortingcomponents/matching/nearest.py | 6 +- .../sortingcomponents/matching/tdc_peeler.py | 7 +- .../widgets/spikes_on_traces.py | 7 +- src/spikeinterface/widgets/unit_locations.py | 5 +- src/spikeinterface/widgets/unit_summary.py | 6 +- .../widgets/unit_waveforms_density_map.py | 7 +- 29 files changed, 346 insertions(+), 116 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index ba9fa53a51..997eac878a 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -13,7 +13,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel class ClusteringBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index be1cf18fbf..5660b68fda 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -5,8 +5,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel from pathlib import Path @@ -33,7 +31,8 @@ def test_benchmark_clustering(create_cache_folder): # sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) # sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index") + extremum_channel_inds = gt_analyzer.get_main_channel(outputs="index", with_dict=True) + spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index b9207caaa3..86b6bde5c5 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -6,7 +6,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel @pytest.mark.skip() @@ -30,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder): sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("templates", **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index e17731c70e..b11801c0be 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -15,6 +15,9 @@ class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. """ + _main_properties = [ + "main_channel_index", + ] def __init__(self, sampling_frequency: float, unit_ids: list): BaseExtractor.__init__(self, unit_ids) @@ -786,6 +789,7 @@ def _compute_and_cache_spike_vector(self) -> None: self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices + # TODO sam : change extremum_channel_inds to main_channel_index with vector def to_spike_vector( self, concatenated=True, @@ -806,7 +810,8 @@ def to_spike_vector( extremum_channel_inds : None or dict, default: None If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index". This can be convinient for computing spikes postion after sorter. - This dict can be computed with `get_template_extremum_channel(we, outputs="index")` + This dict can be given by analyzer.get_main_channel(outputs="index", with_dict=True) + use_cache : bool, default: True When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). This caching only occurs when extremum_channel_inds=None. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 5d7ca1917a..454540663a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2444,6 +2444,10 @@ def generate_ground_truth_recording( **generate_templates_kwargs, ) sorting.set_property("gt_unit_locations", unit_locations) + distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2) + main_channel_index = np.argmin(distances, axis=1) + sorting.set_property("main_channel_index", main_channel_index) + else: assert templates.shape[0] == num_units diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1609f11d17..01340f2e68 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -132,6 +132,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea return (local_peaks,) +# TODO sam replace extremum_channels_indices by main_channel_index + # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8de45210cd..fe125193d3 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -45,6 +45,10 @@ from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open from .node_pipeline import run_node_pipeline +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection + + # high level function def create_sorting_analyzer( @@ -52,6 +56,10 @@ def create_sorting_analyzer( recording, format="memory", folder=None, + main_channel_index=None, + main_channel_peak_sign="both", + main_channel_mode="extremum", + num_spikes_for_main_channel=100, sparse=True, sparsity=None, set_sparsity_by_dict_key=False, @@ -59,7 +67,9 @@ def create_sorting_analyzer( return_in_uV=True, overwrite=False, backend_options=None, - **sparsity_kwargs, + sparsity_kwargs=None, + seed=None, + **job_kwargs ) -> "SortingAnalyzer": """ Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. @@ -69,6 +79,11 @@ def create_sorting_analyzer( This object will be also use used for plotting purpose. + The main_channel_index can be externally provided. If not then this is taken from + sorting property. If not then the main_channel_index is estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. + Parameters ---------- @@ -82,6 +97,12 @@ def create_sorting_analyzer( The mode to store analyzer. If "folder", the analyzer is stored on disk in the specified folder. The "folder" argument must be specified in case of mode "folder". If "memory" is used, the analyzer is stored in RAM. Use this option carefully! + main_channel_index : None | np.array + The main_channel_index can be externally provided + main_channel_peak_sign : "both" | "neg" + In case when the main_channel_index is estimated wich sign to consider "both" or "neg". + num_spikes_for_main_channel : int, default: 100 + How many spikes per units to compute the main channel. sparse : bool, default: True If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. @@ -107,8 +128,8 @@ def create_sorting_analyzer( * storage_options: dict | None (fsspec storage options) * saving_options: dict | None (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) - - sparsity_kwargs : keyword arguments + sparsity_kwargs : dict | None + Dict given to estimate the sparsity. Returns ------- @@ -144,6 +165,9 @@ def create_sorting_analyzer( sparsity off (or give external sparsity) like this. """ + if sparsity_kwargs is None: + sparsity_kwargs = dict() + if isinstance(sorting, dict) and isinstance(recording, dict): if sorting.keys() != recording.keys(): @@ -168,9 +192,14 @@ def create_sorting_analyzer( return_in_uV=return_in_uV, overwrite=overwrite, backend_options=backend_options, - **sparsity_kwargs, + sparsity_kwargs=sparsity_kwargs, + **job_kwargs ) + # normal case + + + if format != "memory": if format == "zarr": if not is_path_remote(folder): @@ -182,6 +211,26 @@ def create_sorting_analyzer( else: shutil.rmtree(folder) + + + # retrieve or compute the main channel index per unit + if main_channel_index is None: + if "main_channel_index" in sorting.get_property_keys(): + main_channel_index = sorting.get_property("main_channel_index") + + if main_channel_index is None: + # this is weird but due to the cyclic import + from .template_tools import estimate_main_channel_from_recording + main_channel_index = estimate_main_channel_from_recording( + recording, + sorting, + main_channel_peak_sign=main_channel_peak_sign, + mode=main_channel_mode, + num_spikes_for_main_channel=num_spikes_for_main_channel, + seed=seed, + **job_kwargs + ) + # handle sparsity if sparsity is not None: # some checks @@ -192,8 +241,9 @@ def create_sorting_analyzer( assert np.array_equal( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" + assert all(sparsity.mask[u, c] for u, c in enumerate(main_channel_index)), "sparsity si not constistentent with main_channel_index" elif sparse: - sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, main_channel_index=main_channel_index, **sparsity_kwargs) else: sparsity = None @@ -215,6 +265,7 @@ def create_sorting_analyzer( recording, format=format, folder=folder, + main_channel_index=main_channel_index, sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, @@ -347,6 +398,7 @@ def create( "zarr", ] = "memory", folder=None, + main_channel_index=None, sparsity=None, return_scaled=None, return_in_uV=True, @@ -381,7 +433,10 @@ def create( from spikeinterface.curation.remove_excess_spikes import RemoveExcessSpikesSorting sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) - + + # This will ensure that the sorting saved always will have this main_channel + sorting.set_property("main_channel_index", main_channel_index) + if format == "memory": sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, rec_attributes=None) elif format == "binary_folder": @@ -541,6 +596,8 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV def load_from_binary_folder(cls, folder, recording=None, backend_options=None): from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -713,6 +770,8 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): import zarr from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) @@ -881,6 +940,24 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n Array of values for the property """ return self.sorting.get_property(key, ids=ids) + + def get_main_channel(self, outputs="index", with_dict=False): + """ + + """ + main_channel_index = self.get_sorting_property("main_channel_index") + if outputs is "index": + main_chans = main_channel_index + elif outputs is "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + def are_units_mergeable( self, diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index ee19601068..70bb4cdb46 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -365,9 +365,39 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): + def from_radius_and_main_channel(cls, unit_ids, channel_ids, main_channel_index, channel_locations, radius_um): """ - Construct sparsity from a radius around the best channel. + Construct sparsity from a radius around the main channel. + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + main_channel_index : np.array + Main channel index per units. + channel_locations : np.array + Channel locations of the recording. + radius_um : float + Radius in um for "radius" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") + distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) + for unit_ind, main_chan in enumerate(main_channel_index): + (chan_inds,) = np.nonzero(distances[main_chan, :] <= radius_um) + mask[unit_ind, chan_inds] = True + return cls(mask, unit_ids, channel_ids) + + + @classmethod + def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="both"): + """ + Construct sparsity from a radius around the main channel. Use the "radius_um" argument to specify the radius in um. Parameters @@ -384,16 +414,14 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): sparsity : ChannelSparsity The estimated sparsity. """ - from .template_tools import get_template_extremum_channel - mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) channel_locations = templates_or_sorting_analyzer.get_channel_locations() distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, outputs="index") + main_channel_index = templates_or_sorting_analyzer.get_main_channel(outputs="index") for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_ind = best_chan[unit_id] + chan_ind = main_channel_index[unit_ind] (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) mask[unit_ind, chan_inds] = True return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @@ -724,6 +752,7 @@ def estimate_sparsity( amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, noise_levels: np.ndarray | list | None = None, + main_channel_index: np.ndarray | list | None = None, **job_kwargs, ): """ @@ -732,11 +761,10 @@ def estimate_sparsity( For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. These can be computed with the `get_noise_levels()` function. - Contrary to the previous implementation: - * all units are computed in one read of recording - * it doesn't require a folder - * it doesn't consume too much memory - * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + If main_channel_index is given and method="radius" then there is not need estimate + the templates otherwise the templates must be estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. @@ -755,6 +783,9 @@ def estimate_sparsity( noise_levels : np.array | None, default: None Noise levels required for the "snr" and "energy" methods. You can use the `get_noise_levels()` function to compute them. + main_channel_index : np.array | None, default: None + Main channel indicies can be provided in case of method="radius", this avoid the + `estimate_templates_with_accumulator()` which is slow. {} Returns @@ -779,7 +810,14 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - if method != "by_property": + if method == "radius" and main_channel_index is not None: + assert main_channel_index.size == sorting.unit_ids.size + chan_locs = recording.get_channel_locations() + sparsity = ChannelSparsity.from_radius_and_main_channel( + sorting.unit_ids, recording.channel_ids, main_channel_index, chan_locs, radius_um + ) + + elif method != "by_property": nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 91d25bece6..151f6e2b8d 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -483,3 +483,27 @@ def get_channel_locations(self) -> np.ndarray: assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" channel_locations = self.probe.contact_positions return channel_locations + + def get_main_channel(self, + main_channel_peak_sign: "neg" | "both" | "pos" = "both", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + outputs="index", + with_dict=True + ): + from .template_tools import _get_main_channel_from_template_array + + templates_array = self.get_dense_templates() + main_channel_index = _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, self.nbefore) + + if outputs is "index": + main_chans = main_channel_index + elif outputs is "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index ecc878e1f4..fe04540476 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -2,8 +2,13 @@ import numpy as np from .template import Templates +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection from .sortinganalyzer import SortingAnalyzer +import warnings + + def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: bool = True): """ @@ -126,6 +131,86 @@ def get_template_amplitudes( return peak_values + +def _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, nbefore): + # Step1 : max on time axis + if mode == "extremum": + if main_channel_peak_sign == "both": + values = np.max(np.abs(templates_array), axis=1) + elif main_channel_peak_sign == "neg": + values = -np.min(templates_array, axis=1) + elif main_channel_peak_sign == "pos": + values = np.max(templates_array, axis=1) + elif mode == "at_index": + if main_channel_peak_sign == "both": + values = np.abs(templates_array[:, nbefore, :]) + elif main_channel_peak_sign in ["neg", "pos"]: + values = templates_array[:, nbefore, :] + elif mode == "peak_to_peak": + values = np.ptp(templates_array, axis=1) + + # Step2: max on channel axis + main_channel_index = np.argmax(values, axis=1) + + return main_channel_index + +def estimate_main_channel_from_recording( + recording, + sorting, + main_channel_peak_sign: "neg" | "both" | "pos" = "both", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + num_spikes_for_main_channel=100, + ms_before = 1.0, + ms_after = 2.5, + seed=None, + **job_kwargs +): + """ + Estimate the main channel from recording using `estimate_templates_with_accumulator()` + + """ + + if main_channel_peak_sign == "pos": + warnings.warn( + "estimate_main_channel_from_recording() with peak_sign='pos' is a strange case maybe you " \ + "should revert the traces instead" + ) + + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_main_channel, + margin_size=max(nbefore, nafter), + seed=seed, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_in_uV=False, + job_name="estimate_main_channel", + **job_kwargs, + ) + + main_channel_index = _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, nbefore) + + return main_channel_index + + + + + def get_template_extremum_channel( templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", @@ -156,6 +241,9 @@ def get_template_extremum_channel( Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ + warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channel() instead") + # TODO make a better logic here + assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`" assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" assert outputs in ("id", "index"), "`outputs` must be either `id` or `index`" diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 4f8e600a3f..74131dac8d 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface import create_sorting_analyzer, generate_ground_truth_recording from spikeinterface.core.base import spike_peak_dtype from spikeinterface.core.job_tools import divide_recording_into_chunks @@ -80,7 +80,7 @@ def test_run_node_pipeline(cache_folder_creation): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channel( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) @@ -202,7 +202,7 @@ def test_skip_after_n_peaks_and_recording_slices(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channel( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..7be23dbf85 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -52,6 +52,12 @@ def dataset(): def test_SortingAnalyzer_memory(tmp_path, dataset): recording, sorting = dataset + + # Note the sorting contain already main_channel_index + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + assert np.array_equal(sorting_analyzer.get_main_channel() , sorting.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -75,6 +81,16 @@ def test_SortingAnalyzer_memory(tmp_path, dataset): assert "quality" in sorting_analyzer.sorting.get_property_keys() assert "number" in sorting_analyzer.sorting.get_property_keys() + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + + # Create when main_channel_index is not given : this is estimated + sorting2 = sorting.clone() + sorting2._properties.pop("main_channel_index") + print(sorting2.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting2, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting2, cache_folder=tmp_path) + def test_SortingAnalyzer_binary_folder(tmp_path, dataset): recording, sorting = dataset @@ -615,12 +631,11 @@ def _set_params(self, param0=5.5): return params def _get_pipeline_nodes(self): - from spikeinterface.core.template_tools import get_template_extremum_channel recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") + extremum_channel_inds = self.sorting_analyzer.get_main_channel( outputs="index", with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) @@ -718,9 +733,9 @@ def test_runtime_dependencies(dataset): tmp_path = Path("test_SortingAnalyzer") dataset = get_dataset() test_SortingAnalyzer_memory(tmp_path, dataset) - test_SortingAnalyzer_binary_folder(tmp_path, dataset) - test_SortingAnalyzer_zarr(tmp_path, dataset) - test_SortingAnalyzer_tmp_recording(dataset) - test_extension() - test_extension_params() - test_runtime_dependencies() + # test_SortingAnalyzer_binary_folder(tmp_path, dataset) + # test_SortingAnalyzer_zarr(tmp_path, dataset) + # test_SortingAnalyzer_tmp_recording(dataset) + # test_extension() + # test_extension_params() + # test_runtime_dependencies(dataset) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index fe9fb3ba52..8c12591c6e 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -6,7 +6,7 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs import spikeinterface.widgets as sw -from spikeinterface.core import get_template_extremum_channel, get_template_extremum_amplitude +from spikeinterface.core import get_template_extremum_amplitude from spikeinterface.postprocessing import compute_correlograms @@ -101,9 +101,10 @@ def export_report( # unit list units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" - units["max_on_channel_id"] = pd.Series( - get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="id") - ) + # max_on_channel_id is kept (oold name) + units["max_on_channel_id"] = sorting_analyzer.get_main_channel(outputs="id", with_dict=False) + units["main_channel_id"] = sorting_analyzer.get_main_channel(outputs="id", with_dict=False) + units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign)) units.to_csv(output_folder / "unit list.csv", sep="\t") diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 6559e89d52..b0445b4753 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -10,7 +10,6 @@ from spikeinterface.core import SortingAnalyzer, BaseRecording, get_random_data_chunks from spikeinterface.core.job_tools import fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.exporters import export_to_phy @@ -102,7 +101,7 @@ def export_to_ibl_gui( output_folder.mkdir(parents=True, exist_ok=True) ### Save spikes info ### - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_indices) # spikes.clusters @@ -137,7 +136,8 @@ def export_to_ibl_gui( np.save(output_folder / "clusters.waveforms.npy", templates) # cluster channels - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + cluster_channels = np.array(list(extremum_channel_indices.values()), dtype="int32") np.save(output_folder / "clusters.channels.npy", cluster_channels) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index bbab9262af..72d8cea634 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -13,7 +13,6 @@ InjectTemplatesRecording, _ensure_seed, ) -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.motion import Motion @@ -128,8 +127,8 @@ def select_templates( min_amplitude is not None or max_amplitude is not None or min_depth is not None or max_depth is not None ), "At least one of min_amplitude, max_amplitude, min_depth, max_depth should be provided" # get template amplitudes and depth - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) + mask = np.ones(templates.num_units, dtype=bool) if min_amplitude is not None or max_amplitude is not None: @@ -143,7 +142,7 @@ def select_templates( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) if min_amplitude is not None: mask &= amplitudes >= min_amplitude if max_amplitude is not None: @@ -152,7 +151,7 @@ def select_templates( assert templates.probe is not None, "Templates should have a probe to filter based on depth" depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] if min_depth is not None: mask &= unit_depths >= min_depth if max_depth is not None: @@ -191,8 +190,7 @@ def scale_template_to_range( Templates The scaled templates. """ - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) # get amplitudes if amplitude_function == "ptp": @@ -204,7 +202,7 @@ def scale_template_to_range( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) # scale templates to meet min_amplitude and max_amplitude range min_scale = np.min(amplitudes) / min_amplitude @@ -265,11 +263,10 @@ def relocate_templates( """ seed = _ensure_seed(seed) - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] assert margin >= 0, "margin should be positive" top_margin = np.max(channel_depths) + margin diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index 1f404ea3f7..e2314319a1 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -107,9 +107,9 @@ def split_sorting_by_amplitudes( rng = np.random.default_rng(seed) fs = sorting_analyzer.sampling_frequency - from spikeinterface.core.template_tools import get_template_extremum_channel - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) new_spikes = spikes[0].copy() amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 3d26f7a85e..f4a6afcd9b 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -20,7 +20,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.core import SortingAnalyzer, get_noise_levels, NumpySorting from spikeinterface.core.template_tools import ( - get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) @@ -182,16 +181,13 @@ def compute_snrs( channel_ids = sorting_analyzer.channel_ids - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) + main_channel_index = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) - # make a dict to access by chan_id - noise_levels = dict(zip(channel_ids, noise_levels)) - snrs = {} for unit_id in unit_ids: - chan_id = extremum_channels_ids[unit_id] - noise = noise_levels[chan_id] + chan_ind = main_channel_index[unit_id] + noise = noise_levels[chan_ind] amplitude = unit_amplitudes[unit_id] snrs[unit_id] = np.abs(amplitude) / noise @@ -1294,7 +1290,8 @@ def compute_sd_ratio( noise_levels = get_noise_levels( sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs ) - best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", **kwargs) + main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) if correct_for_template_itself: @@ -1330,7 +1327,7 @@ def compute_sd_ratio( else: unit_std = np.std(spk_amp) - best_channel = best_channels[unit_id] + best_channel = main_channels[unit_id] std_noise = noise_levels[best_channel] n_samples = sorting_analyzer.get_total_samples() diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 5476aa405a..e4ce1a68fe 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -5,7 +5,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension @@ -137,7 +136,8 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): all_labels = sorting_analyzer.sorting.unit_ids[spike_unit_indices] # Get extremum channels for neighbor selection in sparse mode - extremum_channels = get_template_extremum_channel(sorting_analyzer) + + main_channels = sorting_analyzer.get_main_channel(outputs="id", with_dict=True) # Pre-compute spike counts and firing rates if advanced NN metrics are requested advanced_nn_metrics = ["nn_advanced"] # Our grouped advanced NN metric @@ -152,7 +152,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): if sorting_analyzer.is_sparse(): neighbor_channel_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ - other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids + other_unit for other_unit in unit_ids if main_channels[other_unit] in neighbor_channel_ids ] neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) else: diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 85ef9e22cb..7f6175e47a 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -12,7 +12,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array +from spikeinterface.core.template_tools import get_dense_templates_array from .metrics import get_trough_and_peak_idx, single_channel_metrics, multi_channel_metrics @@ -189,7 +189,8 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) channel_locations = sorting_analyzer.get_channel_locations() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 473798fe7c..6cae681f15 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -3,7 +3,7 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore +from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -105,9 +105,8 @@ def _get_pipeline_nodes(self): cut_out_after = nafter peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + # collisions handle_collisions = self.params["handle_collisions"] diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 671b9bb239..e180b868ab 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -6,7 +6,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity -from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -99,7 +99,9 @@ def compute_monopolar_triangulation( chan_inds = sparsity.unit_id_to_channel_indices[unit_id] neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = get_template_extremum_channel(sorting_analyzer_or_templates, outputs="index") + + best_channels = sorting_analyzer_or_templates.get_main_channel(outputs="index", with_dict=True) + unit_location = np.zeros((unit_ids.size, 4), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -278,7 +280,7 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - peak_channels = get_template_extremum_channel(sorting_analyzer_or_templates, peak_sign, outputs="index") + main_channels = sorting_analyzer_or_templates.get_main_channel(outputs="index", with_dict=True) weights_sparsity_mask = weights > 0 @@ -286,7 +288,7 @@ def compute_grid_convolution( unit_location = np.zeros((len(unit_ids), 3), dtype="float64") for i, unit_id in enumerate(unit_ids): - main_chan = peak_channels[unit_id] + main_chan = main_channels[unit_id] wf = templates[i, :, :] nearest_mask = nearest_template_mask[main_chan, :] channel_mask = np.sum(weights_sparsity_mask[:, :, nearest_mask], axis=(0, 2)) > 0 @@ -661,14 +663,10 @@ def get_convolution_weights( def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> np.ndarray: """ Localize a unit using max channel. - This uses internally `get_template_extremum_channel()` - Parameters ---------- @@ -689,9 +687,8 @@ def compute_location_max_channel( unit_locations: np.ndarray 2d """ - extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" - ) + extremum_channels_index = templates_or_sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: unit_ids = templates_or_sorting_analyzer.unit_ids diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 993d1a105d..b5c7ff1f6d 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -4,7 +4,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type @@ -35,9 +35,8 @@ def _get_pipeline_nodes(self): peak_sign = self.params["peak_sign"] return_in_uV = self.sorting_analyzer.return_in_uV - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) spike_retriever_node = SpikeRetriever( diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d4e226aa99..cf5adb6198 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -4,7 +4,6 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.sortinganalyzer import register_result_extension -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.node_pipeline import SpikeRetriever from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -75,9 +74,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) retriever = SpikeRetriever( sorting, diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 44389cc503..7ac4d36d0a 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,13 +53,11 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - from spikeinterface.core.template_tools import get_template_extremum_channel + main_channels = self.templates.get_main_channel(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) - best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") - best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( - channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], axis=2 + channel_locations[:, None] - channel_locations[main_channels][np.newaxis, :], axis=2 ) self.neighborhood_mask = template_distances <= neighborhood_radius_um else: diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index d3ae787a4b..1622ab83a5 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -6,7 +6,6 @@ import numpy as np from spikeinterface.core import ( get_channel_distances, - get_template_extremum_channel, ) from spikeinterface.sortingcomponents.peak_detection.method_list import ( @@ -222,12 +221,12 @@ def __init__( self.sparse_templates_array_static = templates.templates_array self.dtype = self.sparse_templates_array_static.dtype - extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") + # as numpy vector - self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") + self.main_channels = templates.get_main_channel(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) channel_locations = templates.probe.contact_positions - unit_locations = channel_locations[self.extremum_channel] + unit_locations = channel_locations[self.main_channels] self.channel_locations = channel_locations # distance between units diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 505027f79a..6f8ba998e7 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -8,7 +8,6 @@ from .utils import get_unit_colors from .traces import TracesWidget from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from spikeinterface.core.baserecording import BaseRecording from spikeinterface.core.basesorting import BaseSorting @@ -121,9 +120,9 @@ def __init__( sparsity = sorting_analyzer.sparsity else: if sparsity is None: - # in this case, we construct a sparsity dictionary only with the best channel - extremum_channel_ids = get_template_extremum_channel(sorting_analyzer) - unit_id_to_channel_ids = {u: [ch] for u, ch in extremum_channel_ids.items()} + # in this case, we construct a sparsity dictionary only with the main channel + main_channels = sorting_analyzer.get_main_channel(outputs="id", with_dict=True) + unit_id_to_channel_ids = {u: [ch] for u, ch in main_channels.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_analyzer.unit_ids, diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index c55c802f9b..69bc2b05cf 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -4,7 +4,6 @@ import numpy as np from probeinterface import ProbeGroup -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from .base import BaseWidget, to_attr @@ -86,10 +85,10 @@ def __init__( if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])): warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) for unit_id in unit_ids: if np.any(np.isnan(unit_locations[unit_id])): - unit_locations[unit_id] = channel_locations[extremum_channel_indices[unit_id]] + unit_locations[unit_id] = channel_locations[main_channels[unit_id]] data_plot = dict( all_unit_ids=sorting.unit_ids, diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fb26a228ef..411cd34203 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -4,7 +4,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -136,12 +135,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] if np.isnan(x) or np.isnan(y): warnings.warn(f"Unit {unit_id} location contains NaN values. Replacing NaN extremum channel location.") - x, y = sorting_analyzer.get_channel_locations()[extremum_channel_indices[unit_id]] + x, y = sorting_analyzer.get_channel_locations()[main_channels[unit_id]] ax_unit_locations.set_xlim(x - 80, x + 80) ax_unit_locations.set_ylim(y - 250, y + 250) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9543cbf734..91854ec0eb 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -5,7 +5,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core import ChannelSparsity, get_template_extremum_channel +from spikeinterface.core import ChannelSparsity class UnitWaveformDensityMapWidget(BaseWidget): @@ -43,7 +43,6 @@ def __init__( sparsity=None, same_axis=False, use_max_channel=False, - peak_sign="neg", unit_colors=None, backend=None, **backend_kwargs, @@ -61,9 +60,7 @@ def __init__( if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = get_template_extremum_channel( - sorting_analyzer, mode="extremum", peak_sign=peak_sign, outputs="index" - ) + max_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all if sorting_analyzer.is_sparse(): From eede722009b7aae012d1c7e8aac92c44d3d88a90 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Feb 2026 19:03:30 +0100 Subject: [PATCH 02/20] Put main_channel_peak_sign and main_channel_peak_mode in analyzer settings to reload them --- .../tests/test_benchmark_clustering.py | 2 +- .../tests/test_benchmark_peak_detection.py | 2 +- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 123 +++++++++++++++--- src/spikeinterface/core/sparsity.py | 27 ++-- src/spikeinterface/core/template.py | 2 +- src/spikeinterface/core/template_tools.py | 14 +- .../core/tests/test_node_pipeline.py | 4 +- .../core/tests/test_sortinganalyzer.py | 5 +- .../core/tests/test_sparsity.py | 4 +- src/spikeinterface/exporters/report.py | 4 +- src/spikeinterface/exporters/to_ibl.py | 4 +- src/spikeinterface/generation/hybrid_tools.py | 6 +- .../generation/splitting_tools.py | 2 +- .../metrics/quality/misc_metrics.py | 4 +- .../metrics/quality/quality_metrics.py | 2 +- .../metrics/template/template_metrics.py | 2 +- .../postprocessing/amplitude_scalings.py | 3 +- .../postprocessing/localization_tools.py | 6 +- .../postprocessing/spike_amplitudes.py | 2 +- .../postprocessing/spike_locations.py | 3 +- .../sortingcomponents/matching/nearest.py | 2 +- .../sortingcomponents/matching/tdc_peeler.py | 2 +- .../widgets/spikes_on_traces.py | 2 +- src/spikeinterface/widgets/unit_locations.py | 2 +- src/spikeinterface/widgets/unit_summary.py | 2 +- .../widgets/unit_waveforms_density_map.py | 4 +- 27 files changed, 158 insertions(+), 79 deletions(-) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index 5660b68fda..eae0bf0e59 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -31,7 +31,7 @@ def test_benchmark_clustering(create_cache_folder): # sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) # sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = gt_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_inds = gt_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index 86b6bde5c5..82a51e8292 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -29,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder): sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("templates", **job_kwargs) - extremum_channel_inds = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index b11801c0be..fd8476e2b6 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -810,7 +810,7 @@ def to_spike_vector( extremum_channel_inds : None or dict, default: None If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index". This can be convinient for computing spikes postion after sorter. - This dict can be given by analyzer.get_main_channel(outputs="index", with_dict=True) + This dict can be given by analyzer.get_main_channels(outputs="index", with_dict=True) use_cache : bool, default: True When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index fe125193d3..f1fe4b81c9 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -58,7 +58,7 @@ def create_sorting_analyzer( folder=None, main_channel_index=None, main_channel_peak_sign="both", - main_channel_mode="extremum", + main_channel_peak_mode="extremum", num_spikes_for_main_channel=100, sparse=True, sparsity=None, @@ -101,6 +101,11 @@ def create_sorting_analyzer( The main_channel_index can be externally provided main_channel_peak_sign : "both" | "neg" In case when the main_channel_index is estimated wich sign to consider "both" or "neg". + main_channel_peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude num_spikes_for_main_channel : int, default: 100 How many spikes per units to compute the main channel. sparse : bool, default: True @@ -225,7 +230,7 @@ def create_sorting_analyzer( recording, sorting, main_channel_peak_sign=main_channel_peak_sign, - mode=main_channel_mode, + peak_mode=main_channel_peak_mode, num_spikes_for_main_channel=num_spikes_for_main_channel, seed=seed, **job_kwargs @@ -266,6 +271,8 @@ def create_sorting_analyzer( format=format, folder=folder, main_channel_index=main_channel_index, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, @@ -335,6 +342,8 @@ def __init__( format: str | None = None, sparsity: ChannelSparsity | None = None, return_in_uV: bool = True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options: dict | None = None, ): # very fast init because checks are done in load and create @@ -345,6 +354,8 @@ def __init__( self.format = format self.sparsity = sparsity self.return_in_uV = return_in_uV + self.main_channel_peak_sign = main_channel_peak_sign + self.main_channel_peak_mode = main_channel_peak_mode # For backward compatibility self.return_scaled = return_in_uV @@ -402,6 +413,8 @@ def create( sparsity=None, return_scaled=None, return_in_uV=True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options=None, ): assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" @@ -435,10 +448,14 @@ def create( sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) # This will ensure that the sorting saved always will have this main_channel + assert main_channel_index is not None sorting.set_property("main_channel_index", main_channel_index) if format == "memory": - sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, rec_attributes=None) + sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes=None) elif format == "binary_folder": sorting_analyzer = cls.create_binary_folder( folder, @@ -446,6 +463,8 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) @@ -459,6 +478,8 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) @@ -497,7 +518,10 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe return sorting_analyzer @classmethod - def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attributes): + def create_memory(cls, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes): # used by create and save_as if rec_attributes is None: @@ -518,11 +542,18 @@ def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attribute format="memory", sparsity=sparsity, return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + + rec_attributes, backend_options): # used by create and save_as folder = Path(folder) @@ -586,12 +617,45 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV settings_file = folder / f"settings.json" settings = dict( return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, ) with open(settings_file, mode="w") as f: json.dump(check_json(settings), f, indent=4) return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) + @classmethod + def _handle_backward_compatibility(cls, settings, sorting, sparsity): + # backward compatibility at analyzer level + # (there is also something similar at extension level) + + new_settings = dict() + new_settings.update(settings) + if "return_scaled" in settings: + new_settings["return_in_uV"] = new_settings.pop("return_scaled") + elif "return_in_uV" in settings: + pass + else: + # old version did not have settings at all + new_settings["return_in_uV"] = True + + retrospect_main_channel_index = None + if "main_channel_peak_sign" not in settings: + # before 0.104.0 was not in main_channel_peak_sign + # TODO make something more fancy that exlore the previous params of extension + new_settings["main_channel_peak_sign"] = "both" + new_settings["main_channel_peak_mode"] = "extremum" + + if "main_channel_index" not in sorting.get_property_keys(): + # TODO + raise NotImplementedError("backward compatibility with main_channel_index is not implemented yet") + + if retrospect_main_channel_index is not None: + sorting.set_property("main_channel_index", retrospect_main_channel_index) + + return new_settings + @classmethod def load_from_binary_folder(cls, folder, recording=None, backend_options=None): from .loading import load @@ -653,13 +717,19 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): if settings_file.exists(): with open(settings_file, "r") as f: settings = json.load(f) + need_to_create = False else: + need_to_create = True + settings = dict() + + settings = cls._handle_backward_compatibility(settings, sorting, sparsity) + + if need_to_create: warnings.warn("settings.json not found for this folder writing one with return_in_uV=True") - settings = dict(return_in_uV=True) with open(settings_file, "w") as f: json.dump(check_json(settings), f, indent=4) - return_in_uV = settings.get("return_in_uV", settings.get("return_scaled", True)) + sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -667,7 +737,9 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="binary_folder", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV = settings["return_in_uV"], + main_channel_peak_sign = settings["main_channel_peak_sign"], + main_channel_peak_mode = settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -682,7 +754,11 @@ def _get_zarr_root(self, mode="r+"): return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + + rec_attributes, backend_options): # used by create and save_as import zarr import numcodecs @@ -706,7 +782,11 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) - settings = dict(return_in_uV=return_in_uV) + settings = dict( + return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) zarr_root.attrs["settings"] = check_json(settings) # the recording @@ -827,10 +907,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): ) else: sparsity = None - - return_in_uV = zarr_root.attrs["settings"].get( - "return_in_uV", zarr_root.attrs["settings"].get("return_scaled", True) - ) + + settings = zarr_root.attrs["settings"] + settings = cls._handle_backward_compatibility(settings, sorting, sparsity) sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -838,7 +917,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="zarr", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV = settings["return_in_uV"], + main_channel_peak_sign = settings["main_channel_peak_sign"], + main_channel_peak_mode = settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -941,7 +1022,7 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n """ return self.sorting.get_property(key, ids=ids) - def get_main_channel(self, outputs="index", with_dict=False): + def get_main_channels(self, outputs="index", with_dict=False): """ """ @@ -1180,7 +1261,10 @@ def _save_or_select_or_merge_or_split( if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( - sorting_provenance, recording, sparsity, self.return_in_uV, self.rec_attributes + sorting_provenance, recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes ) elif format == "binary_folder": @@ -1193,6 +1277,9 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes, backend_options=backend_options, ) @@ -1206,6 +1293,8 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, self.rec_attributes, backend_options=backend_options, ) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 70bb4cdb46..e82d5f19e9 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -29,8 +29,6 @@ In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. - peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels. num_channels : int Number of channels for "best_channels" method. radius_um : float @@ -83,14 +81,14 @@ class ChannelSparsity: Using the N best channels (largest template amplitude): - >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels) Using a neighborhood by radius: - >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um) Using a SNR threshold: - >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold) Using a template energy threshold: >>> sparsity = ChannelSparsity.from_energy(sorting_analyzer, threshold) @@ -395,7 +393,7 @@ def from_radius_and_main_channel(cls, unit_ids, channel_ids, main_channel_index, @classmethod - def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="both"): + def from_radius(cls, templates_or_sorting_analyzer, radius_um): """ Construct sparsity from a radius around the main channel. Use the "radius_um" argument to specify the radius in um. @@ -414,17 +412,14 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="both") sparsity : ChannelSparsity The estimated sparsity. """ - mask = np.zeros( - (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" - ) + main_channel_index = templates_or_sorting_analyzer.get_main_channels(outputs="index") channel_locations = templates_or_sorting_analyzer.get_channel_locations() - distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - main_channel_index = templates_or_sorting_analyzer.get_main_channel(outputs="index") - for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_ind = main_channel_index[unit_ind] - (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) - mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) + return cls.from_radius_and_main_channel( + templates_or_sorting_analyzer.unit_ids, + templates_or_sorting_analyzer.channel_ids, + main_channel_index, + channel_locations, + radius_um) @classmethod def from_snr( diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 151f6e2b8d..4aac641372 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -484,7 +484,7 @@ def get_channel_locations(self) -> np.ndarray: channel_locations = self.probe.contact_positions return channel_locations - def get_main_channel(self, + def get_main_channels(self, main_channel_peak_sign: "neg" | "both" | "pos" = "both", mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", outputs="index", diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index fe04540476..584bfaee7e 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -132,21 +132,21 @@ def get_template_amplitudes( -def _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, nbefore): +def _get_main_channel_from_template_array(templates_array, peak_mode, main_channel_peak_sign, nbefore): # Step1 : max on time axis - if mode == "extremum": + if peak_mode == "extremum": if main_channel_peak_sign == "both": values = np.max(np.abs(templates_array), axis=1) elif main_channel_peak_sign == "neg": values = -np.min(templates_array, axis=1) elif main_channel_peak_sign == "pos": values = np.max(templates_array, axis=1) - elif mode == "at_index": + elif peak_mode == "at_index": if main_channel_peak_sign == "both": values = np.abs(templates_array[:, nbefore, :]) elif main_channel_peak_sign in ["neg", "pos"]: values = templates_array[:, nbefore, :] - elif mode == "peak_to_peak": + elif peak_mode == "peak_to_peak": values = np.ptp(templates_array, axis=1) # Step2: max on channel axis @@ -158,7 +158,7 @@ def estimate_main_channel_from_recording( recording, sorting, main_channel_peak_sign: "neg" | "both" | "pos" = "both", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", num_spikes_for_main_channel=100, ms_before = 1.0, ms_after = 2.5, @@ -203,7 +203,7 @@ def estimate_main_channel_from_recording( **job_kwargs, ) - main_channel_index = _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, nbefore) + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, main_channel_peak_sign, nbefore) return main_channel_index @@ -241,7 +241,7 @@ def get_template_extremum_channel( Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ - warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channel() instead") + warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channels() instead") # TODO make a better logic here assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`" diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 74131dac8d..7a29a3cee6 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -80,7 +80,7 @@ def test_run_node_pipeline(cache_folder_creation): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = sorting_analyzer.get_main_channel( outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) @@ -202,7 +202,7 @@ def test_skip_after_n_peaks_and_recording_slices(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = sorting_analyzer.get_main_channel( outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 7be23dbf85..765d5bec7e 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -56,7 +56,7 @@ def test_SortingAnalyzer_memory(tmp_path, dataset): # Note the sorting contain already main_channel_index sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) - assert np.array_equal(sorting_analyzer.get_main_channel() , sorting.get_property("main_channel_index")) + assert np.array_equal(sorting_analyzer.get_main_channels() , sorting.get_property("main_channel_index")) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -365,7 +365,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): assert ext is None assert sorting_analyzer.has_recording() - # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": @@ -635,7 +634,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - extremum_channel_inds = self.sorting_analyzer.get_main_channel( outputs="index", with_dict=True) + extremum_channel_inds = self.sorting_analyzer.get_main_channels( outputs="index", with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index c865068e4a..b013c9ca90 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -285,7 +285,7 @@ def test_compute_sparsity(): # using object SortingAnalyzer sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0) sparsity = compute_sparsity(sorting_analyzer, method="closest_channels", num_channels=2) sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") sparsity = compute_sparsity( @@ -299,7 +299,7 @@ def test_compute_sparsity(): templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(templates, method="radius", radius_um=50.0) sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(templates, method="closest_channels", num_channels=2) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 8c12591c6e..66b71a1b1c 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -102,8 +102,8 @@ def export_report( units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" # max_on_channel_id is kept (oold name) - units["max_on_channel_id"] = sorting_analyzer.get_main_channel(outputs="id", with_dict=False) - units["main_channel_id"] = sorting_analyzer.get_main_channel(outputs="id", with_dict=False) + units["max_on_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) + units["main_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign)) units.to_csv(output_folder / "unit list.csv", sep="\t") diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index b0445b4753..9a0d5847fb 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -101,7 +101,7 @@ def export_to_ibl_gui( output_folder.mkdir(parents=True, exist_ok=True) ### Save spikes info ### - extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_indices) # spikes.clusters @@ -136,7 +136,7 @@ def export_to_ibl_gui( np.save(output_folder / "clusters.waveforms.npy", templates) # cluster channels - extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) cluster_channels = np.array(list(extremum_channel_indices.values()), dtype="int32") np.save(output_folder / "clusters.channels.npy", cluster_channels) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 72d8cea634..406fbff18e 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -127,7 +127,7 @@ def select_templates( min_amplitude is not None or max_amplitude is not None or min_depth is not None or max_depth is not None ), "At least one of min_amplitude, max_amplitude, min_depth, max_depth should be provided" # get template amplitudes and depth - main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) mask = np.ones(templates.num_units, dtype=bool) @@ -190,7 +190,7 @@ def scale_template_to_range( Templates The scaled templates. """ - main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) # get amplitudes if amplitude_function == "ptp": @@ -263,7 +263,7 @@ def relocate_templates( """ seed = _ensure_seed(seed) - main_channel_indices = templates.get_main_channel(outputs="index", with_dict=False) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] unit_depths = channel_depths[main_channel_indices] diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index e2314319a1..03d7cd5ef8 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -108,7 +108,7 @@ def split_sorting_by_amplitudes( rng = np.random.default_rng(seed) fs = sorting_analyzer.sampling_frequency - extremum_channel_inds = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) new_spikes = spikes[0].copy() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index f4a6afcd9b..128a509e80 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -181,7 +181,7 @@ def compute_snrs( channel_ids = sorting_analyzer.channel_ids - main_channel_index = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) snrs = {} @@ -1290,7 +1290,7 @@ def compute_sd_ratio( noise_levels = get_noise_levels( sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs ) - main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index e4ce1a68fe..405a31d068 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -137,7 +137,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): # Get extremum channels for neighbor selection in sparse mode - main_channels = sorting_analyzer.get_main_channel(outputs="id", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) # Pre-compute spike counts and firing rates if advanced NN metrics are requested advanced_nn_metrics = ["nn_advanced"] # Our grouped advanced NN metric diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 7f6175e47a..ed7ded47e6 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -189,7 +189,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) - extremum_channel_indices = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 6cae681f15..dfef0e3bc6 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -104,8 +104,7 @@ def _get_pipeline_nodes(self): else: cut_out_after = nafter - peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" - extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) # collisions diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index e180b868ab..fb9c89c521 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -100,7 +100,7 @@ def compute_monopolar_triangulation( neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = sorting_analyzer_or_templates.get_main_channel(outputs="index", with_dict=True) + best_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) unit_location = np.zeros((unit_ids.size, 4), dtype="float64") @@ -280,7 +280,7 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - main_channels = sorting_analyzer_or_templates.get_main_channel(outputs="index", with_dict=True) + main_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) weights_sparsity_mask = weights > 0 @@ -687,7 +687,7 @@ def compute_location_max_channel( unit_locations: np.ndarray 2d """ - extremum_channels_index = templates_or_sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channels_index = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=True) contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index b5c7ff1f6d..5a7aebf728 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -35,7 +35,7 @@ def _get_pipeline_nodes(self): peak_sign = self.params["peak_sign"] return_in_uV = self.sorting_analyzer.return_in_uV - extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index cf5adb6198..c6c9eca021 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -73,8 +73,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] - extremum_channels_indices = self.sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) retriever = SpikeRetriever( sorting, diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 7ac4d36d0a..0e46ba4df2 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,7 +53,7 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - main_channels = self.templates.get_main_channel(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) + main_channels = self.templates.get_main_channels(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 1622ab83a5..96b667e404 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -223,7 +223,7 @@ def __init__( # as numpy vector - self.main_channels = templates.get_main_channel(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) + self.main_channels = templates.get_main_channels(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) channel_locations = templates.probe.contact_positions unit_locations = channel_locations[self.main_channels] diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 6f8ba998e7..23f662d14a 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -121,7 +121,7 @@ def __init__( else: if sparsity is None: # in this case, we construct a sparsity dictionary only with the main channel - main_channels = sorting_analyzer.get_main_channel(outputs="id", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) unit_id_to_channel_ids = {u: [ch] for u, ch in main_channels.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 69bc2b05cf..2483f0a792 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -85,7 +85,7 @@ def __init__( if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])): warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.") - main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) for unit_id in unit_ids: if np.any(np.isnan(unit_locations[unit_id])): unit_locations[unit_id] = channel_locations[main_channels[unit_id]] diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 411cd34203..93c18a01e3 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -135,7 +135,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - main_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 91854ec0eb..10f9651bc6 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -25,8 +25,6 @@ class UnitWaveformDensityMapWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored use_max_channel : bool, default: False Use only the max channel - peak_sign : "neg" | "pos" | "both", default: "neg" - Used to detect max channel only when use_max_channel=True unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. @@ -60,7 +58,7 @@ def __init__( if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = sorting_analyzer.get_main_channel(outputs="index", with_dict=True) + max_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all if sorting_analyzer.is_sparse(): From 7b6a1095f6f84bfb78f8842f384676454ec37d6a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 17 Feb 2026 18:51:24 +0100 Subject: [PATCH 03/20] continue the tedious refactoring --- src/spikeinterface/core/__init__.py | 10 +- src/spikeinterface/core/sortinganalyzer.py | 90 +++++- src/spikeinterface/core/template.py | 10 +- src/spikeinterface/core/template_tools.py | 299 ++++++++++++------ src/spikeinterface/core/tests/test_loading.py | 5 + .../core/tests/test_template_tools.py | 34 +- .../curation/remove_redundant.py | 4 +- src/spikeinterface/exporters/report.py | 4 +- .../metrics/quality/misc_metrics.py | 4 +- .../postprocessing/spike_amplitudes.py | 4 +- .../tests/test_multi_extensions.py | 6 +- src/spikeinterface/sortingcomponents/tools.py | 4 +- src/spikeinterface/widgets/unit_depths.py | 4 +- 13 files changed, 319 insertions(+), 159 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 24c64162ee..5fba540db0 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -163,9 +163,13 @@ # template tools from .template_tools import ( get_template_amplitudes, - get_template_extremum_channel, - get_template_extremum_channel_peak_shift, - get_template_extremum_amplitude, + get_template_main_channel_peak_shift, + get_template_main_channel_amplitude, + + # this is not needed anymore + get_template_extremum_channel, # keep for backward compatibility can be removed in 0.105 + get_template_extremum_channel_peak_shift, # keep for backward compatibility can be removed in 0.105 + get_template_extremum_amplitude, # keep for backward compatibility can be removed in 0.105 ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ad68b3343a..57f558851b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -229,7 +229,7 @@ def create_sorting_analyzer( main_channel_index = estimate_main_channel_from_recording( recording, sorting, - main_channel_peak_sign=main_channel_peak_sign, + peak_sign=main_channel_peak_sign, peak_mode=main_channel_peak_mode, num_spikes_for_main_channel=num_spikes_for_main_channel, seed=seed, @@ -626,9 +626,17 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) @classmethod - def _handle_backward_compatibility(cls, settings, sorting, sparsity): - # backward compatibility at analyzer level - # (there is also something similar at extension level) + def _handle_backward_compatibility_settings_pre_init(cls, settings, sorting, sparsity): + """ + backward compatibility before the __init__ to handle the settings: + * return_scaled > return_in_uV + * main_channel_peak_sign + * main_channel_peak_mode + + Note : + * see also _handle_backward_compatibility_settings_post_init + * there is also something at extension level to handle changes in paramaters with deferents mechanism + """ new_settings = dict() new_settings.update(settings) @@ -640,21 +648,73 @@ def _handle_backward_compatibility(cls, settings, sorting, sparsity): # old version did not have settings at all new_settings["return_in_uV"] = True - retrospect_main_channel_index = None if "main_channel_peak_sign" not in settings: # before 0.104.0 was not in main_channel_peak_sign # TODO make something more fancy that exlore the previous params of extension new_settings["main_channel_peak_sign"] = "both" new_settings["main_channel_peak_mode"] = "extremum" - if "main_channel_index" not in sorting.get_property_keys(): - # TODO - raise NotImplementedError("backward compatibility with main_channel_index is not implemented yet") + return new_settings + + def _handle_backward_compatibility_settings_post_init(self): + """ + backward compatibility after the __init__ to : + * main_channel_index + + Note : + * see also _handle_backward_compatibility_settings_pre_init + * there is also something at extension level to handle changes in paramaters with deferents mechanism + """ + + + if "main_channel_index" not in self.sorting.get_property_keys(): + + warnings.warn("This loaded analyzer is from an older verion main_channel_index need to be computed from templates") + + main_channel_index = None + if self.has_extension("templates"): + # first try to load templates extension + ext = self.get_extension("templates") + + for k in ("average", "median"): + if k in ext.data: + from .template_tools import _get_main_channel_from_template_array + templates_array = ext.data[k] + # TODO @alessio @chris : we need to discuss this + peak_sign = "both" # or "neg" ????? + peak_mode = "extremum" + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, ext.nbefore) + break + + if main_channel_index is None: + if not self.has_recording(): + # TODO @alessio @chris : we need to discuss this + # what to do in this case ??????? + raise ValueError("This analyzer cannot be load and is from an old version, the recording is not available") + else: - if retrospect_main_channel_index is not None: - sorting.set_property("main_channel_index", retrospect_main_channel_index) + # otherwise we need to estimate the + + from .template_tools import estimate_main_channel_from_recording + # TODO @alessio @chris : we need to discuss this + peak_sign = "both" # or "neg" ????? + peak_mode = "extremum" + + main_channel_index = estimate_main_channel_from_recording( + self.recording, + self.sorting, + peak_sign=peak_sign, + peak_mode=peak_mode, + num_spikes_for_main_channel=100, + seed=None, + ) + + # this is only in memory + self.sorting.set_property("main_channel_index", main_channel_index) + # TODO @alessio @chris : we need to discuss this + # this save also to disk but maybe there is no write for the analyzer... + self.set_sorting_property("main_channel_index", main_channel_index, save=True) - return new_settings @classmethod def load_from_binary_folder(cls, folder, recording=None, backend_options=None): @@ -722,7 +782,7 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): need_to_create = True settings = dict() - settings = cls._handle_backward_compatibility(settings, sorting, sparsity) + settings = cls._handle_backward_compatibility_settings_pre_init(settings, sorting, sparsity) if need_to_create: warnings.warn("settings.json not found for this folder writing one with return_in_uV=True") @@ -909,7 +969,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): sparsity = None settings = zarr_root.attrs["settings"] - settings = cls._handle_backward_compatibility(settings, sorting, sparsity) + settings = cls._handle_backward_compatibility_settings_pre_init(settings, sorting, sparsity) sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -1027,9 +1087,9 @@ def get_main_channels(self, outputs="index", with_dict=False): """ main_channel_index = self.get_sorting_property("main_channel_index") - if outputs is "index": + if outputs == "index": main_chans = main_channel_index - elif outputs is "id": + elif outputs == "id": main_chans = self.channel_ids[main_channel_index] else: raise ValueError("wrong outputs") diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 4aac641372..20c3b7b09a 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -485,19 +485,19 @@ def get_channel_locations(self) -> np.ndarray: return channel_locations def get_main_channels(self, - main_channel_peak_sign: "neg" | "both" | "pos" = "both", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_sign: "neg" | "both" | "pos" = "both", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", outputs="index", with_dict=True ): from .template_tools import _get_main_channel_from_template_array templates_array = self.get_dense_templates() - main_channel_index = _get_main_channel_from_template_array(templates_array, mode, main_channel_peak_sign, self.nbefore) + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, self.nbefore) - if outputs is "index": + if outputs == "index": main_chans = main_channel_index - elif outputs is "id": + elif outputs == "id": main_chans = self.channel_ids[main_channel_index] else: raise ValueError("wrong outputs") diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 584bfaee7e..de347b49db 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -51,6 +51,33 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_in raise ValueError("Input should be Templates or SortingAnalyzer") return templates_array +def get_main_channel_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: bool = True): + """ + Return dense templates as numpy array from either a Templates object or a SortingAnalyzer. + + Parameters + ---------- + one_object : Templates | SortingAnalyzer + The Templates or SortingAnalyzer objects. If SortingAnalyzer, it needs the "templates" extension. + return_in_uV : bool, default: True + If True, templates are scaled. + + Returns + ------- + main_channel_templates : np.ndarray + The dense templates (num_units, num_samples) + """ + # TODO later: do not load the dense templates array if this is not necessary (when sprse internally) + main_channels = one_object.get_main_channels(outputs="index", with_dict=False) + templates_array = get_dense_templates_array(one_object, return_in_uV=return_in_uV) + num_units = templates_array.shape[0] + num_samples = templates_array.shape[1] + main_channel_templates = np.zeros((num_units, num_samples), dtype=templates_array.dtype) + for i in range(num_units): + main_channel_templates[i, :] = templates_array[i, :, main_channels[i]] + return main_channel_templates + + def _get_nbefore(one_object): if isinstance(one_object, Templates): @@ -66,9 +93,9 @@ def _get_nbefore(one_object): def get_template_amplitudes( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", - return_in_uV: bool = True, + peak_sign: None | "neg" | "pos" | "both" = None, + peak_mode: None | "extremum" | "at_index" | "peak_to_peak" = None, + # return_in_uV: bool = True, abs_value: bool = True, ): """ @@ -78,15 +105,15 @@ def get_template_amplitudes( ---------- templates_or_sorting_analyzer : Templates | SortingAnalyzer A Templates or a SortingAnalyzer object - peak_sign : "neg" | "pos" | "both" + peak_sign : None | "neg" | "pos" | "both" + Used only when input is Templates. Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : None | "extremum" | "at_index" | "peak_to_peak", default: None + Used only when input is Templates. Where the amplitude is computed * "extremum" : take the peak value (max or min depending on `peak_sign`) * "at_index" : take value at `nbefore` index * "peak_to_peak" : take the peak-to-peak amplitude - return_in_uV : bool, default True - The amplitude is scaled or not. abs_value : bool = True Whether the extremum amplitude should be returned as an absolute value or not @@ -95,8 +122,24 @@ def get_template_amplitudes( peak_values : dict Dictionary with unit ids as keys and template amplitudes as values """ + + if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + assert peak_sign is None, "get_template_amplitudes() peak_sign is now contained in SortingAnalyzer, should be None here" + assert peak_mode is None, "get_template_amplitudes() peak_mode is now contained in SortingAnalyzer, should be None here" + peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign + peak_mode = templates_or_sorting_analyzer.main_channel_peak_mode + return_in_uV = templates_or_sorting_analyzer.return_in_uV + elif isinstance(templates_or_sorting_analyzer, Templates): + return_in_uV = templates_or_sorting_analyzer.is_in_uV + if peak_sign is None: + warnings.warn("get_template_amplitudes() with Templates should provide a peak_sign") + peak_sign = "both" + if peak_mode is None: + warnings.warn("get_template_amplitudes() with Templates should provide a peak_mode") + peak_mode = "extremum" + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" + assert peak_mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" unit_ids = templates_or_sorting_analyzer.unit_ids before = _get_nbefore(templates_or_sorting_analyzer) @@ -108,19 +151,19 @@ def get_template_amplitudes( for unit_ind, unit_id in enumerate(unit_ids): template = templates_array[unit_ind, :, :] - if mode == "extremum": + if peak_mode == "extremum": if peak_sign == "both": values = np.max(np.abs(template), axis=0) elif peak_sign == "neg": values = np.min(template, axis=0) elif peak_sign == "pos": values = np.max(template, axis=0) - elif mode == "at_index": + elif peak_mode == "at_index": if peak_sign == "both": values = np.abs(template[before, :]) elif peak_sign in ["neg", "pos"]: values = template[before, :] - elif mode == "peak_to_peak": + elif peak_mode == "peak_to_peak": values = np.ptp(template, axis=0) if abs_value: @@ -132,19 +175,19 @@ def get_template_amplitudes( -def _get_main_channel_from_template_array(templates_array, peak_mode, main_channel_peak_sign, nbefore): +def _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore): # Step1 : max on time axis if peak_mode == "extremum": - if main_channel_peak_sign == "both": + if peak_sign == "both": values = np.max(np.abs(templates_array), axis=1) - elif main_channel_peak_sign == "neg": + elif peak_sign == "neg": values = -np.min(templates_array, axis=1) - elif main_channel_peak_sign == "pos": + elif peak_sign == "pos": values = np.max(templates_array, axis=1) elif peak_mode == "at_index": - if main_channel_peak_sign == "both": + if peak_sign == "both": values = np.abs(templates_array[:, nbefore, :]) - elif main_channel_peak_sign in ["neg", "pos"]: + elif peak_sign in ["neg", "pos"]: values = templates_array[:, nbefore, :] elif peak_mode == "peak_to_peak": values = np.ptp(templates_array, axis=1) @@ -154,10 +197,11 @@ def _get_main_channel_from_template_array(templates_array, peak_mode, main_chann return main_channel_index + def estimate_main_channel_from_recording( recording, sorting, - main_channel_peak_sign: "neg" | "both" | "pos" = "both", + peak_sign: "neg" | "both" | "pos" = "both", peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", num_spikes_for_main_channel=100, ms_before = 1.0, @@ -170,7 +214,7 @@ def estimate_main_channel_from_recording( """ - if main_channel_peak_sign == "pos": + if peak_sign == "pos": warnings.warn( "estimate_main_channel_from_recording() with peak_sign='pos' is a strange case maybe you " \ "should revert the traces instead" @@ -203,7 +247,7 @@ def estimate_main_channel_from_recording( **job_kwargs, ) - main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, main_channel_peak_sign, nbefore) + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore) return main_channel_index @@ -211,13 +255,18 @@ def estimate_main_channel_from_recording( +# TODO remove this in 0.105.0 def get_template_extremum_channel( templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", outputs: "id" | "index" = "id", ): """ + Depracted will be removed in 0.105.0. + Use analyzer.get_main_channels() or tempates.get_main_channels(peak_sign=...) instead. + + Compute the channel with the extremum peak for each unit. Parameters @@ -226,7 +275,7 @@ def get_template_extremum_channel( A Templates or a SortingAnalyzer object peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" Where the amplitude is computed * "extremum" : take the peak value (max or min depending on `peak_sign`) * "at_index" : take value at `nbefore` index @@ -237,49 +286,32 @@ def get_template_extremum_channel( Returns ------- - extremum_channels : dict + main_channels : dict Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channels() instead") - # TODO make a better logic here - - assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" - assert outputs in ("id", "index"), "`outputs` must be either `id` or `index`" - - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids - # if SortingAnalyzer need to use global SortingAnalyzer return_scaled otherwise - # we use the Templates is_in_uV if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV + assert peak_sign is None, "get_template_extremum_channel() peak_sign is now contained in SortingAnalyzer, should be None here" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs=outputs, with_dict=True) + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_extremum_channel() with Templates should provide a peak_sign") + peak_sign = "both" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs=outputs, peak_sign=peak_sign, with_dict=True) + + return main_channels - peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_in_uV=return_in_uV - ) - extremum_channels_id = {} - extremum_channels_index = {} - for unit_id in unit_ids: - max_ind = np.argmax(np.abs(peak_values[unit_id])) - extremum_channels_id[unit_id] = channel_ids[max_ind] - extremum_channels_index[unit_id] = max_ind - if outputs == "id": - return extremum_channels_id - elif outputs == "index": - return extremum_channels_index -def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg"): +# TODO remove this in 0.105.0 +def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None): """ + Depracted will be removed in 0.105.0. + Use get_template_main_channel_peak_shift() instead. + In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. This function is internally used by `compute_spike_amplitudes()` to accurately retrieve the spike amplitudes. @@ -288,7 +320,7 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak ---------- templates_or_sorting_analyzer : Templates | SortingAnalyzer A Templates or a SortingAnalyzer object - peak_sign : "neg" | "pos" | "both" + peak_sign : None Sign of the template to find extremum channels Returns @@ -296,52 +328,82 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak shifts : dict Dictionary with unit ids as keys and shifts as values """ - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids - nbefore = _get_nbefore(templates_or_sorting_analyzer) - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign) + warnings.warn("get_template_extremum_channel_peak_shift() is deprecated use get_template_main_channel_peak_shift() instead" + "Will be removed in 0.105.0" + ) - shifts = {} + return get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None, with_dict=True) - # We need to use the SortingAnalyzer return_scaled - # We need to use the Templates is_in_uV - if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_in_uV=return_in_uV) +def get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None, with_dict=True): + """ + In some situations spike sorters could return a spike index with a small shift related to the waveform peak. + This function estimates and return these alignment shifts for the mean template. + This function is internally used by `compute_spike_amplitudes()` to accurately retrieve the spike amplitudes. - for unit_ind, unit_id in enumerate(unit_ids): - template = templates_array[unit_ind, :, :] + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object + peak_sign : None | "neg" | "pos" | "both" + Sign of the template to find extremum channels + + Returns + ------- + shifts : dict + Dictionary with unit ids as keys and shifts as values + """ + - chan_id = extremum_channels_ids[unit_id] - chan_ind = list(channel_ids).index(chan_id) + if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + assert peak_sign is None + peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=False) + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") + peak_sign = "both" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", peak_sign=peak_sign, with_dict=False) + + unit_ids = templates_or_sorting_analyzer.unit_ids + nbefore = _get_nbefore(templates_or_sorting_analyzer) + + main_channel_templates = get_main_channel_templates_array(templates_or_sorting_analyzer) + + shifts = [] + for unit_ind, unit_id in enumerate(unit_ids): + chan_ind = main_channels[unit_ind] + template = main_channel_templates[chan_ind] if peak_sign == "both": - peak_pos = np.argmax(np.abs(template[:, chan_ind])) + peak_pos = np.argmax(np.abs(template)) elif peak_sign == "neg": - peak_pos = np.argmin(template[:, chan_ind]) + peak_pos = np.argmin(template) elif peak_sign == "pos": - peak_pos = np.argmax(template[:, chan_ind]) + peak_pos = np.argmax(template) shift = peak_pos - nbefore - shifts[unit_id] = shift + shifts.append(shift) + + if with_dict: + shifts = dict(zip(unit_ids, shifts)) + else: + shifts = np.array(shifts) return shifts +# TODO remove this in 0.105.0 def get_template_extremum_amplitude( templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "at_index", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "at_index", abs_value: bool = True, ): """ + Depracted will be removed in 0.105.0. + Use get_template_main_channel_amplitude() instead. + Computes amplitudes on the best channel. Parameters @@ -350,7 +412,7 @@ def get_template_extremum_amplitude( A Templates or a SortingAnalyzer object peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" Where the amplitude is computed * "extremum": take the peak value (max or min depending on `peak_sign`) * "at_index": take value at `nbefore` index @@ -364,30 +426,67 @@ def get_template_extremum_amplitude( amplitudes : dict Dictionary with unit ids as keys and amplitudes as values """ - assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode) + warnings.warn("get_template_extremum_amplitude() is deprecated use get_template_main_channel_amplitude() instead" + "Will be removed in 0.105.0" + ) + return get_template_main_channel_amplitude( + templates_or_sorting_analyzer, + peak_sign=peak_sign, + peak_mode=peak_mode, + abs_value=abs_value, + ) + +def get_template_main_channel_amplitude( + templates_or_sorting_analyzer, + peak_sign: None | "neg" | "pos" | "both" = None, + peak_mode: None | "extremum" | "at_index" | "peak_to_peak" = None, + abs_value: bool = True, + with_dict=True, +): + """ + Computes amplitudes on the best channel. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum": take the peak value (max or min depending on `peak_sign`) + * "at_index": take value at `nbefore` index + * "peak_to_peak": take the peak-to-peak amplitude + abs_value : bool = True + Whether the extremum amplitude should be returned as an absolute value or not + + + Returns + ------- + amplitudes : dict + Dictionary with unit ids as keys and amplitudes as values + """ if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=False) + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") + peak_sign = "both" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", peak_sign=peak_sign, with_dict=False) extremum_amplitudes = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_in_uV=return_in_uV, abs_value=abs_value + templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=peak_mode, abs_value=abs_value ) - unit_amplitudes = {} - for unit_id in unit_ids: - channel_id = extremum_channels_ids[unit_id] - best_channel = list(channel_ids).index(channel_id) - unit_amplitudes[unit_id] = extremum_amplitudes[unit_id][best_channel] + unit_ids = templates_or_sorting_analyzer.unit_ids + unit_amplitudes = [] + for unit_ind, unit_id in enumerate(unit_ids): + chan_ind = main_channels[unit_ind] + unit_amplitudes.append(extremum_amplitudes[unit_id][chan_ind]) + + if with_dict: + unit_amplitudes = dict(zip()) - return unit_amplitudes + return unit_amplitudes \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index bfaf97ec4a..44a7da5424 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -227,3 +227,8 @@ def test_remote_analyzer(): "quality_metrics", ]: assert ext in analyzer.get_saved_extension_names() + + +if __name__ == "__main__": + test_remote_analyzer() + \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index a28680612a..5e4cd1f12c 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -7,9 +7,8 @@ from spikeinterface import Templates from spikeinterface.core import ( get_template_amplitudes, - get_template_extremum_channel, - get_template_extremum_channel_peak_shift, - get_template_extremum_amplitude, + get_template_main_channel_peak_shift, + get_template_main_channel_amplitude, ) @@ -56,24 +55,17 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer): def test_get_template_amplitudes(sorting_analyzer): peak_values = get_template_amplitudes(sorting_analyzer) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - peak_values = get_template_amplitudes(templates, abs_value=True) - peak_to_peak_values = get_template_amplitudes(templates, mode="peak_to_peak") + peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="extremum", abs_value=True) + peak_to_peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="peak_to_peak") assert np.all(ptp > p for ptp, p in zip(peak_to_peak_values.values(), peak_values.values())) -def test_get_template_extremum_channel(sorting_analyzer): - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign="both") - print(extremum_channels_ids) - templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - extremum_channels_ids = get_template_extremum_channel(templates, peak_sign="both") - print(extremum_channels_ids) - -def test_get_template_extremum_channel_peak_shift(sorting_analyzer): - shifts = get_template_extremum_channel_peak_shift(sorting_analyzer, peak_sign="neg") +def test_get_template_main_channel_peak_shift(sorting_analyzer): + shifts = get_template_main_channel_peak_shift(sorting_analyzer) print(shifts) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - shifts = get_template_extremum_channel_peak_shift(templates, peak_sign="neg") + shifts = get_template_main_channel_peak_shift(templates, peak_sign="both") # DEBUG # import matplotlib.pyplot as plt @@ -91,13 +83,14 @@ def test_get_template_extremum_channel_peak_shift(sorting_analyzer): # plt.show() -def test_get_template_extremum_amplitude(sorting_analyzer): +def test_get_template_main_channel_amplitude(sorting_analyzer): - extremum_channels_ids = get_template_extremum_amplitude(sorting_analyzer, peak_sign="both") + extremum_channels_ids = get_template_main_channel_amplitude(sorting_analyzer) print(extremum_channels_ids) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - extremum_channels_ids = get_template_extremum_amplitude(templates, peak_sign="both") + + extremum_channels_ids = get_template_main_channel_amplitude(templates, peak_sign="both", peak_mode="extremum") if __name__ == "__main__": @@ -107,6 +100,5 @@ def test_get_template_extremum_amplitude(sorting_analyzer): print(sorting_analyzer) test_get_template_amplitudes(sorting_analyzer) - test_get_template_extremum_channel(sorting_analyzer) - test_get_template_extremum_channel_peak_shift(sorting_analyzer) - test_get_template_extremum_amplitude(sorting_analyzer) + test_get_template_main_channel_peak_shift(sorting_analyzer) + test_get_template_main_channel_amplitude(sorting_analyzer) diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index e860de10ab..b62969d441 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -4,7 +4,7 @@ from spikeinterface import SortingAnalyzer -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift, get_template_amplitudes +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift, get_template_amplitudes from spikeinterface.postprocessing import align_sorting _remove_strategies = ("minimum_shift", "highest_amplitude", "max_spikes") @@ -81,7 +81,7 @@ def remove_redundant_units( if align and unit_peak_shifts is None: assert sorting_analyzer is not None, "For align=True must give a SortingAnalyzer or explicit unit_peak_shifts" - unit_peak_shifts = get_template_extremum_channel_peak_shift(sorting_analyzer) + unit_peak_shifts = get_template_main_channel_peak_shift(sorting_analyzer, with_dict=True) if align: sorting_aligned = align_sorting(sorting, unit_peak_shifts) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 66b71a1b1c..a3d6027f27 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -6,7 +6,7 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs import spikeinterface.widgets as sw -from spikeinterface.core import get_template_extremum_amplitude +from spikeinterface.core import get_template_main_channel_amplitude from spikeinterface.postprocessing import compute_correlograms @@ -105,7 +105,7 @@ def export_report( units["max_on_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) units["main_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) - units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign)) + units["amplitude"] = pd.Series(get_template_main_channel_amplitude(sorting_analyzer)) units.to_csv(output_folder / "unit list.csv", sep="\t") unit_colors = sw.get_unit_colors(sorting) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index b57b1cad4b..7c937bf785 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -20,7 +20,7 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.core import SortingAnalyzer, get_noise_levels, NumpySorting from spikeinterface.core.template_tools import ( - get_template_extremum_amplitude, + get_template_main_channel_amplitude, get_dense_templates_array, ) from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate @@ -182,7 +182,7 @@ def compute_snrs( channel_ids = sorting_analyzer.channel_ids main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) - unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign, mode=peak_mode) + unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer, with_dict=True) snrs = {} for unit_id in unit_ids: diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 5a7aebf728..aa8768146e 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -4,7 +4,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type @@ -37,7 +37,7 @@ def _get_pipeline_nodes(self): extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) - peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) + peak_shifts = get_template_main_channel_peak_shift(self.sorting_analyzer, with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index 87ff3cdeb7..4a17c7a829 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -7,7 +7,7 @@ create_sorting_analyzer, generate_ground_truth_recording, set_global_job_kwargs, - get_template_extremum_amplitude, + get_template_main_channel_amplitude, ) from spikeinterface.core.generate import inject_some_split_units @@ -85,7 +85,7 @@ def get_dataset_to_merge(): analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) analyzer_raw.compute(["random_spikes", "templates"]) # select 3 largest templates to split - sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + sort_by_amp = np.argsort(get_template_main_channel_amplitude(analyzer_raw, with_dict=False))[::-1] split_ids = sorting.unit_ids[sort_by_amp][:3] sorting_with_splits, split_unit_ids = inject_some_split_units( @@ -116,7 +116,7 @@ def get_dataset_to_split(): analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) analyzer_raw.compute(["random_spikes", "templates"]) # select 3 largest templates to split - sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + sort_by_amp = np.argsort(list(get_template_main_channel_amplitude(analyzer_raw).values()))[::-1] large_units = sorting.unit_ids[sort_by_amp][:2] return recording, sorting, large_units diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 8a2b64b6d2..3c133ad129 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -15,7 +15,7 @@ from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.sorting_tools import get_numba_vector_to_list_of_spiketrain @@ -629,7 +629,7 @@ def clean_templates( if max_jitter_ms is not None: max_jitter = int(max_jitter_ms * templates.sampling_frequency / 1000.0) n_before = len(templates.unit_ids) - shifts = get_template_extremum_channel_peak_shift(templates) + shifts = get_template_main_channel_peak_shift(templates, with_dict=True) to_select = [] for unit_id in templates.unit_ids: if np.abs(shifts[unit_id]) <= max_jitter: diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index b759beb963..8c1cbb3af9 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -6,7 +6,7 @@ from .utils import get_unit_colors -from spikeinterface.core.template_tools import get_template_extremum_amplitude +from spikeinterface.core.template_tools import get_template_main_channel_amplitude class UnitDepthsWidget(BaseWidget): @@ -44,7 +44,7 @@ def __init__( unit_locations = ulc.get_data(outputs="numpy") unit_depths = unit_locations[:, depth_axis] - unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign) + unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer) unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids]) num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="array") From 84e8d379cf88d429cd9f06381f7f269789f19654 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 25 Mar 2026 18:57:13 +0100 Subject: [PATCH 04/20] wip peak sign remove --- src/spikeinterface/core/sortinganalyzer.py | 8 ++++++-- src/spikeinterface/core/sparsity.py | 4 +--- src/spikeinterface/core/tests/test_sparsity.py | 4 ++-- src/spikeinterface/postprocessing/localization_tools.py | 7 ------- src/spikeinterface/postprocessing/spike_amplitudes.py | 8 +++----- src/spikeinterface/postprocessing/spike_locations.py | 5 ----- src/spikeinterface/widgets/unit_depths.py | 4 +--- 7 files changed, 13 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 53495aa528..3892e68c51 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -413,6 +413,8 @@ def create( backend_options=None, ): assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" + assert main_channel_index is not None, "To create a SortingAnalyzer you need to specify the main_channel_index" + # some checks if sorting.sampling_frequency != recording.sampling_frequency: if math.isclose(sorting.sampling_frequency, recording.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5): @@ -443,8 +445,8 @@ def create( sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) # This will ensure that the sorting saved always will have this main_channel - assert main_channel_index is not None - sorting.set_property("main_channel_index", main_channel_index) + + if format == "memory": sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, @@ -481,6 +483,8 @@ def create( else: raise ValueError("SortingAnalyzer.create: wrong format") + sorting_analyzer.set_sorting_property.set_sorting_property("main_channel_index", main_channel_index, save=True) + return sorting_analyzer @classmethod diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 52184abe93..9d766732fa 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -406,8 +406,6 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um): A Templates or a SortingAnalyzer object. radius_um : float Radius in um for "radius" method. - peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels. Returns ------- @@ -860,7 +858,7 @@ def estimate_sparsity( sparsity = ChannelSparsity.from_closest_channels(templates, num_channels) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" - sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_radius(templates, radius_um) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" assert noise_levels is not None, ( diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 2b5fd95c1a..071de348e0 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -325,6 +325,6 @@ def test_compute_sparsity(): if __name__ == "__main__": - # test_ChannelSparsity() - # test_estimate_sparsity() + test_ChannelSparsity() + test_estimate_sparsity() test_compute_sparsity() diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index dd5ead8f7e..9022ea12a1 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -674,13 +674,6 @@ def compute_location_max_channel( A SortingAnalyzer or Templates object unit_ids: list[str] | list[int] | None A list of unit_id to restrict the computation - peak_sign : "neg" | "pos" | "both" - Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" - Where the amplitude is computed - * "extremum" : take the peak value (max or min depending on `peak_sign`) - * "at_index" : take value at `nbefore` index - * "peak_to_peak" : take the peak-to-peak amplitude Returns ------- diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 1d1c199e50..45d8a6d0ae 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -15,22 +15,20 @@ class ComputeSpikeAmplitudes(BaseSpikeVectorExtension): Parameters ---------- - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute extremum channel used to retrieve spike amplitudes. + """ extension_name = "spike_amplitudes" depend_on = ["templates"] nodepipeline_variables = ["amplitudes"] - def _set_params(self, peak_sign="neg"): - return super()._set_params(peak_sign=peak_sign) + def _set_params(self): + return super()._set_params() def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["peak_sign"] return_in_uV = self.sorting_analyzer.return_in_uV extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 7499aa8b30..5fb15a6ea1 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -13,8 +13,6 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - peak_sign : "neg" | "pos" | "both", default: "neg" - The peak sign to use when looking for the template extremum channel. spike_retriever_kwargs : dict Arguments to control the spike retriever behavior. See `spikeinterface.sortingcomponents.peak_localization.SpikeRetriever`. @@ -37,14 +35,12 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): def _handle_backward_compatibility_on_load(self): # For backwards compatibility - this renames spike_retriver_kwargs to spike_retriever_kwargs if "spike_retriver_kwargs" in self.params: - self.params["peak_sign"] = self.params["spike_retriver_kwargs"].get("peak_sign", "neg") self.params["spike_retriever_kwargs"] = self.params.pop("spike_retriver_kwargs") def _set_params( self, ms_before=0.5, ms_after=0.5, - peak_sign="neg", spike_retriever_kwargs=None, method="center_of_mass", method_kwargs={}, @@ -54,7 +50,6 @@ def _set_params( return super()._set_params( ms_before=ms_before, ms_after=ms_after, - peak_sign=peak_sign, spike_retriever_kwargs=spike_retriever_kwargs, method=method, method_kwargs=method_kwargs, diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 5080159422..378948817b 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -20,12 +20,10 @@ class UnitDepthsWidget(BaseWidget): by matplotlib. If None, default colors are chosen using the `get_some_colors` function. depth_axis : int, default: 1 The dimension of unit_locations that is depth - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of peak for amplitudes """ def __init__( - self, sorting_analyzer, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs + self, sorting_analyzer, unit_colors=None, depth_axis=1, backend=None, **backend_kwargs ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) From 91206acabc7833b483b8ecf20796f7f058101f1a Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 23 Jun 2026 14:56:41 +0100 Subject: [PATCH 05/20] Change Template's default for get_main_channels output format --- src/spikeinterface/core/template.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 9cb407589c..619ee22021 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np import json from dataclasses import dataclass, field, astuple, replace @@ -483,10 +485,10 @@ def get_channel_locations(self) -> np.ndarray: return channel_locations def get_main_channels(self, - peak_sign: "neg" | "both" | "pos" = "both", - peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_sign : Literal["neg","both","pos"] = "both", + peak_mode : Literal["extremum","at_index","peak_to_peak"] = "extremum", outputs="index", - with_dict=True + with_dict=False, ): from .template_tools import _get_main_channel_from_template_array From 6e00b9221853e879d76441739b6b45e9504b51ba Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 23 Jun 2026 15:22:19 +0100 Subject: [PATCH 06/20] get most tests passing --- src/spikeinterface/core/sortinganalyzer.py | 142 ++++++++++-------- src/spikeinterface/core/sparsity.py | 32 ++-- src/spikeinterface/core/template_tools.py | 104 +++++++------ .../core/tests/test_sparsity.py | 9 +- .../tests/test_spike_locations.py | 2 +- 5 files changed, 154 insertions(+), 135 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0283dddd6c..70f149bc23 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -48,7 +48,6 @@ from .sorting_tools import random_spikes_selection - # high level function def create_sorting_analyzer( sorting, @@ -68,7 +67,7 @@ def create_sorting_analyzer( backend_options=None, sparsity_kwargs=None, seed=None, - **job_kwargs + **job_kwargs, ) -> "SortingAnalyzer": """ Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. @@ -197,7 +196,7 @@ def create_sorting_analyzer( overwrite=overwrite, backend_options=backend_options, sparsity_kwargs=sparsity_kwargs, - **job_kwargs + **job_kwargs, ) if format != "memory": @@ -211,16 +210,15 @@ def create_sorting_analyzer( else: shutil.rmtree(folder) - - # retrieve or compute the main channel index per unit if main_channel_index is None: if "main_channel_index" in sorting.get_property_keys(): main_channel_index = sorting.get_property("main_channel_index") - + if main_channel_index is None: # this is weird but due to the cyclic import from .template_tools import estimate_main_channel_from_recording + main_channel_index = estimate_main_channel_from_recording( recording, sorting, @@ -228,7 +226,7 @@ def create_sorting_analyzer( peak_mode=main_channel_peak_mode, num_spikes_for_main_channel=num_spikes_for_main_channel, seed=seed, - **job_kwargs + **job_kwargs, ) # handle sparsity @@ -241,7 +239,9 @@ def create_sorting_analyzer( assert np.array_equal( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" - assert all(sparsity.mask[u, c] for u, c in enumerate(main_channel_index)), "sparsity si not constistentent with main_channel_index" + assert all( + sparsity.mask[u, c] for u, c in enumerate(main_channel_index) + ), "sparsity si not constistentent with main_channel_index" elif sparse: sparsity = estimate_sparsity(sorting, recording, main_channel_index=main_channel_index, **sparsity_kwargs) else: @@ -443,16 +443,19 @@ def create( from spikeinterface.curation.remove_excess_spikes import RemoveExcessSpikesSorting sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) - + # This will ensure that the sorting saved always will have this main_channel - - - + if format == "memory": - sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, - main_channel_peak_sign, - main_channel_peak_mode, - rec_attributes=None) + sorting_analyzer = cls.create_memory( + sorting, + recording, + sparsity, + return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes=None, + ) elif format == "binary_folder": sorting_analyzer = cls.create_binary_folder( folder, @@ -483,7 +486,7 @@ def create( else: raise ValueError("SortingAnalyzer.create: wrong format") - sorting_analyzer.set_sorting_property.set_sorting_property("main_channel_index", main_channel_index, save=True) + sorting_analyzer.set_sorting_property("main_channel_index", main_channel_index, save=True) return sorting_analyzer @@ -517,10 +520,9 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe return sorting_analyzer @classmethod - def create_memory(cls, sorting, recording, sparsity, return_in_uV, - main_channel_peak_sign, - main_channel_peak_mode, - rec_attributes): + def create_memory( + cls, sorting, recording, sparsity, return_in_uV, main_channel_peak_sign, main_channel_peak_mode, rec_attributes + ): # used by create and save_as if rec_attributes is None: @@ -543,16 +545,22 @@ def create_memory(cls, sorting, recording, sparsity, return_in_uV, return_in_uV=return_in_uV, main_channel_peak_sign=main_channel_peak_sign, main_channel_peak_mode=main_channel_peak_mode, - ) return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, - main_channel_peak_sign, - main_channel_peak_mode, - - rec_attributes, backend_options): + def create_binary_folder( + cls, + folder, + sorting, + recording, + sparsity, + return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes, + backend_options, + ): # used by create and save_as folder = Path(folder) @@ -654,49 +662,56 @@ def _handle_backward_compatibility_settings_pre_init(cls, settings, sorting, spa new_settings["main_channel_peak_mode"] = "extremum" return new_settings - + def _handle_backward_compatibility_settings_post_init(self): """ backward compatibility after the __init__ to : * main_channel_index - + Note : * see also _handle_backward_compatibility_settings_pre_init * there is also something at extension level to handle changes in paramaters with deferents mechanism """ - if "main_channel_index" not in self.sorting.get_property_keys(): - - warnings.warn("This loaded analyzer is from an older verion main_channel_index need to be computed from templates") + + warnings.warn( + "This loaded analyzer is from an older verion main_channel_index need to be computed from templates" + ) main_channel_index = None if self.has_extension("templates"): # first try to load templates extension ext = self.get_extension("templates") - + for k in ("average", "median"): if k in ext.data: from .template_tools import _get_main_channel_from_template_array + templates_array = ext.data[k] # TODO @alessio @chris : we need to discuss this - peak_sign = "both" # or "neg" ????? + peak_sign = "both" # or "neg" ????? peak_mode = "extremum" - main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, ext.nbefore) + main_channel_index = _get_main_channel_from_template_array( + templates_array, peak_mode, peak_sign, ext.nbefore + ) break if main_channel_index is None: if not self.has_recording(): # TODO @alessio @chris : we need to discuss this # what to do in this case ??????? - raise ValueError("This analyzer cannot be load and is from an old version, the recording is not available") + raise ValueError( + "This analyzer cannot be load and is from an old version, the recording is not available" + ) else: - # otherwise we need to estimate the + # otherwise we need to estimate the from .template_tools import estimate_main_channel_from_recording + # TODO @alessio @chris : we need to discuss this - peak_sign = "both" # or "neg" ????? + peak_sign = "both" # or "neg" ????? peak_mode = "extremum" main_channel_index = estimate_main_channel_from_recording( @@ -714,7 +729,6 @@ def _handle_backward_compatibility_settings_post_init(self): # this save also to disk but maybe there is no write for the analyzer... self.set_sorting_property("main_channel_index", main_channel_index, save=True) - @classmethod def load_from_binary_folder(cls, folder, recording=None, backend_options=None): from .loading import load @@ -783,17 +797,15 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): with open(settings_file, "w") as f: json.dump(check_json(settings), f, indent=4) - - sorting_analyzer = SortingAnalyzer( sorting=sorting, recording=recording, rec_attributes=rec_attributes, format="binary_folder", sparsity=sparsity, - return_in_uV = settings["return_in_uV"], - main_channel_peak_sign = settings["main_channel_peak_sign"], - main_channel_peak_mode = settings["main_channel_peak_mode"], + return_in_uV=settings["return_in_uV"], + main_channel_peak_sign=settings["main_channel_peak_sign"], + main_channel_peak_mode=settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -808,11 +820,18 @@ def _get_zarr_root(self, mode="r+"): return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, - main_channel_peak_sign, - main_channel_peak_mode, - - rec_attributes, backend_options): + def create_zarr( + cls, + folder, + sorting, + recording, + sparsity, + return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes, + backend_options, + ): # used by create and save_as import zarr from .zarrextractors import add_sorting_to_zarr_group @@ -961,7 +980,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): ) else: sparsity = None - + settings = zarr_root.attrs["settings"] settings = cls._handle_backward_compatibility_settings_pre_init(settings, sorting, sparsity) @@ -971,9 +990,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="zarr", sparsity=sparsity, - return_in_uV = settings["return_in_uV"], - main_channel_peak_sign = settings["main_channel_peak_sign"], - main_channel_peak_mode = settings["main_channel_peak_mode"], + return_in_uV=settings["return_in_uV"], + main_channel_peak_sign=settings["main_channel_peak_sign"], + main_channel_peak_mode=settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -1075,11 +1094,9 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n Array of values for the property """ return self.sorting.get_property(key, ids=ids) - - def get_main_channels(self, outputs="index", with_dict=False): - """ - """ + def get_main_channels(self, outputs="index", with_dict=False): + """ """ main_channel_index = self.get_sorting_property("main_channel_index") if outputs == "index": main_chans = main_channel_index @@ -1087,13 +1104,12 @@ def get_main_channels(self, outputs="index", with_dict=False): main_chans = self.channel_ids[main_channel_index] else: raise ValueError("wrong outputs") - + if with_dict: return dict(zip(self.unit_ids, main_chans)) else: return main_chans - def are_units_mergeable( self, merge_unit_groups: list[str | int], @@ -1315,10 +1331,13 @@ def _save_or_select_or_merge_or_split( if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( - sorting_provenance, recording, sparsity, self.return_in_uV, + sorting_provenance, + recording, + sparsity, + self.return_in_uV, self.main_channel_peak_sign, self.main_channel_peak_mode, - self.rec_attributes + self.rec_attributes, ) elif format == "binary_folder": @@ -1333,7 +1352,6 @@ def _save_or_select_or_merge_or_split( self.return_in_uV, self.main_channel_peak_sign, self.main_channel_peak_mode, - self.rec_attributes, backend_options=backend_options, ) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 396306b0af..457717ad6c 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -321,7 +321,7 @@ def from_best_channels( mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode) + peak_values = get_template_amplitudes(templates_or_sorting_analyzer) for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] @@ -330,7 +330,7 @@ def from_best_channels( ## Some convinient function to compute sparsity from several strategy @classmethod - def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): + def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign=None): """ Construct sparsity from N closest channels Use the "num_channels" argument to specify the number of channels. @@ -394,7 +394,6 @@ def from_radius_and_main_channel(cls, unit_ids, channel_ids, main_channel_index, mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) - @classmethod def from_radius(cls, templates_or_sorting_analyzer, radius_um): """ @@ -420,15 +419,16 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um): templates_or_sorting_analyzer.channel_ids, main_channel_index, channel_locations, - radius_um) + radius_um, + ) @classmethod def from_snr( cls, templates_or_sorting_analyzer, threshold, - amplitude_mode="extremum", - peak_sign="neg", + amplitude_mode=None, + peak_sign=None, noise_levels=None, ): """ @@ -479,7 +479,7 @@ def from_snr( mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_in_uV=return_in_uV + templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=amplitude_mode ) for unit_ind, unit_id in enumerate(unit_ids): @@ -488,7 +488,7 @@ def from_snr( return cls(mask, unit_ids, channel_ids) @classmethod - def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): + def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode=None, peak_sign=None): """ Construct sparsity from a threshold based on template amplitude. The amplitude is computed with the specified amplitude mode and it is assumed @@ -534,7 +534,7 @@ def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_in_uV=True + templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=amplitude_mode ) for unit_ind, unit_id in enumerate(unit_ids): @@ -695,29 +695,25 @@ def compute_sparsity( if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" - sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels) elif method == "closest_channels": assert num_channels is not None, "For the 'closest_channels' method, 'num_channels' needs to be given" sparsity = ChannelSparsity.from_closest_channels(templates_or_sorting_analyzer, num_channels) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" - sparsity = ChannelSparsity.from_radius(templates_or_sorting_analyzer, radius_um, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_radius(templates_or_sorting_analyzer, radius_um) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_snr( templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, - peak_sign=peak_sign, - amplitude_mode=amplitude_mode, ) elif method == "amplitude": assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_amplitude( templates_or_sorting_analyzer, threshold, - amplitude_mode=amplitude_mode, - peak_sign=peak_sign, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -760,7 +756,7 @@ def estimate_sparsity( These can be computed with the `get_noise_levels()` function. If main_channel_index is given and method="radius" then there is not need estimate - the templates otherwise the templates must be estimated using + the templates otherwise the templates must be estimated using `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse the recording. @@ -782,7 +778,7 @@ def estimate_sparsity( Noise levels required for the "snr" and "energy" methods. You can use the `get_noise_levels()` function to compute them. main_channel_index : np.array | None, default: None - Main channel indicies can be provided in case of method="radius", this avoid the + Main channel indicies can be provided in case of method="radius", this avoid the `estimate_templates_with_accumulator()` which is slow. {} @@ -808,7 +804,7 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - if method == "radius" and main_channel_index is not None: + if method == "radius" and main_channel_index is not None: assert main_channel_index.size == sorting.unit_ids.size chan_locs = recording.get_channel_locations() sparsity = ChannelSparsity.from_radius_and_main_channel( diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 8d7b9443bf..6228390932 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -10,7 +10,6 @@ import warnings - def get_dense_templates_array( one_object: Templates | SortingAnalyzer, return_in_uV: bool = True, operator="average" ) -> np.ndarray: @@ -52,6 +51,7 @@ def get_dense_templates_array( raise ValueError("Input should be Templates or SortingAnalyzer") return templates_array + def get_main_channel_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: bool = True): """ Return dense templates as numpy array from either a Templates object or a SortingAnalyzer. @@ -79,7 +79,6 @@ def get_main_channel_templates_array(one_object: Templates | SortingAnalyzer, re return main_channel_templates - def _get_nbefore(one_object): if isinstance(one_object, Templates): return one_object.nbefore @@ -94,8 +93,8 @@ def _get_nbefore(one_object): def get_template_amplitudes( templates_or_sorting_analyzer, - peak_sign: None | "neg" | "pos" | "both" = None, - peak_mode: None | "extremum" | "at_index" | "peak_to_peak" = None, + peak_sign: None | Literal["neg", "pos", "both"] = None, + peak_mode: None | Literal["extremum", "at_index", "peak_to_peak"] = None, # return_in_uV: bool = True, abs_value: bool = True, operator: str = "average", @@ -129,8 +128,12 @@ def get_template_amplitudes( """ if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - assert peak_sign is None, "get_template_amplitudes() peak_sign is now contained in SortingAnalyzer, should be None here" - assert peak_mode is None, "get_template_amplitudes() peak_mode is now contained in SortingAnalyzer, should be None here" + assert ( + peak_sign is None + ), "get_template_amplitudes() peak_sign is now contained in SortingAnalyzer, should be None here" + assert ( + peak_mode is None + ), "get_template_amplitudes() peak_mode is now contained in SortingAnalyzer, should be None here" peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign peak_mode = templates_or_sorting_analyzer.main_channel_peak_mode return_in_uV = templates_or_sorting_analyzer.return_in_uV @@ -144,7 +147,11 @@ def get_template_amplitudes( peak_mode = "extremum" assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" - assert peak_mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" + assert peak_mode in ( + "extremum", + "at_index", + "peak_to_peak", + ), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" unit_ids = templates_or_sorting_analyzer.unit_ids before = _get_nbefore(templates_or_sorting_analyzer) @@ -181,7 +188,6 @@ def get_template_amplitudes( return peak_values - def _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore): # Step1 : max on time axis if peak_mode == "extremum": @@ -198,7 +204,7 @@ def _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, values = templates_array[:, nbefore, :] elif peak_mode == "peak_to_peak": values = np.ptp(templates_array, axis=1) - + # Step2: max on channel axis main_channel_index = np.argmax(values, axis=1) @@ -206,15 +212,15 @@ def _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, def estimate_main_channel_from_recording( - recording, - sorting, - peak_sign: "neg" | "both" | "pos" = "both", - peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", - num_spikes_for_main_channel=100, - ms_before = 1.0, - ms_after = 2.5, - seed=None, - **job_kwargs + recording, + sorting, + peak_sign: Literal["neg", "both", "pos"] = "both", + peak_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + num_spikes_for_main_channel=100, + ms_before=1.0, + ms_after=2.5, + seed=None, + **job_kwargs, ): """ Estimate the main channel from recording using `estimate_templates_with_accumulator()` @@ -223,11 +229,10 @@ def estimate_main_channel_from_recording( if peak_sign == "pos": warnings.warn( - "estimate_main_channel_from_recording() with peak_sign='pos' is a strange case maybe you " \ + "estimate_main_channel_from_recording() with peak_sign='pos' is a strange case maybe you " "should revert the traces instead" ) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) @@ -259,15 +264,12 @@ def estimate_main_channel_from_recording( return main_channel_index - - - # TODO remove this in 0.105.0 def get_template_extremum_channel( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", - outputs: "id" | "index" = "id", + peak_sign: Literal["neg", "pos", "both"] = "neg", + peak_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + outputs: Literal["id", "index"] = "id", operator: str = "average", ): """ @@ -304,17 +306,19 @@ def get_template_extremum_channel( warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channels() instead") if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - assert peak_sign is None, "get_template_extremum_channel() peak_sign is now contained in SortingAnalyzer, should be None here" + assert ( + peak_sign is None + ), "get_template_extremum_channel() peak_sign is now contained in SortingAnalyzer, should be None here" main_channels = templates_or_sorting_analyzer.get_main_channels(outputs=outputs, with_dict=True) elif isinstance(templates_or_sorting_analyzer, Templates): if peak_sign is None: warnings.warn("get_template_extremum_channel() with Templates should provide a peak_sign") peak_sign = "both" - main_channels = templates_or_sorting_analyzer.get_main_channels(outputs=outputs, peak_sign=peak_sign, with_dict=True) - - return main_channels - + main_channels = templates_or_sorting_analyzer.get_main_channels( + outputs=outputs, peak_sign=peak_sign, with_dict=True + ) + return main_channels # TODO remove this in 0.105.0 @@ -340,9 +344,10 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak Dictionary with unit ids as keys and shifts as values """ - warnings.warn("get_template_extremum_channel_peak_shift() is deprecated use get_template_main_channel_peak_shift() instead" - "Will be removed in 0.105.0" - ) + warnings.warn( + "get_template_extremum_channel_peak_shift() is deprecated use get_template_main_channel_peak_shift() instead" + "Will be removed in 0.105.0" + ) return get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None, with_dict=True) @@ -368,7 +373,6 @@ def get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sig shifts : dict Dictionary with unit ids as keys and shifts as values """ - if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): assert peak_sign is None @@ -378,12 +382,13 @@ def get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sig if peak_sign is None: warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") peak_sign = "both" - main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", peak_sign=peak_sign, with_dict=False) + main_channels = templates_or_sorting_analyzer.get_main_channels( + outputs="index", peak_sign=peak_sign, with_dict=False + ) unit_ids = templates_or_sorting_analyzer.unit_ids nbefore = _get_nbefore(templates_or_sorting_analyzer) - main_channel_templates = get_main_channel_templates_array(templates_or_sorting_analyzer) shifts = [] @@ -410,15 +415,15 @@ def get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sig # TODO remove this in 0.105.0 def get_template_extremum_amplitude( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - peak_mode: "extremum" | "at_index" | "peak_to_peak" = "at_index", + peak_sign: Literal["neg", "pos", "both"] = "neg", + peak_mode: Literal["extremum", "at_index", "peak_to_peak"] = "at_index", abs_value: bool = True, operator: Literal["average", "median"] = "average", ): """ Depracted will be removed in 0.105.0. Use get_template_main_channel_amplitude() instead. - + Computes amplitudes on the best channel. Parameters @@ -445,21 +450,22 @@ def get_template_extremum_amplitude( Dictionary with unit ids as keys and amplitudes as values """ - warnings.warn("get_template_extremum_amplitude() is deprecated use get_template_main_channel_amplitude() instead" - "Will be removed in 0.105.0" - ) + warnings.warn( + "get_template_extremum_amplitude() is deprecated use get_template_main_channel_amplitude() instead" + "Will be removed in 0.105.0" + ) return get_template_main_channel_amplitude( templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=peak_mode, abs_value=abs_value, - ) + def get_template_main_channel_amplitude( templates_or_sorting_analyzer, - peak_sign: None | "neg" | "pos" | "both" = None, - peak_mode: None | "extremum" | "at_index" | "peak_to_peak" = None, + peak_sign: None | Literal["neg", "pos", "both"] = None, + peak_mode: None | Literal["extremum", "at_index", "peak_to_peak"] = None, abs_value: bool = True, with_dict=True, ): @@ -492,7 +498,9 @@ def get_template_main_channel_amplitude( if peak_sign is None: warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") peak_sign = "both" - main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", peak_sign=peak_sign, with_dict=False) + main_channels = templates_or_sorting_analyzer.get_main_channels( + outputs="index", peak_sign=peak_sign, with_dict=False + ) extremum_amplitudes = get_template_amplitudes( templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=peak_mode, abs_value=abs_value @@ -507,4 +515,4 @@ def get_template_main_channel_amplitude( if with_dict: unit_amplitudes = dict(zip()) - return unit_amplitudes \ No newline at end of file + return unit_amplitudes diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 071de348e0..90ee3867c1 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -303,14 +303,11 @@ def test_compute_sparsity(): # this is needed for method="energy" # using object SortingAnalyzer - sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2) sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0) sparsity = compute_sparsity(sorting_analyzer, method="closest_channels", num_channels=2) - sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") - sparsity = compute_sparsity( - sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" - ) - sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5) + sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index f457bd9250..645314237a 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -8,7 +8,7 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( "params", [ - dict(method="center_of_mass", peak_sign="both"), + dict(method="center_of_mass"), dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=True)), dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=False)), dict(method="center_of_mass"), From a51a24fddec731fea5a7c81225b7d095bf4b4d5a Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 23 Jun 2026 16:44:45 +0100 Subject: [PATCH 07/20] Make main_channel_index a property --- .../core/analyzer_extension_core.py | 2 +- src/spikeinterface/core/sortinganalyzer.py | 146 +++++++++++------- 2 files changed, 95 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 3eeede22b5..7f8d210209 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -381,7 +381,7 @@ class ComputeTemplates(AnalyzerExtension): extension_name = "templates" depend_on = ["random_spikes|waveforms"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = True need_backward_compatibility_on_load = True diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 70f149bc23..3e9921f0be 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -351,6 +351,7 @@ def __init__( self.return_in_uV = return_in_uV self.main_channel_peak_sign = main_channel_peak_sign self.main_channel_peak_mode = main_channel_peak_mode + self._main_channel_index = None # For backward compatibility self.return_scaled = return_in_uV @@ -663,71 +664,112 @@ def _handle_backward_compatibility_settings_pre_init(cls, settings, sorting, spa return new_settings - def _handle_backward_compatibility_settings_post_init(self): + @property + def main_channel_index(self): + if self._main_channel_index is None: + sorting_main_channel_index = self.get_sorting_property("main_channel_index") + if sorting_main_channel_index is not None: + self._main_channel_index = sorting_main_channel_index + else: + self._main_channel_index = self.compute_main_channel_backwards_compatibility() + + return self._main_channel_index + + def compute_main_channel_backwards_compatibility(self): """ - backward compatibility after the __init__ to : - * main_channel_index + Computes the `main_channel_index` for an old analyzer. Logic is: - Note : - * see also _handle_backward_compatibility_settings_pre_init - * there is also something at extension level to handle changes in paramaters with deferents mechanism + 1) If you have the `templates` extension: use this to find the main_channel_index, with + default settings and restricted to the sparsity of the analyzer. + + 2) If you do not have the `templates` extension but do have `waveforms`: compute + `templates`, then go to 1) + + 3) If you do not have `templates` or `waveforms`, but do have `sparsity = True`, we + will take the "average" channel as the `main_channel_index` + + 4) Failing that, if you have an attached `recording`, compute the `main_channel_index` + using the accumulator. + + 5) If you have a dense analyzer with no `templates`, `waveforms` or `recording`, your + analyzer is not compatible with newer versions of SpikeInterface. Raise an error + and ask the user to attach a recording. """ - if "main_channel_index" not in self.sorting.get_property_keys(): + warnings.warn( + "This sorting analyzer is from an an older version of spikeinterface. " + "For future compatibility we will compute the `main_channel_indices`." + ) - warnings.warn( - "This loaded analyzer is from an older verion main_channel_index need to be computed from templates" + main_channel_index = None + + templates_array = None + peak_sign = "both" + peak_mode = "extremum" + + # Case 1 + if self.has_extension("templates"): + templates = self.get_extension("templates") + for k in ("average", "median"): + if k in templates.data: + from .template_tools import _get_main_channel_from_template_array + + templates_array = templates.data[k] + break + else: + # Case 2 - from waveforms + if self.has_extension("waveforms") and self.has_extension("random_spikes"): + from spikeinterface.core.analyzer_extension_core import ComputeTemplates + + templates = ComputeTemplates(self) + templates_array = templates.data["average"] + + if templates_array is not None: + from .template_tools import _get_main_channel_from_template_array + + main_channel_index = _get_main_channel_from_template_array( + templates_array, peak_mode, peak_sign, templates.nbefore ) - main_channel_index = None - if self.has_extension("templates"): - # first try to load templates extension - ext = self.get_extension("templates") - - for k in ("average", "median"): - if k in ext.data: - from .template_tools import _get_main_channel_from_template_array - - templates_array = ext.data[k] - # TODO @alessio @chris : we need to discuss this - peak_sign = "both" # or "neg" ????? - peak_mode = "extremum" - main_channel_index = _get_main_channel_from_template_array( - templates_array, peak_mode, peak_sign, ext.nbefore - ) - break + # Case 3 + if self.is_sparse(): + channel_locations = self.get_channel_locations() + sparsity = self.sparsity + main_channel_index = [] + for channel_indices in sparsity.unit_id_to_channel_indices.values(): + unit_channel_locations = channel_locations[channel_indices] + average_unit_channel_location = np.average(unit_channel_locations, axis=0) + distance_from_average_channel = np.linalg.norm( + channel_locations - average_unit_channel_location, axis=1 + ) + closest_channel_index = np.argmin(distance_from_average_channel) + main_channel_index.append(closest_channel_index) - if main_channel_index is None: - if not self.has_recording(): - # TODO @alessio @chris : we need to discuss this - # what to do in this case ??????? - raise ValueError( - "This analyzer cannot be load and is from an old version, the recording is not available" - ) - else: + return main_channel_index - # otherwise we need to estimate the + # Case 4 + if self.has_recording() or self.has_temporary_recording(): + from .template_tools import estimate_main_channel_from_recording - from .template_tools import estimate_main_channel_from_recording + main_channel_index = estimate_main_channel_from_recording( + self.recording, + self.sorting, + peak_sign=peak_sign, + peak_mode=peak_mode, + num_spikes_for_main_channel=100, + ) - # TODO @alessio @chris : we need to discuss this - peak_sign = "both" # or "neg" ????? - peak_mode = "extremum" + return main_channel_index - main_channel_index = estimate_main_channel_from_recording( - self.recording, - self.sorting, - peak_sign=peak_sign, - peak_mode=peak_mode, - num_spikes_for_main_channel=100, - seed=None, - ) + raise ValueError( + "This analyzer is dense, and has no attached recording, waveforms or templates. Hence we cannot estimate the `main_channel_indices`, making the analyzer incompatible with newer versions of spikeinterface. Please attach a recording to continue, or re-create your analyzer from scratch." + ) - # this is only in memory - self.sorting.set_property("main_channel_index", main_channel_index) - # TODO @alessio @chris : we need to discuss this - # this save also to disk but maybe there is no write for the analyzer... - self.set_sorting_property("main_channel_index", main_channel_index, save=True) + # # this is only in memory + # self.sorting.set_property("main_channel_index", main_channel_index) + # # TODO @alessio @chris : we need to discuss this + # # this save also to disk but maybe there is no write for the analyzer... + # self.set_sorting_property("main_channel_index", main_channel_index, save=True) @classmethod def load_from_binary_folder(cls, folder, recording=None, backend_options=None): From 0b81eb0f802dd831147799277fdc166c8ecead64 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 23 Jun 2026 16:50:36 +0100 Subject: [PATCH 08/20] get waveform extractor working --- .../core/waveforms_extractor_backwards_compatibility.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index b1510fb27d..43ed16fbd8 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -493,7 +493,13 @@ def _read_old_waveforms_extractor_binary(folder, sorting): sorting = load(folder / "sorting.pickle", base_folder=folder) sorting_analyzer = SortingAnalyzer.create_memory( - sorting, recording, sparsity=sparsity, return_in_uV=return_in_uV, rec_attributes=rec_attributes + sorting, + recording, + sparsity=sparsity, + main_channel_peak_mode="extremum", + main_channel_peak_sign="neg", + return_in_uV=return_in_uV, + rec_attributes=rec_attributes, ) # waveforms From 49b519978fd651b30a2a2c512b9840e8f4ecf5f8 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 23 Jun 2026 16:57:44 +0100 Subject: [PATCH 09/20] fix 3d channel stuff --- src/spikeinterface/core/generate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index e3551d1b11..95b1f37347 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2427,10 +2427,10 @@ def generate_ground_truth_recording( **generate_templates_kwargs, ) sorting.set_property("gt_unit_locations", unit_locations) - distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2) + distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :2], axis=2) main_channel_index = np.argmin(distances, axis=1) sorting.set_property("main_channel_index", main_channel_index) - + else: assert templates.shape[0] == num_units From 72e2ced6ccf65e99832389c6eec281052c50baa3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jun 2026 16:01:47 +0000 Subject: [PATCH 10/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/__init__.py | 5 ++--- src/spikeinterface/core/basesorting.py | 1 + src/spikeinterface/core/node_pipeline.py | 1 + src/spikeinterface/core/template.py | 10 +++++----- src/spikeinterface/core/tests/test_loading.py | 1 - src/spikeinterface/core/tests/test_node_pipeline.py | 4 ++-- src/spikeinterface/core/tests/test_sortinganalyzer.py | 4 ++-- src/spikeinterface/core/tests/test_template_tools.py | 3 +-- src/spikeinterface/generation/hybrid_tools.py | 1 - src/spikeinterface/metrics/quality/misc_metrics.py | 1 - src/spikeinterface/metrics/quality/quality_metrics.py | 1 - .../metrics/template/template_metrics.py | 2 -- .../postprocessing/amplitude_scalings.py | 3 +-- .../postprocessing/localization_tools.py | 1 - src/spikeinterface/postprocessing/spike_amplitudes.py | 2 +- .../sortingcomponents/matching/nearest.py | 4 +++- .../sortingcomponents/matching/tdc_peeler.py | 5 +++-- src/spikeinterface/widgets/unit_depths.py | 4 +--- 18 files changed, 23 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index d9017421a2..7a683a3118 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -168,11 +168,10 @@ get_template_amplitudes, get_template_main_channel_peak_shift, get_template_main_channel_amplitude, - # this is not needed anymore get_template_extremum_channel, # keep for backward compatibility can be removed in 0.105 - get_template_extremum_channel_peak_shift, # keep for backward compatibility can be removed in 0.105 - get_template_extremum_amplitude, # keep for backward compatibility can be removed in 0.105 + get_template_extremum_channel_peak_shift, # keep for backward compatibility can be removed in 0.105 + get_template_extremum_amplitude, # keep for backward compatibility can be removed in 0.105 ) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index ac328f27f4..46da558388 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -12,6 +12,7 @@ class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. """ + _main_properties = [ "main_channel_index", ] diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 13f93e4df7..4871568ead 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -147,6 +147,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea # TODO sam replace extremum_channels_indices by main_channel_index + # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 619ee22021..6005fcec69 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -484,9 +484,10 @@ def get_channel_locations(self) -> np.ndarray: channel_locations = self.probe.contact_positions return channel_locations - def get_main_channels(self, - peak_sign : Literal["neg","both","pos"] = "both", - peak_mode : Literal["extremum","at_index","peak_to_peak"] = "extremum", + def get_main_channels( + self, + peak_sign: Literal["neg", "both", "pos"] = "both", + peak_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", outputs="index", with_dict=False, ): @@ -501,9 +502,8 @@ def get_main_channels(self, main_chans = self.channel_ids[main_channel_index] else: raise ValueError("wrong outputs") - + if with_dict: return dict(zip(self.unit_ids, main_chans)) else: return main_chans - diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index 7bcd2a78de..979113a2aa 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -232,4 +232,3 @@ def test_remote_analyzer(): if __name__ == "__main__": test_remote_analyzer() - \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 14df2f86d7..d3c3b719c5 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -80,7 +80,7 @@ def test_run_node_pipeline(cache_folder_creation): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) @@ -202,7 +202,7 @@ def test_skip_after_n_peaks_and_recording_slices(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 765d5bec7e..7a30ac87b2 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -56,7 +56,7 @@ def test_SortingAnalyzer_memory(tmp_path, dataset): # Note the sorting contain already main_channel_index sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) - assert np.array_equal(sorting_analyzer.get_main_channels() , sorting.get_property("main_channel_index")) + assert np.array_equal(sorting_analyzer.get_main_channels(), sorting.get_property("main_channel_index")) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -634,7 +634,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - extremum_channel_inds = self.sorting_analyzer.get_main_channels( outputs="index", with_dict=True) + extremum_channel_inds = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index 5e4cd1f12c..e78bed0bd0 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -56,11 +56,10 @@ def test_get_template_amplitudes(sorting_analyzer): peak_values = get_template_amplitudes(sorting_analyzer) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="extremum", abs_value=True) - peak_to_peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="peak_to_peak") + peak_to_peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="peak_to_peak") assert np.all(ptp > p for ptp, p in zip(peak_to_peak_values.values(), peak_values.values())) - def test_get_template_main_channel_peak_shift(sorting_analyzer): shifts = get_template_main_channel_peak_shift(sorting_analyzer) print(shifts) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 25e95bb321..1776f60ec8 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -128,7 +128,6 @@ def select_templates( # get template amplitudes and depth main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) - mask = np.ones(templates.num_units, dtype=bool) if min_amplitude is not None or max_amplitude is not None: # filter amplitudes diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index d5c3e037da..9df38ef25c 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -175,7 +175,6 @@ def compute_snrs( channel_ids = sorting_analyzer.channel_ids - main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer, with_dict=True) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 110e4fd53f..04b7e12e5d 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -162,7 +162,6 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): # Get extremum channels for neighbor selection in sparse mode - main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) # Pre-compute spike counts and firing rates if advanced NN metrics are requested diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index dc7395ba13..925852983a 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -255,13 +255,11 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) - extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) operator = self.params["template_operator"] all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True, operator=operator) - analyzer_channel_locations = sorting_analyzer.get_channel_locations() # the template metrics only work for 2D probes. We warn users with 3D locations above. channel_locations = analyzer_channel_locations[:, :2] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 29faaff860..948878028e 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -1,7 +1,7 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore +from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -104,7 +104,6 @@ def _get_pipeline_nodes(self): extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) - # collisions handle_collisions = self.params["handle_collisions"] delta_collision_ms = self.params["delta_collision_ms"] diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 14f1b4fbb9..68b2ac3408 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -102,7 +102,6 @@ def compute_monopolar_triangulation( best_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) - unit_location = np.zeros((unit_ids.size, 4), dtype="float64") for i, unit_id in enumerate(unit_ids): chan_inds = sparsity.unit_id_to_channel_indices[unit_id] diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 8d2c394e9b..75c794298a 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -15,7 +15,7 @@ class ComputeSpikeAmplitudes(BaseSpikeVectorExtension): Parameters ---------- - + """ extension_name = "spike_amplitudes" diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 7c59b857ac..9d847ec1f5 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,7 +53,9 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - main_channels = self.templates.get_main_channels(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) + main_channels = self.templates.get_main_channels( + main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False + ) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index b6182ba471..f8276e62ef 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -221,9 +221,10 @@ def __init__( self.sparse_templates_array_static = templates.templates_array self.dtype = self.sparse_templates_array_static.dtype - # as numpy vector - self.main_channels = templates.get_main_channels(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) + self.main_channels = templates.get_main_channels( + main_channel_peak_sign=peak_sign, outputs="index", with_dict=False + ) channel_locations = templates.probe.contact_positions unit_locations = channel_locations[self.main_channels] diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 378948817b..6551f6e7b0 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -22,9 +22,7 @@ class UnitDepthsWidget(BaseWidget): The dimension of unit_locations that is depth """ - def __init__( - self, sorting_analyzer, unit_colors=None, depth_axis=1, backend=None, **backend_kwargs - ): + def __init__(self, sorting_analyzer, unit_colors=None, depth_axis=1, backend=None, **backend_kwargs): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) From 01e2b22ad77ad8d4fbafeb973fa18db2c350affa Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 23 Jun 2026 17:30:08 +0100 Subject: [PATCH 11/20] 2d 3d channel locations fix --- src/spikeinterface/postprocessing/localization_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 68b2ac3408..fb5c6764fb 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -688,7 +688,7 @@ def compute_location_max_channel( unit_ids = np.asarray(unit_ids) unit_locations = np.zeros((unit_ids.size, 2), dtype="float32") for i, unit_id in enumerate(unit_ids): - unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]] + unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id], :2] return unit_locations From 1895255e07a8096a2e544829301daa85d097855e Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 23 Jun 2026 18:50:24 +0100 Subject: [PATCH 12/20] fix more bugs --- src/spikeinterface/core/sortinganalyzer.py | 2 +- src/spikeinterface/core/template_tools.py | 6 +++--- .../metrics/quality/misc_metrics.py | 17 ++++++----------- .../tests/test_multi_extensions.py | 2 +- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3e9921f0be..724c88dc50 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1139,7 +1139,7 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n def get_main_channels(self, outputs="index", with_dict=False): """ """ - main_channel_index = self.get_sorting_property("main_channel_index") + main_channel_index = self.main_channel_index if outputs == "index": main_chans = main_channel_index elif outputs == "id": diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 6228390932..e3ec916685 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -393,8 +393,7 @@ def get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sig shifts = [] for unit_ind, unit_id in enumerate(unit_ids): - chan_ind = main_channels[unit_ind] - template = main_channel_templates[chan_ind] + template = main_channel_templates[unit_ind] if peak_sign == "both": peak_pos = np.argmax(np.abs(template)) elif peak_sign == "neg": @@ -513,6 +512,7 @@ def get_template_main_channel_amplitude( unit_amplitudes.append(extremum_amplitudes[unit_id][chan_ind]) if with_dict: - unit_amplitudes = dict(zip()) + unit_amplitudes_dict = dict(zip(unit_ids, unit_amplitudes)) + return unit_amplitudes_dict return unit_amplitudes diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 9df38ef25c..394d02f435 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -173,19 +173,14 @@ def compute_snrs( noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() - channel_ids = sorting_analyzer.channel_ids - - main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) - unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer, with_dict=True) - - # make a dict to access by chan_id - noise_levels = dict(zip(channel_ids, noise_levels)) + main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=False) + unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer, with_dict=False) snrs = {} - for unit_id in unit_ids: - chan_ind = main_channel_index[unit_id] + for unit_index, unit_id in enumerate(unit_ids): + chan_ind = main_channel_index[unit_index] noise = noise_levels[chan_ind] - amplitude = unit_amplitudes[unit_id] + amplitude = unit_amplitudes[unit_index] snrs[unit_id] = np.abs(amplitude) / noise return snrs @@ -194,7 +189,7 @@ def compute_snrs( class SNR(BaseMetric): metric_name = "snr" metric_function = compute_snrs - metric_params = {"peak_sign": "both", "peak_mode": "extremum"} + metric_params = {} metric_columns = {"snr": float} metric_descriptions = {"snr": "Signal to noise ratio for each unit."} depend_on = ["noise_levels", "templates"] diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index 16650a09ec..df80a2a34c 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -116,7 +116,7 @@ def get_dataset_to_split(): analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) analyzer_raw.compute(["random_spikes", "templates"]) # select 3 largest templates to split - sort_by_amp = np.argsort(list(get_template_main_channel_amplitude(analyzer_raw).values()))[::-1] + sort_by_amp = np.argsort(get_template_main_channel_amplitude(analyzer_raw, with_dict=False))[::-1] large_units = sorting.unit_ids[sort_by_amp][:2] return recording, sorting, large_units From 978d4908125fdb076569517b465a040edf94078e Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 24 Jun 2026 11:29:12 +0100 Subject: [PATCH 13/20] more bug fixes --- src/spikeinterface/core/template_tools.py | 20 +++++++++++-------- .../curation/remove_redundant.py | 6 +----- .../metrics/quality/misc_metrics.py | 3 ++- .../quality/tests/test_metrics_functions.py | 15 +++++++++----- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index e3ec916685..3d39a4abb2 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -98,6 +98,7 @@ def get_template_amplitudes( # return_in_uV: bool = True, abs_value: bool = True, operator: str = "average", + override_peak_error: bool = False, ): """ Get amplitude per channel for each unit. @@ -128,14 +129,17 @@ def get_template_amplitudes( """ if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - assert ( - peak_sign is None - ), "get_template_amplitudes() peak_sign is now contained in SortingAnalyzer, should be None here" - assert ( - peak_mode is None - ), "get_template_amplitudes() peak_mode is now contained in SortingAnalyzer, should be None here" - peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign - peak_mode = templates_or_sorting_analyzer.main_channel_peak_mode + if not override_peak_error: + assert ( + peak_sign is None + ), "get_template_amplitudes() peak_sign is now contained in SortingAnalyzer, should be None here" + assert ( + peak_mode is None + ), "get_template_amplitudes() peak_mode is now contained in SortingAnalyzer, should be None here" + if peak_sign is None: + peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign + if peak_mode is None: + peak_mode = templates_or_sorting_analyzer.main_channel_peak_mode return_in_uV = templates_or_sorting_analyzer.return_in_uV elif isinstance(templates_or_sorting_analyzer, Templates): return_in_uV = templates_or_sorting_analyzer.is_in_uV diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index 9675335848..f998bcb8d4 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -17,7 +17,6 @@ def remove_redundant_units( agreement_threshold: float = 0.2, duplicate_threshold: float = 0.8, remove_strategy: str = "minimum_shift", - peak_sign: str = "neg", extra_outputs: bool = False, ) -> BaseSorting | tuple[BaseSorting, list[tuple[int, int]]]: """ @@ -52,9 +51,6 @@ def remove_redundant_units( If shifts are equal then the "highest_amplitude" is used * "highest_amplitude" : keep the unit with the best amplitude on unshifted max. * "max_spikes" : keep the unit with more spikes - - peak_sign : "neg" | "pos" | "both", default: "neg" - Used when remove_strategy="highest_amplitude" extra_outputs : bool, default: False If True, will return the redundant pairs. unit_peak_shifts : dict @@ -98,7 +94,7 @@ def remove_redundant_units( if remove_strategy in ("minimum_shift", "highest_amplitude"): # this is the values at spike index ! - peak_values = get_template_amplitudes(sorting_analyzer, peak_sign=peak_sign, mode="at_index") + peak_values = get_template_amplitudes(sorting_analyzer, peak_mode="at_index", override_peak_error=True) peak_values = {unit_id: np.max(np.abs(values)) for unit_id, values in peak_values.items()} if remove_strategy == "minimum_shift": diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 394d02f435..e451a39f4d 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -177,7 +177,8 @@ def compute_snrs( unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer, with_dict=False) snrs = {} - for unit_index, unit_id in enumerate(unit_ids): + for unit_id in unit_ids: + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) chan_ind = main_channel_index[unit_index] noise = noise_levels[chan_ind] amplitude = unit_amplitudes[unit_index] diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index e267b176ce..8a288d21fe 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -639,15 +639,20 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): ) # check that, when parameters are changed, the data and metadata are updated - old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + old_presence_ratio_data = deepcopy(quality_metric_extension.get_data()["presence_ratio"].values) small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + { + "quality_metrics": { + "metric_names": ["presence_ratio"], + "metric_params": {"presence_ratio": {"bin_duration_s": "10"}}, + } + } ) new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") - new_snr_data = new_quality_metric_extension.get_data()["snr"].values + new_presence_ratio_data = new_quality_metric_extension.get_data()["presence_ratio"].values - assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" + assert np.all(old_presence_ratio_data != new_presence_ratio_data) + assert new_quality_metric_extension.params["metric_params"]["presence_ratio"]["bin_duration_s"] == "10" def test_metric_names_in_same_order(small_sorting_analyzer): From c09f8cbe6aaf54916fac0d3d09235b6aba072ee4 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 24 Jun 2026 11:55:26 +0100 Subject: [PATCH 14/20] deprecation warnings and return_to_uv None in template functions --- src/spikeinterface/core/template_tools.py | 47 +++++++++++++---------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 3d39a4abb2..13c90d5332 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -2,6 +2,7 @@ import numpy as np +from spikeinterface.core.core_tools import ms_to_samples from .template import Templates from .waveform_tools import estimate_templates_with_accumulator from .sorting_tools import random_spikes_selection @@ -11,7 +12,7 @@ def get_dense_templates_array( - one_object: Templates | SortingAnalyzer, return_in_uV: bool = True, operator="average" + one_object: Templates | SortingAnalyzer, return_in_uV: None | bool = None, operator="average" ) -> np.ndarray: """ Return dense templates as numpy array from either a Templates object or a SortingAnalyzer. @@ -32,13 +33,13 @@ def get_dense_templates_array( The dense templates (num_units, num_samples, num_channels) """ if isinstance(one_object, Templates): - if return_in_uV != one_object.is_in_uV: + if return_in_uV is not None and return_in_uV != one_object.is_in_uV: raise ValueError( f"get_dense_templates_array: return_in_uV={return_in_uV} is not possible Templates has the reverse" ) templates_array = one_object.get_dense_templates() elif isinstance(one_object, SortingAnalyzer): - if return_in_uV != one_object.return_in_uV: + if return_in_uV is not None and return_in_uV != one_object.return_in_uV: raise ValueError( f"get_dense_templates_array: return_in_uV={return_in_uV} is not possible SortingAnalyzer has the reverse" ) @@ -52,9 +53,9 @@ def get_dense_templates_array( return templates_array -def get_main_channel_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: bool = True): +def get_main_channel_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: None | bool = None): """ - Return dense templates as numpy array from either a Templates object or a SortingAnalyzer. + Return template on main channel for each unit, as numpy array from either a Templates object or a SortingAnalyzer. Parameters ---------- @@ -95,13 +96,12 @@ def get_template_amplitudes( templates_or_sorting_analyzer, peak_sign: None | Literal["neg", "pos", "both"] = None, peak_mode: None | Literal["extremum", "at_index", "peak_to_peak"] = None, - # return_in_uV: bool = True, abs_value: bool = True, operator: str = "average", override_peak_error: bool = False, ): """ - Get amplitude per channel for each unit. + Get amplitude on every channel for each unit. Parameters ---------- @@ -193,6 +193,10 @@ def get_template_amplitudes( def _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore): + """ + Get the main channel for each template in a `templates_array`, which has dimensions + (num_units) x (num time samples) x (num channels) + """ # Step1 : max on time axis if peak_mode == "extremum": if peak_sign == "both": @@ -237,8 +241,9 @@ def estimate_main_channel_from_recording( "should revert the traces instead" ) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + sampling_frequency = recording.sampling_frequency + nbefore = ms_to_samples(ms_before, sampling_frequency) + nafter = ms_to_samples(ms_after, sampling_frequency) num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] random_spikes_indices = random_spikes_selection( @@ -268,7 +273,7 @@ def estimate_main_channel_from_recording( return main_channel_index -# TODO remove this in 0.105.0 +# TODO remove this in 0.106.0 def get_template_extremum_channel( templates_or_sorting_analyzer, peak_sign: Literal["neg", "pos", "both"] = "neg", @@ -277,7 +282,7 @@ def get_template_extremum_channel( operator: str = "average", ): """ - Depracted will be removed in 0.105.0. + Deprecated will be removed in 0.106.0. Use analyzer.get_main_channels() or tempates.get_main_channels(peak_sign=...) instead. @@ -307,7 +312,10 @@ def get_template_extremum_channel( Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ - warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channels() instead") + warnings.warn( + "get_template_extremum_channel() is deprecated and will be removed in version 0.106.0 use analyzer.get_main_channels() instead", + category=DeprecationWarning, + ) if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): assert ( @@ -325,10 +333,10 @@ def get_template_extremum_channel( return main_channels -# TODO remove this in 0.105.0 +# TODO remove this in 0.106.0 def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None): """ - Depracted will be removed in 0.105.0. + Depracted will be removed in 0.106.0. Use get_template_main_channel_peak_shift() instead. In some situations spike sorters could return a spike index with a small shift related to the waveform peak. @@ -350,9 +358,9 @@ def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak warnings.warn( "get_template_extremum_channel_peak_shift() is deprecated use get_template_main_channel_peak_shift() instead" - "Will be removed in 0.105.0" + "Will be removed in 0.106.0", + category=DeprecationWarning, ) - return get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None, with_dict=True) @@ -381,14 +389,10 @@ def get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sig if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): assert peak_sign is None peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign - main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=False) elif isinstance(templates_or_sorting_analyzer, Templates): if peak_sign is None: warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") peak_sign = "both" - main_channels = templates_or_sorting_analyzer.get_main_channels( - outputs="index", peak_sign=peak_sign, with_dict=False - ) unit_ids = templates_or_sorting_analyzer.unit_ids nbefore = _get_nbefore(templates_or_sorting_analyzer) @@ -455,7 +459,8 @@ def get_template_extremum_amplitude( warnings.warn( "get_template_extremum_amplitude() is deprecated use get_template_main_channel_amplitude() instead" - "Will be removed in 0.105.0" + "Will be removed in 0.106.0", + category=DeprecationWarning, ) return get_template_main_channel_amplitude( templates_or_sorting_analyzer, From 2df4187212eab8ab785fe02ff400e91ae9e9490d Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 24 Jun 2026 12:08:49 +0100 Subject: [PATCH 15/20] change main_channel_peak_sign to peak_sign --- src/spikeinterface/sortingcomponents/matching/nearest.py | 4 +--- src/spikeinterface/sortingcomponents/matching/tdc_peeler.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 9d847ec1f5..a6eb7adf29 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,9 +53,7 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - main_channels = self.templates.get_main_channels( - main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False - ) + main_channels = self.templates.get_main_channels(peak_sign=self.peak_sign, outputs="index", with_dict=False) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index f8276e62ef..f96ee035e9 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -222,9 +222,7 @@ def __init__( self.dtype = self.sparse_templates_array_static.dtype # as numpy vector - self.main_channels = templates.get_main_channels( - main_channel_peak_sign=peak_sign, outputs="index", with_dict=False - ) + self.main_channels = templates.get_main_channels(peak_sign=peak_sign, outputs="index", with_dict=False) channel_locations = templates.probe.contact_positions unit_locations = channel_locations[self.main_channels] From 94b4372164fad947fa982fa6ed8b731c727dd808 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 24 Jun 2026 12:53:12 +0100 Subject: [PATCH 16/20] good grouping for make_sorting_analyzer in tests --- src/spikeinterface/exporters/tests/common.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index d86df0c6c8..c9c0879331 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -23,14 +23,21 @@ def make_sorting_analyzer(sparse=True, with_group=False): ) if with_group: + + # this is a bit painful. We need to figure out which units belong to which group + initial_sorting_analyzer = create_sorting_analyzer( + sorting=sorting, recording=recording, format="memory", sparse=True + ) + main_channels = initial_sorting_analyzer.main_channel_index + sorting_groups = [0 if main_channel < 4 else 1 for main_channel in main_channels] recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1]) - sorting.set_property("group", [0, 0, 1, 1]) + sorting.set_property("group", sorting_groups) + # Now we know that, we can make an analyzer with a working sparsity_group sorting_analyzer_unused = create_sorting_analyzer( sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=None ) sparsity_group = compute_sparsity(sorting_analyzer_unused, method="by_property", by_property="group") - sorting_analyzer = create_sorting_analyzer( sorting=sorting, recording=recording, format="memory", sparse=False, sparsity=sparsity_group ) From b61bebfd7ccf3a2d1d73cca7c36436eda10fec9e Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 25 Jun 2026 17:15:29 +0100 Subject: [PATCH 17/20] respond to Erick --- src/spikeinterface/core/sparsity.py | 35 +++++-------------- src/spikeinterface/core/template_tools.py | 17 +++++++-- .../metrics/quality/misc_metrics.py | 14 ++++---- .../metrics/template/template_metrics.py | 1 - .../postprocessing/localization_tools.py | 12 +++---- .../postprocessing/spike_amplitudes.py | 4 --- 6 files changed, 35 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 457717ad6c..d5129e8c6d 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -644,12 +644,10 @@ def compute_sparsity( method: Literal[ "radius", "best_channels", "closest_channels", "snr", "amplitude", "energy", "by_property" ] = "radius", - peak_sign: Literal["neg", "pos", "both"] = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, - amplitude_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", ) -> ChannelSparsity: """ Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. @@ -739,7 +737,7 @@ def estimate_sparsity( ms_before: float = 1.0, ms_after: float = 2.5, method: Literal["radius", "best_channels", "closest_channels", "amplitude", "snr", "by_property"] = "radius", - peak_sign: Literal["neg", "pos", "both"] = "neg", + peak_sign: Literal["neg", "pos", "both"] = "both", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, @@ -789,6 +787,7 @@ def estimate_sparsity( """ # Can't be done at module because this is a cyclic import, too bad from .template import Templates + from .template_tools import get_templates_array_from_recording_and_sorting assert method in ("radius", "best_channels", "closest_channels", "snr", "amplitude", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " @@ -812,31 +811,13 @@ def estimate_sparsity( ) elif method != "by_property": - nbefore = ms_to_samples(ms_before, recording.sampling_frequency) - nafter = ms_to_samples(ms_after, recording.sampling_frequency) - - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - random_spikes_indices = random_spikes_selection( - sorting, - num_samples, - method="uniform", - max_spikes_per_unit=num_spikes_for_sparsity, - margin_size=max(nbefore, nafter), - seed=2205, - ) - spikes = sorting.to_spike_vector() - spikes = spikes[random_spikes_indices] - - templates_array = estimate_templates_with_accumulator( - recording, - spikes, - sorting.unit_ids, - nbefore, - nafter, - return_in_uV=False, - job_name="estimate_sparsity", - **job_kwargs, + + templates_array = get_templates_array_from_recording_and_sorting( + recording, sorting, ms_before, ms_after, num_spikes_for_sparsity, 2205, **job_kwargs ) + + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) + templates = Templates( templates_array=templates_array, sampling_frequency=recording.sampling_frequency, diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 13c90d5332..c8e6d3bd2e 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -241,6 +241,19 @@ def estimate_main_channel_from_recording( "should revert the traces instead" ) + templates_array = get_templates_array_from_recording_and_sorting( + recording, sorting, ms_before, ms_after, num_spikes_for_main_channel, seed + ) + + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore) + + return main_channel_index + + +def get_templates_array_from_recording_and_sorting( + recording, sorting, ms_before, ms_after, num_spikes_for_main_channel, seed, **job_kwargs +): + sampling_frequency = recording.sampling_frequency nbefore = ms_to_samples(ms_before, sampling_frequency) nafter = ms_to_samples(ms_after, sampling_frequency) @@ -268,9 +281,7 @@ def estimate_main_channel_from_recording( **job_kwargs, ) - main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore) - - return main_channel_index + return templates_array # TODO remove this in 0.106.0 diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index e451a39f4d..e633b8429a 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -152,12 +152,14 @@ def compute_snrs( A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the SNR. If None, all units are used. - peak_sign : "neg" | "pos" | "both", default: "neg" - The sign of the template to compute best channels. - peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" - How to compute the amplitude. - Extremum takes the maxima/minima - At_index takes the value at t=sorting_analyzer.nbefore. + peak_sign_for_signal : "neg" | "pos" | "both", default: "neg" + Which extremum of the template to use for the "signal". Use either the minimum ("neg"), + maximum ("pos") or absmax ("both") + peak_mode_for_signal : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + How to compute the amplitude: + "extremum" takes the maxima/minima + "at_index" takes the value at t=sorting_analyzer.nbefore. + "peak_to_peak" takes the difference between the maximum and minimum operator : "median" | "average", default: "median" The operator to apply to retrieve templates and amplitudes. diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 925852983a..37f00d4376 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -238,7 +238,6 @@ def _prepare_data(self, sorting_analyzer, unit_ids): if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] min_thresh_detect_peaks_troughs = self.params.get("min_thresh_detect_peaks_troughs", 0.3) edge_exclusion_ms = self.params.get("edge_exclusion_ms", 0.09) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index fb5c6764fb..4d07eb8a5b 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -162,9 +162,7 @@ def compute_center_of_mass( assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" if sorting_analyzer_or_templates.sparsity is None: - sparsity = compute_sparsity( - sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um - ) + sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) else: sparsity = sorting_analyzer_or_templates.sparsity @@ -206,7 +204,7 @@ def compute_center_of_mass( def compute_grid_convolution( sorting_analyzer_or_templates: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: str = "neg", + invert_prototype_waveform: bool = True, radius_um: float = 40.0, upsampling_um: float = 5, sigma_ms: float = 0.25, @@ -224,8 +222,8 @@ def compute_grid_convolution( A SortingAnalyzer or Templates object unit_ids: str | int | None A list of unit_id to restrci the computation - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute best channels + invert_prototype_waveform: bool, default: True + If True, will multiply the prototype waveform by -1 radius_um : float, default: 40.0 Radius to consider for the fake templates upsampling_um : float, default: 5 @@ -270,7 +268,7 @@ def compute_grid_convolution( time_axis = np.arange(-nbefore, nafter) * 1000 / fs if prototype is None: prototype = np.exp(-(time_axis**2) / (2 * (sigma_ms**2))) - if peak_sign == "neg": + if invert_prototype_waveform: prototype *= -1 prototype = prototype[:, np.newaxis] diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 75c794298a..eed0314816 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -12,10 +12,6 @@ class ComputeSpikeAmplitudes(BaseSpikeVectorExtension): Needs "templates" to be computed first. Computes spike amplitudes from the template's peak channel for every spike. - - Parameters - ---------- - """ extension_name = "spike_amplitudes" From a39155808fbdcccf0c0e2accde21e4441032d804 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 25 Jun 2026 17:34:00 +0100 Subject: [PATCH 18/20] oups --- src/spikeinterface/core/template_tools.py | 1 + src/spikeinterface/core/tests/test_sparsity.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index c8e6d3bd2e..bed53139ed 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -244,6 +244,7 @@ def estimate_main_channel_from_recording( templates_array = get_templates_array_from_recording_and_sorting( recording, sorting, ms_before, ms_after, num_spikes_for_main_channel, seed ) + nbefore = ms_to_samples(ms_before, recording.sampling_frequency) main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 90ee3867c1..372cc2094d 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -314,10 +314,10 @@ def test_compute_sparsity(): # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() - sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") + sparsity = compute_sparsity(templates, method="best_channels", num_channels=2) sparsity = compute_sparsity(templates, method="radius", radius_um=50.0) - sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") - sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5) + sparsity = compute_sparsity(templates, method="amplitude", threshold=5) sparsity = compute_sparsity(templates, method="closest_channels", num_channels=2) From 61b790b70d6c0dd7bab3b52e19fcc56aac06f33a Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 25 Jun 2026 17:56:29 +0100 Subject: [PATCH 19/20] Remove peak_sign and peak_mode from compute_sparsity (were not being used) --- src/spikeinterface/benchmark/tests/test_benchmark_matching.py | 2 +- src/spikeinterface/metrics/quality/pca_metrics.py | 4 ++-- src/spikeinterface/sortingcomponents/matching/nearest.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py index 22b519ec78..a529141323 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py @@ -28,7 +28,7 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_in_uV=False, **job_kwargs ) noise_levels = get_noise_levels(recording, **job_kwargs) - sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25) + sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/metrics/quality/pca_metrics.py b/src/spikeinterface/metrics/quality/pca_metrics.py index bede198016..26bed55ac5 100644 --- a/src/spikeinterface/metrics/quality/pca_metrics.py +++ b/src/spikeinterface/metrics/quality/pca_metrics.py @@ -653,7 +653,7 @@ def nearest_neighbors_isolation( if sorting_analyzer.is_sparse(): sparsity = sorting_analyzer.sparsity else: - sparsity = compute_sparsity(sorting_analyzer, method="radius", peak_sign=peak_sign, radius_um=radius_um) + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=radius_um) closest_chans_target_unit = sparsity.unit_id_to_channel_indices[this_unit_id] n_channels_target_unit = len(closest_chans_target_unit) # select other units that have a minimum spatial overlap with target unit @@ -867,7 +867,7 @@ def nearest_neighbors_noise_overlap( if sorting_analyzer.is_sparse(): sparsity = sorting_analyzer.sparsity else: - sparsity = compute_sparsity(sorting_analyzer, method="radius", peak_sign=peak_sign, radius_um=radius_um) + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=radius_um) noise_cluster = noise_cluster[:, :, sparsity.unit_id_to_channel_indices[this_unit_id]] # compute weighted noise snippet (Z) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index a6eb7adf29..f3b29fa541 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -68,7 +68,9 @@ def __init__( from spikeinterface.core.sparsity import compute_sparsity sparsity = compute_sparsity( - templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign + templates, + method="radius", + radius_um=sparsity_radius_um, ) else: sparsity = templates.sparsity From ec6fac310e4d7a7ec38d14c255530993039e60b5 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 25 Jun 2026 18:08:42 +0100 Subject: [PATCH 20/20] Remove more peak_amplitude from compute_sparsity --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 - src/spikeinterface/sortingcomponents/tools.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 85e0da7220..f164facf1f 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -376,7 +376,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sparsity_snr = compute_sparsity( dense_templates, method="snr", - amplitude_mode="peak_to_peak", noise_levels=noise_levels, threshold=threshold, ) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index b8d430163e..26a28b67f1 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -602,7 +602,6 @@ def clean_templates( sparsity = compute_sparsity( templates, method="snr", - amplitude_mode="peak_to_peak", noise_levels=noise_levels, threshold=sparsify_threshold, ) @@ -637,7 +636,6 @@ def clean_templates( sparsity = compute_sparsity( templates.to_dense(), method="snr", - amplitude_mode="peak_to_peak", noise_levels=noise_levels, threshold=min_snr, )