Skip to content

equinox.error_if hangs on error if one of the arguments has a multi-device replicated sharding #1232

Description

@Gattocrucco

In eager mode, equinox.error_if(x, cond, msg) hangs forever in an XLA all-reduce rendezvous when the truthy cond (or x) lives on a NamedSharding whose mesh has more than one device, even though the value is fully replicated.

Expected behavior: raise EquinoxRuntimeError(msg) exactly as it does for a single-device array.

Minimal reproducer

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'

import jax
import jax.numpy as jnp
import equinox

mesh = jax.sharding.Mesh(jax.devices('cpu'), ('x',))
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

x = jax.device_put(jnp.array(1.0), replicated)
cond = jax.device_put(jnp.array(True), replicated)

try:
    equinox.error_if(x, cond, 'err')
except equinox.EquinoxRuntimeError:
    pass
E0519 23:08:20.320073 52455925 rendezvous.cc:116] [id=0] This thread has been waiting for `all reduce RendezvousKey{run_id=RunId: -1375800151, global_devices=[0, 1], num_local_participants=2, collective_op_kind=cross_module, op_id=1}` for 20 seconds and may be stuck. Expected 2 threads to join the rendezvous, but not all of them arrived on time.
E0519 23:09:00.321819 52455925 rendezvous.cc:41] All thread stack traces (if supported on this platform):
E0519 23:09:00.321926 52455925 rendezvous.cc:43] 
F0519 23:09:00.321934 52455925 rendezvous.cc:161] [id=0] Termination timeout for `all reduce RendezvousKey{run_id=RunId: -1375800151, global_devices=[0, 1], num_local_participants=2, collective_op_kind=cross_module, op_id=1}` of 40 seconds exceeded. Exiting to ensure a consistent program state. Expected 2 threads to join the rendezvous, but only 1 of them arrived on time.
*** Check failure stack trace: ***
    @        0x11d42a900  absl::lts_20260107::log_internal::LogMessage::Flush()
    @        0x11fd725c0  xla::internal::AwaitAndLogIfStuck()
    @        0x11fd5c86c  xla::cpu::InProcessCommunicator::AllReduce()
    @        0x11e1157ac  absl::lts_20260107::internal_any_invocable::RemoteInvoker<>()
    @        0x11e99e1ac  xla::cpu::CollectiveThunk::ExecuteWithCommunicator()
    @        0x11e115080  xla::cpu::AllReduceThunk::Execute()
    @        0x11fd50188  xla::cpu::ThunkExecutor::ExecuteSequential()
    @        0x11e119b0c  xla::cpu::ConditionalThunk::Execute()
    @        0x11fd50188  xla::cpu::ThunkExecutor::ExecuteSequential()
    @        0x11d6b3f3c  xla::CpuPjRtRawLoadedExecutable::Execute()::$_0::operator()()
    @        0x11d6bb8a0  absl::lts_20260107::internal_any_invocable::RemoteInvoker<>()
    @        0x11d6f3304  std::__1::__function::__func<>::operator()()
    @        0x120996e80  Eigen::ThreadPoolTempl<>::WorkerLoop()
    @        0x11d300100  Eigen::ThreadPoolTempl<>::WorkerLoop()
    @        0x11d2ffef8  std::__1::invoke[abi:nn190102]<>()
    @        0x11d2f0f90  tsl::(anonymous namespace)::PThread::ThreadFn()
    @        0x183587c58  _pthread_start
    @        0x183582c1c  thread_start

Observations:

  • Placing only x (and leaving cond unsharded) or only cond on the multi-device replicated sharding is sufficient to trigger the hang.
  • With cond=False (no error to raise) the call returns normally.
  • With a single-device mesh (--xla_force_host_platform_device_count=1) the call raises as expected.

Related

Note

Written by Opus 4.7, edited by a human.

System info

jax:    0.10.0
jaxlib: 0.10.0
numpy:  2.4.6
python: 3.14.0 (main, Oct 14 2025, 21:10:22) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', release='25.4.0', version='Darwin Kernel Version 25.4.0: Thu Mar 19 19:30:44 PDT 2026; root:xnu-12377.101.15~1/RELEASE_ARM64_T6000', machine='arm64')
equinox: 0.13.6

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions