Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 60 additions & 55 deletions src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ class ComputeCorrelograms(AnalyzerExtension):
bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ...
method : "auto" | "numpy" | "numba", default: "auto"
If "auto" and numba is installed, numba is used, otherwise numpy is used.
fast_mode : "auto" | "on" | "off", default: "auto"
If "auto", a faster multithreaded implementations is used if method is "numba" and
if the number of units is greater than 300.
n_jobs : int | float, default: 1.0
Number of workers that will be requested during multiprocessing. Note that
the OS determines how this is distributed, but for convenience one can use
* -1 the number of workers is the same as number of cores (from os.cpu_count())
* float between 0 and 1 uses fraction of total cores (from os.cpu_count())

Returns
-------
Expand Down Expand Up @@ -90,8 +92,10 @@ class ComputeCorrelograms(AnalyzerExtension):
use_nodepipeline = False
need_job_kwargs = False

def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"):
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode)
def _set_params(
self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", n_jobs: int | float = 1.0
):
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs)

return params

Expand Down Expand Up @@ -223,9 +227,11 @@ class ComputeAutoCorrelograms(AnalyzerExtension):
bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ...
method : "auto" | "numpy" | "numba", default: "auto"
If "auto" and numba is installed, numba is used, otherwise numpy is used.
fast_mode : "auto" | "off" | "on", default: "auto"
If "auto", a faster multithreaded implementations is used if method is "numba" and
if the number of units is greater than 300.
n_jobs : int | float, default: 1.0
Number of workers that will be requested during multiprocessing. Note that
the OS determines how this is distributed, but for convenience one can use
* -1 the number of workers is the same as number of cores (from os.cpu_count())
* float between 0 and 1 uses fraction of total cores (from os.cpu_count())

Returns
-------
Expand Down Expand Up @@ -263,8 +269,10 @@ class ComputeAutoCorrelograms(AnalyzerExtension):
use_nodepipeline = False
need_job_kwargs = False

def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"):
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode)
def _set_params(
self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", n_jobs: int | float = 1.0
):
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs)
return params

def _select_extension_data(self, unit_ids):
Expand Down Expand Up @@ -341,7 +349,7 @@ def compute_correlograms(
window_ms: float = 50.0,
bin_ms: float = 1.0,
method: str = "auto",
fast_mode: str = "auto",
n_jobs: int | float = 1.0,
):
"""
Compute correlograms using Numba or Numpy.
Expand All @@ -352,11 +360,11 @@ def compute_correlograms(

if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
return compute_correlograms_sorting_analyzer(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs
)
else:
return _compute_correlograms_on_sorting(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs
)


Expand Down Expand Up @@ -422,7 +430,7 @@ def _compute_num_bins(window_size, bin_size):
return num_bins, num_half_bins


def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode="auto"):
def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", n_jobs=1.0):
"""
Computes cross-correlograms from multiple units.

Expand All @@ -441,9 +449,11 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto",
method : str
To use "numpy" or "numba". "auto" will use numba if available,
otherwise numpy.
fast_mode : "auto" | "on" | "off", default: "auto"
If "auto", a faster multithreaded implementations is used if method is "numba" and
if the number of units is greater than 300.
n_jobs : int | float, default: 1.0
Number of workers that will be requested during multiprocessing. Note that
the OS determines how this is distributed, but for convenience one can use
* -1 the number of workers is the same as number of cores (from os.cpu_count())
* float between 0 and 1 uses fraction of total cores (from os.cpu_count())

Returns
-------
Expand All @@ -459,20 +469,12 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto",
if method == "auto":
method = "numba" if HAVE_NUMBA else "numpy"

if method == "numba" and fast_mode == "auto":
num_units = len(sorting.unit_ids)
fast_mode = num_units > 300
elif fast_mode == "off":
fast_mode = False
elif fast_mode == "on":
fast_mode = True

bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms)

if method == "numpy":
correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size)
if method == "numba":
correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode)
correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, n_jobs=n_jobs)

