Skip to content

Guard KJT-pool update against corrupt A2A lengths (#4327)#4327

Open
hammad45 wants to merge 1 commit into
meta-pytorch:mainfrom
hammad45:export-D107405007
Open

Guard KJT-pool update against corrupt A2A lengths (#4327)#4327
hammad45 wants to merge 1 commit into
meta-pytorch:mainfrom
hammad45:export-D107405007

Conversation

@hammad45

@hammad45 hammad45 commented Jun 8, 2026

Copy link
Copy Markdown

Summary:

Row-wise ShardedKeyedJaggedTensorPool.update() can silently poison _key_lengths at small world sizes: the post-all-to-all lengths from JaggedTensorAllToAll are written verbatim into _key_lengths (object_pool_lookups.py), and a corrupted / desynced all-to-all yields out-of-range (negative or ~±2^31) lengths that only crash much later in the training-forward lookup() as a negative dimension in jagged_index_select_2d_forward_v2. This is hard to debug because the object-pool JaggedTensorAllToAll has none of the int32-overflow guards that the embedding KJTAllToAll / FusedKJTAllToAll paths already have. See T273509522 (and the 2022 twin T118141711).

This diff adds detection + fast-fail:

  • _assert_valid_key_lengths() in object_pool_lookups.py, called in all three KJT-pool update() methods before writing _key_lengths. Always-on, eager-only (mirrors the existing _update_preproc max-length assert); turns silent corruption into a clear error at the write site.
  • _check_int_overflow guards in JaggedTensorAllToAll on the lengths/value split sizes and received-lengths min/max, mirroring the embedding all-to-all paths. No-op unless TORCHREC_OVERFLOW_DEBUG=1, so the default path is unchanged (the min/max device sync is extra-gated behind the env var).

This is detection/hardening, not the root-cause fix: with TORCHREC_OVERFLOW_DEBUG=1 the 2x8 repro can now pinpoint the exact overflowing/mismatched split, after which that split/cumsum can be made int64-safe.

Differential Revision: D107405007

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 8, 2026
@meta-codesync

meta-codesync Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

@hammad45 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D107405007.

Summary:

Row-wise `ShardedKeyedJaggedTensorPool.update()` can silently poison `_key_lengths` at small world sizes: the post-all-to-all lengths from `JaggedTensorAllToAll` are written verbatim into `_key_lengths` (`object_pool_lookups.py`), and a corrupted / desynced all-to-all yields out-of-range (negative or ~±2^31) lengths that only crash much later in the training-forward `lookup()` as a negative dimension in `jagged_index_select_2d_forward_v2`. This is hard to debug because the object-pool `JaggedTensorAllToAll` has none of the int32-overflow guards that the embedding `KJTAllToAll` / `FusedKJTAllToAll` paths already have. See T273509522 (and the 2022 twin T118141711).

This diff adds detection + fast-fail:
- `_assert_valid_key_lengths()` in `object_pool_lookups.py`, called in all three KJT-pool `update()` methods before writing `_key_lengths`. Always-on, eager-only (mirrors the existing `_update_preproc` max-length assert); turns silent corruption into a clear error at the write site.
- `_check_int_overflow` guards in `JaggedTensorAllToAll` on the lengths/value split sizes and received-lengths min/max, mirroring the embedding all-to-all paths. No-op unless `TORCHREC_OVERFLOW_DEBUG=1`, so the default path is unchanged (the min/max device sync is extra-gated behind the env var).

This is detection/hardening, not the root-cause fix: with `TORCHREC_OVERFLOW_DEBUG=1` the 2x8 repro can now pinpoint the exact overflowing/mismatched split, after which that split/cumsum can be made int64-safe.

Differential Revision: D107405007
@meta-codesync meta-codesync Bot changed the title Guard KJT-pool update against corrupt A2A lengths Guard KJT-pool update against corrupt A2A lengths (#4327) Jun 11, 2026
@hammad45 hammad45 force-pushed the export-D107405007 branch from 7e4eecf to c69a036 Compare June 11, 2026 21:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant