diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 47926f1b6..d6dfd50eb 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1370,13 +1370,39 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no ), ) + sharded_tensor_metadata = sharding_spec.build_metadata( + tensor_sizes=self._name_to_table_size[table_name], + tensor_properties=tensor_properties, + ) + + # Use global_rank: torch core's assert in + # ShardedTensor._init_from_local_shards_and_global_metadata + # compares placement.rank() against dist.get_rank() (global). + if isinstance(self._env, ShardingEnv2D): + my_global_rank = self._env.global_rank + expected_for_this_rank = sum( + 1 + for sm in sharded_tensor_metadata.shards_metadata + if sm.placement is not None + and sm.placement.rank() == my_global_rank + ) + if len(local_shards) != expected_for_this_rank: + logger.warning( + "[2d-sharding-diag] local_shards/metadata mismatch: " + f"table={table_name} " + f"local_shards={len(local_shards)} " + f"expected_for_this_rank={expected_for_this_rank} " + f"total_metadata_shards={len(sharded_tensor_metadata.shards_metadata)} " + f"global_rank={my_global_rank} " + f"sharding_pg_rank={self._env.rank} " + f"sharding_pg_size={self._env.world_size} " + f"global_world_size={self._env.global_world_size}" + ) + self._model_parallel_name_to_sharded_tensor[table_name] = ( ShardedTensor._init_from_local_shards_and_global_metadata( local_shards=local_shards, - sharded_tensor_metadata=sharding_spec.build_metadata( - tensor_sizes=self._name_to_table_size[table_name], - tensor_properties=tensor_properties, - ), + sharded_tensor_metadata=sharded_tensor_metadata, process_group=( self._env.sharding_pg if isinstance(self._env, ShardingEnv2D) diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index 897080060..c91aff26f 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -8,6 +8,7 @@ # pyre-strict import itertools +import logging import math from typing import Any, cast, Dict, List, Optional, Tuple, TypeVar @@ -64,6 +65,8 @@ T = TypeVar("T") W = TypeVar("W") +logger: logging.Logger = logging.getLogger(__name__) + class BaseTwRwEmbeddingSharding(EmbeddingSharding[C, F, T, W]): """ @@ -139,13 +142,25 @@ def _shard( peer_group = get_process_group_ranks(self._pg) if self._is_2D_parallel else None for info in sharding_infos: # Under 2D parallelism we transform rank to the logical ordering in a regular parallelism scheme - rank = ( - # pyrefly: ignore[unsupported-operation] - peer_group.index(info.param_sharding.ranks[0]) - if peer_group is not None - # pyrefly: ignore[unsupported-operation] - else info.param_sharding.ranks[0] - ) + # pyrefly: ignore[unsupported-operation] + planner_rank = info.param_sharding.ranks[0] + if peer_group is not None: + pg_members: List[int] = peer_group + try: + rank = pg_members.index(planner_rank) + except ValueError: + logger.warning( + "[2d-sharding-diag] peer_group.index() failed: " + f"table={info.embedding_config.name} " + f"planner_rank={planner_rank} " + f"peer_group_size={len(pg_members)} " + f"peer_group_head={pg_members[:8]} " + f"sharding_pg_size={self._world_size} " + f"global_world_size={dist.get_world_size()}" + ) + raise + else: + rank = planner_rank table_node = rank // local_size # pyrefly: ignore[missing-attribute] shards = info.param_sharding.sharding_spec.shards