Partner-whitened steepest descent for the QK and OV attention circuits. [blog post]
Compositional Muon (CM) extends Muon from single-matrix steepest descent to the
composed operators a transformer actually applies — the QK product
The loss sees
with
Requires torch. cm_ov / cm_qk take the attention weights in nn.Linear
convention plus caller-managed momentum buffers, and apply one CM update in place:
from compositional_muon import cm_ov, cm_qk
# attn has q_proj/k_proj/v_proj/o_proj as bias-free nn.Linear layers; the momentum
# buffers (zeros at init) are caller-managed, one per weight.
cm_qk(attn.q_proj.weight, attn.k_proj.weight,
attn.q_proj.weight.grad, attn.k_proj.weight.grad,
m_q, m_k, head_dim=attn.head_dim, eta=lr)
cm_ov(attn.v_proj.weight, attn.o_proj.weight,
attn.v_proj.weight.grad, attn.o_proj.weight.grad,
m_v, m_o, head_dim=attn.head_dim, eta=lr)CM governs only the attention QK and OV pairs; update the other parameters with
your optimizer of choice. src/main.py is a runnable demo (a small transformer
trained with CM on attention and Muon on the rest).
| argument | default | values | description |
|---|---|---|---|
method |
"half_split" |
"half_split", "joint"
|
split the budget per factor, or one shared spectral sign over the stacked factors |
isotropic |
False |
False, True
|
full matrix partner whitening, or its per-head scalar approximation |
hybrid (OV) |
True |
True, False
|
|
whitening |
"both" |
"both", "pre", "post", "none"
|
which side(s) of the partner whitening to apply |
connection |
"none" |
"none", "frobenius", "scale_aware", "frobenius_scalar", "scale_aware_scalar"
|
gauge fix removing the vertical (gauge) component of the update |
momentum_reproject |
False |
False, True
|
project the momentum onto the horizontal (gauge-fixed) bundle |
per_mat_renorm |
False |
False, True
|
restore each leg to its pre-whiten Frobenius norm |
Apache 2.0
