Surface 2D-sharding asymmetry root cause at the boundary#4310
Open
kaanbaloglu wants to merge 1 commit into
Open
Surface 2D-sharding asymmetry root cause at the boundary#4310kaanbaloglu wants to merge 1 commit into
kaanbaloglu wants to merge 1 commit into
Conversation
Summary: `EmbeddingBagCollectionSharder.shard` fires the `Number of local shards (0) does not match the length of sharded_tensor_metadata` assertion ~20K times / 28d on `torchrec_event_logging`, 99% concentrated in two IG models (`ig_organic_feed_mtml`, `ig_reels_tab_esr_ttsn`). Static analysis points at two candidate root causes in the 2D-parallel path: 1. `TwRwBaseSharding._shard` at `twrw_sharding.py:142-148` calls `peer_group.index(info.param_sharding.ranks[0])` to translate the planner's rank value into a local-to-sharding-PG index. Under non-contiguous 2D sharding-PG group layouts (`comm.py:230-310`), the planner's rank value may not exist in the peer_group list, causing a silent `ValueError` to escape. 2. `ShardedEmbeddingBagCollection._initialize_torch_state` at `embeddingbag.py:1373-1386` then calls `ShardedTensor._init_from_local_shards_and_global_metadata(local_shards, sharded_tensor_metadata, process_group=sharding_pg)`. When `local_shards` came back empty due to (1), the assertion fires inside `torch/distributed/_shard/sharded_tensor/api.py:971` with a generic "0 does not match N" message. This diff adds two boundary diagnostics so the failure carries actionable context: - `twrw_sharding.py`: wrap `peer_group.index()` in a try/except that logs the table name, planner rank, peer-group size, peer-group head, sharding-PG size, and global world size before re-raising. Only fires on ValueError; healthy 2D shards see no log. - `embeddingbag.py`: just before the `ShardedTensor` construction in the ShardingEnv2D branch, compute the per-rank expected shard count from `sharded_tensor_metadata.shards_metadata` and log a warning if it disagrees with `len(local_shards)`. Healthy ShardingEnv2D paths emit one comparison per table that produces no log; mismatched paths emit one warning per table before the torch core assert fires. No behavior change for healthy jobs or non-2D jobs. No new exception types. Diagnostic-only — intended to be the first step of a three-step plan: (1) land diagnostic, (2) canary on one of the two affected IG models, (3) land the targeted fix based on which signal fires. Once we have the canary signal: - If the twrw warning fires, the fix is to drop `peer_group.index()` under 2D — the planner-rank is already the local index. ~5 LOC. - If only the embeddingbag warning fires, the fix is to filter shards by peer_group membership and rebuild the metadata. Larger, touches both files. The embeddingbag warning is worth keeping permanently as a guard against future regressions; the twrw warning can be removed once the underlying assignment is fixed. Differential Revision: D106892622
Contributor
|
@kaanbaloglu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D106892622. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
EmbeddingBagCollectionSharder.shardfires theNumber of local shards (0) does not match the length of sharded_tensor_metadataassertion ~20K times / 28d ontorchrec_event_logging, 99% concentrated in two IG models (ig_organic_feed_mtml,ig_reels_tab_esr_ttsn). Static analysis points at two candidate root causes in the 2D-parallel path:TwRwBaseSharding._shardattwrw_sharding.py:142-148callspeer_group.index(info.param_sharding.ranks[0])to translate the planner's rank value into a local-to-sharding-PG index. Under non-contiguous 2D sharding-PG group layouts (comm.py:230-310), the planner's rank value may not exist in the peer_group list, causing a silentValueErrorto escape.ShardedEmbeddingBagCollection._initialize_torch_stateatembeddingbag.py:1373-1386then callsShardedTensor._init_from_local_shards_and_global_metadata(local_shards, sharded_tensor_metadata, process_group=sharding_pg). Whenlocal_shardscame back empty due to (1), the assertion fires insidetorch/distributed/_shard/sharded_tensor/api.py:971with a generic "0 does not match N" message.This diff adds two boundary diagnostics so the failure carries actionable context:
twrw_sharding.py: wrappeer_group.index()in a try/except that logs the table name, planner rank, peer-group size, peer-group head, sharding-PG size, and global world size before re-raising. Only fires on ValueError; healthy 2D shards see no log.embeddingbag.py: just before theShardedTensorconstruction in the ShardingEnv2D branch, compute the per-rank expected shard count fromsharded_tensor_metadata.shards_metadataand log a warning if it disagrees withlen(local_shards). Healthy ShardingEnv2D paths emit one comparison per table that produces no log; mismatched paths emit one warning per table before the torch core assert fires.No behavior change for healthy jobs or non-2D jobs. No new exception types. Diagnostic-only — intended to be the first step of a three-step plan: (1) land diagnostic, (2) canary on one of the two affected IG models, (3) land the targeted fix based on which signal fires.
Once we have the canary signal:
peer_group.index()under 2D — the planner-rank is already the local index. ~5 LOC.The embeddingbag warning is worth keeping permanently as a guard against future regressions; the twrw warning can be removed once the underlying assignment is fixed.
Differential Revision: D106892622