Skip to content

Surface 2D-sharding asymmetry root cause at the boundary#4310

Open
kaanbaloglu wants to merge 1 commit into
meta-pytorch:mainfrom
kaanbaloglu:export-D106892622
Open

Surface 2D-sharding asymmetry root cause at the boundary#4310
kaanbaloglu wants to merge 1 commit into
meta-pytorch:mainfrom
kaanbaloglu:export-D106892622

Conversation

@kaanbaloglu

Copy link
Copy Markdown
Contributor

Summary:
EmbeddingBagCollectionSharder.shard fires the Number of local shards (0) does not match the length of sharded_tensor_metadata assertion ~20K times / 28d on torchrec_event_logging, 99% concentrated in two IG models (ig_organic_feed_mtml, ig_reels_tab_esr_ttsn). Static analysis points at two candidate root causes in the 2D-parallel path:

  1. TwRwBaseSharding._shard at twrw_sharding.py:142-148 calls peer_group.index(info.param_sharding.ranks[0]) to translate the planner's rank value into a local-to-sharding-PG index. Under non-contiguous 2D sharding-PG group layouts (comm.py:230-310), the planner's rank value may not exist in the peer_group list, causing a silent ValueError to escape.
  2. ShardedEmbeddingBagCollection._initialize_torch_state at embeddingbag.py:1373-1386 then calls ShardedTensor._init_from_local_shards_and_global_metadata(local_shards, sharded_tensor_metadata, process_group=sharding_pg). When local_shards came back empty due to (1), the assertion fires inside torch/distributed/_shard/sharded_tensor/api.py:971 with a generic "0 does not match N" message.

This diff adds two boundary diagnostics so the failure carries actionable context:

  • twrw_sharding.py: wrap peer_group.index() in a try/except that logs the table name, planner rank, peer-group size, peer-group head, sharding-PG size, and global world size before re-raising. Only fires on ValueError; healthy 2D shards see no log.
  • embeddingbag.py: just before the ShardedTensor construction in the ShardingEnv2D branch, compute the per-rank expected shard count from sharded_tensor_metadata.shards_metadata and log a warning if it disagrees with len(local_shards). Healthy ShardingEnv2D paths emit one comparison per table that produces no log; mismatched paths emit one warning per table before the torch core assert fires.

No behavior change for healthy jobs or non-2D jobs. No new exception types. Diagnostic-only — intended to be the first step of a three-step plan: (1) land diagnostic, (2) canary on one of the two affected IG models, (3) land the targeted fix based on which signal fires.

Once we have the canary signal:

  • If the twrw warning fires, the fix is to drop peer_group.index() under 2D — the planner-rank is already the local index. ~5 LOC.
  • If only the embeddingbag warning fires, the fix is to filter shards by peer_group membership and rebuild the metadata. Larger, touches both files.

The embeddingbag warning is worth keeping permanently as a guard against future regressions; the twrw warning can be removed once the underlying assignment is fixed.

Differential Revision: D106892622

Summary:
`EmbeddingBagCollectionSharder.shard` fires the `Number of local shards (0) does not match the length of sharded_tensor_metadata` assertion ~20K times / 28d on `torchrec_event_logging`, 99% concentrated in two IG models (`ig_organic_feed_mtml`, `ig_reels_tab_esr_ttsn`). Static analysis points at two candidate root causes in the 2D-parallel path:

1. `TwRwBaseSharding._shard` at `twrw_sharding.py:142-148` calls `peer_group.index(info.param_sharding.ranks[0])` to translate the planner's rank value into a local-to-sharding-PG index. Under non-contiguous 2D sharding-PG group layouts (`comm.py:230-310`), the planner's rank value may not exist in the peer_group list, causing a silent `ValueError` to escape.
2. `ShardedEmbeddingBagCollection._initialize_torch_state` at `embeddingbag.py:1373-1386` then calls `ShardedTensor._init_from_local_shards_and_global_metadata(local_shards, sharded_tensor_metadata, process_group=sharding_pg)`. When `local_shards` came back empty due to (1), the assertion fires inside `torch/distributed/_shard/sharded_tensor/api.py:971` with a generic "0 does not match N" message.

This diff adds two boundary diagnostics so the failure carries actionable context:

- `twrw_sharding.py`: wrap `peer_group.index()` in a try/except that logs the table name, planner rank, peer-group size, peer-group head, sharding-PG size, and global world size before re-raising. Only fires on ValueError; healthy 2D shards see no log.
- `embeddingbag.py`: just before the `ShardedTensor` construction in the ShardingEnv2D branch, compute the per-rank expected shard count from `sharded_tensor_metadata.shards_metadata` and log a warning if it disagrees with `len(local_shards)`. Healthy ShardingEnv2D paths emit one comparison per table that produces no log; mismatched paths emit one warning per table before the torch core assert fires.

No behavior change for healthy jobs or non-2D jobs. No new exception types. Diagnostic-only — intended to be the first step of a three-step plan: (1) land diagnostic, (2) canary on one of the two affected IG models, (3) land the targeted fix based on which signal fires.

Once we have the canary signal:
- If the twrw warning fires, the fix is to drop `peer_group.index()` under 2D — the planner-rank is already the local index. ~5 LOC.
- If only the embeddingbag warning fires, the fix is to filter shards by peer_group membership and rebuild the metadata. Larger, touches both files.

The embeddingbag warning is worth keeping permanently as a guard against future regressions; the twrw warning can be removed once the underlying assignment is fixed.

Differential Revision: D106892622
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 3, 2026
@meta-codesync

meta-codesync Bot commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

@kaanbaloglu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D106892622.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant