Skip to content

tilde-research/comp-muon-release

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tilde Research

Partner-whitened steepest descent for the QK and OV attention circuits. [blog post]

Compositional Muon (CM) extends Muon from single-matrix steepest descent to the composed operators a transformer actually applies — the QK product $M = W_Q W_K^\top$ and the OV product $W_O W_V$. Each factor's gradient is whitened by its partner's inverse Gram root before the spectral sign and scaled by it again afterward, so the step size for each matrix adapts to the geometry of its partner.

How it works

The loss sees $W_Q$ and $W_K$ only through $M = W_Q W_K^\top$. Constraining the operator norm of the composed update $\Delta M$ and splitting the budget equally between the two factors gives the partner-whitened half-split rule

$$\Delta W_Q = -\tfrac{\eta}{2},\mathrm{msign}\left(G_Q C_K^{-1}\right) C_K^{-1}, \qquad \Delta W_K = -\tfrac{\eta}{2},\mathrm{msign}\left(G_K C_Q^{-1}\right) C_Q^{-1},$$

with $C_K = (W_K^\top W_K + \lambda I)^{1/2}$ and symmetrically $C_Q$. The same construction applies to the OV product $W_O W_V$ (with $W_V$ per-head and $W_O$ per-matrix). When each partner Gram is near-isotropic the inverse root collapses to a scalar $C^{-1} \approx c^{-1} I$, recovering a cheap per-head dynamic learning rate.

Usage

Requires torch. cm_ov / cm_qk take the attention weights in nn.Linear convention plus caller-managed momentum buffers, and apply one CM update in place:

from compositional_muon import cm_ov, cm_qk

# attn has q_proj/k_proj/v_proj/o_proj as bias-free nn.Linear layers; the momentum
# buffers (zeros at init) are caller-managed, one per weight.
cm_qk(attn.q_proj.weight, attn.k_proj.weight,
      attn.q_proj.weight.grad, attn.k_proj.weight.grad,
      m_q, m_k, head_dim=attn.head_dim, eta=lr)

cm_ov(attn.v_proj.weight, attn.o_proj.weight,
      attn.v_proj.weight.grad, attn.o_proj.weight.grad,
      m_v, m_o, head_dim=attn.head_dim, eta=lr)

CM governs only the attention QK and OV pairs; update the other parameters with your optimizer of choice. src/main.py is a runnable demo (a small transformer trained with CM on attention and Muon on the rest).

Variants

argument default values description
method "half_split" "half_split", "joint" split the budget per factor, or one shared spectral sign over the stacked factors
isotropic False False, True full matrix partner whitening, or its per-head scalar approximation
hybrid (OV) True True, False $W_O$ per-matrix spectral sign with $W_V$ per-head, or both per-head
whitening "both" "both", "pre", "post", "none" which side(s) of the partner whitening to apply
connection "none" "none", "frobenius", "scale_aware", "frobenius_scalar", "scale_aware_scalar" gauge fix removing the vertical (gauge) component of the update
momentum_reproject False False, True project the momentum onto the horizontal (gauge-fixed) bundle
per_mat_renorm False False, True restore each leg to its pre-whiten Frobenius norm

License

Apache 2.0

About

Compositional Muon release

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages