Skip to content

Speed up template similarity computing using numba#4343

Merged
samuelgarcia merged 20 commits into
SpikeInterface:mainfrom
tayheau:soft_merges_and_refactor
Mar 13, 2026
Merged

Speed up template similarity computing using numba#4343
samuelgarcia merged 20 commits into
SpikeInterface:mainfrom
tayheau:soft_merges_and_refactor

Conversation

@tayheau

@tayheau tayheau commented Jan 26, 2026

Copy link
Copy Markdown
Contributor

Following #4310, @chrishalcrow showed that given that the diff between two merged templates is considered small, we can approximate the distance of a megred template and a new one a as a linear function. This should allow us to speed up significantly template similarity computations for merged ones.

@alejoe91 alejoe91 added the postprocessing Related to postprocessing module label Jan 28, 2026
@tayheau tayheau force-pushed the soft_merges_and_refactor branch from ce491e6 to 99985d4 Compare February 16, 2026 16:20
@tayheau

tayheau commented Feb 24, 2026

Copy link
Copy Markdown
Contributor Author

So everything is supposed to work fine, i think i have some differences mostly due to casting but it's "minimal" in the context of a normalized distance.
image

Just to quickly sum up i moved the computing of the support matrix out of the loop, it will be more 'memory' costly in case of different template matrix but i think since most of the time we have the same ones it's a boost here. Also more leverage of numpy views instead of copies.

Im not this familiar with the numba/numpy efficiency stuff but for the union it was faster to do it the "dummy" way (mine lol) that the vectorised one. So if you guys have so rule of thumbs tips for Numba, im in ;) @samuelgarcia

image

@tayheau tayheau changed the title Speed up template similarity for soft merges Speed up template similarity computing using numba Feb 24, 2026
@tayheau tayheau marked this pull request as ready for review February 24, 2026 13:32
@chrishalcrow

Copy link
Copy Markdown
Member

Looks great as a user! I get a small ~30% speed up on initial compute, and 10x speed-up on recompute (=> trying out different methods in the gui is super fast, and hard merges are very fast).

And on some real data with 800 neurons, the biggest absmax in difference is...

np.max(np.abs(old_temps - new_temps))
>> np.float32(4.2915344e-06)

Nice!!!

Comment thread src/spikeinterface/postprocessing/template_similarity.py
@alejoe91 alejoe91 added this to the 0.104.0 milestone Feb 25, 2026

@yger yger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be carefull, because in between, we patched a bug in the similarity due to over optimizations. We can not compute only hald the times, and the upper part of the matrix and use symmetry everywhere, otherwise this is not complete. If we symmetrize in time, we need compute for all indices

Comment thread src/spikeinterface/postprocessing/template_similarity.py Outdated
Comment thread src/spikeinterface/postprocessing/template_similarity.py Outdated
@yger

yger commented Feb 25, 2026

Copy link
Copy Markdown
Collaborator

I would be carefull here, because we recently patched a bug in #4345 and I think this is not propagated here

@tayheau tayheau requested a review from yger March 4, 2026 09:51
@alejoe91

alejoe91 commented Mar 9, 2026

Copy link
Copy Markdown
Member

@yger I think that the symmetry issue is fixed. Can you double check?

Comment thread src/spikeinterface/postprocessing/template_similarity.py Outdated
@samuelgarcia

Copy link
Copy Markdown
Member

Salut Theo.
We check it with Pierre and we have push a small commit.
You did the same mistake pierre did some time ago. Now it is correct.

@samuelgarcia

Copy link
Copy Markdown
Member

This is now OK for me.
I am not sure that the speedup will be as big as @chrishalcrow has seen.
Could you check ?

@tayheau

tayheau commented Mar 11, 2026

Copy link
Copy Markdown
Contributor Author

salut Samuel,
but with the fix you did, won't the array only half populated in the case of same_array ?

@yger

yger commented Mar 11, 2026

Copy link
Copy Markdown
Collaborator

No because if this is not the same array, we'll explore all shifts, am I right?

@tayheau

tayheau commented Mar 11, 2026

Copy link
Copy Markdown
Contributor Author

to me there two symmetries : a "spatial" one where we can say dist(i, j) = dist(j, i) at lag t and a temporal one where dist(i, j) at lag t = dist(i, j) at lag -t (we do agree that both apply only in case of same_array) . So that's why we check twice the matrix similarity, and for num_shifts != 0 it's just so that it doesnt erase with the transpose but it should be shift  != 0. here i think this fix erease the 'spatial' symmetry optimisation

@yger

yger commented Mar 12, 2026

Copy link
Copy Markdown
Collaborator

You are right. Can we make that together with a small call ?

@yger

yger commented Mar 13, 2026

Copy link
Copy Markdown
Collaborator

@tayheau if you can finish today and propagate changes to compute_(...)_numpy this would be great, since we would like to merge that quickly for release. Otherwise let me know and I can have a go on the branch

Comment thread src/spikeinterface/postprocessing/template_similarity.py Outdated
@yger

yger commented Mar 13, 2026

Copy link
Copy Markdown
Collaborator

As told to @samuelgarcia this is good to me, expect that currently optimizations are only performed at the numba level, not yet in pure numpy (in case numba is not present)

@yger

yger commented Mar 13, 2026

Copy link
Copy Markdown
Collaborator

I'm not able to push in your branch, but here is a code to change _compute_similarity_matrix_numpy()

def _compute_similarity_matrix_numpy(
    templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union"
):

    num_templates = templates_array.shape[0]
    num_samples = templates_array.shape[1]
    other_num_templates = other_templates_array.shape[0]

    num_shifts_both_sides = 2 * num_shifts + 1
    distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
    same_array = np.array_equal(templates_array, other_templates_array)

    # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
    # So the matrix can be computed only for negative lags and be transposed

    if same_array:
        # optimisation when array are the same because of symetry in shift
        shift_loop = range(-num_shifts, 1)
    else:
        shift_loop = range(-num_shifts, num_shifts + 1)

    for count, shift in enumerate(shift_loop):
        src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
        tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
        for i in range(num_templates):
            src_template = src_sliced_templates[i]
            local_mask = get_overlapping_mask_for_one_template(i, sparsity_mask, other_sparsity_mask, support=support)
            overlapping_templates = np.flatnonzero(np.sum(local_mask, 1))
            tgt_templates = tgt_sliced_templates[overlapping_templates]
            for gcount, j in enumerate(overlapping_templates):
                if j < i and same_array:
                    continue
                src = src_template[:, local_mask[j]].reshape(1, -1)
                tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1)

                if method == "l1":
                    norm_i = np.sum(np.abs(src))
                    norm_j = np.sum(np.abs(tgt))
                    distances[count, i, j] = np.sum(np.abs(src - tgt))
                    distances[count, i, j] /= norm_i + norm_j
                elif method == "l2":
                    norm_i = np.linalg.norm(src, ord=2)
                    norm_j = np.linalg.norm(tgt, ord=2)
                    distances[count, i, j] = np.linalg.norm(src - tgt, ord=2)
                    distances[count, i, j] /= norm_i + norm_j
                elif method == "cosine":
                    norm_i = np.linalg.norm(src, ord=2)
                    norm_j = np.linalg.norm(tgt, ord=2)
                    distances[count, i, j] = np.sum(src * tgt)
                    distances[count, i, j] /= norm_i * norm_j
                    distances[count, i, j] = 1 - distances[count, i, j]

                if same_array:
                    distances[count, j, i] = distances[count, i, j]

        if same_array and shift != 0:
            distances[num_shifts_both_sides - count - 1] = distances[count].T

    return distances

@samuelgarcia

Copy link
Copy Markdown
Member

Merci Theo.
Désolé pour l'intrusion inutile!!
je merge.

@samuelgarcia samuelgarcia merged commit 3fa0779 into SpikeInterface:main Mar 13, 2026
15 checks passed
@tayheau

tayheau commented Mar 13, 2026

Copy link
Copy Markdown
Contributor Author

pas de soucis !
c'est fait pour ça aussi git hehe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

postprocessing Related to postprocessing module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants