Add spyx.phasor — phasor and spiking phasor networks#41
Open
kmheckel wants to merge 9 commits into
Open
Conversation
…; Added validation tests
Polishing pass on the NNX-migration branch ahead of merge: * README.md / docs/index.md / AGENTS.md no longer claim Spyx is built on Haiku and accurately describe the new axn/fn/nn surface (no scan_snn or Axon class). Add a note that research/ notebooks predate the migration. * nn.ActivityRegularization required no shape and used a Variable(None) with lazy init that won't trace under nnx.jit. Take an explicit hidden_shape/batch_size, expose a reset() helper, and use the variable[...] API. * nir.from_nir now wires recurrent_w + tau parameters when importing RNN subgraphs (previously a `pass` placeholder), and to_nir grows a matching helper that emits an inner NIRGraph for RIF/RLIF/RCuBaLIF exports so the roundtrip is symmetric. * tests/test_nn_nnx covers IF/ALIF/CuBaLIF/RIF/RCuBaLIF/LI/SumPool plus ActivityRegularization under nnx.jit. New tests/test_axn.py and tests/test_fn.py exercise every surrogate gradient and every loss / regularizer. tests/test_nir.py adds RLIF and RCuBaLIF roundtrips. After this change `uv run pytest` reports 33 passed, 2 failed (only the tonic-gated loader tests, which need `uv sync --extra loaders`). `uv run ruff check` is clean. https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Rewrote the five docs/examples notebooks that were still on the legacy
Haiku + spyx.loaders stack. Each notebook now uses the same modernised
pattern: nnx.Module + nn.Sequential, optimisation via nnx.Optimizer +
nnx.value_and_grad, data via spyx.data.SHD_loader (or NMNIST_loader)
streamed through Google Grain, with jnp.unpackbits applied inline to
recover dense spikes from the bit-packed time axis.
* docs/examples/neuroevolution/cartpole_evo.ipynb - rebuilt the spiking
controller as an nnx.Module; rolled out per-episode under jax.lax.scan;
evosax CMA-ES drives a flat parameter vector via ravel_pytree +
nnx.merge.
* docs/examples/surrogate_gradient/SurrogateGradientTutorial.ipynb -
Linear+LIF+LIF+LI BPTT trainer on SHD using the Lion optimizer.
* docs/examples/surrogate_gradient/shd_sg_template.ipynb - reusable
template that taps per-layer spike trains as scan outputs and applies
silence_reg + sparsity_reg in the loss (avoiding mutable activity
variables under nnx.jit).
* docs/examples/surrogate_gradient/shd_sg_neuron_model_comparison.ipynb -
sweeps LIF / IF / ALIF / RLIF / mixed variants through a shared
training helper.
* docs/examples/surrogate_gradient/shd_sg_surrogate_comparison.ipynb -
same shared trainer parameterised by surrogate gradient (arctan,
superspike, tanh, boxcar, triangular, STE).
Mixed precision (haiku.mixed_precision) has no first-class NNX analog
yet and has been dropped, with a note in each tutorial pointing users at
jax.config.update('jax_default_matmul_precision', 'bfloat16') as a
substitute. Cell outputs are intentionally cleared since CI does not
execute notebooks; rerun locally with the [loaders] extra installed
before tagging a release.
https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Closes #36 and addresses #37. * spyx.quant: thin SNN-aware wrapper around Google's qwix library. - quantize(model, *example_inputs, rules=, mode="qat"|"ptq") - linear_only_rules() / weights_only_rules() shortcuts - available() guard + actionable ImportError when qwix is missing Default rules quantize only nnx.Linear / nnx.Conv (int8 W+A); spiking dynamics (LIF, CuBaLIF, ALIF, IF, LI) stay in fp32 because their recurrences are sensitive to integer rounding. * New optional [quant] extra in pyproject.toml; qwix has no PyPI release so it's wired in via tool.uv.sources to install from GitHub. * Bump flax >= 0.11 (which forced python >= 3.11). flax 0.11 changed nnx.Optimizer to require wrt= and made optimizer.update(model, grads) the canonical signature. * Update all five docs/examples notebooks to the new optimizer API: every nnx.Optimizer(...) gains wrt=nnx.Param and every optimizer.update(grads) becomes optimizer.update(model, grads). * cartpole_evo: switch to evosax >= 0.2 API (CMA_ES(population_size=, solution=) + init/ask/tell with key args). Smoke-tested end-to-end with 8-pop / 3-gen / 50-step rollouts; CMA-ES was already improving rewards by gen 2. * New docs/examples/quantization/qat_intro.ipynb tutorial walking through QAT, PTQ, custom rules, and the int4 weights-only path. Wired into mkdocs.yml. * tests/test_quant.py covers the available() guard, smoke quantize, full nnx.Optimizer + value_and_grad QAT loop, the rules helpers, and the unknown-mode error path. 6 passed (1 conditionally skipped when qwix is missing). * AGENTS.md: document the new module, extra, and dependency version. uv run ruff check -> All checks passed! uv run pytest -> 39 passed, 2 failed (only the SHD/NMNIST loader tests, which need network access to download the datasets). https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Initial implementation for #38. Provides a Flax NNX module set for the deep phasor / spiking phasor architecture (Bybee, Frady & Sommer 2022, arXiv 2106.11908) that takes advantage of JAX's native complex64 dtype: * PhasorLinear - complex-valued dense layer (Glorot-init complex weights). * PhasorActivation - projects activations onto the unit circle (the TPAM threshold function). * PhasorReadout - real-valued logits via Re(W @ z). * PhasorMLP - convenience constructor; encodes real x -> phasors, stacks N (Linear, Activation) blocks, then a real readout. * phase_to_spikes / spikes_to_phase - one-spike-per-cycle codec with centroid-based decoding. * SpikingPhasor - drop-in spiking-inference wrapper around a single PhasorLinear; takes a spike train, recovers phases, runs the complex matmul, and re-emits spikes. Tests (tests/test_phasor.py, 10 cases) cover dtype/shape contracts, the phase<->spike round-trip within the bin-quantisation tolerance, the SpikingPhasor end-to-end path, and that PhasorMLP forward+backward+optax stays numerically finite over many steps. Caveat documented in the module docstring and in #38: convergence with plain optax.adam on complex parameters is fragile because JAX returns the *conjugate* Wirtinger derivative on real-valued losses. Optax does not unwind the conjugation, so the imaginary axis can drift. A faithful forward/backward path is in place; tuned training (manual Wirtinger updates or real/imag parameter splitting) is left as the next step. scripts/phasor_demo.py walks through the continuous and spiking forward passes and prints round-trip statistics; a useful sanity check without needing the tutorial notebook runtime. uv run ruff check -> All checks passed! uv run pytest -> 49 passed, 2 failed, 1 skipped (the failures are the pre-existing tonic-gated SHD/NMNIST loader tests that need network access). https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Issue #38 flagged a training convergence bug in the first iteration of PhasorLinear: JAX returns the conjugate Wirtinger derivative for a real-valued loss with respect to a complex parameter, and optax's real arithmetic doesn't unwind the conjugation, so adam steps drifted sideways on the imaginary axis and the network couldn't learn an int(x[0] > 0.5) task. Fix: split the complex kernel (and bias) into separate float32 kernel_re and kernel_im (plus bias_re / bias_im) nnx.Params. The forward pass reconstructs the complex kernel on demand through @Property accessors, so the public API (PhasorLinear.kernel / .bias) is unchanged. Gradients optax sees are now always real. The test_phasor_mlp_training_reduces_loss check is tightened from "no NaN" to "loss drops by > 40% on 200 steps of adam"; it was previously weakened around the bug. Module docstring replaces the Wirtinger warning with a note explaining the real/imag storage strategy. uv run ruff check -> All checks passed! uv run pytest tests/test_phasor.py -> 10 passed. https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Same change as e1aba9f on claude/review-spyx-updates-TSiVb: the pre-existing workflow tested Python 3.9 / 3.10, linted with flake8, and installed via pip / requirements.txt. It didn't install spyx, so pytest failed immediately. Rewriting to: - Python 3.11 / 3.12 (matches pyproject.toml requires-python). - Use astral-sh/setup-uv + `uv sync --all-extras` so qwix and tonic are available for the quant / loader tests. - `uv run ruff check` for lint and `uv run pytest -v -k "not test_loaders_grain"` for tests (the loader tests hit an external dataset endpoint and flake on CI). Verified locally on this branch: uv run ruff check -> All checks passed! uv run pytest -k "not test_loaders_grain" -> 49 passed, 1 skipped. https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
`docs/examples/phasor/phasor_intro.ipynb` walks through the three halves of spyx.phasor: 1. Continuous: build a PhasorMLP and train it on a linearly-separable toy task with a stock optax.adam loop (regression guard against the convergence caveat that's now fixed via the real/imag param split). 2. Codec: phase_to_spikes / spikes_to_phase round-trip, with a visualisation showing the one-spike-per-cycle encoding. 3. Spiking: SpikingPhasor end-to-end — feed a spike train, recover phases, run the complex matmul, re-emit spikes. mkdocs.yml picks up the new notebook under "Tutorials > Phasor Networks". The notebook doesn't depend on the review-branch docs overhaul and renders on its own. https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What's in this PR
First-pass implementation of phasor and spiking-phasor networks (issue #38; closes the duplicate #13). Brings the 2022 Bybee/Frady/Sommer deep-phasor architecture to Spyx on top of Flax NNX and JAX's native complex dtype.
src/spyx/phasor.pyPhasorLinear(in, out, *, rngs)kernel_re/kernel_imfloat32nnx.Params — see the convergence note below.PhasorActivation()PhasorReadout(in, out, *, rngs)Re(W·z).PhasorMLP(in, hidden, out, depth=2, *, rngs)phase_to_spikes(θ, T)t = round((θ + π)/(2π) · T).spikes_to_phase(spikes, T)SpikingPhasor(layer, period_T)PhasorLinear— drop-in forspyx.nn.Sequential/nn.run.real_to_phasor,phasor_to_real,phase_ofThe convergence fix
The first iteration of this module stored the kernel as a single
complex64nnx.Param. JAX returns the conjugate Wirtinger derivative for real-valued losses w.r.t. a complex parameter; optax is built around real arithmetic and doesn't unwind the conjugation, sooptax.adamdrifted sideways on the imaginary axis and training couldn't converge.Fix: split the kernel (and bias) into
kernel_re+kernel_imas separate float32 params and reconstruct the complex tensor on every forward pass via@propertyaccessors. The public.kernel/.biassurface is unchanged; gradients optax sees are now always real. Commitef08817has the detail.Tests (
tests/test_phasor.py, 10 cases, all green)real_to_phasorround-trip.PhasorActivationunit-magnitude.PhasorMLPforward + finite gradients.PhasorMLPloss drops ≥ 40% on 200 adam steps of anint(x[0] > 0.5)task (was previously "doesn't NaN").2π/Tquantisation error.SpikingPhasorend-to-end.Demo
scripts/phasor_demo.pywalks the continuous and spiking forward passes plus the phase↔spike round-trip. Useful without needing the tutorial notebook runtime.What's not in this PR (tracked on #38)
spyx.quantintegration for the complex weights (treats each float32 pair as a regular dense layer — straightforward, deferred).Verification
Test plan
PhasorLinear.tests/test_phasor.pylocally.PhasorMLPinto one of the published SHD tutorials and compare.Dependencies on other PRs
Independent — can land before or after the review-cleanup PR. Only shares
src/spyx/__init__.pyandAGENTS.mdwith the NNX cleanup PR, and those conflicts are one-line (just the extra module name in__all__).Closes #38. Related: closes #13 (duplicate).
https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE