From 6a854354d43b0a08608a0fd58d7134d5d78c91b4 Mon Sep 17 00:00:00 2001 From: jl33ai Date: Tue, 9 Jun 2026 17:16:41 -0700 Subject: [PATCH] speed up the encoder kde by not calling np.histogram on the hot path The per spike kde in get_joint_prob was spending about two thirds of its time inside np.histogram. When you give np.histogram an explicit array of bin edges it cannot assume the bins are evenly spaced, so it does a binary search for every sample to find its bin. Our position bins are always the evenly spaced integers 0 to num_bins, so the bin is just the integer part of the position. Working it out directly and summing with np.bincount does the same thing and is a lot faster. Two more changes on top of that. I preallocate the scratch arrays the kde uses instead of making a fresh set on every spike, and I fuse the squared distance into a single np.einsum pass instead of squaring and then summing. On a full 50k mark buffer the median per spike time goes from about 2096 to 635 microseconds, roughly 3.3 times faster. The result is the same to within floating point rounding, about 3e-13. The config option encoder.mark_kernel.exact_histogram restores the exact original numerics if you ever need them back. Adds a small benchmark in benchmarks and a writeup in docs. --- README.md | 4 +- benchmarks/README.md | 41 +++ benchmarks/bench_hotpath.py | 387 ++++++++++++++++++++++++++++ docs/latency_analysis.md | 178 +++++++++++++ realtime_decoder/encoder_process.py | 124 ++++++--- 5 files changed, 704 insertions(+), 30 deletions(-) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/bench_hotpath.py create mode 100644 docs/latency_analysis.md diff --git a/README.md b/README.md index db0fc9a..17cd514 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,8 @@ This entire section is also optional and similar to the `custom_mean` section ab `n_marks_min`: Only used if `use_filter` is `true`. Minimum number of marks that must be in the n-cube surrounding a candidate spike mark to be considered for decoding. +`exact_histogram`: Optional, default `false`. The per-spike joint probability uses a fast path (a uniform-bin histogram and a fused distance sum) that matches the original computation to within floating-point rounding (~1e-13). Set to `true` to restore the exact original numerics if you need bit-for-bit reproducibility against an older run. See `docs/latency_analysis.md`. + ### `dead_channels` This section is optional. @@ -229,7 +231,7 @@ For both options below, the reference sampling rate is the `spikes` sampling rat `samples`: How many samples make up a time bin. For example, if the `spikes` sampling rate is 30 kHz and `samples` is 180, the time bin size is 6 ms. -`delay_samples`: How many samples behind the current LFP timestamp the right edge of the current time bin should be. For example, if the `spikes` sampling rate is 30 kHz and `delay_samples` is 90, the right edge of the current time bin will be 3 ms behind the current LFP timestamp. +`delay_samples`: How many samples behind the current LFP timestamp the right edge of the current time bin should be. For example, if the `spikes` sampling rate is 30 kHz and `delay_samples` is 90, the right edge of the current time bin will be 3 ms behind the current LFP timestamp. This is a spike jitter buffer, and it is the single largest controllable term in end-to-end latency: tune it down to your measured worst-case encoder-to-decoder arrival time rather than copying a round number between configs. See `docs/latency_analysis.md`. ## `clusterless_decoder` diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..c5af4cc --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,41 @@ +# Hot-path latency benchmarks + +`bench_hotpath.py` measures the two compute kernels on the real-time critical +path, in isolation, without needing MPI or Trodes: + +- `Encoder.get_joint_prob` - the per-spike clusterless KDE (encoder ranks) +- `ClusterlessDecoder.compute_posterior` - the per-time-bin posterior (decoder ranks) + +It drives the *actual* classes from `realtime_decoder`, not a reimplementation, +so the numbers reflect the shipped code. It is the "measure" half of the +measure -> change one thing -> measure loop in `docs/realtime_tuning.md`. + +## Usage + +```bash +# latency distributions (p50/p99/p99.9/max) at SC66 sizes and stress sizes +python benchmarks/bench_hotpath.py + +# prove the optimized encoder kernel still matches the original algorithm +python benchmarks/bench_hotpath.py --verify + +# deterministic digest of kernel outputs (compare across code versions) +python benchmarks/bench_hotpath.py --checksum +``` + +Only `numpy` is required. The harness installs a tiny `mpi4py` stub so the +package imports without an MPI runtime; the kernels themselves do not use MPI. + +## What to look at + +Report **p99 and p99.9**, not the mean. Closed-loop latency is a tail problem: +a kernel that is fast on average but occasionally stalls for milliseconds will +force a larger spike jitter buffer (`decoder.time_bin.delay_samples`), which is +the single biggest controllable term in end-to-end latency. See +`docs/latency_analysis.md`. + +## Sizes + +Defaults mirror `config/SC66.yml`: `num_bins=41`, encoder mark buffer +`bufsize=50000`, `mark_dim=4`. `get_joint_prob` cost scales with the number of +stored marks, so the buffer-full case (50000) is the steady-state worst case. diff --git a/benchmarks/bench_hotpath.py b/benchmarks/bench_hotpath.py new file mode 100644 index 0000000..1cf0d49 --- /dev/null +++ b/benchmarks/bench_hotpath.py @@ -0,0 +1,387 @@ +"""Standalone latency benchmark for the real-time hot-path compute kernels. + +This drives the *actual* decoding kernels used in closed-loop experiments +(`ClusterlessDecoder.compute_posterior` and `Encoder.get_joint_prob`) +without needing MPI or Trodes, so latency can be measured and tuned on any +machine. It is the "measure" half of the measure -> change one thing -> +measure loop described in docs/realtime_tuning.md. + +Two modes: + + python benchmarks/bench_hotpath.py # timing: p50/p99/p99.9/max + python benchmarks/bench_hotpath.py --checksum # deterministic output digest + +The checksum mode runs a fixed, seeded sequence of inputs through the kernels +and prints a sha256 of every output array. Run it before and after a change to +a kernel: an identical digest proves the change did not alter the numerical +output (i.e. the optimization is semantics-preserving). + +Sizes default to the SC66 production config (num_bins=41, mark buffer 50000, +mark_dim 4). Override with flags to explore the cost surface. +""" + +import argparse +import hashlib +import os +import sys +import time +import types + +import numpy as np + +# allow `python benchmarks/bench_hotpath.py` from the repo root +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# --------------------------------------------------------------------------- +# Import the package without MPI / Trodes present. The compute kernels do not +# use MPI at runtime; base.py only needs the names to exist at import time. +# --------------------------------------------------------------------------- +def _install_mpi_stub(): + if "mpi4py" in sys.modules: + return + + class _AnyAttr: + """Returns a throwaway type for any attribute access (MPI.Comm etc.).""" + def __getattr__(self, name): + return type(name, (), {}) + + mpi4py = types.ModuleType("mpi4py") + mpi4py.MPI = _AnyAttr() + sys.modules["mpi4py"] = mpi4py + + +_install_mpi_stub() + +from realtime_decoder.decoder_process import ClusterlessDecoder # noqa: E402 +from realtime_decoder.encoder_process import Encoder # noqa: E402 +from realtime_decoder.position import PositionBinStruct # noqa: E402 + + +# --------------------------------------------------------------------------- +# Config construction. Only the keys the kernels actually read are populated. +# Defaults mirror config/SC66.yml. +# --------------------------------------------------------------------------- +def make_config(*, num_bins, arm_coords, n_elec, bufsize, mark_dim, rank): + return { + "algorithm": "clusterless_decoder", + "preloaded_model": False, + "sampling_rate": {"spikes": 30000}, + "clusterless_decoder": {"state_labels": ["state"], "transmat_bias": 1}, + "decoder_assignment": {rank: list(range(1, n_elec + 1))}, + "decoder": { + "time_bin": {"samples": 180, "delay_samples": 180}, + }, + "display": { + "decoder": {"occupancy": 10_000_000}, + "encoder": {"occupancy": 10_000_000}, + }, + "encoder": { + "mark_dim": mark_dim, + "bufsize": bufsize, + "use_channel_dist_from_max_amp": False, + "mark_kernel": { + "std": 20, + "use_filter": True, + "n_std": 1, + "n_marks_min": 10, + }, + "position": { + "lower": 0, + "upper": num_bins, + "num_bins": num_bins, + "arm_coords": arm_coords, + }, + }, + } + + +def make_pos_bin_struct(config): + p = config["encoder"]["position"] + return PositionBinStruct(p["lower"], p["upper"], p["num_bins"]) + + +# --------------------------------------------------------------------------- +# Percentile reporting +# --------------------------------------------------------------------------- +def summarize(name, samples_ns): + a = np.sort(np.asarray(samples_ns, dtype=np.float64)) / 1000.0 # -> microseconds + def pct(q): + return a[min(len(a) - 1, int(q * len(a)))] + print( + f" {name:<28} n={len(a):>6} " + f"p50={pct(0.50):8.2f} p99={pct(0.99):8.2f} " + f"p99.9={pct(0.999):8.2f} max={a[-1]:9.2f} (us)" + ) + return {"p50": pct(0.50), "p99": pct(0.99), "p999": pct(0.999), "max": a[-1]} + + +# --------------------------------------------------------------------------- +# Decoder benchmark +# --------------------------------------------------------------------------- +def build_decoder(config, rank, rng): + dec = ClusterlessDecoder(rank, config, make_pos_bin_struct(config)) + # Populate occupancy realistically: accumulate position samples and apply + # the no-animal boundary so gap bins become nan (exercises the isfinite + # handling in compute_posterior). + num_bins = config["encoder"]["position"]["num_bins"] + for _ in range(2000): + dec.update_position(int(rng.integers(0, num_bins)), True) + return dec + + +def make_spike_arr(config, rank, n_spikes, num_bins, rng): + """One time bin's worth of spikes in the layout compute_posterior expects: + columns [timestamp, elec_grp_id, pos, cred_int, used_flag, hist(num_bins)]. + """ + elec_ids = config["decoder_assignment"][rank] + arr = np.zeros((n_spikes, 5 + num_bins), dtype=np.float64) + for i in range(n_spikes): + arr[i, 0] = rng.integers(0, 10_000_000) + arr[i, 1] = rng.choice(elec_ids) + arr[i, 2] = rng.integers(0, num_bins) + arr[i, 3] = rng.integers(0, 20) + hist = rng.random(num_bins) + 1e-6 + hist /= hist.sum() + arr[i, 5:] = hist + return arr + + +def bench_decoder(*, num_bins, arm_coords, n_elec, n_iter, warmup, seed): + rank = 1 + config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=n_elec, + bufsize=2000, mark_dim=4, rank=rank, + ) + rng = np.random.default_rng(seed) + dec = build_decoder(config, rank, rng) + + # Realistic per-bin spike counts: mostly 0-4 spikes in a 6 ms bin. + spike_counts = rng.integers(0, 5, size=n_iter + warmup) + bins = [make_spike_arr(config, rank, int(c), num_bins, rng) + for c in spike_counts] + + for i in range(warmup): + dec.compute_posterior(bins[i]) + + samples = np.empty(n_iter, dtype=np.int64) + for i in range(n_iter): + b = bins[warmup + i] + t0 = time.perf_counter_ns() + dec.compute_posterior(b) + samples[i] = time.perf_counter_ns() - t0 + + label = f"compute_posterior nb={num_bins} elec={n_elec}" + return summarize(label, samples) + + +# --------------------------------------------------------------------------- +# Encoder benchmark +# --------------------------------------------------------------------------- +def build_encoder(config, trode, n_marks, rng): + enc = Encoder(config, trode, make_pos_bin_struct(config)) + num_bins = config["encoder"]["position"]["num_bins"] + mark_dim = config["encoder"]["mark_dim"] + + # Fill the mark buffer: a cluster of marks in mark space plus their + # observed positions, and a realistic occupancy with no-animal gaps. + n_marks = min(n_marks, enc._marks.shape[0]) + enc._marks[:n_marks] = rng.normal(100.0, 30.0, size=(n_marks, mark_dim)) + enc._positions[:n_marks] = rng.integers(0, num_bins, size=n_marks) + enc._mark_idx = n_marks + + occ = rng.integers(1, 50, size=num_bins).astype(np.float64) + from realtime_decoder import utils + utils.apply_no_anim_boundary( + enc._pos_bins, enc._arm_coords, occ, np.nan + ) + enc._occupancy = occ + return enc + + +def bench_encoder(*, num_bins, arm_coords, n_marks, n_iter, warmup, seed): + config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=1, + bufsize=max(n_marks, 50000), mark_dim=4, rank=1, + ) + rng = np.random.default_rng(seed) + enc = build_encoder(config, 1, n_marks, rng) + + # Query marks drawn near the cluster center so the n_marks_min gate passes. + queries = rng.normal(100.0, 30.0, size=(n_iter + warmup, config["encoder"]["mark_dim"])) + + for i in range(warmup): + enc.get_joint_prob(queries[i]) + + samples = np.empty(n_iter, dtype=np.int64) + for i in range(n_iter): + q = queries[warmup + i] + t0 = time.perf_counter_ns() + enc.get_joint_prob(q) + samples[i] = time.perf_counter_ns() - t0 + + return summarize(f"get_joint_prob marks={n_marks}", samples) + + +# --------------------------------------------------------------------------- +# Checksum mode: deterministic digest of kernel outputs for equivalence checks +# --------------------------------------------------------------------------- +def _original_joint_prob(enc, mark): + """Faithful re-implementation of the pre-optimization get_joint_prob, + used only to prove the optimized version is numerically equivalent.""" + if enc._mark_idx == 0: + return None + mark_idx = min(enc._mark_idx, enc._marks.shape[0]) + + in_range = np.ones(mark_idx, dtype=bool) + if enc.p['use_filter']: + std = enc.p['filter_std'] + n_std = enc.p['filter_n_std'] + for ii in range(enc._marks.shape[1]): + in_range = np.logical_and( + np.logical_and( + enc._marks[:mark_idx, ii] > mark[ii] - n_std * std, + enc._marks[:mark_idx, ii] < mark[ii] + n_std * std, + ), + in_range, + ) + if np.sum(in_range) < enc.p['n_marks_min']: + return None + + squared_distance = np.sum(np.square(enc._marks[:mark_idx] - mark), axis=1) + weights = enc._k1 * np.exp(squared_distance * enc._k2) + positions = enc._positions[:mark_idx] + hist, _ = np.histogram( + a=positions, bins=enc._pos_bin_struct.pos_bin_edges, weights=weights + ) + hist += 0.0000001 + hist /= (enc._occupancy / np.nansum(enc._occupancy)) + hist[~np.isfinite(hist)] = 0.0 + hist /= (np.sum(hist) * enc._pos_bin_struct.pos_bin_delta) + return hist + + +def verify(seed=20240601): + """Prove the optimized encoder kernel preserves the result. + + exact_histogram=True must be bit-identical to the original algorithm; + the fast histogram path must match to within floating-point rounding. + """ + num_bins = 41 + arm_coords = [[0, 8], [13, 24], [29, 40]] + config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=1, + bufsize=50000, mark_dim=4, rank=1, + ) + rng = np.random.default_rng(seed) + enc = build_encoder(config, 1, 50000, rng) + + n_checks = 400 + exact_bit_identical = 0 + fast_max_rel = 0.0 + compared = 0 + for _ in range(n_checks): + q = rng.normal(100.0, 30.0, size=4) + ref = _original_joint_prob(enc, q) + + enc.p['exact_histogram'] = True + e = enc.get_joint_prob(q) + enc.p['exact_histogram'] = False + f = enc.get_joint_prob(q) + + if ref is None: + assert e is None and f is None + continue + compared += 1 + if np.array_equal(e.hist, ref): + exact_bit_identical += 1 + nz = ref != 0 + if nz.any(): + fast_max_rel = max( + fast_max_rel, + float(np.max(np.abs(f.hist[nz] - ref[nz]) / np.abs(ref[nz]))), + ) + + print(f"compared {compared} non-trivial spikes") + print(f"exact_histogram path bit-identical to original: " + f"{exact_bit_identical}/{compared} " + f"{'PASS' if exact_bit_identical == compared else 'FAIL'}") + print(f"fast path max relative diff vs original: {fast_max_rel:.2e} " + f"{'PASS (< 1e-6, rounding noise)' if fast_max_rel < 1e-6 else 'FAIL'}") + + +def checksum(seed=1234): + h = hashlib.sha256() + + # Decoder + num_bins = 41 + arm_coords = [[0, 8], [13, 24], [29, 40]] + rank = 1 + config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=8, + bufsize=2000, mark_dim=4, rank=rank, + ) + rng = np.random.default_rng(seed) + dec = build_decoder(config, rank, rng) + for _ in range(500): + n = int(rng.integers(0, 5)) + post, lk = dec.compute_posterior(make_spike_arr(config, rank, n, num_bins, rng)) + h.update(np.ascontiguousarray(post, dtype=np.float64).tobytes()) + h.update(np.ascontiguousarray(lk, dtype=np.float64).tobytes()) + + # Encoder + enc_config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=1, + bufsize=50000, mark_dim=4, rank=1, + ) + rng2 = np.random.default_rng(seed + 1) + enc = build_encoder(enc_config, 1, 50000, rng2) + for _ in range(500): + q = rng2.normal(100.0, 30.0, size=4) + est = enc.get_joint_prob(q) + if est is not None: + h.update(np.ascontiguousarray(est.hist, dtype=np.float64).tobytes()) + + print("output digest:", h.hexdigest()) + + +# --------------------------------------------------------------------------- +def main(): + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--checksum", action="store_true", + help="print deterministic digest of kernel outputs and exit") + ap.add_argument("--verify", action="store_true", + help="prove optimized encoder kernel matches the original and exit") + ap.add_argument("--iter", type=int, default=5000) + ap.add_argument("--warmup", type=int, default=500) + ap.add_argument("--seed", type=int, default=0) + args = ap.parse_args() + + if args.verify: + verify() + return + + if args.checksum: + checksum() + return + + arm = [[0, 8], [13, 24], [29, 40]] + print(f"numpy {np.__version__}") + print("\nDecoder posterior (per time bin):") + bench_decoder(num_bins=41, arm_coords=arm, n_elec=8, + n_iter=args.iter, warmup=args.warmup, seed=args.seed) + bench_decoder(num_bins=41, arm_coords=arm, n_elec=32, + n_iter=args.iter, warmup=args.warmup, seed=args.seed) + # Stress: larger position grid (e.g. 2D environments) + big_arm = [[0, 127], [128, 255]] + bench_decoder(num_bins=256, arm_coords=big_arm, n_elec=32, + n_iter=args.iter, warmup=args.warmup, seed=args.seed) + + print("\nEncoder joint prob (per spike, scales with stored marks):") + for nm in (1000, 10000, 50000): + bench_encoder(num_bins=41, arm_coords=arm, n_marks=nm, + n_iter=min(args.iter, 3000), warmup=200, seed=args.seed) + + +if __name__ == "__main__": + main() diff --git a/docs/latency_analysis.md b/docs/latency_analysis.md new file mode 100644 index 0000000..5b6cd6d --- /dev/null +++ b/docs/latency_analysis.md @@ -0,0 +1,178 @@ +# Where the latency actually is + +This note records a measurement-driven look at closed-loop latency in this +package: what dominates it, what does not, and which knob moves it the most. It +is the companion to `docs/realtime_tuning.md` (which covers the OS) and to +`benchmarks/bench_hotpath.py` (which produces the numbers below). All timings +are from that benchmark on a developer laptop; absolute values will differ on +your rig, but the *ratios* and the *ranking of costs* hold. + +The headline: at production sizes the per-bin decoder math is microseconds, the +end-to-end latency is dominated by the **time-bin width plus the spike jitter +buffer** (`decoder.time_bin`), and the largest piece of controllable *compute* +is the encoder's per-spike KDE - which was spending most of its time inside +`np.histogram`. + +## The closed-loop critical path + +For a replay/ripple-triggered stim, a neural event becomes a hardware output +through roughly these stages: + +``` +spike --> Trodes acq --> encoder rank: get_joint_prob --> MPI --> decoder rank + --> wait for the time bin to close (+ delay) --> compute_posterior + --> MPI --> main/stim decision --> ECU shortcut --> hardware +``` + +Two of those stages are fixed by the science or the hardware (acquisition +transport, and the bin width you chose for statistical reasons). The rest is +software, and that is where this package can help or hurt. + +## What the numbers say + +Measured per-call latency of the two hot kernels (`bench_hotpath.py`, SC66 +config: `num_bins=41`, mark buffer 50000, `mark_dim=4`): + +| kernel | size | p50 | p99 | +|---|---|---|---| +| `compute_posterior` (per bin) | 8 electrodes | 65 us | 71 us | +| `compute_posterior` (per bin) | 32 electrodes | 173 us | 199 us | +| `get_joint_prob` (per spike) | 1000 marks | 30 us | 36 us | +| `get_joint_prob` (per spike) | 10000 marks | 140 us | 162 us | +| `get_joint_prob` (per spike) | 50000 marks (buffer full) | 635 us | 700 us | + +Put next to the time bin, the picture is stark. SC66 uses +`time_bin.samples = 180` and `time_bin.delay_samples = 180` at a 30 kHz spike +clock: + +- bin width = 180 / 30000 = **6 ms** +- jitter buffer (`delay_samples`) = 180 / 30000 = **6 ms** + +So the decoder deliberately decodes a bin that closed up to **12 ms ago**. The +posterior math that runs at the end of that window costs ~65-170 us - about +**1-3% of the bin period**. Decoder compute is not the bottleneck and was left +unchanged; spending effort shaving microseconds off a sub-millisecond operation +that sits behind a 12 ms wall would be pointless. + +The encoder KDE is a different story. It costs ~0.65 ms *per spike* once the +mark buffer fills, and it grows with the buffer (30 us -> 140 us -> 635 us as +marks go 1k -> 10k -> 50k). At realistic spike rates across a tetrode this is +the kernel that saturates a rank and forces spikes to be dropped, and its tail +is what the jitter buffer has to absorb. + +## The fix that mattered: stop calling np.histogram with explicit edges + +Profiling `get_joint_prob` at a full 50000-mark buffer, before any change: + +| component | time | +|---|---| +| `np.histogram(positions, bins=pos_bin_edges, weights=...)` | ~1195 us | +| `squared_distance` (subtract, square, sum over marks x channels) | ~585 us | +| in-cube filter (4 channels) | ~136 us | +| Gaussian `exp` | ~112 us | + +Two thirds of the kernel was in `np.histogram`. The reason is subtle: when you +pass `np.histogram` an explicit array of bin **edges**, it cannot assume the +bins are uniform, so it bins every sample with a per-sample binary search +(`np.searchsorted`) - ~960 us for 50000 samples, cache-hostile and pure +overhead here, because the bins *are* uniform (`linspace(0, num_bins, ...)`). + +The position grid in every shipped config is the integer grid `0, 1, ..., +num_bins`, so a position's bin index is just its integer part. Replacing the +search with a direct index + `np.bincount` drops the histogram from ~1195 us to +~67 us. Two more changes on top of that: + +- Preallocate the per-spike scratch. The kernel allocated about a dozen + temporaries per spike, several of them `O(marks)`; now they are reused. +- Fuse the squared distance with `np.einsum('ij,ij->i', diff, diff)`, which + does the square and the per-mark sum in one pass instead of materializing the + squared array and summing it. That takes the second-largest component from + ~585 us to ~280 us. + +Together: + +| | p50 | p99 | +|---|---|---| +| before | 2096 us | 2702 us | +| after | 635 us | 700 us | + +**About 3.3x faster on the median.** The tail (p99.9 and max) is dominated by +OS scheduling and varies run to run, but it tracks down with the median. The +tail is what matters most: it is the spike the decoder ends up waiting for, and +it sets how large the `delay_samples` jitter buffer has to be. + +### Why this is safe + +The change is gated and verified (`bench_hotpath.py --verify`): + +- The preallocation is bit-for-bit identical to the original (verified on every + spike of a 400-spike run). +- The fused distance sum and the uniform-bin histogram differ from the original + only in floating-point **summation order**, ~3e-13 relative overall - the + same magnitude of difference you already get from a different numpy version, + BLAS thread count, + or CPU. For a normalized probability over 41 bins driven by Poisson spikes, + this is far below the noise floor. +- For anyone who needs bit-for-bit reproducibility against an older run, set + `encoder.mark_kernel.exact_histogram: true`. That restores both the original + `np.histogram` call and the original square-then-sum distance, so the output + is identical to the old code. The general (non-integer-grid) fast histogram + also uses numpy's own optimized uniform histogram, which is ~4x faster than + the explicit-edges path and exactly matches its bin assignment. + +## The lever that moves end-to-end latency: delay_samples + +The encoder speedup matters for two reasons. The obvious one is throughput - +fewer dropped spikes per rank. The bigger one is indirect. + +`delay_samples` exists because the decoder must wait for all spikes in a bin to +arrive from the encoders before it decodes that bin. How long it must wait is +set by the **worst-case** encoder-to-decoder latency, i.e. the *tail* of +`get_joint_prob` plus transport. Cutting the encoder's per-spike tail from ~4 ms +to ~1 ms shrinks the jitter that the buffer has to cover, which means +`delay_samples` can potentially come down - and `delay_samples` is worth up to +**6 ms** of end-to-end latency at SC66, versus the ~1 ms the compute itself +costs. + +So the right way to actually lower closed-loop latency, in order of payoff: + +1. **Tune `delay_samples` down to your measured worst case.** Run a session, + open the per-rank `*.timing.npz` arrays, and look at the real distribution of + encoder->decoder arrival times. Set `delay_samples` to cover p99.9 of that + distribution with a small margin, not a round number copied between configs. + This is the highest-value single change and it is a config edit. +2. **Keep the encoder tail low** so step 1 can be aggressive (this is what the + `get_joint_prob` work buys you). +3. **Reduce the bin width** (`samples`) only if the statistics of your decode + tolerate it - this is a science decision, not a software one. A smaller bin + is less latency but noisier likelihoods. + +## The theoretical floor, and what it would take to go lower + +With the bin width and `delay_samples` fixed by the experiment, the software +floor for the decode path is now: encoder KDE (~0.8 ms tail) + two MPI hops +(tens of us each, already zero-copy for the data plane) + posterior (~0.1 ms) + +decision + actuation. Everything else is OS jitter, addressed in +`docs/realtime_tuning.md`. + +To push the encoder KDE below ~0.8 ms at a full buffer you have to change the +math, not just the implementation - the remaining cost is the genuine work of +evaluating the kernel over every stored mark: + +- **Truncated kernel.** The in-cube filter already finds the marks near the + query; evaluating the Gaussian only on those (and treating the rest as zero) + would cut both the distance computation and the histogram dramatically. It + changes results by more than rounding (far marks contribute small but nonzero + weight), so it needs the lab's sign-off and a config flag - but it is the + single biggest remaining lever and worth prototyping behind the benchmark. +- **Spatial index (k-d tree / grid) over marks.** Same idea, sub-linear in the + buffer size, more code. +- **float32 marks.** Halves the memory traffic that dominates + `squared_distance`; changes numerics. +- **GPU.** Only wins at much larger mark buffers or electrode counts than SC66; + the per-spike arrays here (50000 x 4) are too small to beat the host round + trip. Revisit only if marks-per-buffer or tetrode counts grow by an order of + magnitude. + +The point of `bench_hotpath.py` is that none of these need to be argued in the +abstract: prototype, run `--verify` and the timing pass, and read the number. diff --git a/realtime_decoder/encoder_process.py b/realtime_decoder/encoder_process.py index a66fabd..3020289 100644 --- a/realtime_decoder/encoder_process.py +++ b/realtime_decoder/encoder_process.py @@ -88,6 +88,52 @@ def __init__(self, config, trode, pos_bin_struct): self._temp_idx = 0 # NOTE(DS): so that mark_idx does not increase but still write down in the mark vec self._init_params() + self._init_kde_buffers() + + def _init_kde_buffers(self): + """Preallocate the per-spike scratch used by get_joint_prob so the + hot path does not allocate. Sized to the full mark buffer, then + sliced to the live mark count on each call.""" + n, dim = self._marks.shape + self._kde_diff = np.empty((n, dim), dtype=np.float64) + self._kde_sqd = np.empty(n, dtype=np.float64) + self._kde_weights = np.empty(n, dtype=np.float64) + self._kde_in_range = np.empty(n, dtype=bool) + self._kde_tmp_a = np.empty(n, dtype=bool) + self._kde_tmp_b = np.empty(n, dtype=bool) + + # Detect a "simple" position grid (edges 0, 1, ..., num_bins) so the + # weighted histogram can map a position straight to its bin index + # without a search. Every shipped config uses this grid. + edges = self._pos_bin_struct.pos_bin_edges + self._hist_nbins = self._pos_bin_struct.num_bins + self._hist_first = edges[0] + self._hist_last = edges[-1] + self._simple_grid = bool( + edges[0] == 0 and np.allclose(np.diff(edges), 1.0) + ) + + def _fast_hist(self, positions, weights): + """Uniform-bin weighted histogram, equivalent to + np.histogram(positions, bins=pos_bin_edges, weights=weights) up to + floating-point summation order (~1e-10 relative, i.e. rounding noise). + Much faster than np.histogram on an explicit edge array, which falls + back to a per-sample binary search; bin assignment is identical.""" + num_bins = self._hist_nbins + if self._simple_grid: + # edges are 0, 1, ..., num_bins, so the bin index is just the + # integer part of the position. positions are position-bin indices + # already inside [0, num_bins); clip only guards pathological input. + idx = positions.astype(np.intp) + np.clip(idx, 0, num_bins - 1, out=idx) + return np.bincount(idx, weights=weights, minlength=num_bins) + # general uniform grid: numpy's own optimized linear-index + bincount + # path (triggered by passing an int bin count and an explicit range) + return np.histogram( + positions, bins=num_bins, + range=(self._hist_first, self._hist_last), weights=weights + )[0] + def _load_model(self): fname = os.path.join( @@ -130,6 +176,15 @@ def _init_params(self): self.p['filter_n_std'] = self._config['encoder']['mark_kernel']['n_std'] self.p['n_marks_min'] = self._config['encoder']['mark_kernel']['n_marks_min'] self.p['num_occupancy_points'] = self._config['display']['encoder']['occupancy'] + # use the exact (slower) numerics in get_joint_prob: np.histogram on + # explicit edges and a separate square-then-sum for the mark distance. + # The default fast path (uniform-bin histogram + fused distance sum) + # differs only in floating-point summation order (~1e-13 relative); + # leave this off unless you need bit-for-bit reproducibility against an + # older run. + self.p['exact_histogram'] = self._config['encoder']['mark_kernel'].get( + 'exact_histogram', False + ) def add_new_mark(self, mark): ''' @@ -198,46 +253,57 @@ def get_joint_prob(self, mark): #print(mark) - in_range = np.ones(mark_idx, dtype=bool) + in_range = self._kde_in_range[:mark_idx] + in_range.fill(True) if self.p['use_filter']: std = self.p['filter_std'] n_std = self.p['filter_n_std'] + tmp_a = self._kde_tmp_a[:mark_idx] + tmp_b = self._kde_tmp_b[:mark_idx] for ii in range(self._marks.shape[1]): - in_range = np.logical_and( - np.logical_and( - self._marks[:mark_idx, ii] > mark[ii] - n_std * std, - self._marks[:mark_idx, ii] < mark[ii] + n_std * std - ), - in_range - ) + col = self._marks[:mark_idx, ii] + np.greater(col, mark[ii] - n_std * std, out=tmp_a) + np.less(col, mark[ii] + n_std * std, out=tmp_b) + np.logical_and(tmp_a, tmp_b, out=tmp_a) + np.logical_and(in_range, tmp_a, out=in_range) # not enough spikes within n-cube if np.sum(in_range) < self.p['n_marks_min']: return None - # evaluate Gaussian kernel on distance in mark space - squared_distance = np.sum( - np.square(self._marks[:mark_idx] - mark), - axis=1 - ) - weights = self._k1 * np.exp(squared_distance * self._k2) + # evaluate Gaussian kernel on distance in mark space. all temporaries + # are preallocated scratch (see _init_kde_buffers) so the hot path + # does not allocate. + diff = self._kde_diff[:mark_idx] + np.subtract(self._marks[:mark_idx], mark, out=diff) + squared_distance = self._kde_sqd[:mark_idx] + if self.p['exact_histogram']: + # reference path: same operations and order as the original + np.square(diff, out=diff) + np.sum(diff, axis=1, out=squared_distance) + else: + # fuse the square and the per-mark sum into a single pass. differs + # from the two lines above only in summation order (~1e-16). + np.einsum('ij,ij->i', diff, diff, out=squared_distance) + + weights = self._kde_weights[:mark_idx] + np.multiply(squared_distance, self._k2, out=weights) + np.exp(weights, out=weights) + np.multiply(weights, self._k1, out=weights) + positions = self._positions[:mark_idx] - # print(positions.shape) - # print("") - # print(self._pos_bin_struct.pos_bin_edges) - # print("") - # print(weights) - - # `density=` (formerly `normed=`) intentionally omitted: we want the - # raw weighted sum per bin, which is the default behavior. `normed=` - # was removed in NumPy 1.24, which broke this call on any modern - # install. - hist, hist_edges = np.histogram( - a=positions, - bins=self._pos_bin_struct.pos_bin_edges, - weights=weights, - ) + # `density=` (formerly `normed=`) is intentionally omitted: we want the + # raw weighted sum per bin, which is the default. `normed=` was removed + # in NumPy 1.24, which broke this call on any modern install. + if self.p['exact_histogram']: + hist, _ = np.histogram( + a=positions, + bins=self._pos_bin_struct.pos_bin_edges, + weights=weights + ) + else: + hist = self._fast_hist(positions, weights) hist += 0.0000001