Summary
XLA compiles a*x + b*x into an FMA (fused multiply-add) instruction fma(a, x, b*x). FMA computes the first product a*x in extended precision without intermediate rounding, then adds b*x (which is rounded). When a ≈ -b and the products are large, this inconsistent rounding causes catastrophic error — the XLA result can be many orders of magnitude away from both the non-JIT result and the mathematically correct answer.
Without JIT, both products are computed with standard float32 rounding, yielding identical intermediate values that cancel correctly to 0.
Reproduced in both TensorFlow (jit_compile=True) and JAX (jax.jit), confirming this is an XLA-level issue.
Reproduction
TensorFlow:
import tensorflow as tf
a = tf.constant(1e20, dtype=tf.float32)
b = tf.constant(-1e20, dtype=tf.float32)
x = tf.constant(1e10, dtype=tf.float32)
print((a * x + b * x).numpy())
# 0.0 ← correct (eager)
print(tf.function(lambda a, b, x: a * x + b * x, jit_compile=True)(a, b, x).numpy())
# 4.993411122843311e+21 ← wrong (XLA)
JAX:
import jax
import jax.numpy as jnp
a = jnp.float32(1e20)
b = jnp.float32(-1e20)
x = jnp.float32(1e10)
print(float(a * x + b * x))
# 0.0 ← correct (eager)
print(float(jax.jit(lambda a, b, x: a * x + b * x)(a, b, x)))
# 4.993411122843311e+21 ← wrong (jit)
Both CPU and GPU XLA produce the same wrong result.
Scale of the problem
The bug triggers whenever a ≈ -b and |a*x| exceeds float32's precision range. It affects a wide range of magnitudes:
| a |
x |
Non-JIT (correct) |
XLA JIT (wrong) |
| 1e10 |
1e10 |
0.0 |
-2.00e+12 |
| 1e14 |
1e10 |
0.0 |
-1.01e+16 |
| 1e18 |
1e10 |
0.0 |
4.01e+20 |
| 1e20 |
1e10 |
0.0 |
4.99e+21 |
| 1e24 |
1e10 |
0.0 |
-2.71e+26 |
| 1e28 |
1e10 |
0.0 |
-2.38e+30 |
All equivalent patterns are affected:
a*x + b*x (common factor x, b = -a)
a*x + a*y (common factor a, y = -x)
a*x - b*x (b = a)
Note: a*x - a*x (literally the same variable twice) is correctly optimized to 0 by XLA — the bug only occurs when the two products are computed from distinct (but equal-magnitude) operands.
Root cause
XLA's FMA contraction fuses a*x + b*x into fma(a, x, b*x):
Standard float32: round(a*x) + round(b*x) = 1.000000015e30 + (-1.000000015e30) = 0.0
FMA instruction: exact(a*x) + round(b*x) = 1.000000020e30 + (-1.000000015e30) = 4.993e21
The FMA instruction computes a*x without intermediate rounding (in extended precision), but b*x is computed separately with standard rounding. This asymmetry means the two terms no longer cancel, exposing a ~5e21 residual that is pure rounding artifact.
This is a known issue with FMA contraction in compilers — LLVM has a -ffp-contract flag to control it, and both GCC and Clang default to not contracting across statements for this reason. The IEEE 754 standard notes that FMA contraction can change results.
Suggested fix
XLA should either:
- Disable FMA contraction for addition of independent multiplies — only use FMA when the user explicitly requests it, or
- Contract both multiplies symmetrically — if fusing
a*x + b*x, ensure both products use the same precision level, or
- Provide a user-facing flag to control FMA contraction behavior (similar to
-ffp-contract=off)
Environment
- TensorFlow 2.22.0-dev20260421 (nightly), JAX 0.10.0
- Both CPU and GPU XLA affected
- Non-JIT execution is NOT affected
Summary
XLA compiles
a*x + b*xinto an FMA (fused multiply-add) instructionfma(a, x, b*x). FMA computes the first producta*xin extended precision without intermediate rounding, then addsb*x(which is rounded). Whena ≈ -band the products are large, this inconsistent rounding causes catastrophic error — the XLA result can be many orders of magnitude away from both the non-JIT result and the mathematically correct answer.Without JIT, both products are computed with standard float32 rounding, yielding identical intermediate values that cancel correctly to 0.
Reproduced in both TensorFlow (
jit_compile=True) and JAX (jax.jit), confirming this is an XLA-level issue.Reproduction
TensorFlow:
JAX:
Both CPU and GPU XLA produce the same wrong result.
Scale of the problem
The bug triggers whenever
a ≈ -band|a*x|exceeds float32's precision range. It affects a wide range of magnitudes:All equivalent patterns are affected:
a*x + b*x(common factor x, b = -a)a*x + a*y(common factor a, y = -x)a*x - b*x(b = a)Note:
a*x - a*x(literally the same variable twice) is correctly optimized to 0 by XLA — the bug only occurs when the two products are computed from distinct (but equal-magnitude) operands.Root cause
XLA's FMA contraction fuses
a*x + b*xintofma(a, x, b*x):The FMA instruction computes
a*xwithout intermediate rounding (in extended precision), butb*xis computed separately with standard rounding. This asymmetry means the two terms no longer cancel, exposing a ~5e21 residual that is pure rounding artifact.This is a known issue with FMA contraction in compilers — LLVM has a
-ffp-contractflag to control it, and both GCC and Clang default to not contracting across statements for this reason. The IEEE 754 standard notes that FMA contraction can change results.Suggested fix
XLA should either:
a*x + b*x, ensure both products use the same precision level, or-ffp-contract=off)Environment