Skip to content

# XLA FMA contraction silently changes a*x + b*x semantics, producing results many orders of magnitude wrong #41579

@wuyii8941

Description

@wuyii8941

Summary

XLA compiles a*x + b*x into an FMA (fused multiply-add) instruction fma(a, x, b*x). FMA computes the first product a*x in extended precision without intermediate rounding, then adds b*x (which is rounded). When a ≈ -b and the products are large, this inconsistent rounding causes catastrophic error — the XLA result can be many orders of magnitude away from both the non-JIT result and the mathematically correct answer.

Without JIT, both products are computed with standard float32 rounding, yielding identical intermediate values that cancel correctly to 0.

Reproduced in both TensorFlow (jit_compile=True) and JAX (jax.jit), confirming this is an XLA-level issue.

Reproduction

TensorFlow:

import tensorflow as tf

a = tf.constant(1e20, dtype=tf.float32)
b = tf.constant(-1e20, dtype=tf.float32)
x = tf.constant(1e10, dtype=tf.float32)

print((a * x + b * x).numpy())
# 0.0  ← correct (eager)

print(tf.function(lambda a, b, x: a * x + b * x, jit_compile=True)(a, b, x).numpy())
# 4.993411122843311e+21  ← wrong (XLA)

JAX:

import jax
import jax.numpy as jnp

a = jnp.float32(1e20)
b = jnp.float32(-1e20)
x = jnp.float32(1e10)

print(float(a * x + b * x))
# 0.0  ← correct (eager)

print(float(jax.jit(lambda a, b, x: a * x + b * x)(a, b, x)))
# 4.993411122843311e+21  ← wrong (jit)

Both CPU and GPU XLA produce the same wrong result.

Scale of the problem

The bug triggers whenever a ≈ -b and |a*x| exceeds float32's precision range. It affects a wide range of magnitudes:

a x Non-JIT (correct) XLA JIT (wrong)
1e10 1e10 0.0 -2.00e+12
1e14 1e10 0.0 -1.01e+16
1e18 1e10 0.0 4.01e+20
1e20 1e10 0.0 4.99e+21
1e24 1e10 0.0 -2.71e+26
1e28 1e10 0.0 -2.38e+30

All equivalent patterns are affected:

  • a*x + b*x (common factor x, b = -a)
  • a*x + a*y (common factor a, y = -x)
  • a*x - b*x (b = a)

Note: a*x - a*x (literally the same variable twice) is correctly optimized to 0 by XLA — the bug only occurs when the two products are computed from distinct (but equal-magnitude) operands.

Root cause

XLA's FMA contraction fuses a*x + b*x into fma(a, x, b*x):

Standard float32:  round(a*x) + round(b*x) = 1.000000015e30 + (-1.000000015e30) = 0.0
FMA instruction:   exact(a*x) + round(b*x) = 1.000000020e30 + (-1.000000015e30) = 4.993e21

The FMA instruction computes a*x without intermediate rounding (in extended precision), but b*x is computed separately with standard rounding. This asymmetry means the two terms no longer cancel, exposing a ~5e21 residual that is pure rounding artifact.

This is a known issue with FMA contraction in compilers — LLVM has a -ffp-contract flag to control it, and both GCC and Clang default to not contracting across statements for this reason. The IEEE 754 standard notes that FMA contraction can change results.

Suggested fix

XLA should either:

  1. Disable FMA contraction for addition of independent multiplies — only use FMA when the user explicitly requests it, or
  2. Contract both multiplies symmetrically — if fusing a*x + b*x, ensure both products use the same precision level, or
  3. Provide a user-facing flag to control FMA contraction behavior (similar to -ffp-contract=off)

Environment

  • TensorFlow 2.22.0-dev20260421 (nightly), JAX 0.10.0
  • Both CPU and GPU XLA affected
  • Non-JIT execution is NOT affected

Metadata

Metadata

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions