Fused CUDA kernels for the Muon optimizer's Newton-Schulz iteration
Quick Start • Benchmarks • How It Works • 中文文档 • Optimization Report
Drop-in replacement for the Muon optimizer that accelerates the Newton-Schulz orthogonalization by exploiting matrix symmetry with custom CuTe SYRK kernels. Achieves stable 1.5x speedup vs torch.compile on dedicated A800 with identical training dynamics.
- 🔺 SYRK symmetry exploitation —
X @ Xᵀis symmetric, compute only the lower triangle → 50% FLOP savings - 🔗 Fused GEMM epilogue —
c·A² + b·A + a·Iin a single kernel, eliminates 2 extra kernel launches - 🎯 Adaptive dispatch — 64/128 tile, Split-K, cuBLAS fallback, auto-selected per shape
- 🔌 Drop-in API —
from muon_fused import FusedMuon, identical interface to standard Muon - 🛡️ Graceful fallback — automatically falls back to pure PyTorch when CUDA extension is unavailable
git clone --recursive https://github.com/StarrickLiu/fused-muon.git
cd fused-muon
pip install -e .from muon_fused import FusedMuon
# Drop-in replacement — same API as standard Muon
optimizer = FusedMuon(model.parameters(), lr=0.02, momentum=0.95, ns_steps=5)Or use just the optimized Newton-Schulz function in your own optimizer:
from muon_fused import fused_newton_schulz
# Replace zeropower_via_newtonschulz5() with this
X_ortho = fused_newton_schulz(G, steps=5)All benchmarks measured on a dedicated NVIDIA A800-SXM4-80GB (no other processes), BF16, with 500-iteration warmup per method to ensure stable GPU frequency.
Baseline is @torch.compile'd vanilla NS — steady-state comparison after compilation warmup:
| Shape (m, n) | torch.compile (us) |
Fused SYRK (us) | Speedup |
|---|---|---|---|
| (896, 1152) | 465 | 269 | 1.73x |
| (896, 896) | 478 | 232 | 2.05x |
| (2048, 2560) | 2114 | 1403 | 1.51x |
| (2048, 2048) | 1965 | 1305 | 1.51x |
| (2560, 4096) | 3443 | 2349 | 1.47x |
| (3584, 4608) | 7687 | 5228 | 1.47x |
| (3584, 3584) | 6587 | 4430 | 1.49x |
| (4096, 4096) | 9351 | 6294 | 1.49x |
Consistent ~1.5x speedup across all tested shapes vs
torch.compile(post-warmup steady state). Small shapes (m ≤ 896) benefit more (1.7–2.0x) due to proportionally larger overhead eliminated by kernel fusion.
![]() |
![]() |
![]() |
![]() |
FusedMuon and VanillaMuon produce identical training curves (loss & accuracy overlap), confirming numerical equivalence.
| Model | Layer | Shape (m, n) | GEMM1 | GEMM2 | E2E Speedup |
|---|---|---|---|---|---|
| Qwen 3B | QKV | (2048, 2560) | 1.59x | 1.69x | 1.38x |
| Qwen3-4B | QKV | (2560, 6144) | 1.72x | 1.64x | 1.32x |
| Qwen3-4B | Down | (2560, 9728) | 1.80x | 1.64x | 1.31x |
| Qwen 7B | QKV | (3584, 4608) | 1.67x | 1.78x | 1.38x |
| Qwen 7B | O | (3584, 3584) | 1.61x | 1.78x | 1.39x |
| Standard | — | (4096, 4096) | 1.66x | 1.84x | 1.42x |
Each Newton-Schulz iteration step computes:
with coefficients
GEMM1: A = X @ Xᵀ → CuTe SYRK (50% FLOPs saved via symmetry)
GEMM2: B = c·A² + b·A + a·I → CuTe SYRK + fused polynomial epilogue (1 kernel)
GEMM3: X_new = B @ X → cuBLAS (standard GEMM, ~80% MFU)
| Technique | Applied To | Benefit |
|---|---|---|
| SYRK lower-triangle | GEMM1, GEMM2 | 50% compute reduction (only lower triangle tiles) |
| Fused polynomial epilogue | GEMM2 | Merges c·acc + b·A + a·I into SYRK kernel, eliminates 2 extra kernels |
| Dual-write epilogue | GEMM1, GEMM2 | Outputs full symmetric matrix via smem transpose, +2μs overhead |
BYPASS L1 (cp.async.cg) |
All SYRK | Avoids L1 cache pollution for large working sets, +56% speedup |
| Split-K | GEMM1 (n ≫ m) | Splits K-dimension across blocks for better SM utilization |
| Adaptive tile dispatch | All | 64×64 tile for small m (more blocks), 128×128 for large m (higher arithmetic intensity) |
m ≤ 1280: cuBLAS fallback (too few SYRK blocks)
m ≥ 2048, n/m ≤ 4: CuTe SYRK 64×64 or 128×128
5 ≤ n/m ≤ 8: CuTe SYRK 128×128 + Split-K
n/m > 8: cuBLAS fallback (extreme aspect ratio)
See docs/optimization_report.md for full NCU profiling data, register pressure analysis, and detailed optimization history.
from muon_fused import FusedMuon
# 2D+ params → Muon (SGD momentum + NS orthogonalization)
# 1D params (bias, norm) → AdamW fallback
param_groups = [
{"params": [p for p in model.parameters() if p.ndim >= 2], "use_muon": True},
{"params": [p for p in model.parameters() if p.ndim < 2], "use_muon": False},
]
optimizer = FusedMuon(param_groups, lr=0.02, momentum=0.95, ns_steps=5)from muon_fused import fused_newton_schulz
# Use in any custom optimizer
# Handles: bf16 cast, normalization, transpose (if m > n), multi-step iteration
X_ortho = fused_newton_schulz(gradient_matrix, steps=5)| Requirement | Version |
|---|---|
| NVIDIA GPU | SM80+ (A100, A800, H100, H200) |
| PyTorch | ≥ 2.0 |
| CUDA Toolkit | ≥ 11.8 |
| CUTLASS | Included as git submodule (header-only) |
git clone --recursive https://github.com/StarrickLiu/fused-muon.git
cd fused-muon
pip install -e .If CUTLASS submodule is missing:
git submodule update --init --recursivepytest tests/ -vpython benchmarks/bench_ns_step.py # NS step breakdown
python benchmarks/train_cifar10.py # CIFAR-10 training
python benchmarks/plot_results.py # Generate figuresImportant: For reliable results, use a dedicated GPU with no other processes. GPU DVFS causes significant measurement variance on shared machines.
@software{fused_muon,
title = {Fused Muon: CUDA-Optimized Newton-Schulz Iteration for the Muon Optimizer},
author = {Xingchen Liu},
year = {2025},
url = {https://github.com/StarrickLiu/fused-muon}
}- Muon optimizer by Keller Jordan — the original Muon algorithm
- CUTLASS & CuTe by NVIDIA — tensor core abstraction
- Moonlight by Moonshot AI — distributed Muon implementation reference



