In eager mode, equinox.error_if(x, cond, msg) hangs forever in an XLA all-reduce rendezvous when the truthy cond (or x) lives on a NamedSharding whose mesh has more than one device, even though the value is fully replicated.
Expected behavior: raise EquinoxRuntimeError(msg) exactly as it does for a single-device array.
Minimal reproducer
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
import jax
import jax.numpy as jnp
import equinox
mesh = jax.sharding.Mesh(jax.devices('cpu'), ('x',))
replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
x = jax.device_put(jnp.array(1.0), replicated)
cond = jax.device_put(jnp.array(True), replicated)
try:
equinox.error_if(x, cond, 'err')
except equinox.EquinoxRuntimeError:
pass
E0519 23:08:20.320073 52455925 rendezvous.cc:116] [id=0] This thread has been waiting for `all reduce RendezvousKey{run_id=RunId: -1375800151, global_devices=[0, 1], num_local_participants=2, collective_op_kind=cross_module, op_id=1}` for 20 seconds and may be stuck. Expected 2 threads to join the rendezvous, but not all of them arrived on time.
E0519 23:09:00.321819 52455925 rendezvous.cc:41] All thread stack traces (if supported on this platform):
E0519 23:09:00.321926 52455925 rendezvous.cc:43]
F0519 23:09:00.321934 52455925 rendezvous.cc:161] [id=0] Termination timeout for `all reduce RendezvousKey{run_id=RunId: -1375800151, global_devices=[0, 1], num_local_participants=2, collective_op_kind=cross_module, op_id=1}` of 40 seconds exceeded. Exiting to ensure a consistent program state. Expected 2 threads to join the rendezvous, but only 1 of them arrived on time.
*** Check failure stack trace: ***
@ 0x11d42a900 absl::lts_20260107::log_internal::LogMessage::Flush()
@ 0x11fd725c0 xla::internal::AwaitAndLogIfStuck()
@ 0x11fd5c86c xla::cpu::InProcessCommunicator::AllReduce()
@ 0x11e1157ac absl::lts_20260107::internal_any_invocable::RemoteInvoker<>()
@ 0x11e99e1ac xla::cpu::CollectiveThunk::ExecuteWithCommunicator()
@ 0x11e115080 xla::cpu::AllReduceThunk::Execute()
@ 0x11fd50188 xla::cpu::ThunkExecutor::ExecuteSequential()
@ 0x11e119b0c xla::cpu::ConditionalThunk::Execute()
@ 0x11fd50188 xla::cpu::ThunkExecutor::ExecuteSequential()
@ 0x11d6b3f3c xla::CpuPjRtRawLoadedExecutable::Execute()::$_0::operator()()
@ 0x11d6bb8a0 absl::lts_20260107::internal_any_invocable::RemoteInvoker<>()
@ 0x11d6f3304 std::__1::__function::__func<>::operator()()
@ 0x120996e80 Eigen::ThreadPoolTempl<>::WorkerLoop()
@ 0x11d300100 Eigen::ThreadPoolTempl<>::WorkerLoop()
@ 0x11d2ffef8 std::__1::invoke[abi:nn190102]<>()
@ 0x11d2f0f90 tsl::(anonymous namespace)::PThread::ThreadFn()
@ 0x183587c58 _pthread_start
@ 0x183582c1c thread_start
Observations:
- Placing only
x (and leaving cond unsharded) or only cond on the multi-device replicated sharding is sufficient to trigger the hang.
- With
cond=False (no error to raise) the call returns normally.
- With a single-device mesh (
--xla_force_host_platform_device_count=1) the call raises as expected.
Related
Note
Written by Opus 4.7, edited by a human.
System info
jax: 0.10.0
jaxlib: 0.10.0
numpy: 2.4.6
python: 3.14.0 (main, Oct 14 2025, 21:10:22) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', release='25.4.0', version='Darwin Kernel Version 25.4.0: Thu Mar 19 19:30:44 PDT 2026; root:xnu-12377.101.15~1/RELEASE_ARM64_T6000', machine='arm64')
equinox: 0.13.6
In eager mode,
equinox.error_if(x, cond, msg)hangs forever in an XLA all-reduce rendezvous when the truthycond(orx) lives on aNamedShardingwhose mesh has more than one device, even though the value is fully replicated.Expected behavior: raise
EquinoxRuntimeError(msg)exactly as it does for a single-device array.Minimal reproducer
Observations:
x(and leavingcondunsharded) or onlycondon the multi-device replicated sharding is sufficient to trigger the hang.cond=False(no error to raise) the call returns normally.--xla_force_host_platform_device_count=1) the call raises as expected.Related
error_ifand sharding #939: raises the same overall topic oferror_ifinteracting with multi-device shardings, but focuses on the all-reduce overhead in the no-error case rather than the eager-mode hang reported here.Note
Written by Opus 4.7, edited by a human.
System info