From ed1e01438efe1ee5a4a6cbab583b991f7fc0509b Mon Sep 17 00:00:00 2001 From: Kaan Baloglu Date: Tue, 2 Jun 2026 17:59:38 -0700 Subject: [PATCH] Surface 2D-sharding asymmetry root cause at the boundary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: `EmbeddingBagCollectionSharder.shard` fires the `Number of local shards (0) does not match the length of sharded_tensor_metadata` assertion ~20K times / 28d on `torchrec_event_logging`, 99% concentrated in two IG models (`ig_organic_feed_mtml`, `ig_reels_tab_esr_ttsn`). Static analysis points at two candidate root causes in the 2D-parallel path: 1. `TwRwBaseSharding._shard` at `twrw_sharding.py:142-148` calls `peer_group.index(info.param_sharding.ranks[0])` to translate the planner's rank value into a local-to-sharding-PG index. Under non-contiguous 2D sharding-PG group layouts (`comm.py:230-310`), the planner's rank value may not exist in the peer_group list, causing a silent `ValueError` to escape. 2. `ShardedEmbeddingBagCollection._initialize_torch_state` at `embeddingbag.py:1373-1386` then calls `ShardedTensor._init_from_local_shards_and_global_metadata(local_shards, sharded_tensor_metadata, process_group=sharding_pg)`. When `local_shards` came back empty due to (1), the assertion fires inside `torch/distributed/_shard/sharded_tensor/api.py:971` with a generic "0 does not match N" message. This diff adds two boundary diagnostics so the failure carries actionable context: - `twrw_sharding.py`: wrap `peer_group.index()` in a try/except that logs the table name, planner rank, peer-group size, peer-group head, sharding-PG size, and global world size before re-raising. Only fires on ValueError; healthy 2D shards see no log. - `embeddingbag.py`: just before the `ShardedTensor` construction in the ShardingEnv2D branch, compute the per-rank expected shard count from `sharded_tensor_metadata.shards_metadata` and log a warning if it disagrees with `len(local_shards)`. Healthy ShardingEnv2D paths emit one comparison per table that produces no log; mismatched paths emit one warning per table before the torch core assert fires. No behavior change for healthy jobs or non-2D jobs. No new exception types. Diagnostic-only — intended to be the first step of a three-step plan: (1) land diagnostic, (2) canary on one of the two affected IG models, (3) land the targeted fix based on which signal fires. Once we have the canary signal: - If the twrw warning fires, the fix is to drop `peer_group.index()` under 2D — the planner-rank is already the local index. ~5 LOC. - If only the embeddingbag warning fires, the fix is to filter shards by peer_group membership and rebuild the metadata. Larger, touches both files. The embeddingbag warning is worth keeping permanently as a guard against future regressions; the twrw warning can be removed once the underlying assignment is fixed. Differential Revision: D106892622 --- torchrec/distributed/embeddingbag.py | 34 ++++++++++++++++--- .../distributed/sharding/twrw_sharding.py | 29 ++++++++++++---- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 47926f1b6..d6dfd50eb 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1370,13 +1370,39 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no ), ) + sharded_tensor_metadata = sharding_spec.build_metadata( + tensor_sizes=self._name_to_table_size[table_name], + tensor_properties=tensor_properties, + ) + + # Use global_rank: torch core's assert in + # ShardedTensor._init_from_local_shards_and_global_metadata + # compares placement.rank() against dist.get_rank() (global). + if isinstance(self._env, ShardingEnv2D): + my_global_rank = self._env.global_rank + expected_for_this_rank = sum( + 1 + for sm in sharded_tensor_metadata.shards_metadata + if sm.placement is not None + and sm.placement.rank() == my_global_rank + ) + if len(local_shards) != expected_for_this_rank: + logger.warning( + "[2d-sharding-diag] local_shards/metadata mismatch: " + f"table={table_name} " + f"local_shards={len(local_shards)} " + f"expected_for_this_rank={expected_for_this_rank} " + f"total_metadata_shards={len(sharded_tensor_metadata.shards_metadata)} " + f"global_rank={my_global_rank} " + f"sharding_pg_rank={self._env.rank} " + f"sharding_pg_size={self._env.world_size} " + f"global_world_size={self._env.global_world_size}" + ) + self._model_parallel_name_to_sharded_tensor[table_name] = ( ShardedTensor._init_from_local_shards_and_global_metadata( local_shards=local_shards, - sharded_tensor_metadata=sharding_spec.build_metadata( - tensor_sizes=self._name_to_table_size[table_name], - tensor_properties=tensor_properties, - ), + sharded_tensor_metadata=sharded_tensor_metadata, process_group=( self._env.sharding_pg if isinstance(self._env, ShardingEnv2D) diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index 897080060..c91aff26f 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -8,6 +8,7 @@ # pyre-strict import itertools +import logging import math from typing import Any, cast, Dict, List, Optional, Tuple, TypeVar @@ -64,6 +65,8 @@ T = TypeVar("T") W = TypeVar("W") +logger: logging.Logger = logging.getLogger(__name__) + class BaseTwRwEmbeddingSharding(EmbeddingSharding[C, F, T, W]): """ @@ -139,13 +142,25 @@ def _shard( peer_group = get_process_group_ranks(self._pg) if self._is_2D_parallel else None for info in sharding_infos: # Under 2D parallelism we transform rank to the logical ordering in a regular parallelism scheme - rank = ( - # pyrefly: ignore[unsupported-operation] - peer_group.index(info.param_sharding.ranks[0]) - if peer_group is not None - # pyrefly: ignore[unsupported-operation] - else info.param_sharding.ranks[0] - ) + # pyrefly: ignore[unsupported-operation] + planner_rank = info.param_sharding.ranks[0] + if peer_group is not None: + pg_members: List[int] = peer_group + try: + rank = pg_members.index(planner_rank) + except ValueError: + logger.warning( + "[2d-sharding-diag] peer_group.index() failed: " + f"table={info.embedding_config.name} " + f"planner_rank={planner_rank} " + f"peer_group_size={len(pg_members)} " + f"peer_group_head={pg_members[:8]} " + f"sharding_pg_size={self._world_size} " + f"global_world_size={dist.get_world_size()}" + ) + raise + else: + rank = planner_rank table_node = rank // local_size # pyrefly: ignore[missing-attribute] shards = info.param_sharding.sharding_spec.shards