Skip to content

[XLA] InternalError: mixed floating-point precision in TanhGrad during XLA compilation — f32 multiply sees mismatched types despite no explicit mixed precision #43052

@jasminetrail

Description

@jasminetrail

TensorFlow XLA compilation fails during gradient lowering when a float16 activation flows through ResizeBilinear, which promotes tensors to float32 internally. The forward pass succeeds in both eager and XLA modes, and eager gradients also succeed. However, XLA gradient compilation fails with:

tensorflow.python.framework.errors_impl.InternalError: during context [Unknown]: Seen floating point types of different precisions in %TanhGrad.9 = f32[5,4,8]{2,1,0} multiply(%unstack.5, %TanhGrad.8), metadata={op_type="TanhGrad" op_name="gradient_tape/TanhGrad" source_file="/home/test/.venv/lib/python3.13/site-packages/tensorflow/python/framework/ops.py" source_line=1221}, but mixed precision is disallowed. [Op:__inference_grad_fn_29]

This appears to happen because:

  • tf.nn.tanh output is float16
  • ResizeBilinear promotes downstream tensors to float32
  • during backprop, TanhGrad receives:
    • upstream gradient (dy) in float32
    • saved forward activation (y) in float16

XLA then emits a mixed-precision multiply instead of inserting an explicit cast.

Observed behavior:

  • eager forward: OK
  • eager gradient: OK
  • XLA forward: OK
  • XLA gradient compilation: FAILS

Minimal Code to reproduce:

import tensorflow as tf
print(tf.__version__)

x = tf.random.uniform(
    [5, 4, 8],
    minval=-1.0,
    maxval=1.0,
    dtype=tf.float16,
)

@tf.function(jit_compile=True)
def grad_fn(x):
    with tf.GradientTape() as tape:
        tape.watch(x)
        y = tf.nn.tanh(x)
        y = tf.stack([y, tf.sign(x)], axis=3)
        # fp16 -> fp32 transition happens here
        y = tf.raw_ops.ResizeBilinear(
            images=y,
            size=tf.constant([4, 8], tf.int32),
        )
        loss = tf.reduce_sum(y)
    return tape.gradient(loss, x)

print(grad_fn(x))

Error logs

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1779385617.959124 1655238 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
I0000 00:00:1779385617.985273 1655238 cpu_feature_guard.cc:227] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1779385618.691055 1655238 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2.21.0
I0000 00:00:1779385618.915190 1655238 gpu_device.cc:2043] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22107 MB memory: -> device: 0, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:01:00.0, compute capability: 8.9
E0000 00:00:1779385619.071301 1655238 util.cc:131] oneDNN supports DT_HALF only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.
I0000 00:00:1779385619.129295 1655238 service.cc:153] XLA service 0x98e2b80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1779385619.129308 1655238 service.cc:161] StreamExecutor [0]: NVIDIA GeForce RTX 4090, Compute Capability 8.9 (Driver: 13.0.0; Runtime: 12.9.0; Toolkit: 12.5.0; DNN: 9.19.0)
I0000 00:00:1779385619.131926 1655238 dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var MLIR_CRASH_REPRODUCER_DIRECTORY to enable.
W0000 00:00:1779385619.138950 1655238 op_kernel.cc:1858] OP_REQUIRES failed at xla_ops.cc:602 : INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %TanhGrad.9 = f32[5,4,8]{2,1,0} multiply(%unstack.5, %TanhGrad.8), metadata={op_type="TanhGrad" op_name="gradient_tape/TanhGrad" source_file="/home/test/.venv/lib/python3.13/site-packages/tensorflow/python/framework/ops.py" source_line=1221}, but mixed precision is disallowed.
Traceback (most recent call last):
File "/home/test/bugs/tf/crash_5e89acba.py", line 31, in
print(grad_fn(x))
~~~~~~~^^^
File "/home/test/.venv/lib/python3.13/site-packages/tensorflow/python/util/traceback_utils.py", line 167, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/test/.venv/lib/python3.13/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
except TypeError as e:
...<5 lines>...
raise e
tensorflow.python.framework.errors_impl.InternalError: during context [Unknown]: Seen floating point types of different precisions in %TanhGrad.9 = f32[5,4,8]{2,1,0} multiply(%unstack.5, %TanhGrad.8), metadata={op_type="TanhGrad" op_name="gradient_tape/TanhGrad" source_file="/home/test/.venv/lib/python3.13/site-packages/tensorflow/python/framework/ops.py" source_line=1221}, but mixed precision is disallowed. [Op:__inference_grad_fn_29]

Versions
Tensorflow-vesion ; 2.21.0

Metadata

Metadata

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions