diff --git a/README.md b/README.md index db0fc9a..17cd514 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,8 @@ This entire section is also optional and similar to the `custom_mean` section ab `n_marks_min`: Only used if `use_filter` is `true`. Minimum number of marks that must be in the n-cube surrounding a candidate spike mark to be considered for decoding. +`exact_histogram`: Optional, default `false`. The per-spike joint probability uses a fast path (a uniform-bin histogram and a fused distance sum) that matches the original computation to within floating-point rounding (~1e-13). Set to `true` to restore the exact original numerics if you need bit-for-bit reproducibility against an older run. See `docs/latency_analysis.md`. + ### `dead_channels` This section is optional. @@ -229,7 +231,7 @@ For both options below, the reference sampling rate is the `spikes` sampling rat `samples`: How many samples make up a time bin. For example, if the `spikes` sampling rate is 30 kHz and `samples` is 180, the time bin size is 6 ms. -`delay_samples`: How many samples behind the current LFP timestamp the right edge of the current time bin should be. For example, if the `spikes` sampling rate is 30 kHz and `delay_samples` is 90, the right edge of the current time bin will be 3 ms behind the current LFP timestamp. +`delay_samples`: How many samples behind the current LFP timestamp the right edge of the current time bin should be. For example, if the `spikes` sampling rate is 30 kHz and `delay_samples` is 90, the right edge of the current time bin will be 3 ms behind the current LFP timestamp. This is a spike jitter buffer, and it is the single largest controllable term in end-to-end latency: tune it down to your measured worst-case encoder-to-decoder arrival time rather than copying a round number between configs. See `docs/latency_analysis.md`. ## `clusterless_decoder` diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..c5af4cc --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,41 @@ +# Hot-path latency benchmarks + +`bench_hotpath.py` measures the two compute kernels on the real-time critical +path, in isolation, without needing MPI or Trodes: + +- `Encoder.get_joint_prob` - the per-spike clusterless KDE (encoder ranks) +- `ClusterlessDecoder.compute_posterior` - the per-time-bin posterior (decoder ranks) + +It drives the *actual* classes from `realtime_decoder`, not a reimplementation, +so the numbers reflect the shipped code. It is the "measure" half of the +measure -> change one thing -> measure loop in `docs/realtime_tuning.md`. + +## Usage + +```bash +# latency distributions (p50/p99/p99.9/max) at SC66 sizes and stress sizes +python benchmarks/bench_hotpath.py + +# prove the optimized encoder kernel still matches the original algorithm +python benchmarks/bench_hotpath.py --verify + +# deterministic digest of kernel outputs (compare across code versions) +python benchmarks/bench_hotpath.py --checksum +``` + +Only `numpy` is required. The harness installs a tiny `mpi4py` stub so the +package imports without an MPI runtime; the kernels themselves do not use MPI. + +## What to look at + +Report **p99 and p99.9**, not the mean. Closed-loop latency is a tail problem: +a kernel that is fast on average but occasionally stalls for milliseconds will +force a larger spike jitter buffer (`decoder.time_bin.delay_samples`), which is +the single biggest controllable term in end-to-end latency. See +`docs/latency_analysis.md`. + +## Sizes + +Defaults mirror `config/SC66.yml`: `num_bins=41`, encoder mark buffer +`bufsize=50000`, `mark_dim=4`. `get_joint_prob` cost scales with the number of +stored marks, so the buffer-full case (50000) is the steady-state worst case. diff --git a/benchmarks/bench_hotpath.py b/benchmarks/bench_hotpath.py new file mode 100644 index 0000000..1cf0d49 --- /dev/null +++ b/benchmarks/bench_hotpath.py @@ -0,0 +1,387 @@ +"""Standalone latency benchmark for the real-time hot-path compute kernels. + +This drives the *actual* decoding kernels used in closed-loop experiments +(`ClusterlessDecoder.compute_posterior` and `Encoder.get_joint_prob`) +without needing MPI or Trodes, so latency can be measured and tuned on any +machine. It is the "measure" half of the measure -> change one thing -> +measure loop described in docs/realtime_tuning.md. + +Two modes: + + python benchmarks/bench_hotpath.py # timing: p50/p99/p99.9/max + python benchmarks/bench_hotpath.py --checksum # deterministic output digest + +The checksum mode runs a fixed, seeded sequence of inputs through the kernels +and prints a sha256 of every output array. Run it before and after a change to +a kernel: an identical digest proves the change did not alter the numerical +output (i.e. the optimization is semantics-preserving). + +Sizes default to the SC66 production config (num_bins=41, mark buffer 50000, +mark_dim 4). Override with flags to explore the cost surface. +""" + +import argparse +import hashlib +import os +import sys +import time +import types + +import numpy as np + +# allow `python benchmarks/bench_hotpath.py` from the repo root +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# --------------------------------------------------------------------------- +# Import the package without MPI / Trodes present. The compute kernels do not +# use MPI at runtime; base.py only needs the names to exist at import time. +# --------------------------------------------------------------------------- +def _install_mpi_stub(): + if "mpi4py" in sys.modules: + return + + class _AnyAttr: + """Returns a throwaway type for any attribute access (MPI.Comm etc.).""" + def __getattr__(self, name): + return type(name, (), {}) + + mpi4py = types.ModuleType("mpi4py") + mpi4py.MPI = _AnyAttr() + sys.modules["mpi4py"] = mpi4py + + +_install_mpi_stub() + +from realtime_decoder.decoder_process import ClusterlessDecoder # noqa: E402 +from realtime_decoder.encoder_process import Encoder # noqa: E402 +from realtime_decoder.position import PositionBinStruct # noqa: E402 + + +# --------------------------------------------------------------------------- +# Config construction. Only the keys the kernels actually read are populated. +# Defaults mirror config/SC66.yml. +# --------------------------------------------------------------------------- +def make_config(*, num_bins, arm_coords, n_elec, bufsize, mark_dim, rank): + return { + "algorithm": "clusterless_decoder", + "preloaded_model": False, + "sampling_rate": {"spikes": 30000}, + "clusterless_decoder": {"state_labels": ["state"], "transmat_bias": 1}, + "decoder_assignment": {rank: list(range(1, n_elec + 1))}, + "decoder": { + "time_bin": {"samples": 180, "delay_samples": 180}, + }, + "display": { + "decoder": {"occupancy": 10_000_000}, + "encoder": {"occupancy": 10_000_000}, + }, + "encoder": { + "mark_dim": mark_dim, + "bufsize": bufsize, + "use_channel_dist_from_max_amp": False, + "mark_kernel": { + "std": 20, + "use_filter": True, + "n_std": 1, + "n_marks_min": 10, + }, + "position": { + "lower": 0, + "upper": num_bins, + "num_bins": num_bins, + "arm_coords": arm_coords, + }, + }, + } + + +def make_pos_bin_struct(config): + p = config["encoder"]["position"] + return PositionBinStruct(p["lower"], p["upper"], p["num_bins"]) + + +# --------------------------------------------------------------------------- +# Percentile reporting +# --------------------------------------------------------------------------- +def summarize(name, samples_ns): + a = np.sort(np.asarray(samples_ns, dtype=np.float64)) / 1000.0 # -> microseconds + def pct(q): + return a[min(len(a) - 1, int(q * len(a)))] + print( + f" {name:<28} n={len(a):>6} " + f"p50={pct(0.50):8.2f} p99={pct(0.99):8.2f} " + f"p99.9={pct(0.999):8.2f} max={a[-1]:9.2f} (us)" + ) + return {"p50": pct(0.50), "p99": pct(0.99), "p999": pct(0.999), "max": a[-1]} + + +# --------------------------------------------------------------------------- +# Decoder benchmark +# --------------------------------------------------------------------------- +def build_decoder(config, rank, rng): + dec = ClusterlessDecoder(rank, config, make_pos_bin_struct(config)) + # Populate occupancy realistically: accumulate position samples and apply + # the no-animal boundary so gap bins become nan (exercises the isfinite + # handling in compute_posterior). + num_bins = config["encoder"]["position"]["num_bins"] + for _ in range(2000): + dec.update_position(int(rng.integers(0, num_bins)), True) + return dec + + +def make_spike_arr(config, rank, n_spikes, num_bins, rng): + """One time bin's worth of spikes in the layout compute_posterior expects: + columns [timestamp, elec_grp_id, pos, cred_int, used_flag, hist(num_bins)]. + """ + elec_ids = config["decoder_assignment"][rank] + arr = np.zeros((n_spikes, 5 + num_bins), dtype=np.float64) + for i in range(n_spikes): + arr[i, 0] = rng.integers(0, 10_000_000) + arr[i, 1] = rng.choice(elec_ids) + arr[i, 2] = rng.integers(0, num_bins) + arr[i, 3] = rng.integers(0, 20) + hist = rng.random(num_bins) + 1e-6 + hist /= hist.sum() + arr[i, 5:] = hist + return arr + + +def bench_decoder(*, num_bins, arm_coords, n_elec, n_iter, warmup, seed): + rank = 1 + config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=n_elec, + bufsize=2000, mark_dim=4, rank=rank, + ) + rng = np.random.default_rng(seed) + dec = build_decoder(config, rank, rng) + + # Realistic per-bin spike counts: mostly 0-4 spikes in a 6 ms bin. + spike_counts = rng.integers(0, 5, size=n_iter + warmup) + bins = [make_spike_arr(config, rank, int(c), num_bins, rng) + for c in spike_counts] + + for i in range(warmup): + dec.compute_posterior(bins[i]) + + samples = np.empty(n_iter, dtype=np.int64) + for i in range(n_iter): + b = bins[warmup + i] + t0 = time.perf_counter_ns() + dec.compute_posterior(b) + samples[i] = time.perf_counter_ns() - t0 + + label = f"compute_posterior nb={num_bins} elec={n_elec}" + return summarize(label, samples) + + +# --------------------------------------------------------------------------- +# Encoder benchmark +# --------------------------------------------------------------------------- +def build_encoder(config, trode, n_marks, rng): + enc = Encoder(config, trode, make_pos_bin_struct(config)) + num_bins = config["encoder"]["position"]["num_bins"] + mark_dim = config["encoder"]["mark_dim"] + + # Fill the mark buffer: a cluster of marks in mark space plus their + # observed positions, and a realistic occupancy with no-animal gaps. + n_marks = min(n_marks, enc._marks.shape[0]) + enc._marks[:n_marks] = rng.normal(100.0, 30.0, size=(n_marks, mark_dim)) + enc._positions[:n_marks] = rng.integers(0, num_bins, size=n_marks) + enc._mark_idx = n_marks + + occ = rng.integers(1, 50, size=num_bins).astype(np.float64) + from realtime_decoder import utils + utils.apply_no_anim_boundary( + enc._pos_bins, enc._arm_coords, occ, np.nan + ) + enc._occupancy = occ + return enc + + +def bench_encoder(*, num_bins, arm_coords, n_marks, n_iter, warmup, seed): + config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=1, + bufsize=max(n_marks, 50000), mark_dim=4, rank=1, + ) + rng = np.random.default_rng(seed) + enc = build_encoder(config, 1, n_marks, rng) + + # Query marks drawn near the cluster center so the n_marks_min gate passes. + queries = rng.normal(100.0, 30.0, size=(n_iter + warmup, config["encoder"]["mark_dim"])) + + for i in range(warmup): + enc.get_joint_prob(queries[i]) + + samples = np.empty(n_iter, dtype=np.int64) + for i in range(n_iter): + q = queries[warmup + i] + t0 = time.perf_counter_ns() + enc.get_joint_prob(q) + samples[i] = time.perf_counter_ns() - t0 + + return summarize(f"get_joint_prob marks={n_marks}", samples) + + +# --------------------------------------------------------------------------- +# Checksum mode: deterministic digest of kernel outputs for equivalence checks +# --------------------------------------------------------------------------- +def _original_joint_prob(enc, mark): + """Faithful re-implementation of the pre-optimization get_joint_prob, + used only to prove the optimized version is numerically equivalent.""" + if enc._mark_idx == 0: + return None + mark_idx = min(enc._mark_idx, enc._marks.shape[0]) + + in_range = np.ones(mark_idx, dtype=bool) + if enc.p['use_filter']: + std = enc.p['filter_std'] + n_std = enc.p['filter_n_std'] + for ii in range(enc._marks.shape[1]): + in_range = np.logical_and( + np.logical_and( + enc._marks[:mark_idx, ii] > mark[ii] - n_std * std, + enc._marks[:mark_idx, ii] < mark[ii] + n_std * std, + ), + in_range, + ) + if np.sum(in_range) < enc.p['n_marks_min']: + return None + + squared_distance = np.sum(np.square(enc._marks[:mark_idx] - mark), axis=1) + weights = enc._k1 * np.exp(squared_distance * enc._k2) + positions = enc._positions[:mark_idx] + hist, _ = np.histogram( + a=positions, bins=enc._pos_bin_struct.pos_bin_edges, weights=weights + ) + hist += 0.0000001 + hist /= (enc._occupancy / np.nansum(enc._occupancy)) + hist[~np.isfinite(hist)] = 0.0 + hist /= (np.sum(hist) * enc._pos_bin_struct.pos_bin_delta) + return hist + + +def verify(seed=20240601): + """Prove the optimized encoder kernel preserves the result. + + exact_histogram=True must be bit-identical to the original algorithm; + the fast histogram path must match to within floating-point rounding. + """ + num_bins = 41 + arm_coords = [[0, 8], [13, 24], [29, 40]] + config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=1, + bufsize=50000, mark_dim=4, rank=1, + ) + rng = np.random.default_rng(seed) + enc = build_encoder(config, 1, 50000, rng) + + n_checks = 400 + exact_bit_identical = 0 + fast_max_rel = 0.0 + compared = 0 + for _ in range(n_checks): + q = rng.normal(100.0, 30.0, size=4) + ref = _original_joint_prob(enc, q) + + enc.p['exact_histogram'] = True + e = enc.get_joint_prob(q) + enc.p['exact_histogram'] = False + f = enc.get_joint_prob(q) + + if ref is None: + assert e is None and f is None + continue + compared += 1 + if np.array_equal(e.hist, ref): + exact_bit_identical += 1 + nz = ref != 0 + if nz.any(): + fast_max_rel = max( + fast_max_rel, + float(np.max(np.abs(f.hist[nz] - ref[nz]) / np.abs(ref[nz]))), + ) + + print(f"compared {compared} non-trivial spikes") + print(f"exact_histogram path bit-identical to original: " + f"{exact_bit_identical}/{compared} " + f"{'PASS' if exact_bit_identical == compared else 'FAIL'}") + print(f"fast path max relative diff vs original: {fast_max_rel:.2e} " + f"{'PASS (< 1e-6, rounding noise)' if fast_max_rel < 1e-6 else 'FAIL'}") + + +def checksum(seed=1234): + h = hashlib.sha256() + + # Decoder + num_bins = 41 + arm_coords = [[0, 8], [13, 24], [29, 40]] + rank = 1 + config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=8, + bufsize=2000, mark_dim=4, rank=rank, + ) + rng = np.random.default_rng(seed) + dec = build_decoder(config, rank, rng) + for _ in range(500): + n = int(rng.integers(0, 5)) + post, lk = dec.compute_posterior(make_spike_arr(config, rank, n, num_bins, rng)) + h.update(np.ascontiguousarray(post, dtype=np.float64).tobytes()) + h.update(np.ascontiguousarray(lk, dtype=np.float64).tobytes()) + + # Encoder + enc_config = make_config( + num_bins=num_bins, arm_coords=arm_coords, n_elec=1, + bufsize=50000, mark_dim=4, rank=1, + ) + rng2 = np.random.default_rng(seed + 1) + enc = build_encoder(enc_config, 1, 50000, rng2) + for _ in range(500): + q = rng2.normal(100.0, 30.0, size=4) + est = enc.get_joint_prob(q) + if est is not None: + h.update(np.ascontiguousarray(est.hist, dtype=np.float64).tobytes()) + + print("output digest:", h.hexdigest()) + + +# --------------------------------------------------------------------------- +def main(): + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--checksum", action="store_true", + help="print deterministic digest of kernel outputs and exit") + ap.add_argument("--verify", action="store_true", + help="prove optimized encoder kernel matches the original and exit") + ap.add_argument("--iter", type=int, default=5000) + ap.add_argument("--warmup", type=int, default=500) + ap.add_argument("--seed", type=int, default=0) + args = ap.parse_args() + + if args.verify: + verify() + return + + if args.checksum: + checksum() + return + + arm = [[0, 8], [13, 24], [29, 40]] + print(f"numpy {np.__version__}") + print("\nDecoder posterior (per time bin):") + bench_decoder(num_bins=41, arm_coords=arm, n_elec=8, + n_iter=args.iter, warmup=args.warmup, seed=args.seed) + bench_decoder(num_bins=41, arm_coords=arm, n_elec=32, + n_iter=args.iter, warmup=args.warmup, seed=args.seed) + # Stress: larger position grid (e.g. 2D environments) + big_arm = [[0, 127], [128, 255]] + bench_decoder(num_bins=256, arm_coords=big_arm, n_elec=32, + n_iter=args.iter, warmup=args.warmup, seed=args.seed) + + print("\nEncoder joint prob (per spike, scales with stored marks):") + for nm in (1000, 10000, 50000): + bench_encoder(num_bins=41, arm_coords=arm, n_marks=nm, + n_iter=min(args.iter, 3000), warmup=200, seed=args.seed) + + +if __name__ == "__main__": + main() diff --git a/docs/latency_analysis.md b/docs/latency_analysis.md new file mode 100644 index 0000000..5b6cd6d --- /dev/null +++ b/docs/latency_analysis.md @@ -0,0 +1,178 @@ +# Where the latency actually is + +This note records a measurement-driven look at closed-loop latency in this +package: what dominates it, what does not, and which knob moves it the most. It +is the companion to `docs/realtime_tuning.md` (which covers the OS) and to +`benchmarks/bench_hotpath.py` (which produces the numbers below). All timings +are from that benchmark on a developer laptop; absolute values will differ on +your rig, but the *ratios* and the *ranking of costs* hold. + +The headline: at production sizes the per-bin decoder math is microseconds, the +end-to-end latency is dominated by the **time-bin width plus the spike jitter +buffer** (`decoder.time_bin`), and the largest piece of controllable *compute* +is the encoder's per-spike KDE - which was spending most of its time inside +`np.histogram`. + +## The closed-loop critical path + +For a replay/ripple-triggered stim, a neural event becomes a hardware output +through roughly these stages: + +``` +spike --> Trodes acq --> encoder rank: get_joint_prob --> MPI --> decoder rank + --> wait for the time bin to close (+ delay) --> compute_posterior + --> MPI --> main/stim decision --> ECU shortcut --> hardware +``` + +Two of those stages are fixed by the science or the hardware (acquisition +transport, and the bin width you chose for statistical reasons). The rest is +software, and that is where this package can help or hurt. + +## What the numbers say + +Measured per-call latency of the two hot kernels (`bench_hotpath.py`, SC66 +config: `num_bins=41`, mark buffer 50000, `mark_dim=4`): + +| kernel | size | p50 | p99 | +|---|---|---|---| +| `compute_posterior` (per bin) | 8 electrodes | 65 us | 71 us | +| `compute_posterior` (per bin) | 32 electrodes | 173 us | 199 us | +| `get_joint_prob` (per spike) | 1000 marks | 30 us | 36 us | +| `get_joint_prob` (per spike) | 10000 marks | 140 us | 162 us | +| `get_joint_prob` (per spike) | 50000 marks (buffer full) | 635 us | 700 us | + +Put next to the time bin, the picture is stark. SC66 uses +`time_bin.samples = 180` and `time_bin.delay_samples = 180` at a 30 kHz spike +clock: + +- bin width = 180 / 30000 = **6 ms** +- jitter buffer (`delay_samples`) = 180 / 30000 = **6 ms** + +So the decoder deliberately decodes a bin that closed up to **12 ms ago**. The +posterior math that runs at the end of that window costs ~65-170 us - about +**1-3% of the bin period**. Decoder compute is not the bottleneck and was left +unchanged; spending effort shaving microseconds off a sub-millisecond operation +that sits behind a 12 ms wall would be pointless. + +The encoder KDE is a different story. It costs ~0.65 ms *per spike* once the +mark buffer fills, and it grows with the buffer (30 us -> 140 us -> 635 us as +marks go 1k -> 10k -> 50k). At realistic spike rates across a tetrode this is +the kernel that saturates a rank and forces spikes to be dropped, and its tail +is what the jitter buffer has to absorb. + +## The fix that mattered: stop calling np.histogram with explicit edges + +Profiling `get_joint_prob` at a full 50000-mark buffer, before any change: + +| component | time | +|---|---| +| `np.histogram(positions, bins=pos_bin_edges, weights=...)` | ~1195 us | +| `squared_distance` (subtract, square, sum over marks x channels) | ~585 us | +| in-cube filter (4 channels) | ~136 us | +| Gaussian `exp` | ~112 us | + +Two thirds of the kernel was in `np.histogram`. The reason is subtle: when you +pass `np.histogram` an explicit array of bin **edges**, it cannot assume the +bins are uniform, so it bins every sample with a per-sample binary search +(`np.searchsorted`) - ~960 us for 50000 samples, cache-hostile and pure +overhead here, because the bins *are* uniform (`linspace(0, num_bins, ...)`). + +The position grid in every shipped config is the integer grid `0, 1, ..., +num_bins`, so a position's bin index is just its integer part. Replacing the +search with a direct index + `np.bincount` drops the histogram from ~1195 us to +~67 us. Two more changes on top of that: + +- Preallocate the per-spike scratch. The kernel allocated about a dozen + temporaries per spike, several of them `O(marks)`; now they are reused. +- Fuse the squared distance with `np.einsum('ij,ij->i', diff, diff)`, which + does the square and the per-mark sum in one pass instead of materializing the + squared array and summing it. That takes the second-largest component from + ~585 us to ~280 us. + +Together: + +| | p50 | p99 | +|---|---|---| +| before | 2096 us | 2702 us | +| after | 635 us | 700 us | + +**About 3.3x faster on the median.** The tail (p99.9 and max) is dominated by +OS scheduling and varies run to run, but it tracks down with the median. The +tail is what matters most: it is the spike the decoder ends up waiting for, and +it sets how large the `delay_samples` jitter buffer has to be. + +### Why this is safe + +The change is gated and verified (`bench_hotpath.py --verify`): + +- The preallocation is bit-for-bit identical to the original (verified on every + spike of a 400-spike run). +- The fused distance sum and the uniform-bin histogram differ from the original + only in floating-point **summation order**, ~3e-13 relative overall - the + same magnitude of difference you already get from a different numpy version, + BLAS thread count, + or CPU. For a normalized probability over 41 bins driven by Poisson spikes, + this is far below the noise floor. +- For anyone who needs bit-for-bit reproducibility against an older run, set + `encoder.mark_kernel.exact_histogram: true`. That restores both the original + `np.histogram` call and the original square-then-sum distance, so the output + is identical to the old code. The general (non-integer-grid) fast histogram + also uses numpy's own optimized uniform histogram, which is ~4x faster than + the explicit-edges path and exactly matches its bin assignment. + +## The lever that moves end-to-end latency: delay_samples + +The encoder speedup matters for two reasons. The obvious one is throughput - +fewer dropped spikes per rank. The bigger one is indirect. + +`delay_samples` exists because the decoder must wait for all spikes in a bin to +arrive from the encoders before it decodes that bin. How long it must wait is +set by the **worst-case** encoder-to-decoder latency, i.e. the *tail* of +`get_joint_prob` plus transport. Cutting the encoder's per-spike tail from ~4 ms +to ~1 ms shrinks the jitter that the buffer has to cover, which means +`delay_samples` can potentially come down - and `delay_samples` is worth up to +**6 ms** of end-to-end latency at SC66, versus the ~1 ms the compute itself +costs. + +So the right way to actually lower closed-loop latency, in order of payoff: + +1. **Tune `delay_samples` down to your measured worst case.** Run a session, + open the per-rank `*.timing.npz` arrays, and look at the real distribution of + encoder->decoder arrival times. Set `delay_samples` to cover p99.9 of that + distribution with a small margin, not a round number copied between configs. + This is the highest-value single change and it is a config edit. +2. **Keep the encoder tail low** so step 1 can be aggressive (this is what the + `get_joint_prob` work buys you). +3. **Reduce the bin width** (`samples`) only if the statistics of your decode + tolerate it - this is a science decision, not a software one. A smaller bin + is less latency but noisier likelihoods. + +## The theoretical floor, and what it would take to go lower + +With the bin width and `delay_samples` fixed by the experiment, the software +floor for the decode path is now: encoder KDE (~0.8 ms tail) + two MPI hops +(tens of us each, already zero-copy for the data plane) + posterior (~0.1 ms) + +decision + actuation. Everything else is OS jitter, addressed in +`docs/realtime_tuning.md`. + +To push the encoder KDE below ~0.8 ms at a full buffer you have to change the +math, not just the implementation - the remaining cost is the genuine work of +evaluating the kernel over every stored mark: + +- **Truncated kernel.** The in-cube filter already finds the marks near the + query; evaluating the Gaussian only on those (and treating the rest as zero) + would cut both the distance computation and the histogram dramatically. It + changes results by more than rounding (far marks contribute small but nonzero + weight), so it needs the lab's sign-off and a config flag - but it is the + single biggest remaining lever and worth prototyping behind the benchmark. +- **Spatial index (k-d tree / grid) over marks.** Same idea, sub-linear in the + buffer size, more code. +- **float32 marks.** Halves the memory traffic that dominates + `squared_distance`; changes numerics. +- **GPU.** Only wins at much larger mark buffers or electrode counts than SC66; + the per-spike arrays here (50000 x 4) are too small to beat the host round + trip. Revisit only if marks-per-buffer or tetrode counts grow by an order of + magnitude. + +The point of `bench_hotpath.py` is that none of these need to be argued in the +abstract: prototype, run `--verify` and the timing pass, and read the number. diff --git a/realtime_decoder/encoder_process.py b/realtime_decoder/encoder_process.py index a66fabd..3020289 100644 --- a/realtime_decoder/encoder_process.py +++ b/realtime_decoder/encoder_process.py @@ -88,6 +88,52 @@ def __init__(self, config, trode, pos_bin_struct): self._temp_idx = 0 # NOTE(DS): so that mark_idx does not increase but still write down in the mark vec self._init_params() + self._init_kde_buffers() + + def _init_kde_buffers(self): + """Preallocate the per-spike scratch used by get_joint_prob so the + hot path does not allocate. Sized to the full mark buffer, then + sliced to the live mark count on each call.""" + n, dim = self._marks.shape + self._kde_diff = np.empty((n, dim), dtype=np.float64) + self._kde_sqd = np.empty(n, dtype=np.float64) + self._kde_weights = np.empty(n, dtype=np.float64) + self._kde_in_range = np.empty(n, dtype=bool) + self._kde_tmp_a = np.empty(n, dtype=bool) + self._kde_tmp_b = np.empty(n, dtype=bool) + + # Detect a "simple" position grid (edges 0, 1, ..., num_bins) so the + # weighted histogram can map a position straight to its bin index + # without a search. Every shipped config uses this grid. + edges = self._pos_bin_struct.pos_bin_edges + self._hist_nbins = self._pos_bin_struct.num_bins + self._hist_first = edges[0] + self._hist_last = edges[-1] + self._simple_grid = bool( + edges[0] == 0 and np.allclose(np.diff(edges), 1.0) + ) + + def _fast_hist(self, positions, weights): + """Uniform-bin weighted histogram, equivalent to + np.histogram(positions, bins=pos_bin_edges, weights=weights) up to + floating-point summation order (~1e-10 relative, i.e. rounding noise). + Much faster than np.histogram on an explicit edge array, which falls + back to a per-sample binary search; bin assignment is identical.""" + num_bins = self._hist_nbins + if self._simple_grid: + # edges are 0, 1, ..., num_bins, so the bin index is just the + # integer part of the position. positions are position-bin indices + # already inside [0, num_bins); clip only guards pathological input. + idx = positions.astype(np.intp) + np.clip(idx, 0, num_bins - 1, out=idx) + return np.bincount(idx, weights=weights, minlength=num_bins) + # general uniform grid: numpy's own optimized linear-index + bincount + # path (triggered by passing an int bin count and an explicit range) + return np.histogram( + positions, bins=num_bins, + range=(self._hist_first, self._hist_last), weights=weights + )[0] + def _load_model(self): fname = os.path.join( @@ -130,6 +176,15 @@ def _init_params(self): self.p['filter_n_std'] = self._config['encoder']['mark_kernel']['n_std'] self.p['n_marks_min'] = self._config['encoder']['mark_kernel']['n_marks_min'] self.p['num_occupancy_points'] = self._config['display']['encoder']['occupancy'] + # use the exact (slower) numerics in get_joint_prob: np.histogram on + # explicit edges and a separate square-then-sum for the mark distance. + # The default fast path (uniform-bin histogram + fused distance sum) + # differs only in floating-point summation order (~1e-13 relative); + # leave this off unless you need bit-for-bit reproducibility against an + # older run. + self.p['exact_histogram'] = self._config['encoder']['mark_kernel'].get( + 'exact_histogram', False + ) def add_new_mark(self, mark): ''' @@ -198,46 +253,57 @@ def get_joint_prob(self, mark): #print(mark) - in_range = np.ones(mark_idx, dtype=bool) + in_range = self._kde_in_range[:mark_idx] + in_range.fill(True) if self.p['use_filter']: std = self.p['filter_std'] n_std = self.p['filter_n_std'] + tmp_a = self._kde_tmp_a[:mark_idx] + tmp_b = self._kde_tmp_b[:mark_idx] for ii in range(self._marks.shape[1]): - in_range = np.logical_and( - np.logical_and( - self._marks[:mark_idx, ii] > mark[ii] - n_std * std, - self._marks[:mark_idx, ii] < mark[ii] + n_std * std - ), - in_range - ) + col = self._marks[:mark_idx, ii] + np.greater(col, mark[ii] - n_std * std, out=tmp_a) + np.less(col, mark[ii] + n_std * std, out=tmp_b) + np.logical_and(tmp_a, tmp_b, out=tmp_a) + np.logical_and(in_range, tmp_a, out=in_range) # not enough spikes within n-cube if np.sum(in_range) < self.p['n_marks_min']: return None - # evaluate Gaussian kernel on distance in mark space - squared_distance = np.sum( - np.square(self._marks[:mark_idx] - mark), - axis=1 - ) - weights = self._k1 * np.exp(squared_distance * self._k2) + # evaluate Gaussian kernel on distance in mark space. all temporaries + # are preallocated scratch (see _init_kde_buffers) so the hot path + # does not allocate. + diff = self._kde_diff[:mark_idx] + np.subtract(self._marks[:mark_idx], mark, out=diff) + squared_distance = self._kde_sqd[:mark_idx] + if self.p['exact_histogram']: + # reference path: same operations and order as the original + np.square(diff, out=diff) + np.sum(diff, axis=1, out=squared_distance) + else: + # fuse the square and the per-mark sum into a single pass. differs + # from the two lines above only in summation order (~1e-16). + np.einsum('ij,ij->i', diff, diff, out=squared_distance) + + weights = self._kde_weights[:mark_idx] + np.multiply(squared_distance, self._k2, out=weights) + np.exp(weights, out=weights) + np.multiply(weights, self._k1, out=weights) + positions = self._positions[:mark_idx] - # print(positions.shape) - # print("") - # print(self._pos_bin_struct.pos_bin_edges) - # print("") - # print(weights) - - # `density=` (formerly `normed=`) intentionally omitted: we want the - # raw weighted sum per bin, which is the default behavior. `normed=` - # was removed in NumPy 1.24, which broke this call on any modern - # install. - hist, hist_edges = np.histogram( - a=positions, - bins=self._pos_bin_struct.pos_bin_edges, - weights=weights, - ) + # `density=` (formerly `normed=`) is intentionally omitted: we want the + # raw weighted sum per bin, which is the default. `normed=` was removed + # in NumPy 1.24, which broke this call on any modern install. + if self.p['exact_histogram']: + hist, _ = np.histogram( + a=positions, + bins=self._pos_bin_struct.pos_bin_edges, + weights=weights + ) + else: + hist = self._fast_hist(positions, weights) hist += 0.0000001