eqx.nn.LayerNorm uses two-pass variance (jnp.mean + jnp.var). Under vmap + jax.grad, this produces ~27% more HLO reduce ops and runs 1.45x slower than an identical model using jax.nn.standardize (one-pass variance: E[x²] - E[x]²).
Benchmark
8-block residual MLP, DIM=64, BATCH=256, SEQ=32, double vmap (batch × seq):
| Variant |
Time |
HLO lines |
Reduces |
Broadcasts |
(a) eqx.nn.LayerNorm |
49.7 ms |
1385 |
132 |
342 |
(b) LayerNormNew (drop-in) |
34.2 ms |
1011 |
104 |
238 |
Full script here to show timings.
Algorithm
eqx.nn.LayerNorm.__call__ (in _normalisation.py) computes:
mean = jnp.mean(x, keepdims=True)
variance = jnp.var(x, keepdims=True) # internally: mean((x - mean)²)
variance = jnp.maximum(0.0, variance)
inv = jax.lax.rsqrt(variance + self.eps)
out = (x - mean) * inv
jnp.var is two-pass, jax.nn.standardize uses one-pass variance:
mean = jnp.mean(x, axis, keepdims=True)
variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean)
Suggested fix
In _normalisation.py, replace the two-pass variance:
mean = jnp.mean(x, keepdims=True)
variance = jnp.var(x, keepdims=True)
variance = jnp.maximum(0.0, variance)
inv = jax.lax.rsqrt(variance + self.eps)
out = (x - mean) * inv
with jax.nn.standardize:
out = jax.nn.standardize(x, axis=range(len(x.shape)), epsilon=self.eps)
This is a one-line change. jax.nn.standardize already handles the clipping (jnp.clip(variance, 0)), so jnp.maximum is no longer needed.
Confirmation of Correctness
Here is an suite of 50 unit tests confirming identical outputs - just faster.
Versions
- equinox 0.13.4
- jax 0.9.0
- CPU (no GPU)
I would be happy to prepare a PR if useful.
eqx.nn.LayerNormuses two-pass variance (jnp.mean+jnp.var). Undervmap+jax.grad, this produces ~27% more HLO reduce ops and runs 1.45x slower than an identical model usingjax.nn.standardize(one-pass variance:E[x²] - E[x]²).Benchmark
8-block residual MLP,
DIM=64,BATCH=256,SEQ=32, doublevmap(batch × seq):eqx.nn.LayerNormLayerNormNew(drop-in)Full script here to show timings.
Algorithm
eqx.nn.LayerNorm.__call__(in_normalisation.py) computes:jnp.varis two-pass,jax.nn.standardizeuses one-pass variance:Suggested fix
In
_normalisation.py, replace the two-pass variance:with
jax.nn.standardize:This is a one-line change.
jax.nn.standardizealready handles the clipping (jnp.clip(variance, 0)), sojnp.maximumis no longer needed.Confirmation of Correctness
Here is an suite of 50 unit tests confirming identical outputs - just faster.
Versions
I would be happy to prepare a PR if useful.