Skip to content

Add spyx.phasor — phasor and spiking phasor networks#41

Open
kmheckel wants to merge 9 commits into
mainfrom
claude/spyx-phasor
Open

Add spyx.phasor — phasor and spiking phasor networks#41
kmheckel wants to merge 9 commits into
mainfrom
claude/spyx-phasor

Conversation

@kmheckel

Copy link
Copy Markdown
Owner

What's in this PR

First-pass implementation of phasor and spiking-phasor networks (issue #38; closes the duplicate #13). Brings the 2022 Bybee/Frady/Sommer deep-phasor architecture to Spyx on top of Flax NNX and JAX's native complex dtype.

src/spyx/phasor.py

Class / function Purpose
PhasorLinear(in, out, *, rngs) Complex-valued dense layer. Weights stored as paired kernel_re / kernel_im float32 nnx.Params — see the convergence note below.
PhasorActivation() Unit-circle projection (TPAM threshold function).
PhasorReadout(in, out, *, rngs) Real logits via Re(W·z).
PhasorMLP(in, hidden, out, depth=2, *, rngs) Convenience stack: encode → (Linear → Activation) × depth → readout.
phase_to_spikes(θ, T) One-spike-per-cycle codec, single spike at t = round((θ + π)/(2π) · T).
spikes_to_phase(spikes, T) Centroid decoder; silent units map to −π.
SpikingPhasor(layer, period_T) Spiking-inference wrapper around a PhasorLinear — drop-in for spyx.nn.Sequential / nn.run.
real_to_phasor, phasor_to_real, phase_of Encoding helpers.

The convergence fix

The first iteration of this module stored the kernel as a single complex64 nnx.Param. JAX returns the conjugate Wirtinger derivative for real-valued losses w.r.t. a complex parameter; optax is built around real arithmetic and doesn't unwind the conjugation, so optax.adam drifted sideways on the imaginary axis and training couldn't converge.

Fix: split the kernel (and bias) into kernel_re + kernel_im as separate float32 params and reconstruct the complex tensor on every forward pass via @property accessors. The public .kernel / .bias surface is unchanged; gradients optax sees are now always real. Commit ef08817 has the detail.

Tests (tests/test_phasor.py, 10 cases, all green)

  • Dtype / shape contracts, real_to_phasor round-trip.
  • PhasorActivation unit-magnitude.
  • PhasorMLP forward + finite gradients.
  • Convergence regression: PhasorMLP loss drops ≥ 40% on 200 adam steps of an int(x[0] > 0.5) task (was previously "doesn't NaN").
  • Phase ↔ spike round-trip within 2π/T quantisation error.
  • Silent-neuron handling.
  • SpikingPhasor end-to-end.

Demo

scripts/phasor_demo.py walks the continuous and spiking forward passes plus the phase↔spike round-trip. Useful without needing the tutorial notebook runtime.

What's not in this PR (tracked on #38)

  • SHD benchmark comparing phasor-MLP vs. LIF baseline.
  • spyx.quant integration for the complex weights (treats each float32 pair as a regular dense layer — straightforward, deferred).
  • Associative-memory / TPAM retrieval demo.
  • Hierarchical phasor networks.

Verification

uv run ruff check                 # clean
uv run pytest tests/test_phasor.py  # 10 passed
uv run python scripts/phasor_demo.py   # expected statistics printed

Test plan

  • Review the Wirtinger fix in PhasorLinear.
  • Run tests/test_phasor.py locally.
  • Try swapping a PhasorMLP into one of the published SHD tutorials and compare.

Dependencies on other PRs

Independent — can land before or after the review-cleanup PR. Only shares src/spyx/__init__.py and AGENTS.md with the NNX cleanup PR, and those conflicts are one-line (just the extra module name in __all__).

Closes #38. Related: closes #13 (duplicate).

https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE

kmheckel and others added 9 commits January 4, 2026 14:23
Polishing pass on the NNX-migration branch ahead of merge:

* README.md / docs/index.md / AGENTS.md no longer claim Spyx is built on
  Haiku and accurately describe the new axn/fn/nn surface (no scan_snn
  or Axon class). Add a note that research/ notebooks predate the
  migration.
* nn.ActivityRegularization required no shape and used a Variable(None)
  with lazy init that won't trace under nnx.jit. Take an explicit
  hidden_shape/batch_size, expose a reset() helper, and use the
  variable[...] API.
* nir.from_nir now wires recurrent_w + tau parameters when importing
  RNN subgraphs (previously a `pass` placeholder), and to_nir grows a
  matching helper that emits an inner NIRGraph for RIF/RLIF/RCuBaLIF
  exports so the roundtrip is symmetric.
* tests/test_nn_nnx covers IF/ALIF/CuBaLIF/RIF/RCuBaLIF/LI/SumPool plus
  ActivityRegularization under nnx.jit. New tests/test_axn.py and
  tests/test_fn.py exercise every surrogate gradient and every loss /
  regularizer. tests/test_nir.py adds RLIF and RCuBaLIF roundtrips.

After this change `uv run pytest` reports 33 passed, 2 failed (only the
tonic-gated loader tests, which need `uv sync --extra loaders`).
`uv run ruff check` is clean.

https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Rewrote the five docs/examples notebooks that were still on the legacy
Haiku + spyx.loaders stack. Each notebook now uses the same modernised
pattern: nnx.Module + nn.Sequential, optimisation via nnx.Optimizer +
nnx.value_and_grad, data via spyx.data.SHD_loader (or NMNIST_loader)
streamed through Google Grain, with jnp.unpackbits applied inline to
recover dense spikes from the bit-packed time axis.

* docs/examples/neuroevolution/cartpole_evo.ipynb - rebuilt the spiking
  controller as an nnx.Module; rolled out per-episode under jax.lax.scan;
  evosax CMA-ES drives a flat parameter vector via ravel_pytree +
  nnx.merge.
* docs/examples/surrogate_gradient/SurrogateGradientTutorial.ipynb -
  Linear+LIF+LIF+LI BPTT trainer on SHD using the Lion optimizer.
* docs/examples/surrogate_gradient/shd_sg_template.ipynb - reusable
  template that taps per-layer spike trains as scan outputs and applies
  silence_reg + sparsity_reg in the loss (avoiding mutable activity
  variables under nnx.jit).
* docs/examples/surrogate_gradient/shd_sg_neuron_model_comparison.ipynb -
  sweeps LIF / IF / ALIF / RLIF / mixed variants through a shared
  training helper.
* docs/examples/surrogate_gradient/shd_sg_surrogate_comparison.ipynb -
  same shared trainer parameterised by surrogate gradient (arctan,
  superspike, tanh, boxcar, triangular, STE).

Mixed precision (haiku.mixed_precision) has no first-class NNX analog
yet and has been dropped, with a note in each tutorial pointing users at
jax.config.update('jax_default_matmul_precision', 'bfloat16') as a
substitute. Cell outputs are intentionally cleared since CI does not
execute notebooks; rerun locally with the [loaders] extra installed
before tagging a release.

https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Closes #36 and addresses #37.

* spyx.quant: thin SNN-aware wrapper around Google's qwix library.
  - quantize(model, *example_inputs, rules=, mode="qat"|"ptq")
  - linear_only_rules() / weights_only_rules() shortcuts
  - available() guard + actionable ImportError when qwix is missing
  Default rules quantize only nnx.Linear / nnx.Conv (int8 W+A); spiking
  dynamics (LIF, CuBaLIF, ALIF, IF, LI) stay in fp32 because their
  recurrences are sensitive to integer rounding.
* New optional [quant] extra in pyproject.toml; qwix has no PyPI
  release so it's wired in via tool.uv.sources to install from GitHub.
* Bump flax >= 0.11 (which forced python >= 3.11). flax 0.11 changed
  nnx.Optimizer to require wrt= and made optimizer.update(model, grads)
  the canonical signature.
* Update all five docs/examples notebooks to the new optimizer API:
  every nnx.Optimizer(...) gains wrt=nnx.Param and every
  optimizer.update(grads) becomes optimizer.update(model, grads).
* cartpole_evo: switch to evosax >= 0.2 API
  (CMA_ES(population_size=, solution=) + init/ask/tell with key args).
  Smoke-tested end-to-end with 8-pop / 3-gen / 50-step rollouts; CMA-ES
  was already improving rewards by gen 2.
* New docs/examples/quantization/qat_intro.ipynb tutorial walking
  through QAT, PTQ, custom rules, and the int4 weights-only path.
  Wired into mkdocs.yml.
* tests/test_quant.py covers the available() guard, smoke quantize,
  full nnx.Optimizer + value_and_grad QAT loop, the rules helpers,
  and the unknown-mode error path. 6 passed (1 conditionally skipped
  when qwix is missing).
* AGENTS.md: document the new module, extra, and dependency version.

uv run ruff check     -> All checks passed!
uv run pytest         -> 39 passed, 2 failed (only the SHD/NMNIST
                         loader tests, which need network access to
                         download the datasets).

https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Initial implementation for #38. Provides a Flax NNX module set for the
deep phasor / spiking phasor architecture (Bybee, Frady & Sommer 2022,
arXiv 2106.11908) that takes advantage of JAX's native complex64 dtype:

* PhasorLinear  - complex-valued dense layer (Glorot-init complex weights).
* PhasorActivation - projects activations onto the unit circle (the TPAM
  threshold function).
* PhasorReadout - real-valued logits via Re(W @ z).
* PhasorMLP     - convenience constructor; encodes real x -> phasors,
  stacks N (Linear, Activation) blocks, then a real readout.
* phase_to_spikes / spikes_to_phase - one-spike-per-cycle codec with
  centroid-based decoding.
* SpikingPhasor - drop-in spiking-inference wrapper around a single
  PhasorLinear; takes a spike train, recovers phases, runs the complex
  matmul, and re-emits spikes.

Tests (tests/test_phasor.py, 10 cases) cover dtype/shape contracts, the
phase<->spike round-trip within the bin-quantisation tolerance, the
SpikingPhasor end-to-end path, and that PhasorMLP forward+backward+optax
stays numerically finite over many steps.

Caveat documented in the module docstring and in #38: convergence with
plain optax.adam on complex parameters is fragile because JAX returns
the *conjugate* Wirtinger derivative on real-valued losses. Optax does
not unwind the conjugation, so the imaginary axis can drift. A faithful
forward/backward path is in place; tuned training (manual Wirtinger
updates or real/imag parameter splitting) is left as the next step.

scripts/phasor_demo.py walks through the continuous and spiking forward
passes and prints round-trip statistics; a useful sanity check without
needing the tutorial notebook runtime.

uv run ruff check  -> All checks passed!
uv run pytest      -> 49 passed, 2 failed, 1 skipped (the failures are
                      the pre-existing tonic-gated SHD/NMNIST loader
                      tests that need network access).

https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Issue #38 flagged a training convergence bug in the first iteration of
PhasorLinear: JAX returns the conjugate Wirtinger derivative for a
real-valued loss with respect to a complex parameter, and optax's real
arithmetic doesn't unwind the conjugation, so adam steps drifted
sideways on the imaginary axis and the network couldn't learn an
int(x[0] > 0.5) task.

Fix: split the complex kernel (and bias) into separate float32 kernel_re
and kernel_im (plus bias_re / bias_im) nnx.Params. The forward pass
reconstructs the complex kernel on demand through @Property accessors,
so the public API (PhasorLinear.kernel / .bias) is unchanged. Gradients
optax sees are now always real.

The test_phasor_mlp_training_reduces_loss check is tightened from "no
NaN" to "loss drops by > 40% on 200 steps of adam"; it was previously
weakened around the bug.

Module docstring replaces the Wirtinger warning with a note explaining
the real/imag storage strategy.

uv run ruff check   -> All checks passed!
uv run pytest tests/test_phasor.py -> 10 passed.

https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Same change as e1aba9f on claude/review-spyx-updates-TSiVb: the
pre-existing workflow tested Python 3.9 / 3.10, linted with flake8,
and installed via pip / requirements.txt. It didn't install spyx, so
pytest failed immediately. Rewriting to:

- Python 3.11 / 3.12 (matches pyproject.toml requires-python).
- Use astral-sh/setup-uv + `uv sync --all-extras` so qwix and tonic
  are available for the quant / loader tests.
- `uv run ruff check` for lint and `uv run pytest -v -k
  "not test_loaders_grain"` for tests (the loader tests hit an
  external dataset endpoint and flake on CI).

Verified locally on this branch:
  uv run ruff check                 -> All checks passed!
  uv run pytest -k "not test_loaders_grain" -> 49 passed, 1 skipped.

https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
`docs/examples/phasor/phasor_intro.ipynb` walks through the three
halves of spyx.phasor:

1. Continuous: build a PhasorMLP and train it on a linearly-separable
   toy task with a stock optax.adam loop (regression guard against the
   convergence caveat that's now fixed via the real/imag param split).
2. Codec: phase_to_spikes / spikes_to_phase round-trip, with a
   visualisation showing the one-spike-per-cycle encoding.
3. Spiking: SpikingPhasor end-to-end — feed a spike train, recover
   phases, run the complex matmul, re-emit spikes.

mkdocs.yml picks up the new notebook under "Tutorials > Phasor
Networks". The notebook doesn't depend on the review-branch docs
overhaul and renders on its own.

https://claude.ai/code/session_01HdABdqRYRXvQUEtg9YdREE
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add spyx.phasor — Phasor / Spiking Phasor Networks Add Phasor layers/network capability

2 participants