Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,6 +2097,29 @@ def __init__(
super().__init__()
self._workers: int = pg.size()

# Object-pool lengths AllToAll. The received lengths are written
# verbatim into ShardedKeyedJaggedTensorPool._key_lengths, so a
# corrupted / desynced AllToAll (e.g. int32-overflowed splits at small
# world sizes) silently poisons the pool and only crashes much later in
# lookup as a negative dimension in jagged_index_select. Mirror the
# int32-overflow guards already on the embedding AllToAll paths; these
# are a no-op unless TORCHREC_OVERFLOW_DEBUG=1. See T273509522 /
# T118141711.
lengths_input_splits = _safe_tolist(num_items_to_send)
lengths_output_splits = _safe_tolist(num_items_to_receive)
_check_int_overflow(
"JaggedTensorAllToAll",
lengths_input_splits,
"lengths input_splits (num_items_to_send, BEFORE AllToAll)",
workers=self._workers,
)
_check_int_overflow(
"JaggedTensorAllToAll",
lengths_output_splits,
"lengths output_splits (num_items_to_receive, BEFORE AllToAll)",
workers=self._workers,
)

self._dist_lengths: torch.Tensor = torch.empty(
sum(num_items_to_receive),
device=jt.lengths().device,
Expand All @@ -2106,12 +2129,28 @@ def __init__(
dist.all_to_all_single(
self._dist_lengths,
jt.lengths(),
output_split_sizes=_safe_tolist(num_items_to_receive),
input_split_sizes=_safe_tolist(num_items_to_send),
output_split_sizes=lengths_output_splits,
input_split_sizes=lengths_input_splits,
group=pg,
async_op=False,
)

# The lengths AllToAll above is blocking, so self._dist_lengths is
# populated here. Surface out-of-range received lengths (negative or
# overflowed) at the source rather than as a downstream negative-dim
# crash. Gated on the debug env var so the min/max device sync only runs
# when explicitly debugging, keeping the default path sync-free.
if _TORCHREC_OVERFLOW_DEBUG and self._dist_lengths.numel() > 0:
_check_int_overflow(
"JaggedTensorAllToAll",
[
int(self._dist_lengths.min()),
int(self._dist_lengths.max()),
],
"received lengths min/max (AFTER AllToAll)",
workers=self._workers,
)

# below will calculate chunks sums e.g.
# num_batches_to_receive = [2,2]
# lengths = [2,3,1,1]
Expand All @@ -2126,6 +2165,12 @@ def __init__(
self._dist_lengths,
)
)
_check_int_overflow(
"JaggedTensorAllToAll",
value_output_splits,
"value output_splits (BEFORE AllToAll)",
workers=self._workers,
)

self._dist_values: torch.Tensor = torch.empty(
sum(value_output_splits),
Expand All @@ -2142,6 +2187,12 @@ def __init__(
jt.lengths(),
)
)
_check_int_overflow(
"JaggedTensorAllToAll",
value_input_splits,
"value input_splits (BEFORE AllToAll)",
workers=self._workers,
)

# pyrefly: ignore
self._dist_values_req: dist.Work = dist.all_to_all_single(
Expand Down
99 changes: 69 additions & 30 deletions torchrec/modules/object_pool_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,52 @@
torch.fx.wrap("jagged_index_select_with_empty")


def _assert_valid_key_lengths(
key_lengths: torch.Tensor,
feature_max_lengths: torch.Tensor,
) -> None:
"""Fail fast if KJT-pool key lengths are out of range before they are stored.

The row-wise ``ShardedKeyedJaggedTensorPool.update()`` path receives lengths
via an all-to-all (``torchrec.distributed.dist_data.JaggedTensorAllToAll``)
and writes them verbatim into ``_key_lengths``. A corrupted or desynced
all-to-all silently yields out-of-range lengths (e.g. +/-2**31 from an int32
overflow at small world sizes), which only blow up much later in ``lookup()``
as a negative dimension in ``jagged_index_select_2d_forward_v2``. Validating
here converts that silent corruption into an immediate, actionable error at
the write site.

This runs only in ``update()`` (training-time pool refresh), never in the
scripted inference ``lookup()`` path. See T273509522 (and the 2022 precedent
T118141711).

Args:
key_lengths (torch.Tensor): 2D ``(num_ids, num_features)`` tensor of
per-feature lengths about to be written into ``_key_lengths``.
feature_max_lengths (torch.Tensor): 1D ``(num_features,)`` tensor of the
configured maximum length for each feature.
"""
if key_lengths.numel() == 0:
return
# Single device->host sync: lengths must be non-negative and within the
# configured per-feature maxima (an invariant already enforced on the input
# side by _update_preproc, so this never fires in a healthy run).
valid = torch.logical_and(
(key_lengths >= 0).all(),
(key_lengths <= feature_max_lengths).all(),
)
if not bool(valid):
raise RuntimeError(
"ShardedKeyedJaggedTensorPool.update received out-of-range key "
"lengths after the update all-to-all: "
f"min={int(key_lengths.min())}, "
f"per-feature max={key_lengths.max(dim=0).values.tolist()}, "
f"configured per-feature max={feature_max_lengths.tolist()}. This "
"indicates a corrupted or desynced update all-to-all (negative or "
"overflowed lengths), not a usage error. See T273509522."
)


class KeyedJaggedTensorPoolLookup(abc.ABC, torch.nn.Module):
"""
Abstract base class for KeyedJaggedTensor pool lookups
Expand Down Expand Up @@ -55,6 +101,7 @@ class KeyedJaggedTensorPoolLookup(abc.ABC, torch.nn.Module):
_is_weighted: bool
_total_lengths: int
_total_lengths_t: torch.Tensor
_feature_max_lengths_t: torch.Tensor
_key_lengths: torch.Tensor
_jagged_lengths: torch.Tensor
_jagged_offsets: torch.Tensor
Expand All @@ -75,6 +122,13 @@ def __init__(
self._total_lengths_t = torch.tensor(
[self._total_lengths], device=device, dtype=torch.int32
)
# Per-feature max lengths, used to validate post-all-to-all lengths in
# update() before they are written into _key_lengths (see T273509522).
self._feature_max_lengths_t = torch.tensor(
list(self._feature_max_lengths.values()),
device=device,
dtype=torch.int32,
)
self._is_weighted = is_weighted

self._key_lengths = torch.zeros(
Expand Down Expand Up @@ -254,11 +308,10 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor:
def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:

with record_function("## TensorPool update ##"):
key_lengths = (
values.lengths().view(-1, len(self._feature_max_lengths))
# pyrefly: ignore[no-matching-overload]
.sum(axis=1)
)
key_lengths_2d = values.lengths().view(-1, len(self._feature_max_lengths))
_assert_valid_key_lengths(key_lengths_2d, self._feature_max_lengths_t)
# pyrefly: ignore[no-matching-overload]
key_lengths = key_lengths_2d.sum(axis=1)
key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths)

padded_values = torch.ops.fbgemm.jagged_to_padded_dense(
Expand All @@ -269,11 +322,7 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:
)

self._values[ids] = padded_values.to(self._values.dtype)
self._key_lengths[ids] = (
values.lengths()
.view(-1, len(self._feature_max_lengths))
.to(self._key_lengths.dtype)
)
self._key_lengths[ids] = key_lengths_2d.to(self._key_lengths.dtype)

if values.weights_or_none() is not None:
padded_weights = torch.ops.fbgemm.jagged_to_padded_dense(
Expand Down Expand Up @@ -403,11 +452,10 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor:

def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:
with record_function("## UVMCachingInt64Lookup update ##"):
key_lengths = (
values.lengths().view(-1, len(self._feature_max_lengths))
# pyrefly: ignore[no-matching-overload]
.sum(axis=1)
)
key_lengths_2d = values.lengths().view(-1, len(self._feature_max_lengths))
_assert_valid_key_lengths(key_lengths_2d, self._feature_max_lengths_t)
# pyrefly: ignore[no-matching-overload]
key_lengths = key_lengths_2d.sum(axis=1)
key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths)
padded_values: torch.Tensor = torch.ops.fbgemm.jagged_to_padded_dense(
values.values(),
Expand All @@ -425,11 +473,7 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:

self._tbe_state[ids] = state

self._key_lengths[ids] = (
values.lengths()
.view(-1, len(self._feature_max_lengths))
.to(self._key_lengths.dtype)
)
self._key_lengths[ids] = key_lengths_2d.to(self._key_lengths.dtype)

def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]:
yield "values_upper_and_lower_bits", self._tbe_state
Expand Down Expand Up @@ -542,11 +586,10 @@ def lookup(self, ids: torch.Tensor) -> JaggedTensor:

def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:
with record_function("## UVMCachingInt32Lookup update##"):
key_lengths = (
values.lengths().view(-1, len(self._feature_max_lengths))
# pyrefly: ignore[no-matching-overload]
.sum(axis=1)
)
key_lengths_2d = values.lengths().view(-1, len(self._feature_max_lengths))
_assert_valid_key_lengths(key_lengths_2d, self._feature_max_lengths_t)
# pyrefly: ignore[no-matching-overload]
key_lengths = key_lengths_2d.sum(axis=1)
key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths)
state = torch.ops.fbgemm.jagged_to_padded_dense(
values.values(),
Expand All @@ -557,11 +600,7 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:

self._tbe_state[ids] = state

self._key_lengths[ids] = (
values.lengths()
.view(-1, len(self._feature_max_lengths))
.to(self._key_lengths.dtype)
)
self._key_lengths[ids] = key_lengths_2d.to(self._key_lengths.dtype)

def states_to_register(self) -> Iterator[Tuple[str, torch.Tensor]]:
yield "values", self._tbe_state
Expand Down
97 changes: 96 additions & 1 deletion torchrec/modules/tests/test_keyed_jagged_tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

import torch
from torchrec.modules.keyed_jagged_tensor_pool import KeyedJaggedTensorPool
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.modules.object_pool_lookups import (
_assert_valid_key_lengths,
TensorJaggedIndexSelectLookup,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor


class KeyedJaggedTensorPoolTest(unittest.TestCase):
Expand Down Expand Up @@ -366,3 +370,94 @@ def test_empty_lookup(
kjt.lengths().cpu(),
torch.tensor([], dtype=torch.int, device=torch.device("cpu")),
)


class KeyedJaggedTensorPoolLengthValidationTest(unittest.TestCase):
"""Regression coverage for T273509522.

A corrupted / desynced row-wise update all-to-all delivers out-of-range key
lengths (negative or +/-2**31 overflow) that used to be written silently
into ``_key_lengths`` and only crash much later in ``lookup()`` as a
negative dimension in ``jagged_index_select``. ``update()`` must now reject
them at the write site.
"""

def test_assert_valid_key_lengths_accepts_valid(self) -> None:
feature_max_lengths_t = torch.tensor([2, 4], dtype=torch.int32)
# In-range lengths (including 0 and the per-feature max) must pass.
_assert_valid_key_lengths(
torch.tensor([[1, 3], [2, 0], [0, 4]], dtype=torch.int32),
feature_max_lengths_t,
)
# An empty batch is a no-op (e.g. a rank that received nothing).
_assert_valid_key_lengths(
torch.zeros((0, 2), dtype=torch.int32), feature_max_lengths_t
)

def test_assert_valid_key_lengths_rejects_negative(self) -> None:
feature_max_lengths_t = torch.tensor([2, 4], dtype=torch.int32)
with self.assertRaisesRegex(RuntimeError, "out-of-range key lengths"):
_assert_valid_key_lengths(
torch.tensor([[1, -7], [2, 0]], dtype=torch.int32),
feature_max_lengths_t,
)

def test_assert_valid_key_lengths_rejects_over_max(self) -> None:
feature_max_lengths_t = torch.tensor([2, 4], dtype=torch.int32)
# f1 length 5 exceeds its configured max of 2.
with self.assertRaisesRegex(RuntimeError, "out-of-range key lengths"):
_assert_valid_key_lengths(
torch.tensor([[5, 1]], dtype=torch.int32), feature_max_lengths_t
)

def test_assert_valid_key_lengths_rejects_overflow_magnitude(self) -> None:
# The actual bug signature: lengths at the +/-2**31 int32 boundary.
feature_max_lengths_t = torch.tensor([2, 4], dtype=torch.int32)
with self.assertRaisesRegex(RuntimeError, "out-of-range key lengths"):
_assert_valid_key_lengths(
torch.tensor([[2_115_462_454, -2_133_429_463]], dtype=torch.int32),
feature_max_lengths_t,
)

def test_update_rejects_corrupted_lengths(self) -> None:
device = (
torch.device("cpu")
if not torch.cuda.is_available()
else torch.device("cuda:0")
)
lookup = TensorJaggedIndexSelectLookup(
pool_size=4,
values_dtype=torch.int64,
feature_max_lengths={"f1": 2, "f2": 4},
is_weighted=False,
device=device,
)
# One id, two features: f1 length 5 exceeds its configured max of 2 -
# exactly the kind of out-of-range length a corrupted update all-to-all
# would deliver. update() must reject it instead of storing garbage.
corrupted = JaggedTensor(
values=torch.arange(6, dtype=torch.int64, device=device),
lengths=torch.tensor([5, 1], dtype=torch.int, device=device),
)
with self.assertRaisesRegex(RuntimeError, "out-of-range key lengths"):
lookup.update(torch.tensor([0], device=device), corrupted)

def test_update_accepts_valid_lengths(self) -> None:
device = (
torch.device("cpu")
if not torch.cuda.is_available()
else torch.device("cuda:0")
)
lookup = TensorJaggedIndexSelectLookup(
pool_size=4,
values_dtype=torch.int64,
feature_max_lengths={"f1": 2, "f2": 4},
is_weighted=False,
device=device,
)
# In-range update must not raise (guards against false positives).
valid = JaggedTensor(
values=torch.arange(3, dtype=torch.int64, device=device),
lengths=torch.tensor([2, 1], dtype=torch.int, device=device),
)
lookup.update(torch.tensor([0], device=device), valid)
Loading