Skip to content

Sync current stream on empty-offsets short-circuit in length_per_key (#4239)#4239

Open
kaanbaloglu wants to merge 2 commits into
meta-pytorch:mainfrom
kaanbaloglu:export-D104351493
Open

Sync current stream on empty-offsets short-circuit in length_per_key (#4239)#4239
kaanbaloglu wants to merge 2 commits into
meta-pytorch:mainfrom
kaanbaloglu:export-D104351493

Conversation

@kaanbaloglu

@kaanbaloglu kaanbaloglu commented May 8, 2026

Copy link
Copy Markdown
Contributor

Summary:

Follow-up to D103719817. Closes the same rank-divergence / cross-PG deadlock window in the offsets-only branch of _maybe_compute_length_per_key. When len(offsets) == 1 the branch's torch.diff(offsets) is empty: previously it fell through to either _length_per_key_from_stride_per_key (which returns [] from its empty-segment short-circuit without syncing) or to _safe_tolist(torch.sum(... .view(-1, stride), dim=1)) on an empty tensor. Both paths skip GPU synchronization while peer ranks taking the non-empty diff path implicitly sync inside _safe_tolist, leaving the empty-diff ranks free to race ahead into the next collective.

D103719817 explicitly flagged this with a TODO: routing through the lengths helper would change the return value from [] to [0] * len(keys), and the lengths helper already syncs on the empty branch. Verified that no caller depends on the [] return — searched every consumer of KeyedJaggedTensor.length_per_key() and _maybe_compute_length_per_key, plus the test suite. _maybe_compute_offset_per_key:1448-1469 still produces a mathematically consistent result ([0]*(N+1) offsets in both the _cumsum and the non-strict-export asynchronous_complete_cumsum paths) and several existing call sites — dist_splits at jagged_tensor.py:3202, KeyedTensor construction, zip(keys, length_per_key) patterns — silently assume the new shape and are currently latently broken with the [] return. test_empty_to_dict already documents [0, 0] as the contractually correct shape for the 2-key empty case.

Changes:

  • _maybe_compute_length_per_key offsets-path: detect len(offsets) == 1 before computing torch.diff(offsets). On that path, sync the current CUDA stream when offsets.is_cuda and return [0] * len(keys). Avoids allocating the empty diff tensor and matches the lengths-path empty-branch contract. TorchScript still bypasses the sync (script-mode divergence is pre-existing and out of scope, same as the lengths helper).
  • Removes the TODO comment on the offsets branch.
  • Tightens the type of _count_current_stream_syncs._patched's device parameter from object to torch.device | int | str | None so it matches torch.cuda.current_stream's signature (pyre fix; the helper was added in D103719817).

Differential Revision: D104351493

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 8, 2026
@meta-codesync

meta-codesync Bot commented May 8, 2026

Copy link
Copy Markdown
Contributor

@kaanbaloglu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104351493.

@meta-codesync meta-codesync Bot changed the title Sync current stream on empty-offsets short-circuit in length_per_key Sync current stream on empty-offsets short-circuit in length_per_key (#4239) May 11, 2026
kaanbaloglu added a commit to kaanbaloglu/torchrec that referenced this pull request May 11, 2026
…eta-pytorch#4239)

Summary:

Follow-up to D103719817. Closes the same rank-divergence / cross-PG deadlock window in the offsets-only branch of `_maybe_compute_length_per_key`. When `len(offsets) == 1` the branch's `torch.diff(offsets)` is empty: previously it fell through to either `_length_per_key_from_stride_per_key` (which returns `[]` from its empty-segment short-circuit without syncing) or to `_safe_tolist(torch.sum(... .view(-1, stride), dim=1))` on an empty tensor. Both paths skip GPU synchronization while peer ranks taking the non-empty diff path implicitly sync inside `_safe_tolist`, leaving the empty-diff ranks free to race ahead into the next collective.

D103719817 explicitly flagged this with a TODO: routing through the lengths helper would change the return value from `[]` to `[0] * len(keys)`, and the lengths helper already syncs on the empty branch. Verified that no caller depends on the `[]` return — searched every consumer of `KeyedJaggedTensor.length_per_key()` and `_maybe_compute_length_per_key`, plus the test suite. `_maybe_compute_offset_per_key:1448-1469` still produces a mathematically consistent result (`[0]*(N+1)` offsets in both the `_cumsum` and the non-strict-export `asynchronous_complete_cumsum` paths) and several existing call sites — `dist_splits` at jagged_tensor.py:3202, `KeyedTensor` construction, `zip(keys, length_per_key)` patterns — silently assume the new shape and are currently latently broken with the `[]` return. `test_empty_to_dict` already documents `[0, 0]` as the contractually correct shape for the 2-key empty case.

Changes:
- `_maybe_compute_length_per_key` offsets-path: detect `len(offsets) == 1` before computing `torch.diff(offsets)`. On that path, sync the current CUDA stream when `offsets.is_cuda` and return `[0] * len(keys)`. Avoids allocating the empty diff tensor and matches the lengths-path empty-branch contract. TorchScript still bypasses the sync (script-mode divergence is pre-existing and out of scope, same as the lengths helper).
- Removes the TODO comment on the offsets branch.
- Tightens the type of `_count_current_stream_syncs._patched`'s `device` parameter from `object` to `torch.device | int | str | None` so it matches `torch.cuda.current_stream`'s signature (pyre fix; the helper was added in D103719817).

Differential Revision: D104351493
kaanbaloglu added a commit to kaanbaloglu/torchrec that referenced this pull request May 11, 2026
…eta-pytorch#4239)

Summary:

Follow-up to D103719817. Closes the same rank-divergence / cross-PG deadlock window in the offsets-only branch of `_maybe_compute_length_per_key`. When `len(offsets) == 1` the branch's `torch.diff(offsets)` is empty: previously it fell through to either `_length_per_key_from_stride_per_key` (which returns `[]` from its empty-segment short-circuit without syncing) or to `_safe_tolist(torch.sum(... .view(-1, stride), dim=1))` on an empty tensor. Both paths skip GPU synchronization while peer ranks taking the non-empty diff path implicitly sync inside `_safe_tolist`, leaving the empty-diff ranks free to race ahead into the next collective.

D103719817 explicitly flagged this with a TODO: routing through the lengths helper would change the return value from `[]` to `[0] * len(keys)`, and the lengths helper already syncs on the empty branch. Verified that no caller depends on the `[]` return — searched every consumer of `KeyedJaggedTensor.length_per_key()` and `_maybe_compute_length_per_key`, plus the test suite. `_maybe_compute_offset_per_key:1448-1469` still produces a mathematically consistent result (`[0]*(N+1)` offsets in both the `_cumsum` and the non-strict-export `asynchronous_complete_cumsum` paths) and several existing call sites — `dist_splits` at jagged_tensor.py:3202, `KeyedTensor` construction, `zip(keys, length_per_key)` patterns — silently assume the new shape and are currently latently broken with the `[]` return. `test_empty_to_dict` already documents `[0, 0]` as the contractually correct shape for the 2-key empty case.

Changes:
- `_maybe_compute_length_per_key` offsets-path: detect `len(offsets) == 1` before computing `torch.diff(offsets)`. On that path, sync the current CUDA stream when `offsets.is_cuda` and return `[0] * len(keys)`. Avoids allocating the empty diff tensor and matches the lengths-path empty-branch contract. TorchScript still bypasses the sync (script-mode divergence is pre-existing and out of scope, same as the lengths helper).
- Removes the TODO comment on the offsets branch.
- Tightens the type of `_count_current_stream_syncs._patched`'s `device` parameter from `object` to `torch.device | int | str | None` so it matches `torch.cuda.current_stream`'s signature (pyre fix; the helper was added in D103719817).

Differential Revision: D104351493
…eta-pytorch#4219)

Summary:

The `[0] * len(keys)` early-return branch in `_maybe_compute_length_per_key` skipped GPU synchronization entirely, while the sibling `.tolist()` path (taken when `lengths.numel() != 0`) implicitly synced the current CUDA stream as part of materializing the result. Ranks taking the empty short-circuit therefore raced ahead of ranks taking the `.tolist()` path, leaving the next collective waiting on a rank that had already moved on. This is one possible upstream trigger for the cross-PG deadlocks observed in `mvai-training-online-2130305043` (rank 7 stuck inside `KJTAllToAllTensorsAwaitable.__init__` calling `dist.all_to_all_single` at `torchrec/distributed/dist_data.py:407`, while ranks 0–6 stuck in `_maybe_compute_length_per_key`'s `.tolist()` at `torchrec/sparse/jagged_tensor.py:1305`).

Same divergence exists at the `_safe_tolist` `numel() == 0` short-circuit added by D101235037 — the plain `.tolist()` it replaced still synced even on an empty CUDA tensor (going through `tensor.cpu()` to materialize), but the new short-circuit returned `[]` immediately with no sync.

Changes:
- New `_length_per_key_from_lengths(lengths, len_keys, stride)` helper in `torchrec/sparse/jagged_tensor.py`. Bundles the empty-lengths early-out with the GPU compute path. Empty branch (`numel == 0`) syncs the current CUDA stream then returns `[0] * len_keys`. Non-empty branch is the existing `view → sum → _safe_tolist` pipeline. TorchScript bypasses the sync (script-mode divergence is pre-existing and out of scope).
- `_maybe_compute_length_per_key`'s lengths-path (the case from the production trace) is collapsed to a single ternary that calls the new helper.
- `_safe_tolist`: removes the `if tensor.numel() == 0: return []` short-circuit. Empty CUDA tensors now fall through to `_cuda_to_cpu_safe(tensor).tolist()` — the empty CPU buffer returns `[]`, but the path now syncs the current stream first. Behavior is unchanged on the return value, but the wait-set now matches the non-empty path.
- TODO added on the `_maybe_compute_length_per_key` offsets-path: it has the same shape of latent divergence (`torch.diff(offsets)` can be empty when `len(offsets) == 1`), but routing it through the helper would change its return value from `[]` to `[0] * len(keys)` for that edge case. Out of scope for this diff; flagged for follow-up once we audit consumers.

Differential Revision: D103719817
…eta-pytorch#4239)

Summary:

Follow-up to D103719817. Closes the same rank-divergence / cross-PG deadlock window in the offsets-only branch of `_maybe_compute_length_per_key`. When `len(offsets) == 1` the branch's `torch.diff(offsets)` is empty: previously it fell through to either `_length_per_key_from_stride_per_key` (which returns `[]` from its empty-segment short-circuit without syncing) or to `_safe_tolist(torch.sum(... .view(-1, stride), dim=1))` on an empty tensor. Both paths skip GPU synchronization while peer ranks taking the non-empty diff path implicitly sync inside `_safe_tolist`, leaving the empty-diff ranks free to race ahead into the next collective.

D103719817 explicitly flagged this with a TODO: routing through the lengths helper would change the return value from `[]` to `[0] * len(keys)`, and the lengths helper already syncs on the empty branch. Verified that no caller depends on the `[]` return — searched every consumer of `KeyedJaggedTensor.length_per_key()` and `_maybe_compute_length_per_key`, plus the test suite. `_maybe_compute_offset_per_key:1448-1469` still produces a mathematically consistent result (`[0]*(N+1)` offsets in both the `_cumsum` and the non-strict-export `asynchronous_complete_cumsum` paths) and several existing call sites — `dist_splits` at jagged_tensor.py:3202, `KeyedTensor` construction, `zip(keys, length_per_key)` patterns — silently assume the new shape and are currently latently broken with the `[]` return. `test_empty_to_dict` already documents `[0, 0]` as the contractually correct shape for the 2-key empty case.

Changes:
- `_maybe_compute_length_per_key` offsets-path: detect `len(offsets) == 1` before computing `torch.diff(offsets)`. On that path, sync the current CUDA stream when `offsets.is_cuda` and return `[0] * len(keys)`. Avoids allocating the empty diff tensor and matches the lengths-path empty-branch contract. TorchScript still bypasses the sync (script-mode divergence is pre-existing and out of scope, same as the lengths helper).
- Removes the TODO comment on the offsets branch.
- Tightens the type of `_count_current_stream_syncs._patched`'s `device` parameter from `object` to `torch.device | int | str | None` so it matches `torch.cuda.current_stream`'s signature (pyre fix; the helper was added in D103719817).

Differential Revision: D104351493
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant