Skip to content

nn.LayerNorm two-pass variance produces ~27% more HLO reduce ops under vmap + grad #1196

Description

@CatchemAL

eqx.nn.LayerNorm uses two-pass variance (jnp.mean + jnp.var). Under vmap + jax.grad, this produces ~27% more HLO reduce ops and runs 1.45x slower than an identical model using jax.nn.standardize (one-pass variance: E[x²] - E[x]²).

Benchmark

8-block residual MLP, DIM=64, BATCH=256, SEQ=32, double vmap (batch × seq):

Variant Time HLO lines Reduces Broadcasts
(a) eqx.nn.LayerNorm 49.7 ms 1385 132 342
(b) LayerNormNew (drop-in) 34.2 ms 1011 104 238

Full script here to show timings.

Algorithm

eqx.nn.LayerNorm.__call__ (in _normalisation.py) computes:

mean = jnp.mean(x, keepdims=True)
variance = jnp.var(x, keepdims=True)        # internally: mean((x - mean)²)
variance = jnp.maximum(0.0, variance)
inv = jax.lax.rsqrt(variance + self.eps)
out = (x - mean) * inv

jnp.var is two-pass, jax.nn.standardize uses one-pass variance:

mean = jnp.mean(x, axis, keepdims=True)
variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean)

Suggested fix

In _normalisation.py, replace the two-pass variance:

mean = jnp.mean(x, keepdims=True)
variance = jnp.var(x, keepdims=True)
variance = jnp.maximum(0.0, variance)
inv = jax.lax.rsqrt(variance + self.eps)
out = (x - mean) * inv

with jax.nn.standardize:

out = jax.nn.standardize(x, axis=range(len(x.shape)), epsilon=self.eps)

This is a one-line change. jax.nn.standardize already handles the clipping (jnp.clip(variance, 0)), so jnp.maximum is no longer needed.

Confirmation of Correctness

Here is an suite of 50 unit tests confirming identical outputs - just faster.

Versions

  • equinox 0.13.4
  • jax 0.9.0
  • CPU (no GPU)

I would be happy to prepare a PR if useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions