diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 9b4081060e..e1422cbceb 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -44,9 +44,11 @@ class ComputeCorrelograms(AnalyzerExtension): bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ... method : "auto" | "numpy" | "numba", default: "auto" If "auto" and numba is installed, numba is used, otherwise numpy is used. - fast_mode : "auto" | "on" | "off", default: "auto" - If "auto", a faster multithreaded implementations is used if method is "numba" and - if the number of units is greater than 300. + n_jobs : int | float, default: 1.0 + Number of workers that will be requested during multiprocessing. Note that + the OS determines how this is distributed, but for convenience one can use + * -1 the number of workers is the same as number of cores (from os.cpu_count()) + * float between 0 and 1 uses fraction of total cores (from os.cpu_count()) Returns ------- @@ -90,8 +92,10 @@ class ComputeCorrelograms(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"): - params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode) + def _set_params( + self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", n_jobs: int | float = 1.0 + ): + params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs) return params @@ -223,9 +227,11 @@ class ComputeAutoCorrelograms(AnalyzerExtension): bin size 1 ms, the correlation will be binned as -25 ms, -24 ms, ... method : "auto" | "numpy" | "numba", default: "auto" If "auto" and numba is installed, numba is used, otherwise numpy is used. - fast_mode : "auto" | "off" | "on", default: "auto" - If "auto", a faster multithreaded implementations is used if method is "numba" and - if the number of units is greater than 300. + n_jobs : int | float, default: 1.0 + Number of workers that will be requested during multiprocessing. Note that + the OS determines how this is distributed, but for convenience one can use + * -1 the number of workers is the same as number of cores (from os.cpu_count()) + * float between 0 and 1 uses fraction of total cores (from os.cpu_count()) Returns ------- @@ -263,8 +269,10 @@ class ComputeAutoCorrelograms(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"): - params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode) + def _set_params( + self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", n_jobs: int | float = 1.0 + ): + params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs) return params def _select_extension_data(self, unit_ids): @@ -341,7 +349,7 @@ def compute_correlograms( window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", - fast_mode: str = "auto", + n_jobs: int | float = 1.0, ): """ Compute correlograms using Numba or Numpy. @@ -352,11 +360,11 @@ def compute_correlograms( if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): return compute_correlograms_sorting_analyzer( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs ) else: return _compute_correlograms_on_sorting( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs ) @@ -422,7 +430,7 @@ def _compute_num_bins(window_size, bin_size): return num_bins, num_half_bins -def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode="auto"): +def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", n_jobs=1.0): """ Computes cross-correlograms from multiple units. @@ -441,9 +449,11 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", method : str To use "numpy" or "numba". "auto" will use numba if available, otherwise numpy. - fast_mode : "auto" | "on" | "off", default: "auto" - If "auto", a faster multithreaded implementations is used if method is "numba" and - if the number of units is greater than 300. + n_jobs : int | float, default: 1.0 + Number of workers that will be requested during multiprocessing. Note that + the OS determines how this is distributed, but for convenience one can use + * -1 the number of workers is the same as number of cores (from os.cpu_count()) + * float between 0 and 1 uses fraction of total cores (from os.cpu_count()) Returns ------- @@ -459,20 +469,12 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", if method == "auto": method = "numba" if HAVE_NUMBA else "numpy" - if method == "numba" and fast_mode == "auto": - num_units = len(sorting.unit_ids) - fast_mode = num_units > 300 - elif fast_mode == "off": - fast_mode = False - elif fast_mode == "on": - fast_mode = True - bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) if method == "numpy": correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size) if method == "numba": - correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode) + correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, n_jobs=n_jobs) return correlograms, bins @@ -617,7 +619,7 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi return correlograms -def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode): +def _compute_correlograms_numba(sorting, window_size, bin_size, n_jobs=1.0): """ Computes cross-correlograms between all units in `sorting`. @@ -633,9 +635,11 @@ def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode): The window size over which to perform the cross-correlation, in samples bin_size : int The size of which to bin lags, in samples. - fast_mode : bool - If True, use faster implementations (currently only if method is 'numba'), - at the cost of possible minor numerical differences. + n_jobs : int | float, default: 1.0 + Number of workers that will be requested during multiprocessing. Note that + the OS determines how this is distributed, but for convenience one can use + * -1 the number of workers is the same as number of cores (from os.cpu_count()) + * float between 0 and 1 uses fraction of total cores (from os.cpu_count()) Returns ------- @@ -653,10 +657,12 @@ def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode): spikes = sorting.to_spike_vector(concatenated=False) correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64) - if fast_mode: - num_threads = mp.cpu_count() + if isinstance(n_jobs, float) and 0 < n_jobs <= 1: + import os + + num_threads = int(n_jobs * os.cpu_count()) else: - num_threads = 1 + num_threads = n_jobs for seg_index in range(sorting.get_num_segments()): spike_times = spikes[seg_index]["sample_index"] @@ -829,7 +835,7 @@ def _compute_auto_correlograms_one_segment_numba( def compute_auto_correlograms( - sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode="auto" + sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", n_jobs=1.0 ): """ Compute correlograms using Numba or Numpy. @@ -840,15 +846,15 @@ def compute_auto_correlograms( if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): return compute_auto_correlograms_sorting_analyzer( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs ) else: return _compute_auto_correlograms_on_sorting( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, n_jobs=n_jobs ) -def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode=False): +def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", n_jobs=1.0): """ Computes auto-correlograms from multiple units. @@ -867,9 +873,11 @@ def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="au method : str To use "numpy" or "numba". "auto" will use numba if available, otherwise numpy. - fast_mode : "auto" | "off" | "on", default: "auto" - If "auto", a faster multithreaded implementations is used if method is "numba" and - if the number of units is greater than 300. + n_jobs : int | float, default: 1.0 + Number of workers that will be requested during multiprocessing. Note that + the OS determines how this is distributed, but for convenience one can use + * -1 the number of workers is the same as number of cores (from os.cpu_count()) + * float between 0 and 1 uses fraction of total cores (from os.cpu_count()) Returns ------- @@ -884,20 +892,13 @@ def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="au if method == "auto": method = "numba" if HAVE_NUMBA else "numpy" - if method == "numba" and fast_mode == "auto": - num_units = len(sorting.unit_ids) - fast_mode = num_units > 300 - elif fast_mode == "off": - fast_mode = False - elif fast_mode == "on": - fast_mode = True bins, window_size, bin_size = _make_bins(sorting, window_ms, bin_ms) if method == "numpy": correlograms = _compute_auto_correlograms_numpy(sorting, window_size, bin_size) if method == "numba": - correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode) + correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size, n_jobs=n_jobs) return correlograms, bins @@ -1036,7 +1037,7 @@ def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_siz return correlograms -def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=False): +def _compute_auto_correlograms_numba(sorting, window_size, bin_size, n_jobs=1.0): """ Computes auto-correlograms between all units in `sorting`. @@ -1052,9 +1053,11 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F The window size over which to perform the cross-correlation, in samples bin_size : int The size of which to bin lags, in samples. - fast_mode : bool - If True, use faster implementations (currently only if method is 'numba'), - at the cost of possible minor numerical differences. + n_jobs : int | float, default: 1.0 + Number of workers that will be requested during multiprocessing. Note that + the OS determines how this is distributed, but for convenience one can use + * -1 the number of workers is the same as number of cores (from os.cpu_count()) + * float between 0 and 1 uses fraction of total cores (from os.cpu_count()) Returns ------- @@ -1072,10 +1075,12 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F spikes = sorting.to_spike_vector(concatenated=False) correlograms = np.zeros((num_units, num_bins), dtype=np.int64) - if fast_mode: - num_threads = mp.cpu_count() + if isinstance(n_jobs, float) and 0 < n_jobs <= 1: + import os + + num_threads = int(n_jobs * os.cpu_count()) else: - num_threads = 1 + num_threads = n_jobs for seg_index in range(sorting.get_num_segments()): spike_times = spikes[seg_index]["sample_index"] diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index afe41dd057..a166e79d53 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -147,10 +147,14 @@ def test_equal_results_fast_correlograms(window_and_bin_ms): sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) result_numba_fast, bins_numba_fast = _compute_correlograms_on_sorting( - sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", n_jobs=2 ) result_numba, bins_numba = _compute_correlograms_on_sorting( - sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False + sorting, + window_ms=window_ms, + bin_ms=bin_ms, + method="numba", + n_jobs=1, ) from numpy.testing import assert_almost_equal @@ -169,10 +173,10 @@ def test_equal_results_fast_auto_correlograms(window_and_bin_ms): sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) result_numba_fast, bins_numba_fast = _compute_auto_correlograms_on_sorting( - sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", n_jobs=2 ) result_numba, bins_numba = _compute_auto_correlograms_on_sorting( - sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", n_jobs=1 ) from numpy.testing import assert_almost_equal diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 98d7cdbf7d..66ffd46aae 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -308,10 +308,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs=job_kwargs, ) - if more_outs["time_shifts"] is not None: - time_shifts = more_outs["time_shifts"] - peaks["sample_index"] += time_shifts - mask = clustering_label >= 0 kept_peaks = peaks[mask] kept_labels = clustering_label[mask] diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index fb833b1a46..2369e1caad 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -348,7 +348,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids) auto_merge = True - analyzer_final = None if auto_merge: from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py index 57d1f738e6..35b01f8fd6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_hdbscan.py @@ -164,15 +164,13 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): print("Kept %d raw clusters" % len(labels)) if params["merge_from_templates"] is not None: - peak_labels, merge_template_array, new_sparse_mask, new_unit_ids, time_shifts = ( - merge_peak_labels_from_templates( - peaks, - peak_labels, - templates.unit_ids, - templates.templates_array, - new_sparse_mask, - **params["merge_from_templates"], - ) + peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates( + peaks, + peak_labels, + templates.unit_ids, + templates.templates_array, + new_sparse_mask, + **params["merge_from_templates"], ) templates = Templates( @@ -185,8 +183,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): probe=recording.get_probe(), is_in_uV=False, ) - else: - time_shifts = None # clean very small cluster before peeler if ( @@ -214,7 +210,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): more_outs = dict( svd_model=svd_model, peaks_svd=peaks_svd, - time_shifts=time_shifts, peak_svd_sparse_mask=sparse_mask, ) return labels, peak_labels, more_outs diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index a8d13dd9df..507611a515 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -299,19 +299,16 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): num_shifts = params_merge_from_templates["num_shifts"] num_shifts = min((num_shifts, nbefore, nafter)) params_merge_from_templates["num_shifts"] = num_shifts - post_merge_label2, templates_array, template_sparse_mask, unit_ids, time_shifts = ( - merge_peak_labels_from_templates( - peaks, - post_merge_label1, - unit_ids, - templates_array, - template_sparse_mask, - **params_merge_from_templates, - ) + post_merge_label2, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_templates( + peaks, + post_merge_label1, + unit_ids, + templates_array, + template_sparse_mask, + **params_merge_from_templates, ) else: post_merge_label2 = post_merge_label1.copy() - time_shifts = None dense_templates = Templates( templates_array=templates_array, @@ -346,5 +343,7 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels_set = templates.unit_ids - more_outs = dict(templates=templates, time_shifts=time_shifts) + more_outs = dict( + templates=templates, + ) return labels_set, final_peak_labels, more_outs diff --git a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py index 6f5ebc6627..64da6abf21 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merging_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/merging_tools.py @@ -556,13 +556,13 @@ def merge_peak_labels_from_templates( if not use_lags: lags = None - clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts = ( + clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids = ( _apply_pair_mask_on_labels_and_recompute_templates( pair_mask, peak_labels, unit_ids, templates_array, template_sparse_mask, lags ) ) - return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts + return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids def _apply_pair_mask_on_labels_and_recompute_templates( @@ -580,10 +580,7 @@ def _apply_pair_mask_on_labels_and_recompute_templates( clean_labels = peak_labels.copy() n_components, group_labels = connected_components(pair_mask, directed=False, return_labels=True) - if lags is not None: - time_shifts = np.zeros(len(peak_labels), dtype=np.int32) - else: - time_shifts = None + # print("merges", templates_array.shape[0], "to", n_components) merge_template_array = templates_array.copy() merge_sparsity_mask = template_sparse_mask.copy() @@ -606,15 +603,10 @@ def _apply_pair_mask_on_labels_and_recompute_templates( for i, l in enumerate(merge_group): label = unit_ids[l] - mask = peak_labels == label - weights[i] = np.sum(mask) + weights[i] = np.sum(peak_labels == label) if i > 0: - clean_labels[mask] = unit_ids[g0] + clean_labels[peak_labels == label] = unit_ids[g0] keep_template[l] = False - if lags is not None: - shift = lags[l, g0] # which is the same as -lags[g0, l] - time_shifts[mask] += shift - weights /= weights.sum() if lags is None: @@ -625,7 +617,7 @@ def _apply_pair_mask_on_labels_and_recompute_templates( # with shifts accumulated_template = np.zeros_like(merge_template_array[g0, :, :]) for i, l in enumerate(merge_group): - shift = lags[l, g0] # which is the same as -lags[g0, l] + shift = -lags[g0, l] if shift > 0: # template is shifted to right temp = np.zeros_like(accumulated_template) @@ -645,4 +637,4 @@ def _apply_pair_mask_on_labels_and_recompute_templates( merge_template_array = merge_template_array[keep_template, :, :] merge_sparsity_mask = merge_sparsity_mask[keep_template, :] - return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids, time_shifts + return clean_labels, merge_template_array, merge_sparsity_mask, new_unit_ids diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index df4c4ae39d..a003e31e87 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -132,15 +132,13 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): print("Kept %d raw clusters" % len(labels)) if params["merge_from_templates"] is not None: - peak_labels, merge_template_array, new_sparse_mask, new_unit_ids, time_shifts = ( - merge_peak_labels_from_templates( - peaks, - peak_labels, - unit_ids, - templates_array, - np.ones((len(unit_ids), num_chans), dtype=bool), - **params["merge_from_templates"], - ) + peak_labels, merge_template_array, new_sparse_mask, new_unit_ids = merge_peak_labels_from_templates( + peaks, + peak_labels, + unit_ids, + templates_array, + np.ones((len(unit_ids), num_chans), dtype=bool), + **params["merge_from_templates"], ) templates = Templates( @@ -153,8 +151,6 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): probe=recording.get_probe(), is_in_uV=False, ) - else: - time_shifts = None labels = templates.unit_ids @@ -164,4 +160,4 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): if verbose: print("Kept %d non-duplicated clusters" % len(labels)) - return labels, peak_labels, dict(time_shifts=time_shifts, templates=templates) + return labels, peak_labels, dict()