return correlograms, bins

Expand Down Expand Up @@ -617,7 +619,7 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi
return correlograms


def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode):
def _compute_correlograms_numba(sorting, window_size, bin_size, n_jobs=1.0):
"""
Computes cross-correlograms between all units in `sorting`.

Expand All @@ -633,9 +635,11 @@ def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode):
The window size over which to perform the cross-correlation, in samples
bin_size : int
The size of which to bin lags, in samples.
fast_mode : bool
If True, use faster implementations (currently only if method is 'numba'),
at the cost of possible minor numerical differences.
n_jobs : int | float, default: 1.0
Number of workers that will be requested during multiprocessing. Note that
the OS determines how this is distributed, but for convenience one can use
* -1 the number of workers is the same as number of cores (from os.cpu_count())
* float between 0 and 1 uses fraction of total cores (from os.cpu_count())

Returns
-------
Expand All @@ -653,10 +657,12 @@ def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode):
spikes = sorting.to_spike_vector(concatenated=False)
correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64)

if fast_mode:
num_threads = mp.cpu_count()
if isinstance(n_jobs, float) and 0 < n_jobs <= 1:
import os

num_threads = int(n_jobs * os.cpu_count())
else:
num_threads = 1
num_threads = n_jobs

for seg_index in range(sorting.get_num_segments()):
spike_times = spikes[seg_index]["sample_index"]
Expand Down Expand Up @@ -829,7 +835,7 @@ def _compute_auto_correlograms_one_segment_numba(


def compute_auto_correlograms(
sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode="auto"
sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", n_jobs=1.0
):
"""
Compute correlograms using Numba or Numpy.
Expand All @@ -840,15 +846,15 @@ def compute_auto_correlograms(

if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
return compute_auto_correlograms_sorting_analyzer(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs
)
else:
return _compute_auto_correlograms_on_sorting(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs
)


def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode=False):
def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", n_jobs=1.0):
"""
Computes auto-correlograms from multiple units.

Expand All @@ -867,9 +873,11 @@ def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="au
method : str
To use "numpy" or "numba". "auto" will use numba if available,
otherwise numpy.
fast_mode : "auto" | "off" | "on", default: "auto"
If "auto", a faster multithreaded implementations is used if method is "numba" and
if the number of units is greater than 300.
n_jobs : int | float, default: 1.0
Number of workers that will be requested during multiprocessing. Note that
the OS determines how this is distributed, but for convenience one can use
* -1 the number of workers is the same as number of cores (from os.cpu_count())
* float between 0 and 1 uses fraction of total cores (from os.cpu_count())

Returns
-------
Expand All @@ -884,20 +892,13 @@ def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="au

if method == "auto":
method = "numba" if HAVE_NUMBA else "numpy"
if method == "numba" and fast_mode == "auto":
num_units = len(sorting.unit_ids)
fast_mode = num_units > 300
elif fast_mode == "off":
fast_mode = False
elif fast_mode == "on":
fast_mode = True

bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms)

if method == "numpy":
correlograms = _compute_auto_correlograms_numpy(sorting, window_size, bin_size)
if method == "numba":
correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode)
correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size, n_jobs=n_jobs)

return correlograms, bins

Expand Down Expand Up @@ -1036,7 +1037,7 @@ def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_siz
return correlograms


def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=False):
def _compute_auto_correlograms_numba(sorting, window_size, bin_size, n_jobs=1.0):
"""
Computes auto-correlograms between all units in `sorting`.

Expand All @@ -1052,9 +1053,11 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F
The window size over which to perform the cross-correlation, in samples
bin_size : int
The size of which to bin lags, in samples.
fast_mode : bool
If True, use faster implementations (currently only if method is 'numba'),
at the cost of possible minor numerical differences.
n_jobs : int | float, default: 1.0
Number of workers that will be requested during multiprocessing. Note that
the OS determines how this is distributed, but for convenience one can use
* -1 the number of workers is the same as number of cores (from os.cpu_count())
* float between 0 and 1 uses fraction of total cores (from os.cpu_count())

