Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 59 additions & 23 deletions src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ComputeCorrelograms(AnalyzerExtension):
fast_mode : "auto" | "on" | "off", default: "auto"
If "auto", a faster multithreaded implementations is used if method is "numba" and
if the number of units is greater than 300.
This uses job_kwargs to determine the number of threads to use.

Returns
-------
Expand Down Expand Up @@ -88,7 +89,7 @@ class ComputeCorrelograms(AnalyzerExtension):
depend_on = []
need_recording = False
use_nodepipeline = False
need_job_kwargs = False
need_job_kwargs = True

def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"):
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode)
Expand Down Expand Up @@ -139,7 +140,9 @@ def _merge_extension_data(
break

if can_apply_soft_method is False:
new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params)
new_ccgs, new_bins = _compute_correlograms_on_sorting(
new_sorting_analyzer.sorting, **self.params, **job_kwargs
)
new_data = dict(ccgs=new_ccgs, bins=new_bins)
else:
# Make a transformation dict, which tells us how unit_indices from the
Expand Down Expand Up @@ -195,12 +198,12 @@ def _merge_extension_data(

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
# for splits, we need to recompute correlograms
new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params)
new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params, **job_kwargs)
new_data = dict(ccgs=new_ccgs, bins=new_bins)
return new_data

def _run(self, verbose=False):
ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params)
def _run(self, verbose=False, **job_kwargs):
ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params, **job_kwargs)
self.data["ccgs"] = ccgs
self.data["bins"] = bins

Expand All @@ -226,6 +229,7 @@ class ComputeAutoCorrelograms(AnalyzerExtension):
fast_mode : "auto" | "off" | "on", default: "auto"
If "auto", a faster multithreaded implementations is used if method is "numba" and
if the number of units is greater than 300.
This uses job_kwargs to determine the number of threads to use.

Returns
-------
Expand Down Expand Up @@ -261,7 +265,7 @@ class ComputeAutoCorrelograms(AnalyzerExtension):
depend_on = []
need_recording = False
use_nodepipeline = False
need_job_kwargs = False
need_job_kwargs = True

def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode: str = "auto"):
params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode)
Expand All @@ -284,7 +288,7 @@ def _merge_extension_data(

# compute all new acgs at once
new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids)
only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params)
only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params, **job_kwargs)
new_acgs = np.zeros((len(all_new_units), only_new_acgs.shape[1]), dtype=np.int64)

for unit_ind, unit_id in enumerate(all_new_units):
Expand All @@ -308,7 +312,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
# compute all new isi at once
new_unit_ids_f = list(chain(*new_unit_ids))
new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids_f)
only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params)
only_new_acgs, new_bins = _compute_auto_correlograms_on_sorting(new_sorting, **self.params, **job_kwargs)

for unit_ind, unit_id in enumerate(all_new_units):
if unit_id not in new_unit_ids_f:
Expand All @@ -321,8 +325,8 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
new_extension_data = dict(acgs=new_acgs, bins=new_bins)
return new_extension_data

def _run(self, verbose=False):
acgs, bins = _compute_auto_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params)
def _run(self, verbose=False, **job_kwargs):
acgs, bins = _compute_auto_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params, **job_kwargs)
self.data["acgs"] = acgs
self.data["bins"] = bins

Expand All @@ -342,6 +346,7 @@ def compute_correlograms(
bin_ms: float = 1.0,
method: str = "auto",
fast_mode: str = "auto",
**job_kwargs,
):
"""
Compute correlograms using Numba or Numpy.
Expand All @@ -352,11 +357,21 @@ def compute_correlograms(

if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
return compute_correlograms_sorting_analyzer(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
sorting_analyzer_or_sorting,
window_ms=window_ms,
bin_ms=bin_ms,
method=method,
fast_mode=fast_mode,
**job_kwargs,
)
else:
return _compute_correlograms_on_sorting(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
sorting_analyzer_or_sorting,
window_ms=window_ms,
bin_ms=bin_ms,
method=method,
fast_mode=fast_mode,
**job_kwargs,
)


Expand Down Expand Up @@ -422,7 +437,7 @@ def _compute_num_bins(window_size, bin_size):
return num_bins, num_half_bins


def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode="auto"):
def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode="auto", **job_kwargs):
"""
Computes cross-correlograms from multiple units.

Expand All @@ -444,6 +459,7 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto",
fast_mode : "auto" | "on" | "off", default: "auto"
If "auto", a faster multithreaded implementations is used if method is "numba" and
if the number of units is greater than 300.
This uses job_kwargs to determine the number of threads to use.

