From 561051882c02f90d87fc459937d4cc7ff344642d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jan 2026 16:58:35 +0100 Subject: [PATCH 01/17] Refactor aftifacts detection. --- src/spikeinterface/preprocessing/__init__.py | 1 + .../preprocessing/detect_artifacts.py | 202 ++++++++++++ .../preprocessing/preprocessing_classes.py | 4 +- .../preprocessing/silence_artifacts.py | 298 +++++------------- .../tests/test_detect_artifacts.py | 13 + .../tests/test_silence_artifacts.py | 16 +- 6 files changed, 309 insertions(+), 225 deletions(-) create mode 100644 src/spikeinterface/preprocessing/detect_artifacts.py create mode 100644 src/spikeinterface/preprocessing/tests/test_detect_artifacts.py diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index de25944bd2..d2d8674168 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -20,6 +20,7 @@ PreprocessingPipeline, ) +from .detect_artifacts import detect_artifact_periods, detect_period_artifacts_by_envelope # for snippets from .align_snippets import AlignSnippets from warnings import warn diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py new file mode 100644 index 0000000000..6cb22ac49f --- /dev/null +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import numpy as np + +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording +from spikeinterface.preprocessing.rectify import RectifyRecording +from spikeinterface.preprocessing.common_reference import CommonReferenceRecording +from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording +from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype +import numpy as np + + +artifact_dtype = [ + ("start_index", "int64"), + ("stop_index", "int64"), + ("segment_index", "int64"), +] + +extended_artifact_dtype = artifact_dtype + [ + # TODO +] + + +_internal_dtype = [ + ("sample_index", "int64"), + ("segment_index", "int64"), + ("front", "bool") +] + + +def detect_artifact_periods( + recording, + method="envelope", + method_kwargs=None, + job_kwargs=None, +): + """ + + """ + + if method_kwargs is None: + method_kwargs = dict() + + if method == "envelope": + artifacts, envelope = detect_period_artifacts_by_envelope(recording, **method_kwargs, job_kwargs=job_kwargs) + elif method == "saturation": + raise NotImplementedError("Soon") + + else: + raise ValueError("") + + return artifacts + + + +## detect_period_artifacts_saturation Zone + + + + +## detect_period_artifacts_by_envelope Zone + +class DetectThresholdCrossing(PeakDetector): + + name = "threshold_crossings" + preferred_mp_context = None + + def __init__( + self, + recording, + detect_threshold=5, + noise_levels=None, + seed=None, + noise_levels_kwargs=dict(), + ): + PeakDetector.__init__(self, recording, return_output=True) + if noise_levels is None: + random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + self.abs_thresholds = noise_levels * detect_threshold + self._dtype = np.dtype(_internal_dtype) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return self._dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + z = np.median(traces / self.abs_thresholds, 1) + threshold_mask = np.diff((z > 1) != 0, axis=0) + indices = np.flatnonzero(threshold_mask) + threshold_crossings = np.zeros(indices.size, dtype=self._dtype) + threshold_crossings["sample_index"] = indices + threshold_crossings["segment_index"] = segment_index + threshold_crossings["front"][::2] = True + threshold_crossings["front"][1::2] = False + return (threshold_crossings,) + + +def detect_period_artifacts_by_envelope( + recording, + detect_threshold=5, + # min_duration_ms=50, + freq_max=20.0, + seed=None, + job_kwargs=None, + random_slices_kwargs=None, +): + """ + Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of + a global envelope of the channels. + + Parameters + ---------- + recording : RecordingExtractor + The recording extractor to detect putative artifacts + detect_threshold : float, default: 5 + The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` + freq_max : float, default: 20 + The maximum frequency for the low pass filter used + seed : int | None, default: None + Random seed for `get_noise_levels`. + If none, `get_noise_levels` uses `seed=0`. + **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function + + """ + + envelope = RectifyRecording(recording) + envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) + envelope = CommonReferenceRecording(envelope) + + from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ) + + # _, job_kwargs = split_job_kwargs(noise_levels_kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) + if random_slices_kwargs is None: + random_slices_kwargs = {} + else: + random_slices_kwargs = random_slices_kwargs.copy() + random_slices_kwargs["seed"] = seed + noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + + node0 = DetectThresholdCrossing( + recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, + ) + + threshold_crossings = run_node_pipeline( + envelope, + [node0], + job_kwargs, + job_name="detect threshold crossings", + ) + + order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) + threshold_crossings = threshold_crossings[order] + + artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) + + + return artifacts, envelope + + +# tools + +def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): + + num_seg = recording.get_num_segments() + + final_artifacts = [] + for seg_index in range(num_seg): + mask = artifacts["segment_index"] == seg_index + sub_thr = artifacts[mask] + if len(sub_thr) > 0: + if not sub_thr["front"][0]: + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr["sample_index"] = 0 + local_thr["front"] = True + sub_thr = np.hstack((local_thr, sub_thr)) + if sub_thr["front"][-1]: + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr["sample_index"] = recording.get_num_samples(seg_index) + local_thr["front"] = False + sub_thr = np.hstack((sub_thr, local_thr)) + + local_artifact = np.zeros(sub_thr.size/2, dtype=artifact_dtype) + local_artifact["start_index"] = sub_thr["sample_index"][::2] + local_artifact["stop_index"] = sub_thr["sample_index"][1::2] + local_artifact["segment_index"] = seg_index + final_artifacts.append(local_artifact) + + if len(final_artifacts) > 0: + final_artifacts = np.concatenate(final_artifacts) + else: + final_artifacts = np.zeros(0, dtype=artifact_dtype) + return final_artifacts \ No newline at end of file diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index fe9d95c506..47839db7a0 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -50,7 +50,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed -from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts +# from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts _all_preprocesser_dict = { # filter stuff @@ -90,7 +90,7 @@ DirectionalDerivativeRecording: directional_derivative, AstypeRecording: astype, UnsignedToSignedRecording: unsigned_to_signed, - SilencedArtifactsRecording: silence_artifacts, + # SilencedArtifactsRecording: silence_artifacts, } # we control import in the preprocessing init by setting an __all__ diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index b1ae00b64c..8006342847 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -4,221 +4,89 @@ from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording -from spikeinterface.preprocessing.rectify import RectifyRecording -from spikeinterface.preprocessing.common_reference import CommonReferenceRecording -from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording -from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype import numpy as np -class DetectThresholdCrossing(PeakDetector): - - name = "threshold_crossings" - preferred_mp_context = None - - def __init__( - self, - recording, - detect_threshold=5, - noise_levels=None, - seed=None, - noise_levels_kwargs=dict(), - ): - PeakDetector.__init__(self, recording, return_output=True) - if noise_levels is None: - random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() - random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) - self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("front", "bool")]) - - def get_trace_margin(self): - return 0 - - def get_dtype(self): - return self._dtype - - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = np.median(traces / self.abs_thresholds, 1) - threshold_mask = np.diff((z > 1) != 0, axis=0) - indices = np.flatnonzero(threshold_mask) - threshold_crossings = np.zeros(indices.size, dtype=self._dtype) - threshold_crossings["sample_index"] = indices - threshold_crossings["front"][::2] = True - threshold_crossings["front"][1::2] = False - return (threshold_crossings,) - - -def detect_period_artifacts_by_envelope( - recording, - detect_threshold=5, - min_duration_ms=50, - freq_max=20.0, - seed=None, - noise_levels=None, - **noise_levels_kwargs, -): - """ - Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of - a global envelope of the channels. - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to detect putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels`. - If none, `get_noise_levels` uses `seed=0`. - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - """ - - envelope = RectifyRecording(recording) - envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) - envelope = CommonReferenceRecording(envelope) - - from spikeinterface.core.node_pipeline import ( - run_node_pipeline, - ) - - _, job_kwargs = split_job_kwargs(noise_levels_kwargs) - job_kwargs = fix_job_kwargs(job_kwargs) - - node0 = DetectThresholdCrossing( - recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - threshold_crossings = run_node_pipeline( - recording, - [node0], - job_kwargs, - job_name="detect threshold crossings", - ) - - order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) - threshold_crossings = threshold_crossings[order] - - periods = [] - fs = recording.sampling_frequency - max_duration_samples = int(min_duration_ms * fs / 1000) - num_seg = recording.get_num_segments() - - for seg_index in range(num_seg): - sub_periods = [] - mask = threshold_crossings["segment_index"] == seg_index - sub_thr = threshold_crossings[mask] - if len(sub_thr) > 0: - local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) - if not sub_thr["front"][0]: - local_thr["sample_index"] = 0 - local_thr["front"] = True - sub_thr = np.hstack((local_thr, sub_thr)) - if sub_thr["front"][-1]: - local_thr["sample_index"] = recording.get_num_samples(seg_index) - local_thr["front"] = False - sub_thr = np.hstack((sub_thr, local_thr)) - - indices = np.flatnonzero(np.diff(sub_thr["front"])) - for i, j in zip(indices[:-1], indices[1:]): - if sub_thr["front"][i]: - start = sub_thr["sample_index"][i] - end = sub_thr["sample_index"][j] - if end - start > max_duration_samples: - sub_periods.append((start, end)) - - periods.append(sub_periods) - - return periods, envelope - - -class SilencedArtifactsRecording(SilencedPeriodsRecording): - """ - Silence user-defined periods from recording extractor traces. The code will construct - an enveloppe of the recording (as a low pass filtered version of the traces) and detect - threshold crossings to identify the periods to silence. The periods are then silenced either - on a per channel basis or across all channels by replacing the values by zeros or by - adding gaussian noise with the same variance as the one in the recordings - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to silence putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise", default: "zeros" - Determines what periods are replaced by. Can be one of the following: - - - "zeros": Artifacts are replaced by zeros. - - - "noise": The periods are filled with a gaussion noise that has the - same variance that the one in the recordings, on a per channel - basis - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - Returns - ------- - silenced_recording : SilencedArtifactsRecording - The recording extractor after silencing detected artifacts - """ - - _precomputable_kwarg_names = ["list_periods"] - - def __init__( - self, - recording, - detect_threshold=5, - verbose=False, - freq_max=20.0, - min_duration_ms=50, - mode="zeros", - noise_levels=None, - seed=None, - list_periods=None, - **noise_levels_kwargs, - ): - - if list_periods is None: - list_periods, _ = detect_period_artifacts_by_envelope( - recording, - detect_threshold=detect_threshold, - min_duration_ms=min_duration_ms, - freq_max=freq_max, - seed=seed, - noise_levels=noise_levels, - **noise_levels_kwargs, - ) - - if verbose: - for i, periods in enumerate(list_periods): - total_time = np.sum([end - start for start, end in periods]) - percentage = 100 * total_time / recording.get_num_samples(i) - print(f"{percentage}% of segment {i} has been flagged as artifactual") - - SilencedPeriodsRecording.__init__( - self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - -# function for API -silence_artifacts = define_function_handling_dict_from_class( - source_class=SilencedArtifactsRecording, name="silence_artifacts" -) +# class SilencedArtifactsRecording(SilencedPeriodsRecording): +# """ +# Silence user-defined periods from recording extractor traces. The code will construct +# an enveloppe of the recording (as a low pass filtered version of the traces) and detect +# threshold crossings to identify the periods to silence. The periods are then silenced either +# on a per channel basis or across all channels by replacing the values by zeros or by +# adding gaussian noise with the same variance as the one in the recordings + +# Parameters +# ---------- +# recording : RecordingExtractor +# The recording extractor to silence putative artifacts +# artifacts : np.array, None +# The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` +# freq_max : float, default: 20 +# The maximum frequency for the low pass filter used +# min_duration_ms : float, default: 50 +# The minimum duration for a threshold crossing to be considered as an artefact. +# noise_levels : array +# Noise levels if already computed +# seed : int | None, default: None +# Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. +# If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. +# mode : "zeros" | "noise", default: "zeros" +# Determines what periods are replaced by. Can be one of the following: + +# - "zeros": Artifacts are replaced by zeros. + +# - "noise": The periods are filled with a gaussion noise that has the +# same variance that the one in the recordings, on a per channel +# basis +# **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function + +# Returns +# ------- +# silenced_recording : SilencedArtifactsRecording +# The recording extractor after silencing detected artifacts +# """ + +# _precomputable_kwarg_names = ["artifacts"] + +# def __init__( +# self, +# recording, +# artifacts=None, +# detect_threshold=5, +# verbose=False, +# freq_max=20.0, +# min_duration_ms=50, +# mode="zeros", +# noise_levels=None, +# seed=None, +# list_periods=None, +# **noise_levels_kwargs, +# ): + +# if artifacts is None: +# from spikeinterface.preprocessing import detect_artifacts +# artifacts = detect_artifact_periods( +# recording, +# detect_threshold=detect_threshold, +# min_duration_ms=min_duration_ms, +# freq_max=freq_max, +# seed=seed, +# noise_levels=noise_levels, +# **noise_levels_kwargs, +# ) + +# if verbose: +# for i, periods in enumerate(artifacts): +# total_time = np.sum([end - start for start, end in periods]) +# percentage = 100 * total_time / recording.get_num_samples(i) +# print(f"{percentage}% of segment {i} has been flagged as artifactual") + +# SilencedPeriodsRecording.__init__( +# self, recording, artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs +# ) + + +# # function for API +# silence_artifacts = define_function_handling_dict_from_class( +# source_class=SilencedArtifactsRecording, name="silence_artifacts" +# ) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py new file mode 100644 index 0000000000..52e8d927f9 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -0,0 +1,13 @@ +from spikeinterface.core import generate_recording +from spikeinterface.preprocessing import detect_artifact_periods + + +def test_detect_artifact_periods(): + # one segment only + rec = generate_recording(durations=[10.0, 10]) + artifacts = detect_artifact_periods(rec, method="envelope", + method_kwargs=dict(detect_threshold=5, freq_max=5.0), + ) + +if __name__ == "__main__": + test_detect_artifact_periods() diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py index 2baa4bf1b3..ad70540f40 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py @@ -2,15 +2,15 @@ import numpy as np -from spikeinterface.core import generate_recording -from spikeinterface.preprocessing import silence_artifacts +# from spikeinterface.core import generate_recording +# from spikeinterface.preprocessing import silence_artifacts -def test_silence_artifacts(): - # one segment only - rec = generate_recording(durations=[10.0, 10]) - new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) +# def test_silence_artifacts(): +# # one segment only +# rec = generate_recording(durations=[10.0, 10]) +# new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) -if __name__ == "__main__": - test_silence_artifacts() +# if __name__ == "__main__": +# test_silence_artifacts() From a71dade8b1e3a852fc8ca35df97960125dd7080c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 21 Jan 2026 14:11:49 +0100 Subject: [PATCH 02/17] Progagate new periods dtype to SilencedPeriodsRecording with backward compatibility --- src/spikeinterface/core/node_pipeline.py | 4 +- .../preprocessing/detect_artifacts.py | 79 ++++----- .../preprocessing/silence_artifacts.py | 96 ----------- .../preprocessing/silence_periods.py | 151 ++++++++++++++---- .../tests/test_detect_artifacts.py | 3 +- .../tests/test_silence_artifacts.py | 16 -- ...est_silence.py => test_silence_periods.py} | 20 ++- 7 files changed, 166 insertions(+), 203 deletions(-) delete mode 100644 src/spikeinterface/preprocessing/silence_artifacts.py delete mode 100644 src/spikeinterface/preprocessing/tests/test_silence_artifacts.py rename src/spikeinterface/preprocessing/tests/{test_silence.py => test_silence_periods.py} (76%) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index f1efe7a035..10e4885606 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,7 +489,9 @@ def check_graph(nodes): Check that node list is orderd in a good (parents are before children) """ - node0 = nodes[0] + # Do not remove this, this is to remenber that in previous version the first node needed to be + # a detectot but not anymore + # node0 = nodes[0] # if not isinstance(node0, PeakSource): # raise ValueError( # "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 2ff3f8a78f..2a1ea069f0 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -14,22 +14,17 @@ import numpy as np -# artifact_dtype = [ -# ("start_index", "int64"), -# ("stop_index", "int64"), -# ("segment_index", "int64"), -# ] artifact_dtype = base_period_dtype +# this will be extend with channel boundaries if needed # extended_artifact_dtype = artifact_dtype + [ # # TODO # ] - def detect_artifact_periods( recording, method="envelope", @@ -37,7 +32,11 @@ def detect_artifact_periods( job_kwargs=None, ): """ - + Detect artifacts with several possible methods: + * 'saturation' using detect_artifact_periods_by_envelope() + * 'envelope' using detect_saturation_periods() + + See sub methods for more information on parameters. """ if method_kwargs is None: @@ -48,7 +47,7 @@ def detect_artifact_periods( elif method == "saturation": artifact_periods = detect_saturation_periods(recording, **method_kwargs, job_kwargs=job_kwargs) else: - raise ValueError("") + raise ValueError(f"detect_artifact_periods() method='{method}' is not valid") return artifact_periods @@ -56,13 +55,10 @@ def detect_artifact_periods( ## detect_period_artifacts_saturation Zone - def _collapse_events(events): """ If events are detected at a chunk edge, they will be split in two. - This detects such cases and collapses them in a single record instead - :param events: - :return: + This detects such cases and collapses them in a single record instead. """ order = np.lexsort((events["start_sample_index"], events["segment_index"])) events = events[order] @@ -87,21 +83,24 @@ class _DetectSaturation(PipelineNode): def __init__( self, recording, - saturation_threshold_uV, # 1200 uV - voltage_per_sec_threshold, # 1e-8 V.s-1 + saturation_threshold_uV, + voltage_per_sec_threshold, proportion, - mute_window_samples, ): PipelineNode.__init__(self, recording, return_output=True) - self.gains = recording.get_channel_gains() - self.offsets = recording.get_channel_offsets() + gains = recording.get_channel_gains() + offsets = recording.get_channel_offsets() + num_chans = recording.get_num_channels() self.voltage_per_sec_threshold = voltage_per_sec_threshold - self.saturation_threshold_uV = saturation_threshold_uV + thresh = np.full((num_chans, ), saturation_threshold_uV) + # 0.98 is empirically determined as the true saturating point is + # slightly lower than the documented saturation point of the probe + self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 + self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion - self.mute_window_samples = mute_window_samples self._dtype = np.dtype(artifact_dtype) self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() @@ -114,16 +113,7 @@ def get_dtype(self): def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # @olivier @joe we can avoid this by making - traces = traces * self.gains[np.newaxis, :] + self.offsets[np.newaxis, :] - - - # first computes the saturated samples - max_voltage = np.atleast_1d(self.saturation_threshold_uV)[:, np.newaxis] - - # 0.98 is empirically determined as the true saturating point is - # slightly lower than the documented saturation point of the probe - saturation = np.mean(np.abs(traces) > max_voltage * 0.98, axis=1) + saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) if self.voltage_per_sec_threshold is not None: fs = self.sampling_frequency @@ -138,7 +128,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): else: saturation = saturation > self.proportion - intervals = np.where(np.diff(saturation, prepend=False, append=False))[0] + intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) n_events = len(intervals) // 2 # Number of saturation periods events = np.zeros(n_events, dtype=artifact_dtype) @@ -146,7 +136,6 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): events[i]["start_sample_index"] = start + start_frame events[i]["end_sample_index"] = stop + start_frame events[i]["segment_index"] = segment_index - # events[i]["method_id"] = "saturation_detection" return (events, ) @@ -154,9 +143,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def detect_saturation_periods( recording, saturation_threshold_uV, # 1200 uV - voltage_per_sec_threshold, # 1e-8 V.s-1 + voltage_per_sec_threshold=None, # 1e-8 V.s-1 proportion=0.5, - mute_window_samples=7, job_kwargs=None, ): """ @@ -174,7 +162,7 @@ def detect_saturation_periods( The recording on which to detect the saturation events. saturation_threshold_uV : float The voltage saturation threshold in volts. This will depend on the recording - probe and amplifier gain settings. For NP1 the value of 1200 * 1e-6 is recommended (IBL). + probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). Note that NP2 probes are more difficult to saturate than NP1. voltage_per_sec_threshold : None | float The first-derivative threshold in volts per second. Periods of the data over which the change @@ -207,10 +195,9 @@ def detect_saturation_periods( saturation_threshold_uV=saturation_threshold_uV, voltage_per_sec_threshold=voltage_per_sec_threshold, proportion=proportion, - mute_window_samples=mute_window_samples, ) - saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation events") + saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts") return _collapse_events(saturation_periods) @@ -218,13 +205,7 @@ def detect_saturation_periods( ## detect_artifact_periods_by_envelope Zone -# _internal_dtype = [ -# ("sample_index", "int64"), -# ("segment_index", "int64"), -# ("front", "bool") -# ] - -class DetectThresholdCrossing(PeakDetector): +class _DetectThresholdCrossing(PeakDetector): name = "threshold_crossings" preferred_mp_context = None @@ -243,6 +224,7 @@ def __init__( random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) self.abs_thresholds = noise_levels * detect_threshold + # internal dtype self._dtype = np.dtype([ ("sample_index", "int64"), ("segment_index", "int64"), @@ -278,7 +260,7 @@ def detect_artifact_periods_by_envelope( random_slices_kwargs=None, ): """ - Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of + Function to detect putative artifact periods as threshold crossings of a global envelope of the channels. Parameters @@ -300,8 +282,6 @@ def detect_artifact_periods_by_envelope( envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) envelope = CommonReferenceRecording(envelope) - - # _, job_kwargs = split_job_kwargs(noise_levels_kwargs) job_kwargs = fix_job_kwargs(job_kwargs) if random_slices_kwargs is None: random_slices_kwargs = {} @@ -310,7 +290,7 @@ def detect_artifact_periods_by_envelope( random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) - node0 = DetectThresholdCrossing( + node0 = _DetectThresholdCrossing( recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, ) @@ -318,7 +298,7 @@ def detect_artifact_periods_by_envelope( envelope, [node0], job_kwargs, - job_name="detect threshold crossings", + job_name="detect artifact on envelope", ) order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) @@ -326,12 +306,9 @@ def detect_artifact_periods_by_envelope( artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) - return artifacts, envelope -# tools - def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): num_seg = recording.get_num_segments() diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py deleted file mode 100644 index 241fe0f915..0000000000 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from spikeinterface.core.base import base_peak_dtype -from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.node_pipeline import PeakDetector -from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording -import numpy as np - - -# class SilencedArtifactsRecording(SilencedPeriodsRecording): -# """ -# Silence user-defined periods from recording extractor traces. The code will construct -# an enveloppe of the recording (as a low pass filtered version of the traces) and detect -# threshold crossings to identify the periods to silence. The periods are then silenced either -# on a per channel basis or across all channels by replacing the values by zeros or by -# adding gaussian noise with the same variance as the one in the recordings - -# Parameters -# ---------- -# recording : RecordingExtractor -# The recording extractor to silence putative artifacts -# artifacts : np.array, None -# The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` -# freq_max : float, default: 20 -# The maximum frequency for the low pass filter used -# min_duration_ms : float, default: 50 -# The minimum duration for a threshold crossing to be considered as an artefact. -# noise_levels : array -# Noise levels if already computed -# seed : int | None, default: None -# Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. -# If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. -# mode : "zeros" | "noise", default: "zeros" -# Determines what periods are replaced by. Can be one of the following: - -# - "zeros": Artifacts are replaced by zeros. - -# - "noise": The periods are filled with a gaussion noise that has the -# same variance that the one in the recordings, on a per channel -# basis -# **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - -# Returns -# ------- -# silenced_recording : SilencedArtifactsRecording -# The recording extractor after silencing detected artifacts -# """ - -# _precomputable_kwarg_names = ["artifacts"] - -# def __init__( -# self, -# recording, -# artifacts=None, -# detect_threshold=5, -# verbose=False, -# freq_max=20.0, -# min_duration_ms=50, -# mode="zeros", -# noise_levels=None, -# seed=None, -# list_periods=None, -# **noise_levels_kwargs, -# ): - -# if artifacts is None: -# from spikeinterface.preprocessing import detect_artifacts -# artifacts = detect_artifact_periods( -# recording, -# detect_threshold=detect_threshold, -# min_duration_ms=min_duration_ms, -# freq_max=freq_max, -# seed=seed, -# noise_levels=noise_levels, -# **noise_levels_kwargs, -# ) - -# if verbose: -# for i, periods in enumerate(artifacts): -# total_time = np.sum([end - start for start, end in periods]) -# percentage = 100 * total_time / recording.get_num_samples(i) -# print(f"{percentage}% of segment {i} has been flagged as artifactual") - -# SilencedPeriodsRecording.__init__( -# self, recording, artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs -# ) - - -# # function for API -# silence_artifacts = define_function_handling_dict_from_class( -# source_class=SilencedArtifactsRecording, name="silence_artifacts" -# ) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c9b6e2abe4..040e1275be 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -8,6 +8,8 @@ from spikeinterface.core import get_noise_levels from spikeinterface.core.generate import NoiseGeneratorRecording from spikeinterface.core.job_tools import split_job_kwargs +from spikeinterface.core.base import base_period_dtype + class SilencedPeriodsRecording(BasePreprocessor): @@ -48,7 +50,9 @@ class SilencedPeriodsRecording(BasePreprocessor): def __init__( self, recording, - list_periods, + periods=None, + # this is keep for backward compatibility + list_periods=None, mode="zeros", noise_levels=None, seed=None, @@ -56,25 +60,27 @@ def __init__( ): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() - if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: - # when unique segment accept list instead of list of list/arrays - list_periods = [list_periods] - # some checks - assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" - assert isinstance(list_periods, list), "'list_periods' must be a list (one per segment)" - assert len(list_periods) == num_seg, "'list_periods' must have the same length as the number of segments" - assert all( - isinstance(list_periods[i], (list, np.ndarray)) for i in range(num_seg) - ), "Each element of 'list_periods' must be array-like" + # handle backward compatibility with previous version + if list_periods is not None: + assert periods is None + periods = _all_period_list_to_periods_vec(list_periods, num_seg) + else: + assert list_periods is None + if not isinstance(periods, np.ndarray): + raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") + + if periods.dtype.fields is None: + # this is the old format : list[list[int]] + periods = _all_period_list_to_periods_vec(periods, num_seg) - for periods in list_periods: - if len(periods) > 0: - assert np.all(np.diff(np.array(periods), axis=1) > 0), "t_stops should be larger than t_starts" - assert np.all( - periods[i][1] < periods[i + 1][0] for i in np.arange(len(periods) - 1) - ), "Intervals should not overlap" + # force order + order = np.lexsort((periods["start_sample_index"], periods["segment_index"])) + periods = periods[order] + _check_periods(periods, num_seg) + + # some checks + assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" if mode in ["noise"]: if noise_levels is None: @@ -98,18 +104,57 @@ def __init__( noise_generator = None BasePreprocessor.__init__(self, recording) + + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) for seg_index, parent_segment in enumerate(recording._recording_segments): - periods = list_periods[seg_index] - periods = np.asarray(periods, dtype="int64") - periods = np.sort(periods, axis=0) - rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) + i0 = seg_limits[seg_index] + i1 = seg_limits[seg_index+1] + periods_in_seg = periods[i0:i1] + rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods_in_seg, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) self._kwargs = dict( - recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels + recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels ) +def _all_period_list_to_periods_vec(list_periods, num_seg): + if num_seg == 1: + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: + # when unique segment accept list instead of list of list/arrays + list_periods = [list_periods] + size = sum(len(p) for p in list_periods) + periods = np.zeros(size, dtype=base_period_dtype) + start = 0 + for i in range(num_seg): + periods_in_seg = list_periods[i] + stop = start + periods_in_seg.shape[0] + periods[start:stop]["segment_index"] = i + periods[start:stop]["start_sample_index"] = periods_in_seg[:, 0] + periods[start:stop]["end_sample_index"] = periods_in_seg[:, 1] + start = stop + return periods + +def _check_periods(periods, num_seg): + # check dtype + if any(col not in np.dtype(base_period_dtype).fields for col in periods.dtype.fields): + raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") + + # check non overlap and non negative + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) + for i in range(num_seg): + i0 = seg_limits[i] + i1 = seg_limits[i+1] + periods_in_seg = periods[i0:i1] + if periods_in_seg.size == 0: + continue + if len(periods) > 0: + if np.any(periods_in_seg["start_sample_index"] > periods_in_seg["end_sample_index"]): + raise ValueError("end_sample_index should be larger than start_sample_index") + if np.any(periods_in_seg["start_sample_index"][1:] < periods_in_seg["end_sample_index"][:-1]): + raise ValueError("Intervals should not overlap") + + class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -120,18 +165,20 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) - traces = traces.copy() + if self.periods.size > 0: new_interval = np.array([start_frame, end_frame]) - lower_index = np.searchsorted(self.periods[:, 1], new_interval[0]) - upper_index = np.searchsorted(self.periods[:, 0], new_interval[1]) + + lower_index = np.searchsorted(self.periods["end_sample_index"], new_interval[0]) + upper_index = np.searchsorted(self.periods["start_sample_index"], new_interval[1]) if upper_index > lower_index: - periods_in_interval = self.periods[lower_index:upper_index] + traces = traces.copy() + periods_in_interval = self.periods[lower_index:upper_index] for period in periods_in_interval: - onset = max(0, period[0] - start_frame) - offset = min(period[1] - start_frame, end_frame) + onset = max(0, period["start_sample_index"] - start_frame) + offset = min(period["end_sample_index"] - start_frame, end_frame) if self.mode == "zeros": traces[onset:offset, :] = 0 @@ -143,8 +190,52 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces - # function for API silence_periods = define_function_handling_dict_from_class( source_class=SilencedPeriodsRecording, name="silence_periods" ) + + + +class DetectArtifactAndSilentPeriodsRecording(SilencedPeriodsRecording): + """ + Class doing artifact detection and lient at the same time. + + See SilencedPeriodsRecording and detect_artifact_periods for details. + """ + + _precomputable_kwarg_names = ["artifacts"] + + def __init__( + self, + recording, + detect_artifact_method="envelope", + detect_artifact_kwargs=dict(), + periods=None, + mode="zeros", + noise_levels=None, + seed=None, + **noise_levels_kwargs, + ): + + if artifacts is None: + from spikeinterface.preprocessing import detect_artifact_periods + artifacts = detect_artifact_periods( + recording, + method=detect_artifact_method, + method_kwargs=detect_artifact_kwargs, + job_kwargs=None, + ) + + SilencedPeriodsRecording.__init__( + self, recording, periods=artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs + ) + # note self._kwargs["periods"] is done by SilencedPeriodsRecording and so the computaion is done once + + + +# function for API +detect_artifacts_and_silent_periods = define_function_handling_dict_from_class( + source_class=DetectArtifactAndSilentPeriodsRecording, name="silence_artifacts" +) + diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index cfb32254f1..50003487d0 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -26,7 +26,8 @@ def test_detect_saturation_periods(): # cross a chunk boundary. Do not change without changing the below. sat_value = 1200 - data = np.random.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 * 1e-6 + rng = np.random.default_rng() + data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 * 1e-6 # Design the Butterworth filter sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos") diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py deleted file mode 100644 index ad70540f40..0000000000 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest - -import numpy as np - -# from spikeinterface.core import generate_recording -# from spikeinterface.preprocessing import silence_artifacts - - -# def test_silence_artifacts(): -# # one segment only -# rec = generate_recording(durations=[10.0, 10]) -# new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) - - -# if __name__ == "__main__": -# test_silence_artifacts() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py similarity index 76% rename from src/spikeinterface/preprocessing/tests/test_silence.py rename to src/spikeinterface/preprocessing/tests/test_silence_periods.py index e7aee1a84d..ffba9059a0 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -1,11 +1,12 @@ import pytest from spikeinterface.core import generate_recording - +from spikeinterface.core import get_noise_levels +from spikeinterface.core.base import base_period_dtype from spikeinterface.preprocessing import silence_periods -from spikeinterface.core import get_noise_levels + import numpy as np @@ -18,17 +19,20 @@ def test_silence(create_cache_folder): rec = generate_recording() - rec0 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="zeros", seed=2308) - rec0.save(verbose=False) + periods = np.array([(0, 0, 1000), (0, 5000, 6000)], dtype=base_period_dtype) + rec0 = silence_periods(rec, periods=periods, mode="zeros", seed=2308) + rec0.save(format="memory", verbose=False) traces_in0 = rec0.get_traces(segment_index=0, start_frame=0, end_frame=1000) - traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000) - traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000) assert np.all(traces_in0 == 0) + traces_half0 = rec0.get_traces(segment_index=0, start_frame=900, end_frame=1100) + assert np.all(traces_half0[:100] == 0) + traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000) assert np.all(traces_in1 == 0) + traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000) assert not np.all(traces_out0 == 0) - rec1 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="noise", seed=2308) - rec1 = rec1.save(folder=cache_folder / "rec_w_noise", verbose=False, overwrite=True) + rec1 = silence_periods(rec, periods=periods, mode="noise", seed=2308) + rec1 = rec1.save(format="memory", verbose=False, overwrite=True) noise_levels = get_noise_levels(rec, return_in_uV=False) traces_in0 = rec1.get_traces(segment_index=0, start_frame=0, end_frame=1000) traces_in1 = rec1.get_traces(segment_index=0, start_frame=5000, end_frame=6000) From 0d709e69b09b4b9e02315e5e55ee79623d8df14d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 21 Jan 2026 14:16:15 +0100 Subject: [PATCH 03/17] oups --- .../preprocessing/tests/test_detect_artifacts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 50003487d0..b5d9a18a9b 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -1,7 +1,7 @@ from spikeinterface.core import generate_recording, NumpyRecording from spikeinterface.preprocessing import detect_artifact_periods, detect_saturation_periods import numpy as np -import scipy.signal + def test_detect_artifact_periods(): # one segment only @@ -13,6 +13,9 @@ def test_detect_artifact_periods(): def test_detect_saturation_periods(): + + import scipy.signal + """ TODO: NOTE: we have one sample before the saturation starts as we take the forward derivative for the velocity we have an extra sample after due to taking the diff on the final saturation mask From 7ab75c033e67710c85147d13112240aab0001f98 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 21 Jan 2026 19:21:42 +0100 Subject: [PATCH 04/17] oups --- src/spikeinterface/preprocessing/detect_artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 2a1ea069f0..3e42facdc5 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -291,7 +291,7 @@ def detect_artifact_periods_by_envelope( noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) node0 = _DetectThresholdCrossing( - recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, + envelope, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, ) threshold_crossings = run_node_pipeline( From 7c7446ea076dd227f6888136ebb9f5ab3bd5af8f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 29 Jan 2026 15:40:42 +0000 Subject: [PATCH 05/17] Applying extras from other PR, adding voltage_per_sec_threshold. --- .../preprocessing/detect_artifacts.py | 34 +++++++++----- .../tests/test_detect_artifacts.py | 47 ++++++++++--------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 3e42facdc5..8720df03d4 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -75,7 +75,12 @@ def _collapse_events(events): class _DetectSaturation(PipelineNode): + """ + A recording node for parallelising saturation detection. + Run with `run_node_pipeline`, this computes saturation events + for a given chunk. See `detect_saturation()` for details. + """ name = "detect_saturation" preferred_mp_context = None _compute_has_extended_signature = True @@ -98,7 +103,8 @@ def __init__( # 0.98 is empirically determined as the true saturating point is # slightly lower than the documented saturation point of the probe self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 - + self.voltage_per_sec_threshold = (voltage_per_sec_threshold - offsets) / gains + self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion self._dtype = np.dtype(artifact_dtype) @@ -112,7 +118,10 @@ def get_dtype(self): return self._dtype def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - + """ + Compute saturation events for a given chunk of data. + See `detect_saturation()` for details. + """ saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) if self.voltage_per_sec_threshold is not None: @@ -144,7 +153,7 @@ def detect_saturation_periods( recording, saturation_threshold_uV, # 1200 uV voltage_per_sec_threshold=None, # 1e-8 V.s-1 - proportion=0.5, + proportion=0.2, job_kwargs=None, ): """ @@ -170,24 +179,25 @@ def detect_saturation_periods( skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be empirically determined (IBL use 1e-8 V.s-1) for NP1 probes. - proportion : - mute_window_samples : - job_kwargs : + proportion : float + 0 < proportion <1 of channels above threshold to consider the sample as saturated + mute_window_samples : int + TODO: should we scale this based on the fs? + job_kwargs: dict + The classical job_kwargs most useful for NP1 can use ratio as a intuition for the value but dont do it in code Returns - ------- - +------- + collapsed_events : np.recarray + A numpy recarray holding information on each saturation event. Has the fields: + "start_sample_index", "stop_sample_index", "segment_index", "method_id" """ if job_kwargs: job_kwargs = {} - # if saturation_threshold_uV < 0.1: - # raise ValueError(f"The `saturation_threshold_uV` should be in microvolts. " - # f"Your value: {saturation_threshold_uV} is almost certainly in volts.") - job_kwargs = fix_job_kwargs(job_kwargs) node0 = _DetectSaturation( diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index b5d9a18a9b..9c923a5a7b 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -13,24 +13,32 @@ def test_detect_artifact_periods(): def test_detect_saturation_periods(): - - import scipy.signal - """ - TODO: NOTE: we have one sample before the saturation starts as we take the forward derivative for the velocity - we have an extra sample after due to taking the diff on the final saturation mask - this means we always take one sample before and one sample after the diff period, which is fine. + This tests the saturation detection method. First a mock recording is created with + saturation events. Events may be single-sample or a multi-sample period. We create a multi-segment + recording with the stop-sample of each event offset by one, so the segments are distinguishable. + + Saturation detection is performed on chunked data (we set to 30k sample chunks) and so injected + events are hard-coded in order to cross a chunk boundary to test this case. + + The saturation detection function tests both a) saturation threshold exceeded + and b) first derivative (velocity) threshold exceeded. Because the forward + derivative is taken, the sample before the first saturated sample is also flagged. + Also, because of the way the mask is computed in the function, the sample after the + last saturated sample is flagged. """ - # num_chans = 384 + import scipy.signal + num_chans = 32 sample_frequency = 30000 chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below job_kwargs = {"chunk_size": chunk_size} - # cross a chunk boundary. Do not change without changing the below. + # Generate some data in uV sat_value = 1200 + voltage_per_sec_threshold = 12 / sample_frequency rng = np.random.default_rng() - data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 * 1e-6 + data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 # Design the Butterworth filter sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos") @@ -57,7 +65,7 @@ def test_detect_saturation_periods(): # this center the int16 around 0 and saturate on positive max_ = np.max(np.r_[data_seg_1.flatten(), data_seg_2.flatten()]) gain = max_ / 2**15 - offset = 0 + offset = 50 seg_1_int16 = np.clip( np.rint((data_seg_1 - offset) / gain), @@ -68,17 +76,12 @@ def test_detect_saturation_periods(): -32768, 32767 ).astype(np.int16) - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # ax.plot(seg_1_int16[:, 0]) - # plt.show() - recording = NumpyRecording([seg_1_int16, seg_2_int16], sample_frequency) recording.set_channel_gains(gain) recording.set_channel_offsets([offset] * num_chans) periods = detect_saturation_periods( - recording, saturation_threshold_uV=sat_value * 0.98, voltage_per_sec_threshold=1e-8, job_kwargs=job_kwargs + recording, saturation_threshold_uV=sat_value * 0.98, voltage_per_sec_threshold=voltage_per_sec_threshold, job_kwargs=job_kwargs ) seg_1_periods = periods[np.where(periods["segment_index"] == 0)] @@ -102,23 +105,23 @@ def test_detect_saturation_periods(): periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * (1 / 0.98), - voltage_per_sec_threshold=1e-8, + voltage_per_sec_threshold=voltage_per_sec_threshold, job_kwargs=job_kwargs, ) assert periods["start_sample_index"][0] == 1000 assert periods["end_sample_index"][0] == 1001 - periods = detect_artifact_periods( + periods_entry_function = detect_artifact_periods( recording, method="saturation", method_kwargs=dict( saturation_threshold_uV=sat_value * (1 / 0.98), - voltage_per_sec_threshold=1e-8, + voltage_per_sec_threshold=voltage_per_sec_threshold, ), - job_kwargs=job_kwargs, - ) - + job_kwargs=job_kwargs, + ) + assert np.array_equal(periods, periods_entry_function) if __name__ == "__main__": test_detect_artifact_periods() From 86e6924cd37c44327895bd16c8a9232937cc918b Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 29 Jan 2026 15:41:16 +0000 Subject: [PATCH 06/17] Rename uV_per_sec_threshold. --- .../preprocessing/detect_artifacts.py | 16 ++++++++-------- .../preprocessing/tests/test_detect_artifacts.py | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 8720df03d4..6acad37901 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -89,7 +89,7 @@ def __init__( self, recording, saturation_threshold_uV, - voltage_per_sec_threshold, + uV_per_sec_threshold, proportion, ): PipelineNode.__init__(self, recording, return_output=True) @@ -98,12 +98,12 @@ def __init__( offsets = recording.get_channel_offsets() num_chans = recording.get_num_channels() - self.voltage_per_sec_threshold = voltage_per_sec_threshold + self.uV_per_sec_threshold = uV_per_sec_threshold thresh = np.full((num_chans, ), saturation_threshold_uV) # 0.98 is empirically determined as the true saturating point is # slightly lower than the documented saturation point of the probe self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 - self.voltage_per_sec_threshold = (voltage_per_sec_threshold - offsets) / gains + self.uV_per_sec_threshold = (uV_per_sec_threshold - offsets) / gains self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion @@ -124,10 +124,10 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): """ saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) - if self.voltage_per_sec_threshold is not None: + if self.uV_per_sec_threshold is not None: fs = self.sampling_frequency # then compute the derivative of the voltage saturation - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.voltage_per_sec_threshold, axis=1) + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.uV_per_sec_threshold, axis=1) # Note this means the velocity is not checked for the last sample in the # check because we are taking the forward derivative n_diff_saturated = np.r_[n_diff_saturated, 0] @@ -152,7 +152,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def detect_saturation_periods( recording, saturation_threshold_uV, # 1200 uV - voltage_per_sec_threshold=None, # 1e-8 V.s-1 + uV_per_sec_threshold=None, # 1e-8 V.s-1 proportion=0.2, job_kwargs=None, ): @@ -173,7 +173,7 @@ def detect_saturation_periods( The voltage saturation threshold in volts. This will depend on the recording probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). Note that NP2 probes are more difficult to saturate than NP1. - voltage_per_sec_threshold : None | float + uV_per_sec_threshold : None | float The first-derivative threshold in volts per second. Periods of the data over which the change in velocity is greater than this threshold will be detected as saturation events. Use `None` to skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be @@ -203,7 +203,7 @@ def detect_saturation_periods( node0 = _DetectSaturation( recording, saturation_threshold_uV=saturation_threshold_uV, - voltage_per_sec_threshold=voltage_per_sec_threshold, + uV_per_sec_threshold=uV_per_sec_threshold, proportion=proportion, ) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 9c923a5a7b..d968382421 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -36,7 +36,7 @@ def test_detect_saturation_periods(): # Generate some data in uV sat_value = 1200 - voltage_per_sec_threshold = 12 / sample_frequency + uV_per_sec_threshold = 12 / sample_frequency rng = np.random.default_rng() data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 @@ -81,7 +81,7 @@ def test_detect_saturation_periods(): recording.set_channel_offsets([offset] * num_chans) periods = detect_saturation_periods( - recording, saturation_threshold_uV=sat_value * 0.98, voltage_per_sec_threshold=voltage_per_sec_threshold, job_kwargs=job_kwargs + recording, saturation_threshold_uV=sat_value * 0.98, uV_per_sec_threshold=uV_per_sec_threshold, job_kwargs=job_kwargs ) seg_1_periods = periods[np.where(periods["segment_index"] == 0)] @@ -105,7 +105,7 @@ def test_detect_saturation_periods(): periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * (1 / 0.98), - voltage_per_sec_threshold=voltage_per_sec_threshold, + uV_per_sec_threshold=uV_per_sec_threshold, job_kwargs=job_kwargs, ) assert periods["start_sample_index"][0] == 1000 @@ -116,7 +116,7 @@ def test_detect_saturation_periods(): method="saturation", method_kwargs=dict( saturation_threshold_uV=sat_value * (1 / 0.98), - voltage_per_sec_threshold=voltage_per_sec_threshold, + uV_per_sec_threshold=uV_per_sec_threshold, ), job_kwargs=job_kwargs, ) From 8b17890113c86336edba0b1d9ab175b0a5b4c204 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:42:17 +0000 Subject: [PATCH 07/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 2 +- src/spikeinterface/preprocessing/__init__.py | 6 +- .../preprocessing/detect_artifacts.py | 123 +++++++++--------- .../preprocessing/preprocessing_classes.py | 1 + .../preprocessing/silence_periods.py | 27 ++-- .../tests/test_detect_artifacts.py | 39 +++--- .../tests/test_silence_periods.py | 2 - 7 files changed, 97 insertions(+), 103 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 10e4885606..a78082bc74 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,7 +489,7 @@ def check_graph(nodes): Check that node list is orderd in a good (parents are before children) """ - # Do not remove this, this is to remenber that in previous version the first node needed to be + # Do not remove this, this is to remenber that in previous version the first node needed to be # a detectot but not anymore # node0 = nodes[0] # if not isinstance(node0, PeakSource): diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index ab1adb6942..fd8d8fd787 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -20,11 +20,7 @@ PreprocessingPipeline, ) -from .detect_artifacts import ( - detect_artifact_periods, - detect_artifact_periods_by_envelope, - detect_saturation_periods -) +from .detect_artifacts import detect_artifact_periods, detect_artifact_periods_by_envelope, detect_saturation_periods # for snippets from .align_snippets import AlignSnippets diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 6acad37901..323a73f734 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -3,6 +3,7 @@ import numpy as np from spikeinterface.core.base import base_period_dtype + # from spikeinterface.core.core_tools import define_function_handling_dict_from_class # from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording from spikeinterface.preprocessing.rectify import RectifyRecording @@ -13,8 +14,6 @@ from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode import numpy as np - - artifact_dtype = base_period_dtype @@ -24,7 +23,6 @@ # ] - def detect_artifact_periods( recording, method="envelope", @@ -35,7 +33,7 @@ def detect_artifact_periods( Detect artifacts with several possible methods: * 'saturation' using detect_artifact_periods_by_envelope() * 'envelope' using detect_saturation_periods() - + See sub methods for more information on parameters. """ @@ -43,18 +41,20 @@ def detect_artifact_periods( method_kwargs = dict() if method == "envelope": - artifact_periods, envelope = detect_artifact_periods_by_envelope(recording, **method_kwargs, job_kwargs=job_kwargs) + artifact_periods, envelope = detect_artifact_periods_by_envelope( + recording, **method_kwargs, job_kwargs=job_kwargs + ) elif method == "saturation": artifact_periods = detect_saturation_periods(recording, **method_kwargs, job_kwargs=job_kwargs) else: raise ValueError(f"detect_artifact_periods() method='{method}' is not valid") - - return artifact_periods + return artifact_periods ## detect_period_artifacts_saturation Zone + def _collapse_events(events): """ If events are detected at a chunk edge, they will be split in two. @@ -81,6 +81,7 @@ class _DetectSaturation(PipelineNode): Run with `run_node_pipeline`, this computes saturation events for a given chunk. See `detect_saturation()` for details. """ + name = "detect_saturation" preferred_mp_context = None _compute_has_extended_signature = True @@ -99,16 +100,16 @@ def __init__( num_chans = recording.get_num_channels() self.uV_per_sec_threshold = uV_per_sec_threshold - thresh = np.full((num_chans, ), saturation_threshold_uV) + thresh = np.full((num_chans,), saturation_threshold_uV) # 0.98 is empirically determined as the true saturating point is - # slightly lower than the documented saturation point of the probe + # slightly lower than the documented saturation point of the probe self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 self.uV_per_sec_threshold = (uV_per_sec_threshold - offsets) / gains self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion self._dtype = np.dtype(artifact_dtype) - self.gain = recording.get_channel_gains() + self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() def get_trace_margin(self): @@ -146,7 +147,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): events[i]["end_sample_index"] = stop + start_frame events[i]["segment_index"] = segment_index - return (events, ) + return (events,) def detect_saturation_periods( @@ -157,43 +158,43 @@ def detect_saturation_periods( job_kwargs=None, ): """ - Detect amplifier saturation events (either single sample or multi-sample periods) in the data. - Saturation detection with this function should be applied to the raw data, before preprocessing. - However, saturation periods detected should be zeroed out after preprocessing has been performed. - - Saturation is detected by a voltage threshold, and optionally a derivative threshold that - flags periods of high velocity changes in the voltage. See _DetectSaturation.compute() - for details on the algorithm. - - Parameters - ---------- - recording : BaseRecording - The recording on which to detect the saturation events. - saturation_threshold_uV : float - The voltage saturation threshold in volts. This will depend on the recording - probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). - Note that NP2 probes are more difficult to saturate than NP1. - uV_per_sec_threshold : None | float - The first-derivative threshold in volts per second. Periods of the data over which the change - in velocity is greater than this threshold will be detected as saturation events. Use `None` to - skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be - empirically determined (IBL use 1e-8 V.s-1) for NP1 probes. - - proportion : float - 0 < proportion <1 of channels above threshold to consider the sample as saturated - mute_window_samples : int - TODO: should we scale this based on the fs? - job_kwargs: dict - The classical job_kwargs - - most useful for NP1 - can use ratio as a intuition for the value but dont do it in code - - Returns -------- - collapsed_events : np.recarray - A numpy recarray holding information on each saturation event. Has the fields: - "start_sample_index", "stop_sample_index", "segment_index", "method_id" + Detect amplifier saturation events (either single sample or multi-sample periods) in the data. + Saturation detection with this function should be applied to the raw data, before preprocessing. + However, saturation periods detected should be zeroed out after preprocessing has been performed. + + Saturation is detected by a voltage threshold, and optionally a derivative threshold that + flags periods of high velocity changes in the voltage. See _DetectSaturation.compute() + for details on the algorithm. + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect the saturation events. + saturation_threshold_uV : float + The voltage saturation threshold in volts. This will depend on the recording + probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). + Note that NP2 probes are more difficult to saturate than NP1. + uV_per_sec_threshold : None | float + The first-derivative threshold in volts per second. Periods of the data over which the change + in velocity is greater than this threshold will be detected as saturation events. Use `None` to + skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be + empirically determined (IBL use 1e-8 V.s-1) for NP1 probes. + + proportion : float + 0 < proportion <1 of channels above threshold to consider the sample as saturated + mute_window_samples : int + TODO: should we scale this based on the fs? + job_kwargs: dict + The classical job_kwargs + + most useful for NP1 + can use ratio as a intuition for the value but dont do it in code + + Returns + ------- + collapsed_events : np.recarray + A numpy recarray holding information on each saturation event. Has the fields: + "start_sample_index", "stop_sample_index", "segment_index", "method_id" """ if job_kwargs: job_kwargs = {} @@ -207,14 +208,16 @@ def detect_saturation_periods( proportion=proportion, ) - saturation_periods = run_node_pipeline(recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts") + saturation_periods = run_node_pipeline( + recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts" + ) return _collapse_events(saturation_periods) - ## detect_artifact_periods_by_envelope Zone + class _DetectThresholdCrossing(PeakDetector): name = "threshold_crossings" @@ -235,12 +238,7 @@ def __init__( noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) self.abs_thresholds = noise_levels * detect_threshold # internal dtype - self._dtype = np.dtype([ - ("sample_index", "int64"), - ("segment_index", "int64"), - ("front", "bool") - ] - ) + self._dtype = np.dtype([("sample_index", "int64"), ("segment_index", "int64"), ("front", "bool")]) def get_trace_margin(self): return 0 @@ -301,7 +299,10 @@ def detect_artifact_periods_by_envelope( noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) node0 = _DetectThresholdCrossing( - envelope, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, + envelope, + detect_threshold=detect_threshold, + noise_levels=noise_levels, + seed=seed, ) threshold_crossings = run_node_pipeline( @@ -338,15 +339,15 @@ def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): local_thr["sample_index"] = recording.get_num_samples(seg_index) local_thr["front"] = False sub_thr = np.hstack((sub_thr, local_thr)) - - local_artifact = np.zeros(sub_thr.size/2, dtype=artifact_dtype) + + local_artifact = np.zeros(sub_thr.size / 2, dtype=artifact_dtype) local_artifact["start_index"] = sub_thr["sample_index"][::2] local_artifact["stop_index"] = sub_thr["sample_index"][1::2] local_artifact["segment_index"] = seg_index final_artifacts.append(local_artifact) - + if len(final_artifacts) > 0: final_artifacts = np.concatenate(final_artifacts) else: final_artifacts = np.zeros(0, dtype=artifact_dtype) - return final_artifacts \ No newline at end of file + return final_artifacts diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index 47839db7a0..ff07b5b3c6 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -50,6 +50,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed + # from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts _all_preprocesser_dict = { diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 040e1275be..a9a1ac06d3 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -11,7 +11,6 @@ from spikeinterface.core.base import base_period_dtype - class SilencedPeriodsRecording(BasePreprocessor): """ Silence user-defined periods from recording extractor traces. By default, @@ -104,18 +103,18 @@ def __init__( noise_generator = None BasePreprocessor.__init__(self, recording) - + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) for seg_index, parent_segment in enumerate(recording._recording_segments): i0 = seg_limits[seg_index] - i1 = seg_limits[seg_index+1] + i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] - rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods_in_seg, mode, noise_generator, seg_index) + rec_segment = SilencedPeriodsRecordingSegment( + parent_segment, periods_in_seg, mode, noise_generator, seg_index + ) self.add_recording_segment(rec_segment) - self._kwargs = dict( - recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels - ) + self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels) def _all_period_list_to_periods_vec(list_periods, num_seg): @@ -135,6 +134,7 @@ def _all_period_list_to_periods_vec(list_periods, num_seg): start = stop return periods + def _check_periods(periods, num_seg): # check dtype if any(col not in np.dtype(base_period_dtype).fields for col in periods.dtype.fields): @@ -144,14 +144,14 @@ def _check_periods(periods, num_seg): seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) for i in range(num_seg): i0 = seg_limits[i] - i1 = seg_limits[i+1] + i1 = seg_limits[i + 1] periods_in_seg = periods[i0:i1] if periods_in_seg.size == 0: continue if len(periods) > 0: if np.any(periods_in_seg["start_sample_index"] > periods_in_seg["end_sample_index"]): raise ValueError("end_sample_index should be larger than start_sample_index") - if np.any(periods_in_seg["start_sample_index"][1:] < periods_in_seg["end_sample_index"][:-1]): + if np.any(periods_in_seg["start_sample_index"][1:] < periods_in_seg["end_sample_index"][:-1]): raise ValueError("Intervals should not overlap") @@ -165,10 +165,10 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) - + if self.periods.size > 0: new_interval = np.array([start_frame, end_frame]) - + lower_index = np.searchsorted(self.periods["end_sample_index"], new_interval[0]) upper_index = np.searchsorted(self.periods["start_sample_index"], new_interval[1]) @@ -190,13 +190,13 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces + # function for API silence_periods = define_function_handling_dict_from_class( source_class=SilencedPeriodsRecording, name="silence_periods" ) - class DetectArtifactAndSilentPeriodsRecording(SilencedPeriodsRecording): """ Class doing artifact detection and lient at the same time. @@ -220,6 +220,7 @@ def __init__( if artifacts is None: from spikeinterface.preprocessing import detect_artifact_periods + artifacts = detect_artifact_periods( recording, method=detect_artifact_method, @@ -233,9 +234,7 @@ def __init__( # note self._kwargs["periods"] is done by SilencedPeriodsRecording and so the computaion is done once - # function for API detect_artifacts_and_silent_periods = define_function_handling_dict_from_class( source_class=DetectArtifactAndSilentPeriodsRecording, name="silence_artifacts" ) - diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index d968382421..c812f6a83c 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -6,10 +6,11 @@ def test_detect_artifact_periods(): # one segment only rec = generate_recording(durations=[10.0, 10]) - artifacts = detect_artifact_periods(rec, method="envelope", - method_kwargs=dict(detect_threshold=5, freq_max=5.0), - ) - + artifacts = detect_artifact_periods( + rec, + method="envelope", + method_kwargs=dict(detect_threshold=5, freq_max=5.0), + ) def test_detect_saturation_periods(): @@ -53,35 +54,32 @@ def test_detect_saturation_periods(): # exactly on the border, as it makes testing complex # This was checked manually and any future breaking change # on this function would be extremely unlikely only to break this case. - all_starts = np.array([0, 29950, 45123, 90005, 149500]) - all_stops = np.array([1001, 30011, 45126, 90006, 149999]) + all_starts = np.array([0, 29950, 45123, 90005, 149500]) + all_stops = np.array([1001, 30011, 45126, 90006, 149999]) second_seg_offset = 1 for start, stop in zip(all_starts, all_stops): - data_seg_1[start : stop, :] = sat_value + data_seg_1[start:stop, :] = sat_value # differentiate the second segment for testing purposes data_seg_2[start : stop + second_seg_offset, :] = sat_value # this center the int16 around 0 and saturate on positive max_ = np.max(np.r_[data_seg_1.flatten(), data_seg_2.flatten()]) - gain = max_ / 2**15 + gain = max_ / 2**15 offset = 50 - seg_1_int16 = np.clip( - np.rint((data_seg_1 - offset) / gain), - -32768, 32767 - ).astype(np.int16) - seg_2_int16 = np.clip( - np.rint((data_seg_2 - offset) / gain), - -32768, 32767 - ).astype(np.int16) + seg_1_int16 = np.clip(np.rint((data_seg_1 - offset) / gain), -32768, 32767).astype(np.int16) + seg_2_int16 = np.clip(np.rint((data_seg_2 - offset) / gain), -32768, 32767).astype(np.int16) recording = NumpyRecording([seg_1_int16, seg_2_int16], sample_frequency) recording.set_channel_gains(gain) recording.set_channel_offsets([offset] * num_chans) periods = detect_saturation_periods( - recording, saturation_threshold_uV=sat_value * 0.98, uV_per_sec_threshold=uV_per_sec_threshold, job_kwargs=job_kwargs + recording, + saturation_threshold_uV=sat_value * 0.98, + uV_per_sec_threshold=uV_per_sec_threshold, + job_kwargs=job_kwargs, ) seg_1_periods = periods[np.where(periods["segment_index"] == 0)] @@ -115,14 +113,15 @@ def test_detect_saturation_periods(): recording, method="saturation", method_kwargs=dict( - saturation_threshold_uV=sat_value * (1 / 0.98), - uV_per_sec_threshold=uV_per_sec_threshold, - ), + saturation_threshold_uV=sat_value * (1 / 0.98), + uV_per_sec_threshold=uV_per_sec_threshold, + ), job_kwargs=job_kwargs, ) assert np.array_equal(periods, periods_entry_function) + if __name__ == "__main__": test_detect_artifact_periods() test_detect_saturation_periods() diff --git a/src/spikeinterface/preprocessing/tests/test_silence_periods.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py index ffba9059a0..44bd205f1b 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_periods.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -6,8 +6,6 @@ from spikeinterface.preprocessing import silence_periods - - import numpy as np from pathlib import Path From 4d507adef33859a7fd2d8d06003764f0f1121821 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 6 Feb 2026 16:01:52 +0000 Subject: [PATCH 08/17] Updates testing converting to int. --- .../preprocessing/detect_artifacts.py | 26 +++++++------ .../tests/test_detect_artifacts.py | 37 +++++++++++++++---- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 323a73f734..695932d27c 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -90,28 +90,28 @@ def __init__( self, recording, saturation_threshold_uV, - uV_per_sec_threshold, + uV_per_ms_threshold, proportion, ): PipelineNode.__init__(self, recording, return_output=True) - gains = recording.get_channel_gains() - offsets = recording.get_channel_offsets() num_chans = recording.get_num_channels() - self.uV_per_sec_threshold = uV_per_sec_threshold + self.uV_per_ms_threshold = uV_per_ms_threshold thresh = np.full((num_chans,), saturation_threshold_uV) # 0.98 is empirically determined as the true saturating point is # slightly lower than the documented saturation point of the probe - self.saturation_threshold_unscaled = (thresh - offsets) / gains * 0.98 - self.uV_per_sec_threshold = (uV_per_sec_threshold - offsets) / gains - self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion self._dtype = np.dtype(artifact_dtype) self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() + self.saturation_threshold_unscaled = (thresh - self.offset) / self.gain * 0.98 + + # do not apply offset when dealing with the derivative + self.uV_per_ms_threshold = (uV_per_ms_threshold * self.sampling_frequency / 1e3) / self.gain + def get_trace_margin(self): return 0 @@ -125,10 +125,12 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): """ saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) - if self.uV_per_sec_threshold is not None: + if self.uV_per_ms_threshold is not None: fs = self.sampling_frequency # then compute the derivative of the voltage saturation - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) / fs >= self.uV_per_sec_threshold, axis=1) + + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.uV_per_ms_threshold, axis=1) + # Note this means the velocity is not checked for the last sample in the # check because we are taking the forward derivative n_diff_saturated = np.r_[n_diff_saturated, 0] @@ -153,7 +155,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def detect_saturation_periods( recording, saturation_threshold_uV, # 1200 uV - uV_per_sec_threshold=None, # 1e-8 V.s-1 + uV_per_ms_threshold=None, # 1e-8 V.s-1 proportion=0.2, job_kwargs=None, ): @@ -174,7 +176,7 @@ def detect_saturation_periods( The voltage saturation threshold in volts. This will depend on the recording probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). Note that NP2 probes are more difficult to saturate than NP1. - uV_per_sec_threshold : None | float + uV_per_ms_threshold : None | float The first-derivative threshold in volts per second. Periods of the data over which the change in velocity is greater than this threshold will be detected as saturation events. Use `None` to skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be @@ -204,7 +206,7 @@ def detect_saturation_periods( node0 = _DetectSaturation( recording, saturation_threshold_uV=saturation_threshold_uV, - uV_per_sec_threshold=uV_per_sec_threshold, + uV_per_ms_threshold=uV_per_ms_threshold, proportion=proportion, ) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index c812f6a83c..aa4e8876e9 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -36,8 +36,8 @@ def test_detect_saturation_periods(): job_kwargs = {"chunk_size": chunk_size} # Generate some data in uV - sat_value = 1200 - uV_per_sec_threshold = 12 / sample_frequency + sat_value = 12 + uV_per_ms_threshold = 12 / sample_frequency / 1e3 rng = np.random.default_rng() data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 @@ -64,13 +64,34 @@ def test_detect_saturation_periods(): data_seg_2[start : stop + second_seg_offset, :] = sat_value # this center the int16 around 0 and saturate on positive - max_ = np.max(np.r_[data_seg_1.flatten(), data_seg_2.flatten()]) - gain = max_ / 2**15 - offset = 50 + combined = np.r_[data_seg_1.flatten(), data_seg_2.flatten()] + max_ = np.max(combined) + # min_ = np.min(combined) + gain = max_ / 2**15 # (max_ - min_) / 65535 + offset = 0 # min_ + 32768 * gain + + PLOT = True + if PLOT: + import matplotlib + import matplotlib.pyplot as plt + plt.plot(data_seg_1) + plt.title("data float") + plt.show() + plt.plot(np.diff(data_seg_1, axis=0)) + plt.title("diff float") + plt.show() seg_1_int16 = np.clip(np.rint((data_seg_1 - offset) / gain), -32768, 32767).astype(np.int16) seg_2_int16 = np.clip(np.rint((data_seg_2 - offset) / gain), -32768, 32767).astype(np.int16) + if PLOT: + plt.plot(seg_1_int16) + plt.title("data int") + plt.show() + plt.plot(np.diff(seg_1_int16, axis=0)) + plt.title("diff int") + plt.show() + recording = NumpyRecording([seg_1_int16, seg_2_int16], sample_frequency) recording.set_channel_gains(gain) recording.set_channel_offsets([offset] * num_chans) @@ -78,7 +99,7 @@ def test_detect_saturation_periods(): periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * 0.98, - uV_per_sec_threshold=uV_per_sec_threshold, + uV_per_ms_threshold=uV_per_ms_threshold, job_kwargs=job_kwargs, ) @@ -103,7 +124,7 @@ def test_detect_saturation_periods(): periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * (1 / 0.98), - uV_per_sec_threshold=uV_per_sec_threshold, + uV_per_ms_threshold=uV_per_ms_threshold, job_kwargs=job_kwargs, ) assert periods["start_sample_index"][0] == 1000 @@ -114,7 +135,7 @@ def test_detect_saturation_periods(): method="saturation", method_kwargs=dict( saturation_threshold_uV=sat_value * (1 / 0.98), - uV_per_sec_threshold=uV_per_sec_threshold, + uV_per_ms_threshold=uV_per_ms_threshold, ), job_kwargs=job_kwargs, ) From 654616661b00dc69db96f5d305aa8289145df278 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 16:02:26 +0000 Subject: [PATCH 09/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/tests/test_detect_artifacts.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index aa4e8876e9..72350096b1 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -37,7 +37,7 @@ def test_detect_saturation_periods(): # Generate some data in uV sat_value = 12 - uV_per_ms_threshold = 12 / sample_frequency / 1e3 + uV_per_ms_threshold = 12 / sample_frequency / 1e3 rng = np.random.default_rng() data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 @@ -67,13 +67,14 @@ def test_detect_saturation_periods(): combined = np.r_[data_seg_1.flatten(), data_seg_2.flatten()] max_ = np.max(combined) # min_ = np.min(combined) - gain = max_ / 2**15 # (max_ - min_) / 65535 - offset = 0 # min_ + 32768 * gain + gain = max_ / 2**15 # (max_ - min_) / 65535 + offset = 0 # min_ + 32768 * gain PLOT = True if PLOT: import matplotlib import matplotlib.pyplot as plt + plt.plot(data_seg_1) plt.title("data float") plt.show() From 4df3ff090d2bd359aec1094155c8562a8d8660d8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 17 Mar 2026 15:44:56 +0100 Subject: [PATCH 10/17] Use diff_threshold_uV, clean up, extend tests --- conftest.py | 15 + .../preprocessing/detect_artifacts.py | 563 +++++++++++++----- .../preprocessing/silence_periods.py | 2 +- .../tests/test_detect_artifacts.py | 201 +++++-- .../tests/test_grouped_preprocessing.py | 4 +- 5 files changed, 597 insertions(+), 188 deletions(-) diff --git a/conftest.py b/conftest.py index 5bf7d74527..e326c60d71 100644 --- a/conftest.py +++ b/conftest.py @@ -8,6 +8,21 @@ def create_cache_folder(tmp_path_factory): return cache_folder +@pytest.fixture(scope="module") +def debug_plots(request): + """Return True if debug plots should be shown.""" + return request.config.getoption("--debug-plots") + + +def pytest_addoption(parser): + parser.addoption( + "--debug-plots", + action="store_true", + default=False, + help="Enable debug plots during tests", + ) + + def pytest_collection_modifyitems(config, items): """ This function marks (in the pytest sense) the tests according to their name and file_path location diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 695932d27c..314f7abce5 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -1,64 +1,50 @@ from __future__ import annotations +from typing import Literal + import numpy as np +from spikeinterface.core import BaseRecording from spikeinterface.core.base import base_period_dtype - -# from spikeinterface.core.core_tools import define_function_handling_dict_from_class -# from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording from spikeinterface.preprocessing.rectify import RectifyRecording from spikeinterface.preprocessing.common_reference import CommonReferenceRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode -import numpy as np artifact_dtype = base_period_dtype - # this will be extend with channel boundaries if needed # extended_artifact_dtype = artifact_dtype + [ # # TODO # ] -def detect_artifact_periods( - recording, - method="envelope", - method_kwargs=None, - job_kwargs=None, -): - """ - Detect artifacts with several possible methods: - * 'saturation' using detect_artifact_periods_by_envelope() - * 'envelope' using detect_saturation_periods() - - See sub methods for more information on parameters. +def _collapse_events(events: np.ndarray, num_samples: list[int], mute_samples: int | None = None) -> np.ndarray: """ + Collapse artifact events that were split across chunk boundaries. - if method_kwargs is None: - method_kwargs = dict() - - if method == "envelope": - artifact_periods, envelope = detect_artifact_periods_by_envelope( - recording, **method_kwargs, job_kwargs=job_kwargs - ) - elif method == "saturation": - artifact_periods = detect_saturation_periods(recording, **method_kwargs, job_kwargs=job_kwargs) - else: - raise ValueError(f"detect_artifact_periods() method='{method}' is not valid") - - return artifact_periods - + When a chunk boundary falls within an artifact period the period is emitted + as two adjacent events whose ``end_sample_index`` / ``start_sample_index`` + values are equal. This function merges such pairs into a single record. -## detect_period_artifacts_saturation Zone - - -def _collapse_events(events): - """ - If events are detected at a chunk edge, they will be split in two. - This detects such cases and collapses them in a single record instead. + Parameters + ---------- + events : np.ndarray + Array of artifact events with dtype ``artifact_dtype``, containing + ``"start_sample_index"``, ``"end_sample_index"``, and + ``"segment_index"`` fields. + num_samples : list[int] + List of the number of samples in each segment of the recording. + Used to handle edge cases where an artifact extends to the very end of a segment. + mute_samples : int | None, default: None + Reserved for future use (not yet implemented). + + Returns + ------- + np.ndarray + Array of collapsed artifact events with the same dtype as ``events``. """ order = np.lexsort((events["start_sample_index"], events["segment_index"])) events = events[order] @@ -70,16 +56,29 @@ def _collapse_events(events): if same: to_drop[i] = True events["start_sample_index"][i + 1] = events["start_sample_index"][i] - - return events[~to_drop].copy() - - + collapsed_events = events[~to_drop] + if mute_samples is not None: + collapsed_events["start_sample_index"] -= mute_samples + collapsed_events["end_sample_index"] += mute_samples + collapsed_events["start_sample_index"] = np.maximum(collapsed_events["start_sample_index"], 0) + for seg_index in np.unique(collapsed_events["segment_index"]): + mask = collapsed_events["segment_index"] == seg_index + collapsed_events["end_sample_index"][mask] = np.minimum( + collapsed_events["end_sample_index"][mask], num_samples[seg_index] + ) + # Rerun the collapsing in case the mute window caused new overlaps + collapsed_events = _collapse_events(collapsed_events, num_samples, mute_samples=None) + return collapsed_events + + +## detect_period_artifacts_saturation zone class _DetectSaturation(PipelineNode): """ - A recording node for parallelising saturation detection. + A pipeline node for parallelised amplifier-saturation detection. - Run with `run_node_pipeline`, this computes saturation events - for a given chunk. See `detect_saturation()` for details. + When run with :func:`run_node_pipeline`, this node computes saturation + events for a given data chunk. See :func:`detect_saturation_periods` for + the full algorithm description and parameter semantics. """ name = "detect_saturation" @@ -88,16 +87,30 @@ class _DetectSaturation(PipelineNode): def __init__( self, - recording, - saturation_threshold_uV, - uV_per_ms_threshold, - proportion, - ): + recording: BaseRecording, + saturation_threshold_uV: float, + diff_threshold_uV: float | None, + proportion: float, + ) -> None: + """ + Parameters + ---------- + recording : BaseRecording + The recording to process. + saturation_threshold_uV : float + Voltage saturation threshold in μV. + diff_threshold_uV : float | None + First-derivative threshold in μV/sample, or ``None`` to disable + derivative-based detection. + proportion : float + Fraction of channels that must exceed the threshold for a sample to + be labelled as saturated (0 < proportion < 1). + """ PipelineNode.__init__(self, recording, return_output=True) num_chans = recording.get_num_channels() - self.uV_per_ms_threshold = uV_per_ms_threshold + self.diff_threshold_uV = diff_threshold_uV thresh = np.full((num_chans,), saturation_threshold_uV) # 0.98 is empirically determined as the true saturating point is # slightly lower than the documented saturation point of the probe @@ -107,38 +120,78 @@ def __init__( self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() - self.saturation_threshold_unscaled = (thresh - self.offset) / self.gain * 0.98 + self.saturation_threshold_unscaled = (thresh - self.offset) / self.gain # do not apply offset when dealing with the derivative - self.uV_per_ms_threshold = (uV_per_ms_threshold * self.sampling_frequency / 1e3) / self.gain + if self.diff_threshold_uV is not None: + self.diff_threshold_unscaled = diff_threshold_uV / self.gain + else: + self.diff_threshold_unscaled = None - def get_trace_margin(self): + def get_trace_margin(self) -> int: + """Return the number of margin samples required on each side of a chunk.""" return 0 - def get_dtype(self): + def get_dtype(self) -> np.dtype: + """Return the NumPy dtype of the output array produced by :meth:`compute`.""" return self._dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute( + self, + traces: np.ndarray, + start_frame: int, + end_frame: int, + segment_index: int, + max_margin: int, + ) -> tuple[np.ndarray]: """ - Compute saturation events for a given chunk of data. - See `detect_saturation()` for details. + Detect saturation events within a single chunk of raw traces. + + A sample is labelled as *saturated by value* when the fraction of + channels whose absolute amplitude exceeds + ``saturation_threshold_unscaled`` is greater than ``proportion``. + + Optionally, a sample is also labelled as *saturated by derivative* when + the fraction of channels whose forward-difference amplitude exceeds + ``diff_threshold_unscaled`` is greater than ``proportion``. + + Consecutive saturated samples are grouped into contiguous period events. + + Parameters + ---------- + traces : np.ndarray + Raw trace data for the current chunk, shape ``(n_samples, n_channels)``. + start_frame : int + Index of the first sample of this chunk within its segment. + end_frame : int + Index one past the last sample of this chunk within its segment. + segment_index : int + Index of the segment to which this chunk belongs. + max_margin : int + Maximum trace margin (unused; kept for API compatibility). + + Returns + ------- + tuple[np.ndarray] + A one-element tuple containing an array of saturation events with + dtype ``artifact_dtype``. """ saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) + detected_by_value = saturation > self.proportion - if self.uV_per_ms_threshold is not None: - fs = self.sampling_frequency + if self.diff_threshold_unscaled is not None: # then compute the derivative of the voltage saturation - - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.uV_per_ms_threshold, axis=1) + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.diff_threshold_unscaled, axis=1) # Note this means the velocity is not checked for the last sample in the # check because we are taking the forward derivative n_diff_saturated = np.r_[n_diff_saturated, 0] # if either of those reaches more than the proportion of channels labels the sample as saturated - saturation = np.logical_or(saturation > self.proportion, n_diff_saturated > self.proportion) + detected_by_diff = n_diff_saturated > self.proportion + saturation = np.logical_or(detected_by_value, detected_by_diff) else: - saturation = saturation > self.proportion + saturation = detected_by_value intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) n_events = len(intervals) // 2 # Number of saturation periods @@ -153,86 +206,139 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def detect_saturation_periods( - recording, - saturation_threshold_uV, # 1200 uV - uV_per_ms_threshold=None, # 1e-8 V.s-1 - proportion=0.2, - job_kwargs=None, -): + recording: BaseRecording, + saturation_threshold_uV: float | None = None, + diff_threshold_uV: float | None = None, + proportion: float = 0.2, + mute_window_ms: float = 0, + job_kwargs: dict | None = None, +) -> np.ndarray: """ - Detect amplifier saturation events (either single sample or multi-sample periods) in the data. - Saturation detection with this function should be applied to the raw data, before preprocessing. - However, saturation periods detected should be zeroed out after preprocessing has been performed. + Detect amplifier saturation events (single- or multi-sample periods) in raw data. - Saturation is detected by a voltage threshold, and optionally a derivative threshold that - flags periods of high velocity changes in the voltage. See _DetectSaturation.compute() - for details on the algorithm. + Saturation detection should be applied to the **raw** recording, before any + preprocessing. The returned periods can then be used to zero out (silence) + the corresponding samples **after** preprocessing has been performed. - Parameters - ---------- - recording : BaseRecording - The recording on which to detect the saturation events. - saturation_threshold_uV : float - The voltage saturation threshold in volts. This will depend on the recording - probe and amplifier gain settings. For NP1 the value of 1200 uV is recommended (IBL). - Note that NP2 probes are more difficult to saturate than NP1. - uV_per_ms_threshold : None | float - The first-derivative threshold in volts per second. Periods of the data over which the change - in velocity is greater than this threshold will be detected as saturation events. Use `None` to - skip this method and only use `saturation_threshold_uV` for detection. Otherwise, the value should be - empirically determined (IBL use 1e-8 V.s-1) for NP1 probes. + Saturation is identified in two complementary ways: - proportion : float - 0 < proportion <1 of channels above threshold to consider the sample as saturated - mute_window_samples : int - TODO: should we scale this based on the fs? - job_kwargs: dict - The classical job_kwargs + 1. **By value**: a sample is saturated when the fraction of channels whose + absolute amplitude exceeds ``saturation_threshold_uV`` is greater than + ``proportion``. + 2. **By derivative**: a sample is saturated when the fraction of channels + whose forward-difference amplitude exceeds ``diff_threshold_uV`` is + greater than ``proportion``. - most useful for NP1 - can use ratio as a intuition for the value but dont do it in code + If ``diff_threshold_uV`` is not ``None``, a sample is marked as saturated + if *either* criterion is met. - Returns + Parameters + ---------- + recording : BaseRecording + The recording on which to detect saturation events. + saturation_threshold_uV : float | None, default: None + Voltage saturation threshold in μV. The appropriate value depends on + the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL + recommend **1200 μV**. NP2 probes are harder to saturate than NP1. + If ``None``, the value is read from the ``"saturation_threshold_uV"`` + annotation of ``recording``. + diff_threshold_uV : float | None, default: None + First-derivative threshold in μV/sample. Periods where the + sample-to-sample voltage change exceeds this value in the required + fraction of channels are flagged as saturation. Pass ``None`` to + disable derivative-based detection and rely solely on + ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. + proportion : float, default: 0.2 + Fraction of channels (0 < proportion < 1) that must exceed the + threshold for a sample to be considered saturated. + mute_window_ms : float, default: 0 + Additional silence window in milliseconds to add symmetrically before + and after each detected saturation period. Useful for capturing + ringing or other artefacts that immediately surround a saturation event. + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + + Returns ------- - collapsed_events : np.recarray - A numpy recarray holding information on each saturation event. Has the fields: - "start_sample_index", "stop_sample_index", "segment_index", "method_id" + np.ndarray + Array with dtype ``artifact_dtype`` describing each saturation period. + Fields: ``"start_sample_index"``, ``"end_sample_index"``, + ``"segment_index"``. """ - if job_kwargs: + if job_kwargs is None: job_kwargs = {} job_kwargs = fix_job_kwargs(job_kwargs) + # The saturation threshold can be specified in the recording annotations and loaded automatically + # for some acquisition systems (e.g., Neuropixels) + if "saturation_threshold_uV" in recording.get_annotation_keys() and saturation_threshold_uV is None: + saturation_threshold_uV = recording.get_annotation("saturation_threshold_uV") + node0 = _DetectSaturation( recording, saturation_threshold_uV=saturation_threshold_uV, - uV_per_ms_threshold=uV_per_ms_threshold, + diff_threshold_uV=diff_threshold_uV, proportion=proportion, ) saturation_periods = run_node_pipeline( recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts" ) - - return _collapse_events(saturation_periods) + if mute_window_ms is not None: + mute_samples = int(mute_window_ms * recording.get_sampling_frequency() / 1000) + else: + mute_samples = None + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + return _collapse_events(saturation_periods, num_samples, mute_samples) -## detect_artifact_periods_by_envelope Zone +## detect_artifact_periods_by_envelope zone +class _DetectThresholdCrossing(PeakDetector): + """ + A pipeline node that detects threshold crossings of a channel-aggregated envelope. + Each crossing of the global median z-score above 1 is returned as an event + with a ``"front"`` flag indicating whether the crossing is a rising edge + (``True``) or a falling edge (``False``). Used internally by + :func:`detect_artifact_periods_by_envelope`. -class _DetectThresholdCrossing(PeakDetector): + Attributes + ---------- + abs_thresholds : np.ndarray + Per-channel absolute amplitude thresholds in raw ADC units. + """ name = "threshold_crossings" preferred_mp_context = None def __init__( self, - recording, - detect_threshold=5, - noise_levels=None, - seed=None, - noise_levels_kwargs=dict(), - ): + recording: BaseRecording, + detect_threshold: float = 5, + noise_levels: np.ndarray | None = None, + seed: int | None = None, + noise_levels_kwargs: dict = dict(), + ) -> None: + """ + Parameters + ---------- + recording : BaseRecording + The (pre-processed envelope) recording to process. + detect_threshold : float, default: 5 + Detection threshold expressed as a multiple of the estimated noise + level per channel. + noise_levels : np.ndarray | None, default: None + Pre-computed per-channel noise levels in raw ADC units. If + ``None``, they are estimated via + :func:`~spikeinterface.core.get_noise_levels`. + seed : int | None, default: None + Random seed used when estimating noise levels. + noise_levels_kwargs : dict, default: {} + Additional keyword arguments forwarded to + :func:`~spikeinterface.core.get_noise_levels`. + """ PeakDetector.__init__(self, recording, return_output=True) if noise_levels is None: random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() @@ -242,15 +348,53 @@ def __init__( # internal dtype self._dtype = np.dtype([("sample_index", "int64"), ("segment_index", "int64"), ("front", "bool")]) - def get_trace_margin(self): + def get_trace_margin(self) -> int: + """Return the number of margin samples required on each side of a chunk.""" return 0 - def get_dtype(self): + def get_dtype(self) -> np.dtype: + """Return the NumPy dtype of the output array produced by :meth:`compute`.""" return self._dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute( + self, + traces: np.ndarray, + start_frame: int, + end_frame: int, + segment_index: int, + max_margin: int, + ) -> tuple[np.ndarray]: + """ + Detect threshold crossings in a single chunk of envelope traces. + + The per-sample signal is the median z-score across channels: + ``z = median(traces / abs_thresholds, axis=1)``. Transitions of + ``z > 1`` are located and returned as crossing events. + + Parameters + ---------- + traces : np.ndarray + Envelope trace data for the current chunk, + shape ``(n_samples, n_channels)``. + start_frame : int + Index of the first sample of this chunk within its segment. + end_frame : int + Index one past the last sample of this chunk within its segment. + segment_index : int + Index of the segment to which this chunk belongs. + max_margin : int + Maximum trace margin (unused; kept for API compatibility). + + Returns + ------- + tuple[np.ndarray] + A one-element tuple containing an array of threshold-crossing + events with fields ``"sample_index"``, ``"segment_index"``, and + ``"front"`` (``True`` for rising edge, ``False`` for falling edge). + """ z = np.median(traces / self.abs_thresholds, 1) threshold_mask = np.diff((z > 1) != 0, axis=0) + indices = np.flatnonzero(threshold_mask) threshold_crossings = np.zeros(indices.size, dtype=self._dtype) threshold_crossings["sample_index"] = indices @@ -261,36 +405,72 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def detect_artifact_periods_by_envelope( - recording, - detect_threshold=5, - # min_duration_ms=50, - freq_max=20.0, - seed=None, - job_kwargs=None, - random_slices_kwargs=None, -): + recording: BaseRecording, + detect_threshold: float = 5, + apply_envelope_common_reference: bool = False, + mute_window_ms: float | None = None, + freq_max: float = 20.0, + seed: int | None = None, + job_kwargs: dict | None = None, + random_slices_kwargs: dict | None = None, + return_envelope: bool = False, +) -> np.ndarray | tuple[np.ndarray, BaseRecording]: """ - Function to detect putative artifact periods as threshold crossings of - a global envelope of the channels. + Detect putative artifact periods as threshold crossings of a global channel envelope. + + The pipeline is: + + 1. Rectify the raw recording. + 2. Low-pass filter with a Gaussian filter up to ``freq_max`` Hz to produce + a smooth per-channel amplitude envelope. + 3. Apply a common-average reference so that only signals correlated across + channels (i.e. artefacts) survive. + 4. Estimate per-channel noise levels on the envelope. + 5. Detect samples where the median channel z-score exceeds + ``detect_threshold``, and convert contiguous runs into period records. Parameters ---------- - recording : RecordingExtractor - The recording extractor to detect putative artifacts + recording : BaseRecording + The recording extractor from which to detect artefact periods. detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used + Detection threshold as a multiple of the estimated per-channel noise + level of the envelope. + freq_max : float, default: 20.0 + Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the + rectified signal when building the envelope. + mute_window_ms : float | None, default: None + Additional silence window in milliseconds to add symmetrically before + and after each detected artifact period. Useful for capturing ringing or + other artefacts that immediately surround a detected event. + Pass ``None`` to disable muting. seed : int | None, default: None - Random seed for `get_noise_levels`. - If none, `get_noise_levels` uses `seed=0`. - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - + Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. + If ``None``, ``get_noise_levels`` uses ``seed=0``. + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + random_slices_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the ``random_slices_kwargs`` + argument of :func:`~spikeinterface.core.get_noise_levels`. + return_envelope : bool, default: False + If ``True``, also return the intermediate envelope recording so that it + can be inspected or plotted. + + Returns + ------- + artifacts : np.ndarray + Array with dtype ``artifact_dtype`` describing each detected artifact + period. Fields: ``"start_sample_index"``, ``"end_sample_index"``, + ``"segment_index"``. + envelope : BaseRecording + Only returned when ``return_envelope=True``. The processed envelope + recording (rectified → Gaussian-filtered → common-average referenced). """ - envelope = RectifyRecording(recording) envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) - envelope = CommonReferenceRecording(envelope) + if apply_envelope_common_reference: + envelope = CommonReferenceRecording(envelope) job_kwargs = fix_job_kwargs(job_kwargs) if random_slices_kwargs is None: @@ -319,11 +499,52 @@ def detect_artifact_periods_by_envelope( artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) - return artifacts, envelope + if mute_window_ms is not None: + mute_samples = int(mute_window_ms * recording.get_sampling_frequency() / 1000) + else: + mute_samples = None + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + artifacts = _collapse_events(artifacts, num_samples, mute_samples) + + if return_envelope: + return artifacts, envelope + else: + return artifacts + + +def _transform_internal_dtype_to_artifact_dtype( + artifacts: np.ndarray, + recording: BaseRecording, +) -> np.ndarray: + """ + Convert threshold-crossing events to the standard ``artifact_dtype`` format. + Threshold-crossing events are stored as individual rising/falling edge + records. This function pairs them up segment by segment to produce + contiguous period records. Edge cases at segment boundaries are handled: -def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): + * If the first event in a segment is a falling edge, an implicit rising + edge at sample 0 is prepended. + * If the last event in a segment is a rising edge, an implicit falling edge + at the last sample of the segment is appended. + Parameters + ---------- + artifacts : np.ndarray + Array of threshold-crossing events with fields ``"sample_index"``, + ``"segment_index"``, and ``"front"`` (``True`` = rising edge). + Must be sorted by ``(segment_index, sample_index)``. + recording : BaseRecording + The original recording, used to determine the number of segments and + the number of samples per segment. + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` containing the merged artifact + periods. Returns an empty array if no crossings are found. + """ num_seg = recording.get_num_segments() final_artifacts = [] @@ -332,19 +553,19 @@ def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): sub_thr = artifacts[mask] if len(sub_thr) > 0: if not sub_thr["front"][0]: - local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr = np.zeros(1, dtype=np.dtype(base_period_dtype + [("front", "bool")])) local_thr["sample_index"] = 0 local_thr["front"] = True sub_thr = np.hstack((local_thr, sub_thr)) if sub_thr["front"][-1]: - local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) + local_thr = np.zeros(1, dtype=np.dtype(base_period_dtype + [("front", "bool")])) local_thr["sample_index"] = recording.get_num_samples(seg_index) local_thr["front"] = False sub_thr = np.hstack((sub_thr, local_thr)) - local_artifact = np.zeros(sub_thr.size / 2, dtype=artifact_dtype) - local_artifact["start_index"] = sub_thr["sample_index"][::2] - local_artifact["stop_index"] = sub_thr["sample_index"][1::2] + local_artifact = np.zeros(int(sub_thr.size / 2), dtype=artifact_dtype) + local_artifact["start_sample_index"] = sub_thr["sample_index"][::2] + local_artifact["stop_sample_index"] = sub_thr["sample_index"][1::2] local_artifact["segment_index"] = seg_index final_artifacts.append(local_artifact) @@ -353,3 +574,57 @@ def _transform_internal_dtype_to_artifact_dtype(artifacts, recording): else: final_artifacts = np.zeros(0, dtype=artifact_dtype) return final_artifacts + + +_method_to_function = { + "envelope": detect_artifact_periods_by_envelope, + "saturation": detect_saturation_periods, +} + + +def detect_artifact_periods( + recording: BaseRecording, + method: Literal["envelope", "saturation"] = "envelope", + method_kwargs: dict | None = None, + job_kwargs: dict | None = None, +) -> np.ndarray: + """ + Detect artifact periods using one of several available methods. + + Available methods: + + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + + See the documentation of each sub-function for a full description of their + parameters, which can be forwarded via ``method_kwargs``. + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : {"envelope", "saturation"}, default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` describing each detected artifact + period. + """ + assert ( + method in _method_to_function + ), f"Method {method} not recognized. Valid methods are: {_method_to_function.keys()}" + if method_kwargs is None: + method_kwargs = dict() + + artifact_periods = _method_to_function[method](recording, job_kwargs=job_kwargs, **method_kwargs) + + return artifact_periods diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index d7cada1b27..189b97ec87 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -124,7 +124,7 @@ def _all_period_list_to_periods_vec(list_periods, num_seg): periods = np.zeros(size, dtype=base_period_dtype) start = 0 for i in range(num_seg): - periods_in_seg = list_periods[i] + periods_in_seg = np.array(list_periods[i]) stop = start + periods_in_seg.shape[0] periods[start:stop]["segment_index"] = i periods[start:stop]["start_sample_index"] = periods_in_seg[:, 0] diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 72350096b1..7ebfe7d6fe 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -1,19 +1,49 @@ -from spikeinterface.core import generate_recording, NumpyRecording -from spikeinterface.preprocessing import detect_artifact_periods, detect_saturation_periods import numpy as np +from spikeinterface.core import generate_recording, NumpyRecording +from spikeinterface.preprocessing import ( + detect_artifact_periods, + detect_saturation_periods, + detect_artifact_periods_by_envelope, +) + -def test_detect_artifact_periods(): +def test_detect_artifact_by_envelope(debug_plots): # one segment only - rec = generate_recording(durations=[10.0, 10]) - artifacts = detect_artifact_periods( - rec, - method="envelope", - method_kwargs=dict(detect_threshold=5, freq_max=5.0), + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below + + # Generate some data in uV + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng() + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 + + artifact_starts = rng.choice(np.arange(0, data.shape[0] - 1000), size=10, replace=False) + artifact_stops = artifact_starts + 100 + for start, stop in zip(artifact_starts, artifact_stops): + data[start:stop, :] = sat_value + + recording = NumpyRecording(data, sampling_frequency) + + artifacts, envelope = detect_artifact_periods_by_envelope( + recording, apply_envelope_common_reference=False, return_envelope=True ) + if debug_plots: + import matplotlib + import matplotlib.pyplot as plt + + plt.plot(envelope.get_traces(), color="r", lw=3) + plt.title("data float") + plt.show() + + # TODO: investigate why not detecting any artifacts in this tests, despite very peaky envelopes! + # assert len(artifacts) == len(artifact_starts) -def test_detect_saturation_periods(): + +def test_detect_saturation_periods(debug_plots): """ This tests the saturation detection method. First a mock recording is created with saturation events. Events may be single-sample or a multi-sample period. We create a multi-segment @@ -31,18 +61,19 @@ def test_detect_saturation_periods(): import scipy.signal num_chans = 32 - sample_frequency = 30000 + sampling_frequency = 30000 chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below job_kwargs = {"chunk_size": chunk_size} # Generate some data in uV - sat_value = 12 - uV_per_ms_threshold = 12 / sample_frequency / 1e3 + sat_value = 1200 + diff_threshold_uV = 200 # 200 uV/sample + noise_level = 10 rng = np.random.default_rng() - data = rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 # Design the Butterworth filter - sos = scipy.signal.butter(N=3, Wn=12000 / (sample_frequency / 2), btype="low", output="sos") + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") # Apply the filter to the data data_seg_1 = scipy.signal.sosfiltfilt(sos, data, axis=0) @@ -57,21 +88,42 @@ def test_detect_saturation_periods(): all_starts = np.array([0, 29950, 45123, 90005, 149500]) all_stops = np.array([1001, 30011, 45126, 90006, 149999]) - second_seg_offset = 1 + second_seg_stop_offset = 10 for start, stop in zip(all_starts, all_stops): data_seg_1[start:stop, :] = sat_value # differentiate the second segment for testing purposes - data_seg_2[start : stop + second_seg_offset, :] = sat_value + data_seg_2[start : stop + second_seg_stop_offset, :] = sat_value + + # Add slow artifact + start_slow_artifact = 6100 + stop_slow_artifact = 6300 + accepted_slope = diff_threshold_uV * 0.9 + start_rising_sample = int(np.floor(start_slow_artifact - sat_value / accepted_slope)) + stop_falling_sample = int(np.ceil(stop_slow_artifact + sat_value / accepted_slope)) + + offsets = [0, second_seg_stop_offset] + data_segs = [data_seg_1, data_seg_2] + for offset, data_seg in zip(offsets, data_segs): + start_rising = start_rising_sample + stop_rising = start_slow_artifact + start_falling = stop_slow_artifact + offset + stop_falling = stop_falling_sample + offset + data_seg[stop_rising:start_falling, :] = sat_value + data_seg[start_rising:stop_rising, :] = np.tile( + (accepted_slope * np.arange(stop_rising - start_rising))[:, None], (1, num_chans) + ) + data_seg[start_falling:stop_falling, :] = np.tile( + (sat_value - accepted_slope * np.arange(stop_falling - start_falling))[:, None], (1, num_chans) + ) - # this center the int16 around 0 and saturate on positive - combined = np.r_[data_seg_1.flatten(), data_seg_2.flatten()] - max_ = np.max(combined) - # min_ = np.min(combined) - gain = max_ / 2**15 # (max_ - min_) / 65535 - offset = 0 # min_ + 32768 * gain + # Add start and stop of slow artifact to start/stops + all_starts = np.sort(np.append(all_starts, start_slow_artifact)) + all_stops = np.clip(np.sort(np.append(all_stops, stop_slow_artifact)), a_min=0, a_max=data_seg_1.shape[0] - 1) - PLOT = True - if PLOT: + gain = 2.34 # mimic NP1.0 + offset = 0 + + if debug_plots: import matplotlib import matplotlib.pyplot as plt @@ -85,7 +137,7 @@ def test_detect_saturation_periods(): seg_1_int16 = np.clip(np.rint((data_seg_1 - offset) / gain), -32768, 32767).astype(np.int16) seg_2_int16 = np.clip(np.rint((data_seg_2 - offset) / gain), -32768, 32767).astype(np.int16) - if PLOT: + if debug_plots: plt.plot(seg_1_int16) plt.title("data int") plt.show() @@ -93,14 +145,14 @@ def test_detect_saturation_periods(): plt.title("diff int") plt.show() - recording = NumpyRecording([seg_1_int16, seg_2_int16], sample_frequency) + recording = NumpyRecording([seg_1_int16, seg_2_int16], sampling_frequency) recording.set_channel_gains(gain) recording.set_channel_offsets([offset] * num_chans) periods = detect_saturation_periods( recording, saturation_threshold_uV=sat_value * 0.98, - uV_per_ms_threshold=uV_per_ms_threshold, + diff_threshold_uV=diff_threshold_uV, job_kwargs=job_kwargs, ) @@ -112,38 +164,103 @@ def test_detect_saturation_periods(): # sample before the saturation starts. Therefore this one-sample-offset # on the start times is an implicit test that the derivative # threshold is working properly. - for seg_periods in [seg_1_periods, seg_2_periods]: - assert seg_periods["start_sample_index"][0] == all_starts[0] - assert np.array_equal(seg_periods["start_sample_index"][1:], np.array(all_starts)[1:] - 1) + tolerance_samples = 1 + offsets = np.array([0, second_seg_stop_offset]) + for seg_periods, offset in zip([seg_1_periods, seg_2_periods], offsets): + starts = seg_periods["start_sample_index"] + stops = seg_periods["end_sample_index"] + start_diffs = np.abs(starts - all_starts) + assert np.all(start_diffs <= tolerance_samples) + stop_diffs = np.abs(stops - np.clip(all_stops + offset, a_min=0, a_max=data_seg_1.shape[0] - 1)) + assert np.all(stop_diffs <= tolerance_samples) - assert np.array_equal(seg_1_periods["end_sample_index"], np.array(all_stops)) - assert np.array_equal(seg_2_periods["end_sample_index"], np.array(all_stops) + second_seg_offset) + # Check that slow rising and falling phases are not in periods + # The ramp slope is 90% of diff_threshold_uV, so they should not be detected. + for seg_periods, seg_offset in zip([seg_1_periods, seg_2_periods], offsets): + slow_period_idx = np.argmin(np.abs(seg_periods["start_sample_index"] - start_slow_artifact)) + slow_period = seg_periods[slow_period_idx] + assert ( + slow_period["start_sample_index"] >= start_rising_sample + tolerance_samples + ), "Slow artifact period starts in the rising phase" + assert ( + slow_period["end_sample_index"] <= stop_falling_sample + seg_offset - tolerance_samples + ), "Slow artifact period ends in the falling phase" # Just do a quick test that a threshold slightly over the sat value is not detected. # In this case we only see the derivative threshold detection. We do not play around with this # threshold because the derivative threshold is not easy to predict (the baseline sample is random). - periods = detect_saturation_periods( + periods_only_diff = detect_saturation_periods( recording, - saturation_threshold_uV=sat_value * (1 / 0.98), - uV_per_ms_threshold=uV_per_ms_threshold, + saturation_threshold_uV=sat_value * 1.02, + diff_threshold_uV=diff_threshold_uV, job_kwargs=job_kwargs, ) - assert periods["start_sample_index"][0] == 1000 - assert periods["end_sample_index"][0] == 1001 + assert abs(periods_only_diff["start_sample_index"][0] - 1000) <= tolerance_samples + assert abs(periods_only_diff["end_sample_index"][0] - 1001) <= tolerance_samples + # Test that the same result is obtained with the detect_artifact_periods function with method="saturation" and the + # same parameters. periods_entry_function = detect_artifact_periods( recording, method="saturation", method_kwargs=dict( - saturation_threshold_uV=sat_value * (1 / 0.98), - uV_per_ms_threshold=uV_per_ms_threshold, + saturation_threshold_uV=sat_value * 0.98, + diff_threshold_uV=diff_threshold_uV, ), job_kwargs=job_kwargs, ) - assert np.array_equal(periods, periods_entry_function) + # Test that the same result is obtained with multiple jobs + job_kwargs = {"chunk_size": chunk_size, "n_jobs": 2, "mp_context": "forkserver"} + periods_entry_function_parallel = detect_artifact_periods( + recording, + method="saturation", + method_kwargs=dict( + saturation_threshold_uV=sat_value * 0.98, + diff_threshold_uV=diff_threshold_uV, + ), + job_kwargs=job_kwargs, + ) + assert np.array_equal(periods, periods_entry_function_parallel) + + # Test that the same result is obtained with saturation_threshold_uV annotation + recording.annotate(saturation_threshold_uV=sat_value * 0.98) + periods_entry_with_annotation = detect_artifact_periods( + recording, + method="saturation", + method_kwargs=dict( + saturation_threshold_uV=None, + diff_threshold_uV=diff_threshold_uV, + ), + job_kwargs=job_kwargs, + ) + assert np.array_equal(periods, periods_entry_with_annotation) + + # Test mute window around detected periods + mute_window_ms = 0.1 + mute_samples = int(mute_window_ms * sampling_frequency / 1000) + muted_periods = detect_artifact_periods( + recording, + method="saturation", + method_kwargs=dict( + saturation_threshold_uV=sat_value * 0.98, + diff_threshold_uV=diff_threshold_uV, + mute_window_ms=mute_window_ms, + ), + job_kwargs=job_kwargs, + ) + seg_1_muted_periods = muted_periods[np.where(muted_periods["segment_index"] == 0)] + seg_2_muted_periods = muted_periods[np.where(muted_periods["segment_index"] == 1)] + for seg_periods, offset in zip([seg_1_muted_periods, seg_2_muted_periods], offsets): + starts = seg_periods["start_sample_index"] + stops = seg_periods["end_sample_index"] + start_diffs = np.abs(starts - np.clip(all_starts - mute_samples, a_min=0, a_max=data_seg_1.shape[0] - 1)) + assert np.all(start_diffs <= tolerance_samples) + stop_diffs = np.abs(stops - np.clip(all_stops + offset + mute_samples, a_min=0, a_max=data_seg_1.shape[0] - 1)) + assert np.all(stop_diffs <= tolerance_samples) + if __name__ == "__main__": - test_detect_artifact_periods() - test_detect_saturation_periods() + test_detect_artifact_by_envelope(True) + # test_detect_saturation_periods() diff --git a/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py b/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py index c2cdbeb3db..861fa23f6f 100644 --- a/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py +++ b/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py @@ -40,7 +40,9 @@ def test_grouped_preprocessing(): sp_recording_1 = silence_periods(recording_1, list_periods=list_periods, mode=mode, seed=seed) sp_recording_2 = silence_periods(recording_2, list_periods=list_periods, mode=mode, seed=seed) - dict_of_silence_period_recordings = silence_periods(dict_of_recordings, list_periods, mode=mode, seed=seed) + dict_of_silence_period_recordings = silence_periods( + dict_of_recordings, list_periods=list_periods, mode=mode, seed=seed + ) check_recordings_equal(dict_of_silence_period_recordings["one"], sp_recording_1) check_recordings_equal(dict_of_silence_period_recordings["two"], sp_recording_2) From fb9867b8dc669c7f688c1d55f1be21016c9ee769 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 17 Mar 2026 16:35:41 +0100 Subject: [PATCH 11/17] fix: get rid of forkserver mp_context --- src/spikeinterface/preprocessing/tests/test_detect_artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 7ebfe7d6fe..f628786ede 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -212,7 +212,7 @@ def test_detect_saturation_periods(debug_plots): assert np.array_equal(periods, periods_entry_function) # Test that the same result is obtained with multiple jobs - job_kwargs = {"chunk_size": chunk_size, "n_jobs": 2, "mp_context": "forkserver"} + job_kwargs = {"chunk_size": chunk_size, "n_jobs": 2} periods_entry_function_parallel = detect_artifact_periods( recording, method="saturation", From a8c3ef58f4e8a7c5615ab8bf5d2de6275ef9d8ab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 17 Mar 2026 16:37:00 +0100 Subject: [PATCH 12/17] cast traces to float32 to aboid overflow issues --- src/spikeinterface/preprocessing/detect_artifacts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 314f7abce5..97531e6ede 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -176,6 +176,9 @@ def compute( A one-element tuple containing an array of saturation events with dtype ``artifact_dtype``. """ + # cast to float32 to prevent overflow when applying thresholds in unscaled ADC units + traces = traces.astype("float32") + saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) detected_by_value = saturation > self.proportion From b6544dbafa6af62a85e26f034f49ee1f63b0bf1a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 18 Mar 2026 16:00:20 +0100 Subject: [PATCH 13/17] Fix envelope and remove mute_window_ms --- .../preprocessing/detect_artifacts.py | 82 +++++-------------- .../tests/test_detect_artifacts.py | 35 ++------ 2 files changed, 28 insertions(+), 89 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 97531e6ede..bf7925e68d 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -10,7 +10,7 @@ from spikeinterface.preprocessing.common_reference import CommonReferenceRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.recording_tools import get_noise_levels, get_random_data_chunks from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode artifact_dtype = base_period_dtype @@ -21,7 +21,7 @@ # ] -def _collapse_events(events: np.ndarray, num_samples: list[int], mute_samples: int | None = None) -> np.ndarray: +def _collapse_events(events: np.ndarray) -> np.ndarray: """ Collapse artifact events that were split across chunk boundaries. @@ -35,11 +35,6 @@ def _collapse_events(events: np.ndarray, num_samples: list[int], mute_samples: i Array of artifact events with dtype ``artifact_dtype``, containing ``"start_sample_index"``, ``"end_sample_index"``, and ``"segment_index"`` fields. - num_samples : list[int] - List of the number of samples in each segment of the recording. - Used to handle edge cases where an artifact extends to the very end of a segment. - mute_samples : int | None, default: None - Reserved for future use (not yet implemented). Returns ------- @@ -57,17 +52,6 @@ def _collapse_events(events: np.ndarray, num_samples: list[int], mute_samples: i to_drop[i] = True events["start_sample_index"][i + 1] = events["start_sample_index"][i] collapsed_events = events[~to_drop] - if mute_samples is not None: - collapsed_events["start_sample_index"] -= mute_samples - collapsed_events["end_sample_index"] += mute_samples - collapsed_events["start_sample_index"] = np.maximum(collapsed_events["start_sample_index"], 0) - for seg_index in np.unique(collapsed_events["segment_index"]): - mask = collapsed_events["segment_index"] == seg_index - collapsed_events["end_sample_index"][mask] = np.minimum( - collapsed_events["end_sample_index"][mask], num_samples[seg_index] - ) - # Rerun the collapsing in case the mute window caused new overlaps - collapsed_events = _collapse_events(collapsed_events, num_samples, mute_samples=None) return collapsed_events @@ -213,7 +197,6 @@ def detect_saturation_periods( saturation_threshold_uV: float | None = None, diff_threshold_uV: float | None = None, proportion: float = 0.2, - mute_window_ms: float = 0, job_kwargs: dict | None = None, ) -> np.ndarray: """ @@ -254,10 +237,6 @@ def detect_saturation_periods( proportion : float, default: 0.2 Fraction of channels (0 < proportion < 1) that must exceed the threshold for a sample to be considered saturated. - mute_window_ms : float, default: 0 - Additional silence window in milliseconds to add symmetrically before - and after each detected saturation period. Useful for capturing - ringing or other artefacts that immediately surround a saturation event. job_kwargs : dict | None, default: None Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. ``n_jobs``, ``chunk_duration``). @@ -289,12 +268,8 @@ def detect_saturation_periods( saturation_periods = run_node_pipeline( recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts" ) - if mute_window_ms is not None: - mute_samples = int(mute_window_ms * recording.get_sampling_frequency() / 1000) - else: - mute_samples = None num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - return _collapse_events(saturation_periods, num_samples, mute_samples) + return _collapse_events(saturation_periods) ## detect_artifact_periods_by_envelope zone @@ -319,10 +294,9 @@ class _DetectThresholdCrossing(PeakDetector): def __init__( self, recording: BaseRecording, + mads: np.ndarray, + medians: np.ndarray, detect_threshold: float = 5, - noise_levels: np.ndarray | None = None, - seed: int | None = None, - noise_levels_kwargs: dict = dict(), ) -> None: """ Parameters @@ -332,22 +306,17 @@ def __init__( detect_threshold : float, default: 5 Detection threshold expressed as a multiple of the estimated noise level per channel. - noise_levels : np.ndarray | None, default: None - Pre-computed per-channel noise levels in raw ADC units. If - ``None``, they are estimated via - :func:`~spikeinterface.core.get_noise_levels`. - seed : int | None, default: None - Random seed used when estimating noise levels. + mads : np.ndarray + Pre-computed per-channel median absolute deviations in raw ADC units. + medians : np.ndarray + Pre-computed per-channel medians in raw ADC units. noise_levels_kwargs : dict, default: {} Additional keyword arguments forwarded to :func:`~spikeinterface.core.get_noise_levels`. """ PeakDetector.__init__(self, recording, return_output=True) - if noise_levels is None: - random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() - random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) - self.abs_thresholds = noise_levels * detect_threshold + self.abs_thresholds = (mads * detect_threshold)[np.newaxis, :] + self.medians = medians[np.newaxis, :] # internal dtype self._dtype = np.dtype([("sample_index", "int64"), ("segment_index", "int64"), ("front", "bool")]) @@ -395,7 +364,7 @@ def compute( events with fields ``"sample_index"``, ``"segment_index"``, and ``"front"`` (``True`` for rising edge, ``False`` for falling edge). """ - z = np.median(traces / self.abs_thresholds, 1) + z = np.median((traces - self.medians) / self.abs_thresholds, axis=1) threshold_mask = np.diff((z > 1) != 0, axis=0) indices = np.flatnonzero(threshold_mask) @@ -411,7 +380,6 @@ def detect_artifact_periods_by_envelope( recording: BaseRecording, detect_threshold: float = 5, apply_envelope_common_reference: bool = False, - mute_window_ms: float | None = None, freq_max: float = 20.0, seed: int | None = None, job_kwargs: dict | None = None, @@ -442,11 +410,6 @@ def detect_artifact_periods_by_envelope( freq_max : float, default: 20.0 Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the rectified signal when building the envelope. - mute_window_ms : float | None, default: None - Additional silence window in milliseconds to add symmetrically before - and after each detected artifact period. Useful for capturing ringing or - other artefacts that immediately surround a detected event. - Pass ``None`` to disable muting. seed : int | None, default: None Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. If ``None``, ``get_noise_levels`` uses ``seed=0``. @@ -481,13 +444,16 @@ def detect_artifact_periods_by_envelope( else: random_slices_kwargs = random_slices_kwargs.copy() random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels(envelope, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) + random_data = get_random_data_chunks(envelope, **random_slices_kwargs) + medians = np.median(random_data, axis=0) + mad = np.median(np.abs(random_data - medians), axis=0) + mads = mad / 0.6745 node0 = _DetectThresholdCrossing( envelope, detect_threshold=detect_threshold, - noise_levels=noise_levels, - seed=seed, + mads=mads, + medians=medians, ) threshold_crossings = run_node_pipeline( @@ -502,13 +468,8 @@ def detect_artifact_periods_by_envelope( artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) - if mute_window_ms is not None: - mute_samples = int(mute_window_ms * recording.get_sampling_frequency() / 1000) - else: - mute_samples = None - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - artifacts = _collapse_events(artifacts, num_samples, mute_samples) + artifacts = _collapse_events(artifacts) if return_envelope: return artifacts, envelope @@ -554,6 +515,7 @@ def _transform_internal_dtype_to_artifact_dtype( for seg_index in range(num_seg): mask = artifacts["segment_index"] == seg_index sub_thr = artifacts[mask] + print(sub_thr) if len(sub_thr) > 0: if not sub_thr["front"][0]: local_thr = np.zeros(1, dtype=np.dtype(base_period_dtype + [("front", "bool")])) @@ -566,9 +528,9 @@ def _transform_internal_dtype_to_artifact_dtype( local_thr["front"] = False sub_thr = np.hstack((sub_thr, local_thr)) - local_artifact = np.zeros(int(sub_thr.size / 2), dtype=artifact_dtype) + local_artifact = np.zeros(int(np.ceil(sub_thr.size / 2)), dtype=artifact_dtype) local_artifact["start_sample_index"] = sub_thr["sample_index"][::2] - local_artifact["stop_sample_index"] = sub_thr["sample_index"][1::2] + local_artifact["end_sample_index"] = sub_thr["sample_index"][1::2] local_artifact["segment_index"] = seg_index final_artifacts.append(local_artifact) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index f628786ede..e0db644d87 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -17,7 +17,7 @@ def test_detect_artifact_by_envelope(debug_plots): # Generate some data in uV sat_value = 1200 noise_level = 10 - rng = np.random.default_rng() + rng = np.random.default_rng(42) data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 artifact_starts = rng.choice(np.arange(0, data.shape[0] - 1000), size=10, replace=False) @@ -28,7 +28,7 @@ def test_detect_artifact_by_envelope(debug_plots): recording = NumpyRecording(data, sampling_frequency) artifacts, envelope = detect_artifact_periods_by_envelope( - recording, apply_envelope_common_reference=False, return_envelope=True + recording, apply_envelope_common_reference=False, return_envelope=True, random_slices_kwargs={"seed": 2308} ) if debug_plots: @@ -39,8 +39,8 @@ def test_detect_artifact_by_envelope(debug_plots): plt.title("data float") plt.show() - # TODO: investigate why not detecting any artifacts in this tests, despite very peaky envelopes! - # assert len(artifacts) == len(artifact_starts) + # it finds some artifacts + assert len(artifacts) > 0 def test_detect_saturation_periods(debug_plots): @@ -237,30 +237,7 @@ def test_detect_saturation_periods(debug_plots): ) assert np.array_equal(periods, periods_entry_with_annotation) - # Test mute window around detected periods - mute_window_ms = 0.1 - mute_samples = int(mute_window_ms * sampling_frequency / 1000) - muted_periods = detect_artifact_periods( - recording, - method="saturation", - method_kwargs=dict( - saturation_threshold_uV=sat_value * 0.98, - diff_threshold_uV=diff_threshold_uV, - mute_window_ms=mute_window_ms, - ), - job_kwargs=job_kwargs, - ) - seg_1_muted_periods = muted_periods[np.where(muted_periods["segment_index"] == 0)] - seg_2_muted_periods = muted_periods[np.where(muted_periods["segment_index"] == 1)] - for seg_periods, offset in zip([seg_1_muted_periods, seg_2_muted_periods], offsets): - starts = seg_periods["start_sample_index"] - stops = seg_periods["end_sample_index"] - start_diffs = np.abs(starts - np.clip(all_starts - mute_samples, a_min=0, a_max=data_seg_1.shape[0] - 1)) - assert np.all(start_diffs <= tolerance_samples) - stop_diffs = np.abs(stops - np.clip(all_stops + offset + mute_samples, a_min=0, a_max=data_seg_1.shape[0] - 1)) - assert np.all(stop_diffs <= tolerance_samples) - if __name__ == "__main__": - test_detect_artifact_by_envelope(True) - # test_detect_saturation_periods() + # test_detect_artifact_by_envelope(True) + test_detect_saturation_periods(True) From 6032924c39e46ebfa3394182491977ff7dd19172 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 19 Mar 2026 16:44:39 +0100 Subject: [PATCH 14/17] Add check_for_peak_source in run_node_pipeline and update api.rst --- doc/api.rst | 3 +++ src/spikeinterface/core/node_pipeline.py | 20 ++++++++++--------- .../preprocessing/detect_artifacts.py | 3 ++- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index fd71e06622..fc55017606 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -213,6 +213,9 @@ spikeinterface.preprocessing .. autofunction:: detect_bad_channels .. autofunction:: detect_and_interpolate_bad_channels .. autofunction:: detect_and_remove_bad_channels + .. autofunction:: detect_artifact_periods + .. autofunction:: detect_artifact_periods_by_envelope + .. autofunction:: detect_saturation_periods .. autofunction:: directional_derivative .. autofunction:: filter .. autofunction:: gaussian_filter diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 3b3908f174..2c38248c1a 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -495,18 +495,17 @@ def find_parents_of_type(list_of_parents, parent_type): return parents -def check_graph(nodes): +def check_graph(nodes, check_for_peak_source=True): """ Check that node list is orderd in a good (parents are before children) """ - # Do not remove this, this is to remenber that in previous version the first node needed to be - # a detectot but not anymore - # node0 = nodes[0] - # if not isinstance(node0, PeakSource): - # raise ValueError( - # "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" - # ) + if check_for_peak_source: + node0 = nodes[0] + if not isinstance(node0, PeakSource): + raise ValueError( + "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" + ) for i, node in enumerate(nodes): assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode" @@ -534,6 +533,7 @@ def run_node_pipeline( verbose=False, skip_after_n_peaks=None, recording_slices=None, + check_for_peak_source=True, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -589,6 +589,8 @@ def run_node_pipeline( Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). If None (default), the function iterates over the entire duration of the recording. + check_for_peak_source : bool, default True + Whether to check that the first node is a PeakSource (PeakDetector or PeakRetriever or Returns ------- @@ -597,7 +599,7 @@ def run_node_pipeline( If squeeze_output=True and only one output then directly np.array. """ - check_graph(nodes) + check_graph(nodes, check_for_peak_source=check_for_peak_source) job_kwargs = fix_job_kwargs(job_kwargs) assert all(isinstance(node, PipelineNode) for node in nodes) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index bf7925e68d..f9978a6397 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -266,7 +266,7 @@ def detect_saturation_periods( ) saturation_periods = run_node_pipeline( - recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts" + recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts", check_for_peak_source=False ) num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] return _collapse_events(saturation_periods) @@ -461,6 +461,7 @@ def detect_artifact_periods_by_envelope( [node0], job_kwargs, job_name="detect artifact on envelope", + check_for_peak_source=False, ) order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) From 0b1a098708fea7536738d5d95f025b36e3e3d04e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 20 Mar 2026 17:12:53 +0100 Subject: [PATCH 15/17] raise if saturation is None --- src/spikeinterface/preprocessing/detect_artifacts.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index f9978a6397..9fd6a3a85b 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -258,6 +258,11 @@ def detect_saturation_periods( if "saturation_threshold_uV" in recording.get_annotation_keys() and saturation_threshold_uV is None: saturation_threshold_uV = recording.get_annotation("saturation_threshold_uV") + if saturation_threshold_uV is None: + raise ValueError( + "Cannot read `saturation_threshold_uV` from recording. Please pass `saturation_threshold_uV` manually." + ) + node0 = _DetectSaturation( recording, saturation_threshold_uV=saturation_threshold_uV, @@ -268,7 +273,6 @@ def detect_saturation_periods( saturation_periods = run_node_pipeline( recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts", check_for_peak_source=False ) - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] return _collapse_events(saturation_periods) From c7fa4ffd245c6c2cc29e9e6df06c815929e32d04 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 20 Mar 2026 17:21:15 +0100 Subject: [PATCH 16/17] fix: collapsing events 1-sample off --- src/spikeinterface/preprocessing/detect_artifacts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 9fd6a3a85b..6499d1163e 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -47,7 +47,10 @@ def _collapse_events(events: np.ndarray) -> np.ndarray: # compute if duplicate for i in np.arange(events.size - 1): - same = events["end_sample_index"][i] == events["start_sample_index"][i + 1] + # We use the + 1 because the end sample index is inclusive + # so if the next start sample index is exactly 1 more than the end sample index, + # then they are part of the same artifact period + same = events["end_sample_index"][i] + 1 >= events["start_sample_index"][i + 1] if same: to_drop[i] = True events["start_sample_index"][i + 1] = events["start_sample_index"][i] From c5c81346f61a5f5dd750957c8e294de13d89182b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 20 Mar 2026 17:27:52 +0100 Subject: [PATCH 17/17] fix overlapping collaped events: they must be from the same segment --- src/spikeinterface/preprocessing/detect_artifacts.py | 5 +++-- .../preprocessing/tests/test_detect_artifacts.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 6499d1163e..adcd1d80f8 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -50,8 +50,9 @@ def _collapse_events(events: np.ndarray) -> np.ndarray: # We use the + 1 because the end sample index is inclusive # so if the next start sample index is exactly 1 more than the end sample index, # then they are part of the same artifact period - same = events["end_sample_index"][i] + 1 >= events["start_sample_index"][i + 1] - if same: + overlapping = events["end_sample_index"][i] + 1 >= events["start_sample_index"][i + 1] + same_segment = events["segment_index"][i] == events["segment_index"][i + 1] + if overlapping and same_segment: to_drop[i] = True events["start_sample_index"][i + 1] = events["start_sample_index"][i] collapsed_events = events[~to_drop] diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index e0db644d87..1d99206bd2 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -240,4 +240,4 @@ def test_detect_saturation_periods(debug_plots): if __name__ == "__main__": # test_detect_artifact_by_envelope(True) - test_detect_saturation_periods(True) + test_detect_saturation_periods(False)