Returns
-------
Expand All @@ -1072,10 +1075,12 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F
spikes = sorting.to_spike_vector(concatenated=False)
correlograms = np.zeros((num_units, num_bins), dtype=np.int64)

if fast_mode:
num_threads = mp.cpu_count()
if isinstance(n_jobs, float) and 0 < n_jobs <= 1:
import os

num_threads = int(n_jobs * os.cpu_count())
else:
num_threads = 1
num_threads = n_jobs

for seg_index in range(sorting.get_num_segments()):
spike_times = spikes[seg_index]["sample_index"]
Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/postprocessing/tests/test_correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,14 @@ def test_equal_results_fast_correlograms(window_and_bin_ms):
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

result_numba_fast, bins_numba_fast = _compute_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", n_jobs=2
)
result_numba, bins_numba = _compute_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False
sorting,
window_ms=window_ms,
bin_ms=bin_ms,
method="numba",
n_jobs=1,
)
from numpy.testing import assert_almost_equal

Expand All @@ -169,10 +173,10 @@ def test_equal_results_fast_auto_correlograms(window_and_bin_ms):
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

result_numba_fast, bins_numba_fast = _compute_auto_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", n_jobs=2
)
result_numba, bins_numba = _compute_auto_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", n_jobs=1
)
from numpy.testing import assert_almost_equal

Expand Down
4 changes: 0 additions & 4 deletions src/spikeinterface/sorters/internal/lupin.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
job_kwargs=job_kwargs,
)

if more_outs["time_shifts"] is not None:
time_shifts = more_outs["time_shifts"]
peaks["sample_index"] += time_shifts

mask = clustering_label >= 0
kept_peaks = peaks[mask]
kept_labels = clustering_label[mask]
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids)

auto_merge = True

analyzer_final = None
if auto_merge:
from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,13 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
print("Kept %d raw clusters" % len(labels))

if params["merge_from_templates"] is not None:
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids, time_shifts = (
merge_peak_labels_from_templates(
peaks,
peak_labels,
templates.unit_ids,
templates.templates_array,
new_sparse_mask,
**params["merge_from_templates"],
)
peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates(
peaks,
peak_labels,
templates.unit_ids,
templates.templates_array,
new_sparse_mask,
**params["merge_from_templates"],
)

templates = Templates(
Expand All @@ -185,8 +183,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
probe=recording.get_probe(),
is_in_uV=False,
)
else:
time_shifts = None

# clean very small cluster before peeler
if (
Expand Down Expand Up @@ -214,7 +210,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
more_outs = dict(
svd_model=svd_model,
peaks_svd=peaks_svd,
time_shifts=time_shifts,
peak_svd_sparse_mask=sparse_mask,
)
return labels, peak_labels, more_outs
Original file line number Diff line number Diff line change
Expand Up @@ -299,19 +299,16 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
num_shifts = params_merge_from_templates["num_shifts"]
num_shifts = min((num_shifts, nbefore, nafter))
params_merge_from_templates["num_shifts"] = num_shifts
post_merge_label2, templates_array, template_sparse_mask, unit_ids, time_shifts = (
merge_peak_labels_from_templates(
peaks,
post_merge_label1,
unit_ids,
templates_array,
template_sparse_mask,
**params_merge_from_templates,
)
post_merge_label2, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_templates(
peaks,
post_merge_label1,
unit_ids,
templates_array,
template_sparse_mask,
**params_merge_from_templates,
)
else:
post_merge_label2 = post_merge_label1.copy()
time_shifts = None

dense_templates = Templates(
templates_array=templates_array,
Expand Down Expand Up @@ -346,5 +343,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):

labels_set = templates.unit_ids

more_outs = dict(templates=templates, time_shifts=time_shifts)
more_outs = dict(
templates=templates,
)
return labels_set, final_peak_labels, more_outs
Loading
Loading