Skip to content

StarrickLiu/fused-muon

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

⚡ fused-muon

Fused CUDA kernels for the Muon optimizer's Newton-Schulz iteration

License CUDA PyTorch BF16

Quick StartBenchmarksHow It Works中文文档Optimization Report


Drop-in replacement for the Muon optimizer that accelerates the Newton-Schulz orthogonalization by exploiting matrix symmetry with custom CuTe SYRK kernels. Achieves stable 1.5x speedup vs torch.compile on dedicated A800 with identical training dynamics.

✨ Highlights

  • 🔺 SYRK symmetry exploitationX @ Xᵀ is symmetric, compute only the lower triangle → 50% FLOP savings
  • 🔗 Fused GEMM epiloguec·A² + b·A + a·I in a single kernel, eliminates 2 extra kernel launches
  • 🎯 Adaptive dispatch — 64/128 tile, Split-K, cuBLAS fallback, auto-selected per shape
  • 🔌 Drop-in APIfrom muon_fused import FusedMuon, identical interface to standard Muon
  • 🛡️ Graceful fallback — automatically falls back to pure PyTorch when CUDA extension is unavailable

🚀 Quick Start

Installation

git clone --recursive https://github.com/StarrickLiu/fused-muon.git
cd fused-muon
pip install -e .

Usage

from muon_fused import FusedMuon

# Drop-in replacement — same API as standard Muon
optimizer = FusedMuon(model.parameters(), lr=0.02, momentum=0.95, ns_steps=5)

Or use just the optimized Newton-Schulz function in your own optimizer:

from muon_fused import fused_newton_schulz

# Replace zeropower_via_newtonschulz5() with this
X_ortho = fused_newton_schulz(G, steps=5)

📊 Benchmarks

All benchmarks measured on a dedicated NVIDIA A800-SXM4-80GB (no other processes), BF16, with 500-iteration warmup per method to ensure stable GPU frequency.

Newton-Schulz Step Speedup (5 iterations, vs torch.compile)

Baseline is @torch.compile'd vanilla NS — steady-state comparison after compilation warmup:

Shape (m, n) torch.compile (us) Fused SYRK (us) Speedup
(896, 1152) 465 269 1.73x
(896, 896) 478 232 2.05x
(2048, 2560) 2114 1403 1.51x
(2048, 2048) 1965 1305 1.51x
(2560, 4096) 3443 2349 1.47x
(3584, 4608) 7687 5228 1.47x
(3584, 3584) 6587 4430 1.49x
(4096, 4096) 9351 6294 1.49x

Consistent ~1.5x speedup across all tested shapes vs torch.compile (post-warmup steady state). Small shapes (m ≤ 896) benefit more (1.7–2.0x) due to proportionally larger overhead eliminated by kernel fusion.

CIFAR-10 Training: FusedMuon vs VanillaMuon vs AdamW

FusedMuon and VanillaMuon produce identical training curves (loss & accuracy overlap), confirming numerical equivalence.

End-to-End NS Iteration Speedup (Qwen Model Shapes)

Model Layer Shape (m, n) GEMM1 GEMM2 E2E Speedup
Qwen 3B QKV (2048, 2560) 1.59x 1.69x 1.38x
Qwen3-4B QKV (2560, 6144) 1.72x 1.64x 1.32x
Qwen3-4B Down (2560, 9728) 1.80x 1.64x 1.31x
Qwen 7B QKV (3584, 4608) 1.67x 1.78x 1.38x
Qwen 7B O (3584, 3584) 1.61x 1.78x 1.39x
Standard (4096, 4096) 1.66x 1.84x 1.42x

🔬 How It Works

Each Newton-Schulz iteration step computes:

$$X_{k+1} = \left(aI + bA + cA^2\right) X_k, \quad A = X_k X_k^\top$$

with coefficients $(a, b, c) = (3.4445, -4.7750, 2.0315)$. This decomposes into 3 GEMMs:

GEMM1: A = X @ Xᵀ           → CuTe SYRK (50% FLOPs saved via symmetry)
GEMM2: B = c·A² + b·A + a·I → CuTe SYRK + fused polynomial epilogue (1 kernel)
GEMM3: X_new = B @ X         → cuBLAS (standard GEMM, ~80% MFU)

Key Kernel Optimizations

Technique Applied To Benefit
SYRK lower-triangle GEMM1, GEMM2 50% compute reduction (only lower triangle tiles)
Fused polynomial epilogue GEMM2 Merges c·acc + b·A + a·I into SYRK kernel, eliminates 2 extra kernels
Dual-write epilogue GEMM1, GEMM2 Outputs full symmetric matrix via smem transpose, +2μs overhead
BYPASS L1 (cp.async.cg) All SYRK Avoids L1 cache pollution for large working sets, +56% speedup
Split-K GEMM1 (n ≫ m) Splits K-dimension across blocks for better SM utilization
Adaptive tile dispatch All 64×64 tile for small m (more blocks), 128×128 for large m (higher arithmetic intensity)

Adaptive GEMM1 Dispatch

m ≤ 1280:            cuBLAS fallback (too few SYRK blocks)
m ≥ 2048, n/m ≤ 4:   CuTe SYRK 64×64 or 128×128
5 ≤ n/m ≤ 8:         CuTe SYRK 128×128 + Split-K
n/m > 8:             cuBLAS fallback (extreme aspect ratio)

See docs/optimization_report.md for full NCU profiling data, register pressure analysis, and detailed optimization history.


📦 Advanced Usage

Parameter Groups (Recommended for LLM Training)

from muon_fused import FusedMuon

# 2D+ params → Muon (SGD momentum + NS orthogonalization)
# 1D params (bias, norm) → AdamW fallback
param_groups = [
    {"params": [p for p in model.parameters() if p.ndim >= 2], "use_muon": True},
    {"params": [p for p in model.parameters() if p.ndim < 2], "use_muon": False},
]
optimizer = FusedMuon(param_groups, lr=0.02, momentum=0.95, ns_steps=5)

Standalone Newton-Schulz Function

from muon_fused import fused_newton_schulz

# Use in any custom optimizer
# Handles: bf16 cast, normalization, transpose (if m > n), multi-step iteration
X_ortho = fused_newton_schulz(gradient_matrix, steps=5)

⚙️ Requirements

Requirement Version
NVIDIA GPU SM80+ (A100, A800, H100, H200)
PyTorch ≥ 2.0
CUDA Toolkit ≥ 11.8
CUTLASS Included as git submodule (header-only)

Installation

git clone --recursive https://github.com/StarrickLiu/fused-muon.git
cd fused-muon
pip install -e .

If CUTLASS submodule is missing:

git submodule update --init --recursive

🧪 Testing

pytest tests/ -v

📈 Reproduce Benchmarks

python benchmarks/bench_ns_step.py      # NS step breakdown
python benchmarks/train_cifar10.py      # CIFAR-10 training
python benchmarks/plot_results.py       # Generate figures

Important: For reliable results, use a dedicated GPU with no other processes. GPU DVFS causes significant measurement variance on shared machines.


📚 Citation

@software{fused_muon,
  title  = {Fused Muon: CUDA-Optimized Newton-Schulz Iteration for the Muon Optimizer},
  author = {Xingchen Liu},
  year   = {2025},
  url    = {https://github.com/StarrickLiu/fused-muon}
}

🙏 Acknowledgments

  • Muon optimizer by Keller Jordan — the original Muon algorithm
  • CUTLASS & CuTe by NVIDIA — tensor core abstraction
  • Moonlight by Moonshot AI — distributed Muon implementation reference

License

Apache 2.0

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors