diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index cda52df1e3..91eb7df864 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -331,7 +331,7 @@ def from_best_channels( ## Some convinient function to compute sparsity from several strategy @classmethod - def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): + def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): """ Construct sparsity from N closest channels Use the "num_channels" argument to specify the number of channels. @@ -342,24 +342,28 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels): A Templates or a SortingAnalyzer object. num_channels : int Number of channels for "best_channels" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. Returns ------- sparsity : ChannelSparsity The estimated sparsity """ - from .template_tools import get_template_amplitudes + from .template_tools import get_template_extremum_channel mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) channel_locations = templates_or_sorting_analyzer.get_channel_locations() distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) + best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, outputs="index") for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_inds = np.argsort(distances[unit_ind]) - chan_inds = chan_inds[:num_channels] - mask[unit_ind, chan_inds] = True + chan_ind = best_chan[unit_id] + chan_inds = np.argsort(distances[chan_ind]) + closest_chan_inds = chan_inds[:num_channels] + mask[unit_ind, closest_chan_inds] = True return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index c865068e4a..6e85221621 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -210,6 +210,25 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + # regression for #4126: closest_channels must work even when there are more units than channels + recording_small, sorting_many_units = generate_ground_truth_recording( + num_channels=4, num_units=6, durations=[10], seed=0 + ) + sparsity = estimate_sparsity( + sorting_many_units, + recording_small, + num_spikes_for_sparsity=20, + ms_before=1.0, + ms_after=2.0, + method="closest_channels", + num_channels=2, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + assert sparsity.mask.shape == (sorting_many_units.unit_ids.size, recording_small.channel_ids.size) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(sorting_many_units.unit_ids.size) * 2) + # by_property sparsity = estimate_sparsity( sorting,