Returns
-------
Expand Down Expand Up @@ -472,7 +488,7 @@ def _compute_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto",
if method == "numpy":
correlograms = _compute_correlograms_numpy(sorting, window_size, bin_size)
if method == "numba":
correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode)
correlograms = _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode=fast_mode, **job_kwargs)

return correlograms, bins

Expand Down Expand Up @@ -617,7 +633,7 @@ def correlogram_for_one_segment(spike_times, spike_unit_indices, window_size, bi
return correlograms


def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode):
def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode, **job_kwargs):
"""
Computes cross-correlograms between all units in `sorting`.

Expand Down Expand Up @@ -654,7 +670,8 @@ def _compute_correlograms_numba(sorting, window_size, bin_size, fast_mode):
correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64)

if fast_mode:
num_threads = mp.cpu_count()
job_kwargs = fix_job_kwargs(job_kwargs)
num_threads = job_kwargs["n_jobs"]
else:
num_threads = 1

Expand Down Expand Up @@ -829,7 +846,12 @@ def _compute_auto_correlograms_one_segment_numba(


def compute_auto_correlograms(
sorting_analyzer_or_sorting, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto", fast_mode="auto"
sorting_analyzer_or_sorting,
window_ms: float = 50.0,
bin_ms: float = 1.0,
method: str = "auto",
fast_mode="auto",
**job_kwargs,
):
"""
Compute correlograms using Numba or Numpy.
Expand All @@ -840,15 +862,25 @@ def compute_auto_correlograms(

if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
return compute_auto_correlograms_sorting_analyzer(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
sorting_analyzer_or_sorting,
window_ms=window_ms,
bin_ms=bin_ms,
method=method,
fast_mode=fast_mode,
**job_kwargs,
)
else:
return _compute_auto_correlograms_on_sorting(
sorting_analyzer_or_sorting, window_ms=window_ms, bin_ms=bin_ms, method=method, fast_mode=fast_mode
sorting_analyzer_or_sorting,
window_ms=window_ms,
bin_ms=bin_ms,
method=method,
fast_mode=fast_mode,
**job_kwargs,
)


def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode=False):
def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="auto", fast_mode=False, **job_kwargs):
"""
Computes auto-correlograms from multiple units.

Expand All @@ -870,6 +902,7 @@ def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="au
fast_mode : "auto" | "off" | "on", default: "auto"
If "auto", a faster multithreaded implementations is used if method is "numba" and
if the number of units is greater than 300.
This uses job_kwargs to determine the number of threads to use.

Returns
-------
Expand Down Expand Up @@ -897,7 +930,7 @@ def _compute_auto_correlograms_on_sorting(sorting, window_ms, bin_ms, method="au
if method == "numpy":
correlograms = _compute_auto_correlograms_numpy(sorting, window_size, bin_size)
if method == "numba":
correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode)
correlograms = _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode, **job_kwargs)

return correlograms, bins

Expand Down Expand Up @@ -1036,7 +1069,7 @@ def auto_correlogram_for_one_segment(spike_times, spike_unit_indices, window_siz
return correlograms


def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=False):
def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=False, **job_kwargs):
"""
Computes auto-correlograms between all units in `sorting`.

Expand All @@ -1055,6 +1088,8 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F
fast_mode : bool
If True, use faster implementations (currently only if method is 'numba'),
at the cost of possible minor numerical differences.
This uses job_kwargs to determine the number of threads to use.


Returns
-------
Expand All @@ -1073,7 +1108,8 @@ def _compute_auto_correlograms_numba(sorting, window_size, bin_size, fast_mode=F
correlograms = np.zeros((num_units, num_bins), dtype=np.int64)

if fast_mode:
num_threads = mp.cpu_count()
job_kwargs = fix_job_kwargs(job_kwargs)
num_threads = job_kwargs["n_jobs"]
else:
num_threads = 1

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/tests/test_correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_equal_results_fast_correlograms(window_and_bin_ms):
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

result_numba_fast, bins_numba_fast = _compute_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True, n_jobs=2
)
result_numba, bins_numba = _compute_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False
Expand All @@ -169,7 +169,7 @@ def test_equal_results_fast_auto_correlograms(window_and_bin_ms):
sorting = generate_sorting(num_units=5, sampling_frequency=30000.0, durations=[10.325, 3.5], seed=0)

result_numba_fast, bins_numba_fast = _compute_auto_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=True, n_jobs=2
)
result_numba, bins_numba = _compute_auto_correlograms_on_sorting(
sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba", fast_mode=False
Expand Down
Loading