diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index b791df3eb..4c49914d7 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -2097,6 +2097,29 @@ def __init__( super().__init__() self._workers: int = pg.size() + # Object-pool lengths AllToAll. The received lengths are written + # verbatim into ShardedKeyedJaggedTensorPool._key_lengths, so a + # corrupted / desynced AllToAll (e.g. int32-overflowed splits at small + # world sizes) silently poisons the pool and only crashes much later in + # lookup as a negative dimension in jagged_index_select. Mirror the + # int32-overflow guards already on the embedding AllToAll paths; these + # are a no-op unless TORCHREC_OVERFLOW_DEBUG=1. See T273509522 / + # T118141711. + lengths_input_splits = _safe_tolist(num_items_to_send) + lengths_output_splits = _safe_tolist(num_items_to_receive) + _check_int_overflow( + "JaggedTensorAllToAll", + lengths_input_splits, + "lengths input_splits (num_items_to_send, BEFORE AllToAll)", + workers=self._workers, + ) + _check_int_overflow( + "JaggedTensorAllToAll", + lengths_output_splits, + "lengths output_splits (num_items_to_receive, BEFORE AllToAll)", + workers=self._workers, + ) + self._dist_lengths: torch.Tensor = torch.empty( sum(num_items_to_receive), device=jt.lengths().device, @@ -2106,12 +2129,28 @@ def __init__( dist.all_to_all_single( self._dist_lengths, jt.lengths(), - output_split_sizes=_safe_tolist(num_items_to_receive), - input_split_sizes=_safe_tolist(num_items_to_send), + output_split_sizes=lengths_output_splits, + input_split_sizes=lengths_input_splits, group=pg, async_op=False, ) + # The lengths AllToAll above is blocking, so self._dist_lengths is + # populated here. Surface out-of-range received lengths (negative or + # overflowed) at the source rather than as a downstream negative-dim + # crash. Gated on the debug env var so the min/max device sync only runs + # when explicitly debugging, keeping the default path sync-free. + if _TORCHREC_OVERFLOW_DEBUG and self._dist_lengths.numel() > 0: + _check_int_overflow( + "JaggedTensorAllToAll", + [ + int(self._dist_lengths.min()), + int(self._dist_lengths.max()), + ], + "received lengths min/max (AFTER AllToAll)", + workers=self._workers, + ) + # below will calculate chunks sums e.g. # num_batches_to_receive = [2,2] # lengths = [2,3,1,1] @@ -2126,6 +2165,12 @@ def __init__( self._dist_lengths, ) ) + _check_int_overflow( + "JaggedTensorAllToAll", + value_output_splits, + "value output_splits (BEFORE AllToAll)", + workers=self._workers, + ) self._dist_values: torch.Tensor = torch.empty( sum(value_output_splits), @@ -2142,6 +2187,12 @@ def __init__( jt.lengths(), ) ) + _check_int_overflow( + "JaggedTensorAllToAll", + value_input_splits, + "value input_splits (BEFORE AllToAll)", + workers=self._workers, + ) # pyrefly: ignore self._dist_values_req: dist.Work = dist.all_to_all_single( diff --git a/torchrec/modules/object_pool_lookups.py b/torchrec/modules/object_pool_lookups.py index d033309f3..c2a69eda7 100644 --- a/torchrec/modules/object_pool_lookups.py +++ b/torchrec/modules/object_pool_lookups.py @@ -25,6 +25,52 @@ torch.fx.wrap("jagged_index_select_with_empty") +def _assert_valid_key_lengths( + key_lengths: torch.Tensor, + feature_max_lengths: torch.Tensor, +) -> None: + """Fail fast if KJT-pool key lengths are out of range before they are stored. + + The row-wise ``ShardedKeyedJaggedTensorPool.update()`` path receives lengths + via an all-to-all (``torchrec.distributed.dist_data.JaggedTensorAllToAll``) + and writes them verbatim into ``_key_lengths``. A corrupted or desynced + all-to-all silently yields out-of-range lengths (e.g. +/-2**31 from an int32 + overflow at small world sizes), which only blow up much later in ``lookup()`` + as a negative dimension in ``jagged_index_select_2d_forward_v2``. Validating + here converts that silent corruption into an immediate, actionable error at + the write site. + + This runs only in ``update()`` (training-time pool refresh), never in the + scripted inference ``lookup()`` path. See T273509522 (and the 2022 precedent + T118141711). + + Args: + key_lengths (torch.Tensor): 2D ``(num_ids, num_features)`` tensor of + per-feature lengths about to be written into ``_key_lengths``. + feature_max_lengths (torch.Tensor): 1D ``(num_features,)`` tensor of the + configured maximum length for each feature. + """ + if key_lengths.numel() == 0: + return + # Single device->host sync: lengths must be non-negative and within the + # configured per-feature maxima (an invariant already enforced on the input + # side by _update_preproc, so this never fires in a healthy run). + valid = torch.logical_and( + (key_lengths >= 0).all(), + (key_lengths <= feature_max_lengths).all(), + ) + if not bool(valid): + raise RuntimeError( + "ShardedKeyedJaggedTensorPool.update received out-of-range key " + "lengths after the update all-to-all: " + f"min={int(key_lengths.min())}, " + f"per-feature max={key_lengths.max(dim=0).values.tolist()}, " + f"configured per-feature max={feature_max_lengths.tolist()}. This " + "indicates a corrupted or desynced update all-to-all (negative or " + "overflowed lengths), not a usage error. See T273509522." + ) + + class KeyedJaggedTensorPoolLookup(abc.ABC, torch.nn.Module): """ Abstract base class for KeyedJaggedTensor pool lookups @@ -55,6 +101,7 @@ class KeyedJaggedTensorPoolLookup(abc.ABC, torch.nn.Module): _is_weighted: bool _total_lengths: int _total_lengths_t: torch.Tensor + _feature_max_lengths_t: torch.Tensor _key_lengths: torch.Tensor _jagged_lengths: torch.Tensor _jagged_offsets: torch.Tensor @@ -75,6 +122,13 @@ def __init__( self._total_lengths_t = torch.tensor( [self._total_lengths], device=device, dtype=torch.int32 ) + # Per-feature max lengths, used to validate post-all-to-all lengths in + # update() before they are written into _key_lengths (see T273509522). + self._feature_max_lengths_t = torch.tensor( + list(self._feature_max_lengths.values()), + device=device, + dtype=torch.int32, + ) self._is_weighted = is_weighted self._key_lengths = torch.zeros( @@ -254,11 +308,10 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor: def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: with record_function("## TensorPool update ##"): - key_lengths = ( - values.lengths().view(-1, len(self._feature_max_lengths)) - # pyrefly: ignore[no-matching-overload] - .sum(axis=1) - ) + key_lengths_2d = values.lengths().view(-1, len(self._feature_max_lengths)) + _assert_valid_key_lengths(key_lengths_2d, self._feature_max_lengths_t) + # pyrefly: ignore[no-matching-overload] + key_lengths = key_lengths_2d.sum(axis=1) key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths) padded_values = torch.ops.fbgemm.jagged_to_padded_dense( @@ -269,11 +322,7 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: ) self._values[ids] = padded_values.to(self._values.dtype) - self._key_lengths[ids] = ( - values.lengths() - .view(-1, len(self._feature_max_lengths)) - .to(self._key_lengths.dtype) - ) + self._key_lengths[ids] = key_lengths_2d.to(self._key_lengths.dtype) if values.weights_or_none() is not None: padded_weights = torch.ops.fbgemm.jagged_to_padded_dense( @@ -403,11 +452,10 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor: def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: with record_function("## UVMCachingInt64Lookup update ##"): - key_lengths = ( - values.lengths().view(-1, len(self._feature_max_lengths)) - # pyrefly: ignore[no-matching-overload] - .sum(axis=1) - ) + key_lengths_2d = values.lengths().view(-1, len(self._feature_max_lengths)) + _assert_valid_key_lengths(key_lengths_2d, self._feature_max_lengths_t) + # pyrefly: ignore[no-matching-overload] + key_lengths = key_lengths_2d.sum(axis=1) key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths) padded_values: torch.Tensor = torch.ops.fbgemm.jagged_to_padded_dense( values.values(), @@ -425,11 +473,7 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: self._tbe_state[ids] = state - self._key_lengths[ids] = ( - values.lengths() - .view(-1, len(self._feature_max_lengths)) - .to(self._key_lengths.dtype) - ) + self._key_lengths[ids] = key_lengths_2d.to(self._key_lengths.dtype) def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: yield "values_upper_and_lower_bits", self._tbe_state @@ -542,11 +586,10 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor: def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: with record_function("## UVMCachingInt32Lookup update##"): - key_lengths = ( - values.lengths().view(-1, len(self._feature_max_lengths)) - # pyrefly: ignore[no-matching-overload] - .sum(axis=1) - ) + key_lengths_2d = values.lengths().view(-1, len(self._feature_max_lengths)) + _assert_valid_key_lengths(key_lengths_2d, self._feature_max_lengths_t) + # pyrefly: ignore[no-matching-overload] + key_lengths = key_lengths_2d.sum(axis=1) key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths) state = torch.ops.fbgemm.jagged_to_padded_dense( values.values(), @@ -557,11 +600,7 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None: self._tbe_state[ids] = state - self._key_lengths[ids] = ( - values.lengths() - .view(-1, len(self._feature_max_lengths)) - .to(self._key_lengths.dtype) - ) + self._key_lengths[ids] = key_lengths_2d.to(self._key_lengths.dtype) def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]: yield "values", self._tbe_state diff --git a/torchrec/modules/tests/test_keyed_jagged_tensor_pool.py b/torchrec/modules/tests/test_keyed_jagged_tensor_pool.py index 041b95441..d05c9fe03 100644 --- a/torchrec/modules/tests/test_keyed_jagged_tensor_pool.py +++ b/torchrec/modules/tests/test_keyed_jagged_tensor_pool.py @@ -10,7 +10,11 @@ import torch from torchrec.modules.keyed_jagged_tensor_pool import KeyedJaggedTensorPool -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.modules.object_pool_lookups import ( + _assert_valid_key_lengths, + TensorJaggedIndexSelectLookup, +) +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor class KeyedJaggedTensorPoolTest(unittest.TestCase): @@ -366,3 +370,94 @@ def test_empty_lookup( kjt.lengths().cpu(), torch.tensor([], dtype=torch.int, device=torch.device("cpu")), ) + + +class KeyedJaggedTensorPoolLengthValidationTest(unittest.TestCase): + """Regression coverage for T273509522. + + A corrupted / desynced row-wise update all-to-all delivers out-of-range key + lengths (negative or +/-2**31 overflow) that used to be written silently + into ``_key_lengths`` and only crash much later in ``lookup()`` as a + negative dimension in ``jagged_index_select``. ``update()`` must now reject + them at the write site. + """ + + def test_assert_valid_key_lengths_accepts_valid(self) -> None: + feature_max_lengths_t = torch.tensor([2, 4], dtype=torch.int32) + # In-range lengths (including 0 and the per-feature max) must pass. + _assert_valid_key_lengths( + torch.tensor([[1, 3], [2, 0], [0, 4]], dtype=torch.int32), + feature_max_lengths_t, + ) + # An empty batch is a no-op (e.g. a rank that received nothing). + _assert_valid_key_lengths( + torch.zeros((0, 2), dtype=torch.int32), feature_max_lengths_t + ) + + def test_assert_valid_key_lengths_rejects_negative(self) -> None: + feature_max_lengths_t = torch.tensor([2, 4], dtype=torch.int32) + with self.assertRaisesRegex(RuntimeError, "out-of-range key lengths"): + _assert_valid_key_lengths( + torch.tensor([[1, -7], [2, 0]], dtype=torch.int32), + feature_max_lengths_t, + ) + + def test_assert_valid_key_lengths_rejects_over_max(self) -> None: + feature_max_lengths_t = torch.tensor([2, 4], dtype=torch.int32) + # f1 length 5 exceeds its configured max of 2. + with self.assertRaisesRegex(RuntimeError, "out-of-range key lengths"): + _assert_valid_key_lengths( + torch.tensor([[5, 1]], dtype=torch.int32), feature_max_lengths_t + ) + + def test_assert_valid_key_lengths_rejects_overflow_magnitude(self) -> None: + # The actual bug signature: lengths at the +/-2**31 int32 boundary. + feature_max_lengths_t = torch.tensor([2, 4], dtype=torch.int32) + with self.assertRaisesRegex(RuntimeError, "out-of-range key lengths"): + _assert_valid_key_lengths( + torch.tensor([[2_115_462_454, -2_133_429_463]], dtype=torch.int32), + feature_max_lengths_t, + ) + + def test_update_rejects_corrupted_lengths(self) -> None: + device = ( + torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda:0") + ) + lookup = TensorJaggedIndexSelectLookup( + pool_size=4, + values_dtype=torch.int64, + feature_max_lengths={"f1": 2, "f2": 4}, + is_weighted=False, + device=device, + ) + # One id, two features: f1 length 5 exceeds its configured max of 2 - + # exactly the kind of out-of-range length a corrupted update all-to-all + # would deliver. update() must reject it instead of storing garbage. + corrupted = JaggedTensor( + values=torch.arange(6, dtype=torch.int64, device=device), + lengths=torch.tensor([5, 1], dtype=torch.int, device=device), + ) + with self.assertRaisesRegex(RuntimeError, "out-of-range key lengths"): + lookup.update(torch.tensor([0], device=device), corrupted) + + def test_update_accepts_valid_lengths(self) -> None: + device = ( + torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda:0") + ) + lookup = TensorJaggedIndexSelectLookup( + pool_size=4, + values_dtype=torch.int64, + feature_max_lengths={"f1": 2, "f2": 4}, + is_weighted=False, + device=device, + ) + # In-range update must not raise (guards against false positives). + valid = JaggedTensor( + values=torch.arange(3, dtype=torch.int64, device=device), + lengths=torch.tensor([2, 1], dtype=torch.int, device=device), + ) + lookup.update(torch.tensor([0], device=device), valid)