Quantization-aware training for LLMs with per-block learned codebooks. At 2.2B parameters and 3-bit precision, a fine-tuned NativeBit model matches its float counterpart on WikiText-103 (30.50 vs 30.51 perplexity) while beating post-hoc RTN quantization (31.33).
WikiText-103 perplexity, 2.2B model (26 layers, 2560 hidden, 6912 FFN; hidden and FFN sizes match BitNet b1.58-2B-4T).
| Method | 3-bit PPL | vs Float |
|---|---|---|
| Float baseline | 30.51 | 0% |
| Post-hoc RTN | 31.33 | +2.7% |
| NativeBit from-scratch (20K steps) | 34.23 | +12.2% |
| NativeBit via QAT (5K steps) | 30.50 | −0.03% |
The QAT recipe — load a trained float checkpoint, then fine-tune with NativeBit active for 5K steps using a commitment loss and canonical VQ-VAE EMA — is what closes the gap. Training NativeBit from scratch at this scale leaves a sizeable gap to float that post-hoc RTN doesn't have.
Compression: packed NativeBit 2.2B is 1.70 GB vs 8.76 GB float, about 5.1× smaller on disk.
Weights are quantized to one of K learned values per block (default K=8, block_size=128, i.e. 3-bit). The codebook entries are trainable and updated via EMA of raw sums and counts (canonical VQ-VAE style). Forward pass uses the quantized values; gradients flow to latent floats via straight-through estimation.
Three ingredients make the training converge to float quality:
- Commitment loss
$\lambda \cdot (\sum \lVert w - \mathrm{sg}(Q(w)) \rVert^2) / n_{\text{layers}}$ pulls latent weights onto codebook entries, reducing STE bias so forward and backward agree near the fixed point. Normalizing by layer count rather than weight count is why$\lambda \approx 1$ works; a per-weight mean would need$\lambda$ of order$10^4$ . - Canonical VQ-VAE EMA updates each codebook entry from raw per-entry sums and counts rather than from EMA-ed batch means. That count-weights the update correctly when clusters have uneven population, instead of treating a 1-sample and a 100-sample batch alike.
- QAT from a trained float checkpoint instead of from-scratch. Decouples "learn the task" from "adapt to the quantization constraint."
Embeddings, LM head, and RMSNorm parameters stay in float. Only the dense matmuls in attention and SwiGLU are quantized.
nativebit_jax/ JAX/Flax backend (TPU training at 2.2B scale)
layers.py NativeBitDense + compute_quant_reg + canonical EMA
model.py LLaMA-style transformer (RoPE, RMSNorm, SwiGLU)
train.py Training loop with QAT init, commitment loss,
periodic validation, full-config JSONL logging
codebook_utils.py Codebook init + EMA helpers
nativebit/ PyTorch backend (local GPU development)
inference/ Packed inference (Triton + CUDA dequant kernels)
configs/tpu.py Model configs (25M → 2.2B)
benchmarks/ Post-hoc quantization baselines for comparison
tests/ Unit tests (attention, quant-reg, canonical EMA, QAT init)
infra/ TPU provisioning + training launch scripts
The two key files are nativebit_jax/layers.py (NativeBitDense + requantize_params + compute_quant_reg + compute_quant_diagnostics) and nativebit_jax/train.py (the training loop). The JAX backend is the primary one — it's what produced the paper results on v6e-8 TPU. The PyTorch backend is maintained for local GPU iteration.
python -m nativebit_jax.train \
--config tpu-2b --dataset openwebtext \
--no-nativebit --name 2b_floatpython -m nativebit_jax.train \
--config tpu-2b --dataset openwebtext --name 2b_nbLands at +12.2% above float on WikiText-103 cross-eval at 3-bit — not competitive with post-hoc RTN. Used here as an ablation; for real quality use QAT.
python -m nativebit_jax.train \
--config tpu-2b --dataset openwebtext --name 2b_nb_qat \
--max-steps 5000 \
--init-from logs/2b_float_params.npz \
--lr 1e-4 --weight-decay 0.01 \
--warmup-steps 200 --delay-quant-steps 0 --ema-decay 0.99 \
--quant-reg-weight 1.0 --use-canonical-ema \
--val-every 500Matches float PPL on WikiText-103.
python benchmarks/benchmark_posthoc_2b.py --ckpt logs/2b_float_params.npzpip install -e ".[ci]"
JAX_PLATFORMS=cpu pytest tests/Covers the cross-position attention regression, canonical VQ-VAE EMA,
commitment-loss gradient, QAT init key translation (JAX); and
NativeBitLinear/NativeBitGPT/pack/generate/data/training-step (PyTorch).
CI runs the equivalent on every push — pytest tests/ -v with
JAX_PLATFORMS=cpu across Python 3.10–3.12 (.github/workflows/tests.yml).
Standalone diagnostic scripts that need real checkpoints or TPU hardware
(stale-cache drift analysis, FSDP sharding, Pallas/Triton kernels) live
in scripts/debug/ and are run manually, not under pytest.
Each run writes a JSONL log. Example records (one of each type):
{"type": "header", "schema_version": 2, "git_hash": "658c381", "argv": [...], "config": {...}, "init_from": "logs/2b_float_params.npz"}
{"type": "init_eval", "step": 0, "val_ppl": 18.98, "val_loss": 2.94, "val_batches": 64}
{"step": 100, "loss": 2.94, "perplexity": 18.82, "quant_err_rms": 0.0036, "cb_utilization": 1.0, "dead_frac": 0.0, "quant_reg": 206.5, "lambda": 0.08}
{"step": 500, "loss": 2.89, "perplexity": 17.98, "val_loss": 2.93, "val_ppl": 18.67, "val_batches": 32}
{"type": "eval", "test_ppl": 18.40, "wt103_test_ppl": 30.50, "total_time_s": 7400.0}quant_regandlambdaappear only when commitment loss is on.val_ppl,val_loss,val_batchesappear on per-step records atstep % val_every == 0.- Everything in the header is JSON-serialisable, including the full dataclass config.
compute_quant_diagnostics(params) in layers.py is a pure function returning {quant_error_rms, codebook_utilization, dead_entries_frac} for any params tree. Useful for ad-hoc monitoring outside the training loop.
Training checkpoints pack into a minimal format: per-block 3-bit-packed indices + fp32 codebook tables + unquantized embeddings/norms. The 2.2B model packs to 1.70 GB, about 5.1× smaller than the 8.76 GB fp32 float checkpoint.
# Pack a trained NativeBit checkpoint
python inference/pack.py logs/2b_nb_qat_params.npz --out inference/2b_nb.nbpack.npz
# Generate (JAX / TPU / CPU)
python inference/generate.py inference/2b_nb.nbpack.npz --packed --benchmark
# Generate (PyTorch + Triton dequant kernel on GPU)
python inference/generate_torch.py inference/2b_nb.nbpack.npz --benchmarkThe packed generate_torch.py path uses a fused dequant-matvec kernel that reads uint8 codebook indices directly from VRAM, avoiding the materialized-weight-matrix bottleneck of single-token decode.
JAX's jax.nn.dot_product_attention(q, k, v, is_causal=True) with input layout (B, H, T, D) computes attention across heads, not positions. An early version of this code shipped that bug. The current code uses an explicit einsum with an fp32 softmax for stability under NativeBit quantization. If you swap attention, add a context-sensitivity probe: KL(predict(full_context), predict(single_token)) > 0.1 confirms the model actually uses context. The broken version gave 0.0001.
The forward computes
rather than requantize_every steps (default 10) instead of every step, which matters for throughput. Commitment loss keeps the latent weights close enough to codebook entries that stale-cache drift stays small between refreshes.
The biggest matrix, the embedding, is left unquantized partly because it's tied to the LM head, and a tied embedding needs different quantization treatment than a plain dense layer. That was a scope decision, not a fundamental limit.
Validated at 2.2B only. The 125M and 350M results in git history used an earlier JAX implementation with the cross-head/cross-position attention bug; they don't reproduce and shouldn't be cited.
QAT is the recommended recipe. NativeBit from-scratch at 2.2B lands +12.2% above float at 3-bit, which is worse than post-hoc RTN. The method's claim is "trainable quantization that matches float via short fine-tuning," not "matches float from random init."
2-bit untested with fixed attention. The earlier 2-bit wins over k-means post-hoc came from pre-fix runs.
Single training dataset (OpenWebText), single evaluation point (WikiText-103 cross-eval), single seed. Other domains (code, multilingual) and multi-seed variance are not measured.
Baseline numbers in the table come from a 2026-04-18 validation: float, RTN, and k-means rows from benchmarks/benchmark_posthoc_2b.py on 2b_float_fixed_params.npz; the from-scratch and QAT WikiText-103 cross-evals from their respective training logs. A fresh single-CPU re-eval of the float baseline lands at 30.50 (consistent with 30.51 within eval-protocol noise). Running the full post-hoc sweep at 2.2B needs more than ~32 GB host RAM or a TPU — the script holds the float and quantized params live simultaneously.
- BitNet b1.58 (the ternary-LLM method NativeBit is benchmarked against): Ma et al., "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits," 2024. arxiv.org/abs/2402.17764
- BitNet b1.58 2B4T (the 2B model whose hidden/FFN width this config matches): Ma et al., "BitNet b1.58 2B4T Technical Report," 2025. arxiv.org/abs/2504.12285
- VQ-VAE (origin of the EMA codebook update used here, Appendix A.1): van den Oord et al., "Neural Discrete Representation Learning," 2017. arxiv.org/abs/1711.00937
- Straight-through estimator: Bengio et al., "Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation," 2013. arxiv.org/abs/1308.3432
MIT