From f68a1caa838ae2578e797984ee2384a2aa68eafa Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 16 Mar 2026 10:06:05 +0000 Subject: [PATCH 1/5] wrong indentation --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 91b3d977fe..8e16757bcc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2198,7 +2198,7 @@ def get_extension_class(extension_name: str, auto_import=True): warnings.warn( f"Extension '{extension_name}' is unknown. Maybe this is an external extension, a typo or was computed by a different version of SpikeInterface." ) - return None + return None ext_class = extensions_dict[extension_name] return ext_class From d8cab2950cb61703b2f8bf2fadf842a8fee2aadd Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 16 Mar 2026 10:55:12 +0000 Subject: [PATCH 2/5] make trained results more robust --- src/spikeinterface/curation/tests/common.py | 28 +++++++++++------ .../tests/test_model_based_curation.py | 30 ++++++++++++------- .../tests/test_unitrefine_curation.py | 4 ++- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index d7fe7b4262..dcf3b12343 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -1,6 +1,6 @@ import pytest -from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer +from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer, aggregate_units from spikeinterface.core.generate import inject_some_split_units from spikeinterface.curation import train_model from pathlib import Path @@ -68,7 +68,15 @@ def make_sorting_analyzer_with_splits(sorting_analyzer, num_unit_splitted=1, num @pytest.fixture(scope="module") def sorting_analyzer_for_curation(): - return make_sorting_analyzer(sparse=True) + """Makes an analyzer whose first 10 units are good normal units, and 10 which are noise. We make them + noise by using a spike trains which are uncorrelated with the recording for `sorting2`.""" + + recording, sorting_1 = generate_ground_truth_recording(num_channels=4, seed=1, num_units=10) + _, sorting_2 = generate_ground_truth_recording(num_channels=4, seed=2, num_units=10) + both_sortings = aggregate_units([sorting_1, sorting_2]) + analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording) + analyzer.compute(["random_spikes", "noise_levels", "templates"]) + return analyzer @pytest.fixture(scope="module") @@ -83,7 +91,7 @@ def sorting_analyzer_with_splits(): @pytest.fixture(scope="module") -def trained_pipeline_path(): +def trained_pipeline_path(sorting_analyzer_for_curation): """ Makes a model saved at "./trained_pipeline" which will be used by other tests in the module. If the model already exists, this function does nothing. @@ -92,20 +100,22 @@ def trained_pipeline_path(): if trained_model_folder.is_dir(): yield trained_model_folder else: - analyzer = make_sorting_analyzer(sparse=True) + analyzer = sorting_analyzer_for_curation analyzer.compute( { - "quality_metrics": {"metric_names": ["snr", "num_spikes"]}, - "template_metrics": {"metric_names": ["half_width"]}, + "quality_metrics": {"metric_names": ["snr"]}, + "template_metrics": {"metric_names": ["half_width", "peak_to_trough_duration", "number_of_peaks"]}, } ) train_model( - analyzers=[analyzer] * 5, - labels=[[1, 0, 1, 0, 1]] * 5, + analyzers=[analyzer], folder=trained_model_folder, - classifiers=["RandomForestClassifier"], + labels=[[1] * 10 + [0] * 10], imputation_strategies=["median"], scaling_techniques=["standard_scaler"], + classifiers=["RandomForestClassifier"], + overwrite=True, + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 2}, ) yield trained_model_folder diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 39a67bcde8..eb7cc7f12c 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -43,8 +43,10 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pip """The function `model_based_label_units` needs the correct metrics to have been computed. However, it should be independent of the order of computation. We test this here.""" - sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + sorting_analyzer_for_curation.compute( + "template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"] + ) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr"]) prediction_prob_dataframe_1 = model_based_label_units( sorting_analyzer=sorting_analyzer_for_curation, @@ -52,7 +54,9 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pip trusted=["numpy.dtype"], ) - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) + sorting_analyzer_for_curation.compute( + "template_metrics", metric_names=["peak_to_trough_duration", "half_width", "number_of_peaks"] + ) prediction_prob_dataframe_2 = model_based_label_units( sorting_analyzer=sorting_analyzer_for_curation, @@ -118,7 +122,9 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation we expect these labels to be outputted. The test checks this, and also checks that label conversion works as expected.""" - sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute( + "template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"] + ) sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) # Test the predict_labels() method of ModelBasedClassification @@ -126,12 +132,14 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation classified_units = model_based_classification.predict_labels() predictions = classified_units["prediction"].values - assert np.all(predictions == np.array([1, 0, 1, 0, 1])) + expected_result = np.array([1] * 10 + [0] * 10) + assert np.all(predictions == expected_result) conversion = {0: "noise", 1: "good"} + expected_result_converted = np.array(["good"] * 10 + ["noise"] * 10) classified_units_labelled = model_based_classification.predict_labels(label_conversion=conversion) predictions_labelled = classified_units_labelled["prediction"] - assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"]) + assert np.all(predictions_labelled == expected_result_converted) @pytest.mark.skip(reason="We need to retrain the model to reflect any changes in metric computation") @@ -142,9 +150,11 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here.""" sorting_analyzer_for_curation.compute( - "quality_metrics", metric_names=["num_spikes", "snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}} + "quality_metrics", metric_names=["snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}} + ) + sorting_analyzer_for_curation.compute( + "template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"] ) - sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) @@ -160,10 +170,10 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura # Now test the positive case. Recompute using the default parameters sorting_analyzer_for_curation.compute( "quality_metrics", - metric_names=["num_spikes", "snr"], + metric_names=["snr"], metric_params={"snr": {"peak_sign": "neg", "peak_mode": "extremum"}}, ) - sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width", "peak_to_trough_duration"]) model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) diff --git a/src/spikeinterface/curation/tests/test_unitrefine_curation.py b/src/spikeinterface/curation/tests/test_unitrefine_curation.py index b14326f4ea..9fe949d4e3 100644 --- a/src/spikeinterface/curation/tests/test_unitrefine_curation.py +++ b/src/spikeinterface/curation/tests/test_unitrefine_curation.py @@ -68,7 +68,9 @@ def test_unitrefine_label_units_hf(sorting_analyzer_for_curation): def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path): # test with trained local models - sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True) + sorting_analyzer_for_curation.compute( + "template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"] + ) sorting_analyzer_for_curation.compute("quality_metrics") # test passing model folder From 9ef09f8785e76ab52fe250b417d5a210b4484613 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 16 Mar 2026 11:11:36 +0000 Subject: [PATCH 3/5] fix unitrefine_curation tests --- .../curation/tests/test_unitrefine_curation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_unitrefine_curation.py b/src/spikeinterface/curation/tests/test_unitrefine_curation.py index 9fe949d4e3..bdb8f41f65 100644 --- a/src/spikeinterface/curation/tests/test_unitrefine_curation.py +++ b/src/spikeinterface/curation/tests/test_unitrefine_curation.py @@ -6,8 +6,13 @@ def test_unitrefine_label_units_hf(sorting_analyzer_for_curation): """Test the `unitrefine_label_units` function.""" - sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True) - sorting_analyzer_for_curation.compute("quality_metrics") + sorting_analyzer_for_curation.compute( + { + "spike_amplitudes": {}, + "template_metrics": {"include_multi_channel_metrics": True}, + "quality_metrics": {}, + } + ) # test passing both classifiers labels = unitrefine_label_units( From 775e9e01b8166a76c169318173ec28e74d9cb6ea Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 16 Mar 2026 11:17:25 +0000 Subject: [PATCH 4/5] change analyzer for unitrefine curation --- src/spikeinterface/curation/tests/common.py | 9 ++- .../tests/test_model_based_curation.py | 72 ++++++++++--------- .../tests/test_unitrefine_curation.py | 34 ++++----- 3 files changed, 61 insertions(+), 54 deletions(-) diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index dcf3b12343..dce3cb1722 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -68,6 +68,11 @@ def make_sorting_analyzer_with_splits(sorting_analyzer, num_unit_splitted=1, num @pytest.fixture(scope="module") def sorting_analyzer_for_curation(): + return make_sorting_analyzer(sparse=True) + + +@pytest.fixture(scope="module") +def sorting_analyzer_for_unitrefine_curation(): """Makes an analyzer whose first 10 units are good normal units, and 10 which are noise. We make them noise by using a spike trains which are uncorrelated with the recording for `sorting2`.""" @@ -91,7 +96,7 @@ def sorting_analyzer_with_splits(): @pytest.fixture(scope="module") -def trained_pipeline_path(sorting_analyzer_for_curation): +def trained_pipeline_path(sorting_analyzer_for_unitrefine_curation): """ Makes a model saved at "./trained_pipeline" which will be used by other tests in the module. If the model already exists, this function does nothing. @@ -100,7 +105,7 @@ def trained_pipeline_path(sorting_analyzer_for_curation): if trained_model_folder.is_dir(): yield trained_model_folder else: - analyzer = sorting_analyzer_for_curation + analyzer = sorting_analyzer_for_unitrefine_curation analyzer.compute( { "quality_metrics": {"metric_names": ["snr"]}, diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index eb7cc7f12c..36bd36e8c6 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -1,7 +1,7 @@ import pytest from pathlib import Path -from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path +from spikeinterface.curation.tests.common import sorting_analyzer_for_unitrefine_curation, trained_pipeline_path from spikeinterface.curation.model_based_curation import ModelBasedClassification from spikeinterface.curation import model_based_label_units, load_model @@ -16,8 +16,8 @@ @pytest.fixture def model(trained_pipeline_path): - """A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`. - It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with + """A toy model, created using the `sorting_analyzer_for_unitrefine_curation` from `spikeinterface.curation.tests.common`. + It has been trained locally and, when applied to `sorting_analyzer_for_unitrefine_curation` will label its 5 units with the following labels: [1,0,1,0,1].""" model = load_model(trained_pipeline_path, trusted=["numpy.dtype"]) @@ -30,36 +30,36 @@ def required_metrics_and_columns(): return ["num_spikes", "snr", "half_width"], ["num_spikes", "snr", "trough_half_width", "peak_half_width"] -def test_model_based_classification_init(sorting_analyzer_for_curation, model): +def test_model_based_classification_init(sorting_analyzer_for_unitrefine_curation, model): """Test that the ModelBasedClassification attributes are correctly initialised""" - model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) - assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation + model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model[0]) + assert model_based_classification.sorting_analyzer == sorting_analyzer_for_unitrefine_curation assert model_based_classification.pipeline == model[0] assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_) -def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path): +def test_metric_ordering_independence(sorting_analyzer_for_unitrefine_curation, trained_pipeline_path): """The function `model_based_label_units` needs the correct metrics to have been computed. However, it should be independent of the order of computation. We test this here.""" - sorting_analyzer_for_curation.compute( + sorting_analyzer_for_unitrefine_curation.compute( "template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"] ) - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer_for_unitrefine_curation.compute("quality_metrics", metric_names=["snr"]) prediction_prob_dataframe_1 = model_based_label_units( - sorting_analyzer=sorting_analyzer_for_curation, + sorting_analyzer=sorting_analyzer_for_unitrefine_curation, model_folder=trained_pipeline_path, trusted=["numpy.dtype"], ) - sorting_analyzer_for_curation.compute( + sorting_analyzer_for_unitrefine_curation.compute( "template_metrics", metric_names=["peak_to_trough_duration", "half_width", "number_of_peaks"] ) prediction_prob_dataframe_2 = model_based_label_units( - sorting_analyzer=sorting_analyzer_for_curation, + sorting_analyzer=sorting_analyzer_for_unitrefine_curation, model_folder=trained_pipeline_path, trusted=["numpy.dtype"], ) @@ -68,40 +68,40 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pip def test_model_based_classification_get_metrics_for_classification( - sorting_analyzer_for_curation, model, required_metrics_and_columns + sorting_analyzer_for_unitrefine_curation, model, required_metrics_and_columns ): """If the user has not computed the required metrics, an error should be returned. This test checks that an error occurs when the required metrics have not been computed, and that no error is returned when the required metrics have been computed. """ - sorting_analyzer_for_curation.delete_extension("quality_metrics") - sorting_analyzer_for_curation.delete_extension("template_metrics") + sorting_analyzer_for_unitrefine_curation.delete_extension("quality_metrics") + sorting_analyzer_for_unitrefine_curation.delete_extension("template_metrics") required_metric_names, required_metric_columns = required_metrics_and_columns - model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model[0]) # Compute some (but not all) of the required metrics in sorting_analyzer, should still error - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metric_names[0]]) - computed_metrics = sorting_analyzer_for_curation.get_metrics_extension_data() + sorting_analyzer_for_unitrefine_curation.compute("quality_metrics", metric_names=[required_metric_names[0]]) + computed_metrics = sorting_analyzer_for_unitrefine_curation.get_metrics_extension_data() with pytest.raises(ValueError): model_based_classification._check_required_metrics_are_present(computed_metrics) # Compute all of the required metrics in sorting_analyzer, no more error - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metric_names[0:2]) - sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metric_names[2]]) + sorting_analyzer_for_unitrefine_curation.compute("quality_metrics", metric_names=required_metric_names[0:2]) + sorting_analyzer_for_unitrefine_curation.compute("template_metrics", metric_names=[required_metric_names[2]]) - metrics_data = sorting_analyzer_for_curation.get_metrics_extension_data() - assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) + metrics_data = sorting_analyzer_for_unitrefine_curation.get_metrics_extension_data() + assert metrics_data.shape[0] == len(sorting_analyzer_for_unitrefine_curation.sorting.get_unit_ids()) assert set(metrics_data.columns.to_list()) == set(required_metric_columns) -def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model): +def test_model_based_classification_export_to_phy(sorting_analyzer_for_unitrefine_curation, model): import pandas as pd # Test the _export_to_phy() method of ModelBasedClassification - model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model[0]) classified_units = pd.DataFrame.from_dict({0: (1, 0.5), 1: (0, 0.5), 2: (1, 0.5), 3: (0, 0.5), 4: (1, 0.5)}) # Function should fail here @@ -116,19 +116,19 @@ def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, assert (phy_folder / "cluster_prediction.tsv").exists() -def test_model_based_classification_predict_labels(sorting_analyzer_for_curation, model): +def test_model_based_classification_predict_labels(sorting_analyzer_for_unitrefine_curation, model): """The model `model` has been trained on the `sorting_analyzer` used in this test with the labels `[1, 0, 1, 0, 1]`. Hence if we apply the model to this `sorting_analyzer` we expect these labels to be outputted. The test checks this, and also checks that label conversion works as expected.""" - sorting_analyzer_for_curation.compute( + sorting_analyzer_for_unitrefine_curation.compute( "template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"] ) - sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + sorting_analyzer_for_unitrefine_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) # Test the predict_labels() method of ModelBasedClassification - model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model[0]) classified_units = model_based_classification.predict_labels() predictions = classified_units["prediction"].values @@ -143,21 +143,21 @@ def test_model_based_classification_predict_labels(sorting_analyzer_for_curation @pytest.mark.skip(reason="We need to retrain the model to reflect any changes in metric computation") -def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_curation, trained_pipeline_path): +def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_unitrefine_curation, trained_pipeline_path): """We track whether the metric parameters used to compute the metrics used to train a model are the same as the parameters used to compute the metrics in the sorting analyzer which is being curated. If they are different, an error or warning will be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here.""" - sorting_analyzer_for_curation.compute( + sorting_analyzer_for_unitrefine_curation.compute( "quality_metrics", metric_names=["snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}} ) - sorting_analyzer_for_curation.compute( + sorting_analyzer_for_unitrefine_curation.compute( "template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"] ) model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) - model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model) # an error should be raised if `enforce_metric_params` is True with pytest.raises(Exception): @@ -168,13 +168,15 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info) # Now test the positive case. Recompute using the default parameters - sorting_analyzer_for_curation.compute( + sorting_analyzer_for_unitrefine_curation.compute( "quality_metrics", metric_names=["snr"], metric_params={"snr": {"peak_sign": "neg", "peak_mode": "extremum"}}, ) - sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width", "peak_to_trough_duration"]) + sorting_analyzer_for_unitrefine_curation.compute( + "template_metrics", metric_names=["half_width", "peak_to_trough_duration"] + ) model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) - model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model) model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) diff --git a/src/spikeinterface/curation/tests/test_unitrefine_curation.py b/src/spikeinterface/curation/tests/test_unitrefine_curation.py index bdb8f41f65..314fcb5878 100644 --- a/src/spikeinterface/curation/tests/test_unitrefine_curation.py +++ b/src/spikeinterface/curation/tests/test_unitrefine_curation.py @@ -1,12 +1,12 @@ import pytest -from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path +from spikeinterface.curation.tests.common import sorting_analyzer_for_unitrefine_curation, trained_pipeline_path from spikeinterface.curation import unitrefine_label_units -def test_unitrefine_label_units_hf(sorting_analyzer_for_curation): +def test_unitrefine_label_units_hf(sorting_analyzer_for_unitrefine_curation): """Test the `unitrefine_label_units` function.""" - sorting_analyzer_for_curation.compute( + sorting_analyzer_for_unitrefine_curation.compute( { "spike_amplitudes": {}, "template_metrics": {"include_multi_channel_metrics": True}, @@ -16,41 +16,41 @@ def test_unitrefine_label_units_hf(sorting_analyzer_for_curation): # test passing both classifiers labels = unitrefine_label_units( - sorting_analyzer_for_curation, + sorting_analyzer_for_unitrefine_curation, noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", ) assert "unitrefine_label" in labels.columns assert "unitrefine_probability" in labels.columns - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) + assert labels.shape[0] == len(sorting_analyzer_for_unitrefine_curation.sorting.unit_ids) # test only noise neural classifier labels = unitrefine_label_units( - sorting_analyzer_for_curation, + sorting_analyzer_for_unitrefine_curation, noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", sua_mua_classifier=None, ) assert "unitrefine_label" in labels.columns assert "unitrefine_probability" in labels.columns - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) + assert labels.shape[0] == len(sorting_analyzer_for_unitrefine_curation.sorting.unit_ids) # test only sua mua classifier labels = unitrefine_label_units( - sorting_analyzer_for_curation, + sorting_analyzer_for_unitrefine_curation, noise_neural_classifier=None, sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", ) assert "unitrefine_label" in labels.columns assert "unitrefine_probability" in labels.columns - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) + assert labels.shape[0] == len(sorting_analyzer_for_unitrefine_curation.sorting.unit_ids) # test passing none with pytest.raises(ValueError): labels = unitrefine_label_units( - sorting_analyzer_for_curation, + sorting_analyzer_for_unitrefine_curation, noise_neural_classifier=None, sua_mua_classifier=None, ) @@ -58,34 +58,34 @@ def test_unitrefine_label_units_hf(sorting_analyzer_for_curation): # test warnings when unexpected labels are returned with pytest.warns(UserWarning): labels = unitrefine_label_units( - sorting_analyzer_for_curation, + sorting_analyzer_for_unitrefine_curation, noise_neural_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", sua_mua_classifier=None, ) with pytest.warns(UserWarning): labels = unitrefine_label_units( - sorting_analyzer_for_curation, + sorting_analyzer_for_unitrefine_curation, noise_neural_classifier=None, sua_mua_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", ) -def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path): +def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_unitrefine_curation, trained_pipeline_path): # test with trained local models - sorting_analyzer_for_curation.compute( + sorting_analyzer_for_unitrefine_curation.compute( "template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"] ) - sorting_analyzer_for_curation.compute("quality_metrics") + sorting_analyzer_for_unitrefine_curation.compute("quality_metrics") # test passing model folder labels = unitrefine_label_units( - sorting_analyzer_for_curation, + sorting_analyzer_for_unitrefine_curation, noise_neural_classifier=trained_pipeline_path, ) # test passing model folder labels = unitrefine_label_units( - sorting_analyzer_for_curation, + sorting_analyzer_for_unitrefine_curation, noise_neural_classifier=trained_pipeline_path / "best_model.skops", ) From 78b42f97b75aa49f658776aac32c0339ad40dce6 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 16 Mar 2026 11:18:57 +0000 Subject: [PATCH 5/5] docstring update --- .../curation/tests/test_model_based_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 36bd36e8c6..0c9a67a834 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -17,8 +17,8 @@ @pytest.fixture def model(trained_pipeline_path): """A toy model, created using the `sorting_analyzer_for_unitrefine_curation` from `spikeinterface.curation.tests.common`. - It has been trained locally and, when applied to `sorting_analyzer_for_unitrefine_curation` will label its 5 units with - the following labels: [1,0,1,0,1].""" + It has been trained locally and, when applied to `sorting_analyzer_for_unitrefine_curation` will label its 10 units with + the following labels: [1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0].""" model = load_model(trained_pipeline_path, trusted=["numpy.dtype"]) return model