Skip to content

hinanohart/circuitbench

Repository files navigation

circuitbench

CI License: MIT Python

Mechanistic interpretability + sparse autoencoder framework for Hybrid SSM-Attention models, with first-class support for pure SSMs.

Where TransformerLens / SAELens dominate Transformer interpretability, circuitbench fills the gap for post-Transformer architectures: Mamba-2, Hymba, Jamba, Falcon-H1, RWKV-7.

v0.1.x scope. v0.1.x ships the API surface + CPU MockSSMAdapter so the harness is end-to-end runnable without GPUs or model downloads. Real model weights, JumpReLU SAEs, and step-wise h_t patching land in v0.2. See Status for the precise per-component split.


What is this?

circuitbench is a research harness for understanding how state-space models (SSMs) compute. It provides four integrated operations over a common hook-point abstraction that maps onto Mamba-2's internal tensor sites:

  • load_model — adapter registry; v0.1.x ships MockSSMAdapter (CPU); real weights in v0.2
  • train_sae — TopK sparse autoencoders trained on SSM-specific hook points
  • extract_circuit — coarse layer-level mean-ablation activation patching
  • steer — additive feature-direction intervention during the forward pass

The same API surface works identically for the CPU mock (available now) and for real model weights (v0.2).


Why circuitbench

SSMs and hybrid SSM-attention models have grown into a serious alternative to pure Transformers, but the mechanistic interpretability tooling has not caught up. Existing libraries either:

  • Hard-bake Transformer-only assumptions (residual streams indexed by layer × position), or
  • Provide raw hooks without SAE training, circuit discovery, or steering glue.

circuitbench provides one integrated harness for all four.


Install

v0.1 is alpha and not yet on PyPI (planned for v0.2 with trusted publisher). Install from source:

git clone https://github.com/hinanohart/circuitbench.git
cd circuitbench
pip install -e .                       # core only (torch + numpy + einops + jaxtyping + pydantic)
pip install -e ".[ssm,hf]"             # placeholders for v0.2: mamba-ssm + HF transformers (GPU)
pip install -e ".[sae]"                # placeholder for v0.2: SAELens interop
pip install -e ".[dev]"                # development (pytest, ruff, mypy)

[ssm] / [hf] / [sae] installs do not yet unlock real backends in v0.1.x — they are reserved so the install path stays stable across the v0.1 → v0.2 transition.


Quick start

from circuitbench import load_model, train_sae, extract_circuit, steer

# v0.1.x ships a CPU-only MockSSMAdapter; real Mamba-2 / Hymba weights arrive
# in v0.2 (need `mamba-ssm` + GPU). The API surface is identical either way.
model = load_model("mock://mamba2-tiny", hook_point="out_proj_in")
sae = train_sae(model, layer=1, k=32, expansion=8, tokens=2048, batch_size=64)
circuit = extract_circuit(model, prompt="Paris is the capital of", target="France")
out = steer(model, prompt="Hello", feature_id=42, strength=2.0, sae=sae, layer=1)

print(circuit.top_layers(n=3))           # [(layer, ablation effect), ...]
print(out.delta_norm)                    # L2 shift in final output under steering

See examples/:

  • 01_load.py, 02_train_sae.py, 03_steer.py — runnable on CPU in seconds
  • titans_hook.py — v0.2 contrib stub (prints a marker; raises NotImplementedError when called)

How it works

Hook points (SSM-specific)

circuitbench defines five hook sites that map onto Mamba-2's internal computation path. The data flow inside each SSM block is:

x → x_proj → split(u, z, s)
              └── u → conv1d ──→ c
                                 └── ssm(c, s) → ssm_y           [H3]
                                                  └── gate(z)
                                                      └── post_gate [H1] → out_proj → +x → output
ID Location Shape Capture Additive Intervention Substitution
H1 out_proj_in (post-gate, pre-projection) (B, L, d_inner)
H2 x_proj (gate/input/dt projection) (B, L, 2*d_inner + d_state)
H3 ssm_y (SSM output, pre-gate) (B, L, d_inner)
H4 hidden_state_h (the SSM state itself) (B, L, D, N) v0.2
H5 conv1d (short-conv branch output) (B, L, d_inner)

Default for SAE training: H1 (post-gate, pre-projection) — analogous to a Transformer's residual stream input.

H1 and H3 are distinct tensors: the gate y * sigmoid(z) sits between them.

Circuit extraction (v0.1.x)

For each candidate layer L, circuitbench:

  1. Runs a clean forward pass and captures the activation at (L, hook_point).
  2. Replaces that activation with its sequence mean (mean-ablation) and re-runs the forward.
  3. Records ‖clean_output − ablated_output‖₂ as the layer's effect score.

The layer with the largest shift is the one the prompt depends on most.

Step-wise h_t patching and target-logit projection are planned for v0.2.

Architecture

circuitbench architecture

Differentiation (design goals — implementation status)

Axis Status
Hybrid head separation SAE (Hymba / Jamba attention vs SSM heads trained separately) design goal, v0.2
ssm_state direct SAE (SAE over h_t ∈ (B, L, D, N)) capture shipped (mock); full SAE training v0.2
State-propagation circuit (step-wise patching for recurrent models) coarse layer-level mean-ablation shipped; step-wise h_t v0.2
RWKV-7 first-class (loader + hook points for RWKV-7 Goose) design goal, v0.2

Status

Component v0.1.x (shipped) v0.2 (planned)
load_model (registry + MockSSMAdapter CPU backend) ✅ shipped + real Mamba-2 / Hymba / Jamba / Falcon-H1 / RWKV-7 weights
train_sae (TopK, k=32, 8× expansion, decoder unit-norm) ✅ shipped + JumpReLU, dead-feature resample (200k step)
extract_circuit (coarse layer-level mean-ablation patching) ✅ shipped + step-wise h_t patching, target-logit projection, hybrid head separation
steer (additive feature intervention during forward) ✅ shipped + composable interventions, beam search
HF Hub SAE distribution (SAELens-compatible) planned shipped
Multi-Agent SAE namespace reserved shipped
PyPI publish install from source shipped (trusted publisher)
arXiv preprint (v0.1 harness paper) deferred shipped

Acknowledgments

Inspired by (no runtime dependency in v0.1.x — these projects are not imported; [sae] extra reserves SAELens interop for v0.2):

  • SAELens — production SAE library; circuitbench's v0.2 will export SAEs in a SAELens-compatible format
  • TransformerLens — hook-based interpretability primitives
  • MambaLens — early Mamba interpretability work
  • mamba-ssm — official Mamba/Mamba-2 reference implementation

Related projects

Part of hinanohart's open-source portfolio:

  • transduce — composable transducer streams
  • exitkit — Nozick closest-continuer model identity over PAM snapshots
  • subjunctor — Nozick-grounded LLM agent gate

License

MIT — see LICENSE.

About

Integrated mechanistic interpretability + sparse autoencoder framework for Hybrid SSM-Attention models (Mamba-2, Hymba, RWKV-7). v0.1.2 alpha: real forward-pass intervention + mean-ablation patching shipped, CPU smoke; GPU/real adapters in v0.2.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages