From 5e27185acf5c14eefa662b66b7c1ca64a483714a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 13 Mar 2026 11:37:49 +0100 Subject: [PATCH 1/3] WIP --- .../postprocessing/correlograms.py | 49 ++++++++++--------- .../postprocessing/tests/test_correlograms.py | 4 +- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 9b4081060e..4bc4111bde 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -88,7 +88,7 @@ class ComputeCorrelograms(AnalyzerExtension): depend_on = [] need_recording = False use_nodepipeline = False - need_job_kwargs = False + need_job_kwargs = True def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode) @@ -139,7 +139,7 @@ def _merge_extension_data( break if can_apply_soft_method is False: - new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params, **job_kwargs) new_data = dict(ccgs=new_ccgs, bins=new_bins) else: # Make a transformation dict, which tells us how unit_indices from the @@ -195,12 +195,12 @@ def _merge_extension_data( def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): # for splits, we need to recompute correlograms - new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params, **job_kwargs) new_data = dict(ccgs=new_ccgs, bins=new_bins) return new_data - def _run(self, verbose=False): - ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) + def _run(self, verbose=False, **job_kwargs): + ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params, **job_kwargs) self.data["ccgs"] = ccgs self.data["bins"] = bins @@ -261,7 +261,7 @@ class ComputeAutoCorrelograms(AnalyzerExtension): depend_on = [] need_recording = False use_nodepipeline = False - need_job_kwargs = False + need_job_kwargs = True def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode) @@ -284,7 +284,7 @@ def _merge_extension_data( # compute all new acgs at once new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids) - only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params) + only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params, **job_kwargs) new_acgs = np.zeros((len(all_new_units), only_new_acgs.shape[1]), dtype=np.int64) for unit_ind, unit_id in enumerate(all_new_units): @@ -308,7 +308,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, # compute all new isi at once new_unit_ids_f = list(chain(*new_unit_ids)) new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids_f) - only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params) + only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params, **job_kwargs) for unit_ind, unit_id in enumerate(all_new_units): if unit_id not in new_unit_ids_f: @@ -321,8 +321,8 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, new_extension_data = dict(acgs=new_acgs, bins=new_bins) return new_extension_data - def _run(self, verbose=False): - acgs, bins = _compute_auto_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) + def _run(self, verbose=False, **job_kwargs): + acgs, bins = _compute_auto_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params, **job_kwargs) self.data["acgs"] = acgs self.data["bins"] = bins @@ -342,6 +342,7 @@ def compute_correlograms( bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto", + **job_kwargs ): """ Compute correlograms using Numba or Numpy. @@ -352,11 +353,11 @@ def compute_correlograms( if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): return compute_correlograms_sorting_analyzer( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode, **job_kwargs ) else: return _compute_correlograms_on_sorting( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode, **job_kwargs ) @@ -422,7 +423,7 @@ def _compute_num_bins(window_size, bin_size): return num_bins, num_half_bins -def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode="auto"): +def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode="auto", **job_kwargs): """ Computes cross-correlograms from multiple units. @@ -472,7 +473,7 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", if method == "numpy": correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size) if method == "numba": - correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode) + correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode, **job_kwargs) return correlograms, bins @@ -617,7 +618,7 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi return correlograms -def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode): +def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode, **job_kwargs): """ Computes cross-correlograms between all units in `sorting`. @@ -654,7 +655,8 @@ def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode): correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64) if fast_mode: - num_threads = mp.cpu_count() + job_kwargs = fix_job_kwargs(job_kwargs) + num_threads = job_kwargs['n_jobs'] else: num_threads = 1 @@ -829,7 +831,7 @@ def _compute_auto_correlograms_one_segment_numba( def compute_auto_correlograms( - sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode="auto" + sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode="auto", **job_kwargs ): """ Compute correlograms using Numba or Numpy. @@ -840,15 +842,15 @@ def compute_auto_correlograms( if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): return compute_auto_correlograms_sorting_analyzer( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode, **job_kwargs ) else: return _compute_auto_correlograms_on_sorting( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode + sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode, **job_kwargs ) -def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode=False): +def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode=False, **job_kwargs): """ Computes auto-correlograms from multiple units. @@ -897,7 +899,7 @@ def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="au if method == "numpy": correlograms = _compute_auto_correlograms_numpy(sorting, window_size, bin_size) if method == "numba": - correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode) + correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode, **job_kwargs) return correlograms, bins @@ -1036,7 +1038,7 @@ def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_siz return correlograms -def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=False): +def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=False, **job_kwargs): """ Computes auto-correlograms between all units in `sorting`. @@ -1073,7 +1075,8 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F correlograms = np.zeros((num_units, num_bins), dtype=np.int64) if fast_mode: - num_threads = mp.cpu_count() + job_kwargs = fix_job_kwargs(job_kwargs) + num_threads = job_kwargs['n_jobs'] else: num_threads = 1 diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index afe41dd057..c86c933924 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -147,7 +147,7 @@ def test_equal_results_fast_correlograms(window_and_bin_ms): sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) result_numba_fast, bins_numba_fast = _compute_correlograms_on_sorting( - sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True, n_jobs=2 ) result_numba, bins_numba = _compute_correlograms_on_sorting( sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False @@ -169,7 +169,7 @@ def test_equal_results_fast_auto_correlograms(window_and_bin_ms): sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0) result_numba_fast, bins_numba_fast = _compute_auto_correlograms_on_sorting( - sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True + sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True, n_jobs=2 ) result_numba, bins_numba = _compute_auto_correlograms_on_sorting( sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False From 4175d9eabe24526d4f7cbf5c7e10072bec7e95f7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 13 Mar 2026 11:43:23 +0100 Subject: [PATCH 2/3] Docs --- src/spikeinterface/postprocessing/correlograms.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 4bc4111bde..1c38e9140f 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -47,6 +47,7 @@ class ComputeCorrelograms(AnalyzerExtension): fast_mode : "auto" | "on" | "off", default: "auto" If "auto", a faster multithreaded implementations is used if method is "numba" and if the number of units is greater than 300. + This uses job_kwargs to determine the number of threads to use. Returns ------- @@ -226,6 +227,7 @@ class ComputeAutoCorrelograms(AnalyzerExtension): fast_mode : "auto" | "off" | "on", default: "auto" If "auto", a faster multithreaded implementations is used if method is "numba" and if the number of units is greater than 300. + This uses job_kwargs to determine the number of threads to use. Returns ------- @@ -445,6 +447,7 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode : "auto" | "on" | "off", default: "auto" If "auto", a faster multithreaded implementations is used if method is "numba" and if the number of units is greater than 300. + This uses job_kwargs to determine the number of threads to use. Returns ------- @@ -872,6 +875,7 @@ def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="au fast_mode : "auto" | "off" | "on", default: "auto" If "auto", a faster multithreaded implementations is used if method is "numba" and if the number of units is greater than 300. + This uses job_kwargs to determine the number of threads to use. Returns ------- @@ -1056,7 +1060,9 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F The size of which to bin lags, in samples. fast_mode : bool If True, use faster implementations (currently only if method is 'numba'), - at the cost of possible minor numerical differences. + at the cost of possible minor numerical differences. + This uses job_kwargs to determine the number of threads to use. + Returns ------- From 4284b99d9931406b5af204ba11f9f06904830aa6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 10:45:00 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/correlograms.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 1c38e9140f..40ac386ecc 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -140,7 +140,9 @@ def _merge_extension_data( break if can_apply_soft_method is False: - new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params, **job_kwargs) + new_ccgs, new_bins = _compute_correlograms_on_sorting( + new_sorting_analyzer.sorting, **self.params, **job_kwargs + ) new_data = dict(ccgs=new_ccgs, bins=new_bins) else: # Make a transformation dict, which tells us how unit_indices from the @@ -344,7 +346,7 @@ def compute_correlograms( bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto", - **job_kwargs + **job_kwargs, ): """ Compute correlograms using Numba or Numpy. @@ -355,11 +357,21 @@ def compute_correlograms( if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): return compute_correlograms_sorting_analyzer( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode, **job_kwargs + sorting_analyzer_or_sorting, + window_ms=window_ms, + bin_ms=bin_ms, + method=method, + fast_mode=fast_mode, + **job_kwargs, ) else: return _compute_correlograms_on_sorting( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode, **job_kwargs + sorting_analyzer_or_sorting, + window_ms=window_ms, + bin_ms=bin_ms, + method=method, + fast_mode=fast_mode, + **job_kwargs, ) @@ -659,7 +671,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode, **job if fast_mode: job_kwargs = fix_job_kwargs(job_kwargs) - num_threads = job_kwargs['n_jobs'] + num_threads = job_kwargs["n_jobs"] else: num_threads = 1 @@ -834,7 +846,12 @@ def _compute_auto_correlograms_one_segment_numba( def compute_auto_correlograms( - sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode="auto", **job_kwargs + sorting_analyzer_or_sorting, + window_ms: float = 50.0, + bin_ms: float = 1.0, + method: str = "auto", + fast_mode="auto", + **job_kwargs, ): """ Compute correlograms using Numba or Numpy. @@ -845,11 +862,21 @@ def compute_auto_correlograms( if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): return compute_auto_correlograms_sorting_analyzer( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode, **job_kwargs + sorting_analyzer_or_sorting, + window_ms=window_ms, + bin_ms=bin_ms, + method=method, + fast_mode=fast_mode, + **job_kwargs, ) else: return _compute_auto_correlograms_on_sorting( - sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode, **job_kwargs + sorting_analyzer_or_sorting, + window_ms=window_ms, + bin_ms=bin_ms, + method=method, + fast_mode=fast_mode, + **job_kwargs, ) @@ -1060,9 +1087,9 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F The size of which to bin lags, in samples. fast_mode : bool If True, use faster implementations (currently only if method is 'numba'), - at the cost of possible minor numerical differences. + at the cost of possible minor numerical differences. This uses job_kwargs to determine the number of threads to use. - + Returns ------- @@ -1082,7 +1109,7 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F if fast_mode: job_kwargs = fix_job_kwargs(job_kwargs) - num_threads = job_kwargs['n_jobs'] + num_threads = job_kwargs["n_jobs"] else: num_threads = 1