From 31f177ce386128801fcdc3c54cd4413d870a5ef7 Mon Sep 17 00:00:00 2001 From: Ian <53369610+ueeseer@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:40:26 +0000 Subject: [PATCH 1/3] Fix closest_channels sparsity when n_units exceed channels --- src/spikeinterface/core/sparsity.py | 6 ++++-- src/spikeinterface/core/tests/test_sparsity.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index cda52df1e3..7cccac2a8b 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -348,16 +348,18 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): sparsity : ChannelSparsity The estimated sparsity """ - from .template_tools import get_template_amplitudes + from .template_tools import get_template_extremum_channel mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) channel_locations = templates_or_sorting_analyzer.get_channel_locations() distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) + best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign="neg", outputs="index") for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_inds = np.argsort(distances[unit_ind]) + chan_ind = best_chan[unit_id] + chan_inds = np.argsort(distances[chan_ind]) chan_inds = chan_inds[:num_channels] mask[unit_ind, chan_inds] = True return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index c865068e4a..f71b64c359 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -210,6 +210,23 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + # regression for #4126: closest_channels must work even when there are more units than channels + recording_small, sorting_many_units = generate_ground_truth_recording(num_channels=4, num_units=6, durations=[10], seed=0) + sparsity = estimate_sparsity( + sorting_many_units, + recording_small, + num_spikes_for_sparsity=20, + ms_before=1.0, + ms_after=2.0, + method="closest_channels", + num_channels=2, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + assert sparsity.mask.shape == (sorting_many_units.unit_ids.size, recording_small.channel_ids.size) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(sorting_many_units.unit_ids.size) * 2) + # by_property sparsity = estimate_sparsity( sorting, From e22a2c5e3f56868f3d312791457094722750e489 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Mar 2026 03:11:57 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_sparsity.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index f71b64c359..6e85221621 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -211,7 +211,9 @@ def test_estimate_sparsity(): assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) # regression for #4126: closest_channels must work even when there are more units than channels - recording_small, sorting_many_units = generate_ground_truth_recording(num_channels=4, num_units=6, durations=[10], seed=0) + recording_small, sorting_many_units = generate_ground_truth_recording( + num_channels=4, num_units=6, durations=[10], seed=0 + ) sparsity = estimate_sparsity( sorting_many_units, recording_small, From 94e48929797dbbc8fd3f0925e9e69f9f177f024a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Mar 2026 10:34:49 +0100 Subject: [PATCH 3/3] fix: expose peak_sign in from_closest_channels --- src/spikeinterface/core/sparsity.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 7cccac2a8b..91eb7df864 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -331,7 +331,7 @@ def from_best_channels( ## Some convinient function to compute sparsity from several strategy @classmethod - def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): + def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): """ Construct sparsity from N closest channels Use the "num_channels" argument to specify the number of channels. @@ -342,6 +342,8 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): A Templates or a SortingAnalyzer object. num_channels : int Number of channels for "best_channels" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. Returns ------- @@ -355,13 +357,13 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): ) channel_locations = templates_or_sorting_analyzer.get_channel_locations() distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign="neg", outputs="index") + best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, outputs="index") for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_ind = best_chan[unit_id] chan_inds = np.argsort(distances[chan_ind]) - chan_inds = chan_inds[:num_channels] - mask[unit_ind, chan_inds] = True + closest_chan_inds = chan_inds[:num_channels] + mask[unit_ind, closest_chan_inds] = True return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod