Sync current stream on empty-offsets short-circuit in length_per_key (#4239)#4239
Open
kaanbaloglu wants to merge 2 commits into
Open
Sync current stream on empty-offsets short-circuit in length_per_key (#4239)#4239kaanbaloglu wants to merge 2 commits into
kaanbaloglu wants to merge 2 commits into
Conversation
Contributor
|
@kaanbaloglu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104351493. |
kaanbaloglu
added a commit
to kaanbaloglu/torchrec
that referenced
this pull request
May 11, 2026
…eta-pytorch#4239) Summary: Follow-up to D103719817. Closes the same rank-divergence / cross-PG deadlock window in the offsets-only branch of `_maybe_compute_length_per_key`. When `len(offsets) == 1` the branch's `torch.diff(offsets)` is empty: previously it fell through to either `_length_per_key_from_stride_per_key` (which returns `[]` from its empty-segment short-circuit without syncing) or to `_safe_tolist(torch.sum(... .view(-1, stride), dim=1))` on an empty tensor. Both paths skip GPU synchronization while peer ranks taking the non-empty diff path implicitly sync inside `_safe_tolist`, leaving the empty-diff ranks free to race ahead into the next collective. D103719817 explicitly flagged this with a TODO: routing through the lengths helper would change the return value from `[]` to `[0] * len(keys)`, and the lengths helper already syncs on the empty branch. Verified that no caller depends on the `[]` return — searched every consumer of `KeyedJaggedTensor.length_per_key()` and `_maybe_compute_length_per_key`, plus the test suite. `_maybe_compute_offset_per_key:1448-1469` still produces a mathematically consistent result (`[0]*(N+1)` offsets in both the `_cumsum` and the non-strict-export `asynchronous_complete_cumsum` paths) and several existing call sites — `dist_splits` at jagged_tensor.py:3202, `KeyedTensor` construction, `zip(keys, length_per_key)` patterns — silently assume the new shape and are currently latently broken with the `[]` return. `test_empty_to_dict` already documents `[0, 0]` as the contractually correct shape for the 2-key empty case. Changes: - `_maybe_compute_length_per_key` offsets-path: detect `len(offsets) == 1` before computing `torch.diff(offsets)`. On that path, sync the current CUDA stream when `offsets.is_cuda` and return `[0] * len(keys)`. Avoids allocating the empty diff tensor and matches the lengths-path empty-branch contract. TorchScript still bypasses the sync (script-mode divergence is pre-existing and out of scope, same as the lengths helper). - Removes the TODO comment on the offsets branch. - Tightens the type of `_count_current_stream_syncs._patched`'s `device` parameter from `object` to `torch.device | int | str | None` so it matches `torch.cuda.current_stream`'s signature (pyre fix; the helper was added in D103719817). Differential Revision: D104351493
c58d3e0 to
0649444
Compare
kaanbaloglu
added a commit
to kaanbaloglu/torchrec
that referenced
this pull request
May 11, 2026
…eta-pytorch#4239) Summary: Follow-up to D103719817. Closes the same rank-divergence / cross-PG deadlock window in the offsets-only branch of `_maybe_compute_length_per_key`. When `len(offsets) == 1` the branch's `torch.diff(offsets)` is empty: previously it fell through to either `_length_per_key_from_stride_per_key` (which returns `[]` from its empty-segment short-circuit without syncing) or to `_safe_tolist(torch.sum(... .view(-1, stride), dim=1))` on an empty tensor. Both paths skip GPU synchronization while peer ranks taking the non-empty diff path implicitly sync inside `_safe_tolist`, leaving the empty-diff ranks free to race ahead into the next collective. D103719817 explicitly flagged this with a TODO: routing through the lengths helper would change the return value from `[]` to `[0] * len(keys)`, and the lengths helper already syncs on the empty branch. Verified that no caller depends on the `[]` return — searched every consumer of `KeyedJaggedTensor.length_per_key()` and `_maybe_compute_length_per_key`, plus the test suite. `_maybe_compute_offset_per_key:1448-1469` still produces a mathematically consistent result (`[0]*(N+1)` offsets in both the `_cumsum` and the non-strict-export `asynchronous_complete_cumsum` paths) and several existing call sites — `dist_splits` at jagged_tensor.py:3202, `KeyedTensor` construction, `zip(keys, length_per_key)` patterns — silently assume the new shape and are currently latently broken with the `[]` return. `test_empty_to_dict` already documents `[0, 0]` as the contractually correct shape for the 2-key empty case. Changes: - `_maybe_compute_length_per_key` offsets-path: detect `len(offsets) == 1` before computing `torch.diff(offsets)`. On that path, sync the current CUDA stream when `offsets.is_cuda` and return `[0] * len(keys)`. Avoids allocating the empty diff tensor and matches the lengths-path empty-branch contract. TorchScript still bypasses the sync (script-mode divergence is pre-existing and out of scope, same as the lengths helper). - Removes the TODO comment on the offsets branch. - Tightens the type of `_count_current_stream_syncs._patched`'s `device` parameter from `object` to `torch.device | int | str | None` so it matches `torch.cuda.current_stream`'s signature (pyre fix; the helper was added in D103719817). Differential Revision: D104351493
0649444 to
8fc156d
Compare
…eta-pytorch#4219) Summary: The `[0] * len(keys)` early-return branch in `_maybe_compute_length_per_key` skipped GPU synchronization entirely, while the sibling `.tolist()` path (taken when `lengths.numel() != 0`) implicitly synced the current CUDA stream as part of materializing the result. Ranks taking the empty short-circuit therefore raced ahead of ranks taking the `.tolist()` path, leaving the next collective waiting on a rank that had already moved on. This is one possible upstream trigger for the cross-PG deadlocks observed in `mvai-training-online-2130305043` (rank 7 stuck inside `KJTAllToAllTensorsAwaitable.__init__` calling `dist.all_to_all_single` at `torchrec/distributed/dist_data.py:407`, while ranks 0–6 stuck in `_maybe_compute_length_per_key`'s `.tolist()` at `torchrec/sparse/jagged_tensor.py:1305`). Same divergence exists at the `_safe_tolist` `numel() == 0` short-circuit added by D101235037 — the plain `.tolist()` it replaced still synced even on an empty CUDA tensor (going through `tensor.cpu()` to materialize), but the new short-circuit returned `[]` immediately with no sync. Changes: - New `_length_per_key_from_lengths(lengths, len_keys, stride)` helper in `torchrec/sparse/jagged_tensor.py`. Bundles the empty-lengths early-out with the GPU compute path. Empty branch (`numel == 0`) syncs the current CUDA stream then returns `[0] * len_keys`. Non-empty branch is the existing `view → sum → _safe_tolist` pipeline. TorchScript bypasses the sync (script-mode divergence is pre-existing and out of scope). - `_maybe_compute_length_per_key`'s lengths-path (the case from the production trace) is collapsed to a single ternary that calls the new helper. - `_safe_tolist`: removes the `if tensor.numel() == 0: return []` short-circuit. Empty CUDA tensors now fall through to `_cuda_to_cpu_safe(tensor).tolist()` — the empty CPU buffer returns `[]`, but the path now syncs the current stream first. Behavior is unchanged on the return value, but the wait-set now matches the non-empty path. - TODO added on the `_maybe_compute_length_per_key` offsets-path: it has the same shape of latent divergence (`torch.diff(offsets)` can be empty when `len(offsets) == 1`), but routing it through the helper would change its return value from `[]` to `[0] * len(keys)` for that edge case. Out of scope for this diff; flagged for follow-up once we audit consumers. Differential Revision: D103719817
…eta-pytorch#4239) Summary: Follow-up to D103719817. Closes the same rank-divergence / cross-PG deadlock window in the offsets-only branch of `_maybe_compute_length_per_key`. When `len(offsets) == 1` the branch's `torch.diff(offsets)` is empty: previously it fell through to either `_length_per_key_from_stride_per_key` (which returns `[]` from its empty-segment short-circuit without syncing) or to `_safe_tolist(torch.sum(... .view(-1, stride), dim=1))` on an empty tensor. Both paths skip GPU synchronization while peer ranks taking the non-empty diff path implicitly sync inside `_safe_tolist`, leaving the empty-diff ranks free to race ahead into the next collective. D103719817 explicitly flagged this with a TODO: routing through the lengths helper would change the return value from `[]` to `[0] * len(keys)`, and the lengths helper already syncs on the empty branch. Verified that no caller depends on the `[]` return — searched every consumer of `KeyedJaggedTensor.length_per_key()` and `_maybe_compute_length_per_key`, plus the test suite. `_maybe_compute_offset_per_key:1448-1469` still produces a mathematically consistent result (`[0]*(N+1)` offsets in both the `_cumsum` and the non-strict-export `asynchronous_complete_cumsum` paths) and several existing call sites — `dist_splits` at jagged_tensor.py:3202, `KeyedTensor` construction, `zip(keys, length_per_key)` patterns — silently assume the new shape and are currently latently broken with the `[]` return. `test_empty_to_dict` already documents `[0, 0]` as the contractually correct shape for the 2-key empty case. Changes: - `_maybe_compute_length_per_key` offsets-path: detect `len(offsets) == 1` before computing `torch.diff(offsets)`. On that path, sync the current CUDA stream when `offsets.is_cuda` and return `[0] * len(keys)`. Avoids allocating the empty diff tensor and matches the lengths-path empty-branch contract. TorchScript still bypasses the sync (script-mode divergence is pre-existing and out of scope, same as the lengths helper). - Removes the TODO comment on the offsets branch. - Tightens the type of `_count_current_stream_syncs._patched`'s `device` parameter from `object` to `torch.device | int | str | None` so it matches `torch.cuda.current_stream`'s signature (pyre fix; the helper was added in D103719817). Differential Revision: D104351493
8fc156d to
3c0cf47
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Follow-up to D103719817. Closes the same rank-divergence / cross-PG deadlock window in the offsets-only branch of
_maybe_compute_length_per_key. Whenlen(offsets) == 1the branch'storch.diff(offsets)is empty: previously it fell through to either_length_per_key_from_stride_per_key(which returns[]from its empty-segment short-circuit without syncing) or to_safe_tolist(torch.sum(... .view(-1, stride), dim=1))on an empty tensor. Both paths skip GPU synchronization while peer ranks taking the non-empty diff path implicitly sync inside_safe_tolist, leaving the empty-diff ranks free to race ahead into the next collective.D103719817 explicitly flagged this with a TODO: routing through the lengths helper would change the return value from
[]to[0] * len(keys), and the lengths helper already syncs on the empty branch. Verified that no caller depends on the[]return — searched every consumer ofKeyedJaggedTensor.length_per_key()and_maybe_compute_length_per_key, plus the test suite._maybe_compute_offset_per_key:1448-1469still produces a mathematically consistent result ([0]*(N+1)offsets in both the_cumsumand the non-strict-exportasynchronous_complete_cumsumpaths) and several existing call sites —dist_splitsat jagged_tensor.py:3202,KeyedTensorconstruction,zip(keys, length_per_key)patterns — silently assume the new shape and are currently latently broken with the[]return.test_empty_to_dictalready documents[0, 0]as the contractually correct shape for the 2-key empty case.Changes:
_maybe_compute_length_per_keyoffsets-path: detectlen(offsets) == 1before computingtorch.diff(offsets). On that path, sync the current CUDA stream whenoffsets.is_cudaand return[0] * len(keys). Avoids allocating the empty diff tensor and matches the lengths-path empty-branch contract. TorchScript still bypasses the sync (script-mode divergence is pre-existing and out of scope, same as the lengths helper)._count_current_stream_syncs._patched'sdeviceparameter fromobjecttotorch.device | int | str | Noneso it matchestorch.cuda.current_stream's signature (pyre fix; the helper was added in D103719817).Differential Revision: D104351493