Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions fbgemm_gpu/test/tbe/cache/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
ComputeDevice,
DEFAULT_ASSOC,
MultiPassPrefetchConfig,
SplitTableBatchedEmbeddingBagsCodegen,
)
Expand Down Expand Up @@ -1041,10 +1042,12 @@ def test_lru_cache_insert_large_grid(self) -> None:
(split_embeddings_cache/lru_cache_populate.cu, lock_cache_line=False
branch).

Block: dim3(kWarpSize=32, kMaxThreads/kWarpSize=32) = 1024 threads.
Grid: div_round_up(N, kMaxThreads / kWarpSize) = ceil(N / 32).
Total threads ~= N * kWarpSize = 32 * N. For N >= 2**27, total
threads exceed the HIP 2**32 limit, causing FBGEMM_LAUNCH_KERNEL ->
Block: dim3(kWarpSize, kMaxThreads/kWarpSize) = 1024 threads
(kWarpSize = 32 on NVIDIA, 64 on AMD; cache associativity ==
kWarpSize == DEFAULT_ASSOC).
Grid: div_round_up(N, kMaxThreads / kWarpSize). Total threads
~= N * kWarpSize. For N >= 2**27 this exceeds the HIP 2**32 limit
on both warp sizes, causing FBGEMM_LAUNCH_KERNEL ->
KernelLauncher::checkThreadCountNotExceeded to TORCH_CHECK-fail
on ROCm.

Expand All @@ -1055,7 +1058,7 @@ def test_lru_cache_insert_large_grid(self) -> None:
"""
N = (1 << 27) + 1
D = 4
device = torch.accelerator.current_accelerator("cuda")
device = torch.accelerator.current_accelerator()

# T=1 table with E = N entries so all unique linear cache indices
# are valid hash keys. cache_hash_size_cumsum has shape (T+1,).
Expand All @@ -1071,16 +1074,22 @@ def test_lru_cache_insert_large_grid(self) -> None:
# kernel grid size.
linear_cache_indices = torch.arange(N, dtype=torch.int64, device=device)
# Small cache -> tiny lxu_cache_state / weights / lru_state.
# Cache associativity == warp size (DEFAULT_ASSOC: 32 on NVIDIA,
# 64 on AMD). The insert kernel strides cache rows by kWarpSize, so
# the slot dimension MUST match or the kernel indexes out of bounds
# on wavefront-64 hardware.
num_cache_sets = 1024
lxu_cache_state = torch.full(
(num_cache_sets, 32), -1, dtype=torch.int64, device=device
(num_cache_sets, DEFAULT_ASSOC), -1, dtype=torch.int64, device=device
)
# Flat weights tensor for the single table.
weights = torch.zeros(N * D, dtype=torch.float32, device=device)
lxu_cache_weights = torch.zeros(
(num_cache_sets * 32, D), dtype=torch.float32, device=device
(num_cache_sets * DEFAULT_ASSOC, D), dtype=torch.float32, device=device
)
lru_state = torch.zeros(
(num_cache_sets, DEFAULT_ASSOC), dtype=torch.int64, device=device
)
lru_state = torch.zeros((num_cache_sets, 32), dtype=torch.int64, device=device)

torch.ops.fbgemm.lru_cache_populate(
weights,
Expand All @@ -1102,14 +1111,14 @@ def test_lru_cache_insert_large_grid(self) -> None:
)
# Tier C structural invariants on the populated cache:
# 1. shape preserved.
self.assertEqual(lxu_cache_state.shape, (num_cache_sets, 32))
# 2. Cache is fully populated. With N >> num_cache_sets * 32 = 32K,
self.assertEqual(lxu_cache_state.shape, (num_cache_sets, DEFAULT_ASSOC))
# 2. Cache is fully populated. With N >> num_cache_sets * DEFAULT_ASSOC,
# every cache slot should hold a valid key (not the -1 sentinel)
# after the insert kernel runs. Pre-fix the kernel never runs;
# post-fix it grid-strides over all N indices and populates
# every slot.
num_filled = int((lxu_cache_state != -1).sum().item())
self.assertEqual(num_filled, num_cache_sets * 32)
self.assertEqual(num_filled, num_cache_sets * DEFAULT_ASSOC)
# 3. Every populated slot holds a valid linear cache index in
# [0, N). This catches "kernel wrote wrong key" bugs.
populated = lxu_cache_state[lxu_cache_state != -1]
Expand Down
Loading