From ab674f844caa7003d9f98d368d073a4064647e99 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Sun, 14 Jun 2026 00:07:58 -0700 Subject: [PATCH] Fix test_lru_cache_insert_large_grid associativity on ROCm wavefront64 Summary: `test_lru_cache_insert_large_grid` (added by D105282095) hardcodes the LXU cache associativity as `32`. The split-embeddings LXU cache is set-associative with associativity == warp size == `DEFAULT_ASSOC` (32 on NVIDIA, 64 on AMD). On AMD wavefront64 (gfx942 / MI300) `lru_cache_insert_kernel` strides cache rows by `kWarpSize = 64` and writes `lxu_cache_state` / `lxu_cache_weights` / `lru_state` for `slot` in `[0, 64)`, indexing past the 32-wide test allocations -> out-of-bounds -> non-deterministic memory corruption -> flaky `assertEqual(lru_state != time_stamp, 0)` failures in OSS ROCm CI (see P2378242263). On NVIDIA (32 == 32) the allocation matches the kernel, so the test passed. Fix (test-only; no kernel/production change): - Size the three cache tensors and assertions by `DEFAULT_ASSOC` instead of the literal `32`, matching the established pattern in `lxu_cache_test.py` and `nbit_cache_test.py`, so the allocation width matches the kernel's `kWarpSize` associativity on both platforms. - Fix `torch.accelerator.current_accelerator("cuda")` -> `current_accelerator()` (the string was silently coerced to `check_available=True`; flagged by ai_diff_reviewer). - Generalize the docstring's NVIDIA-specific (32) grid math. Differential Revision: D108540654 --- fbgemm_gpu/test/tbe/cache/cache_test.py | 31 ++++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/test/tbe/cache/cache_test.py b/fbgemm_gpu/test/tbe/cache/cache_test.py index efd047a157..af7d435462 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_test.py @@ -24,6 +24,7 @@ ) from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, + DEFAULT_ASSOC, MultiPassPrefetchConfig, SplitTableBatchedEmbeddingBagsCodegen, ) @@ -1041,10 +1042,12 @@ def test_lru_cache_insert_large_grid(self) -> None: (split_embeddings_cache/lru_cache_populate.cu, lock_cache_line=False branch). - Block: dim3(kWarpSize=32, kMaxThreads/kWarpSize=32) = 1024 threads. - Grid: div_round_up(N, kMaxThreads / kWarpSize) = ceil(N / 32). - Total threads ~= N * kWarpSize = 32 * N. For N >= 2**27, total - threads exceed the HIP 2**32 limit, causing FBGEMM_LAUNCH_KERNEL -> + Block: dim3(kWarpSize, kMaxThreads/kWarpSize) = 1024 threads + (kWarpSize = 32 on NVIDIA, 64 on AMD; cache associativity == + kWarpSize == DEFAULT_ASSOC). + Grid: div_round_up(N, kMaxThreads / kWarpSize). Total threads + ~= N * kWarpSize. For N >= 2**27 this exceeds the HIP 2**32 limit + on both warp sizes, causing FBGEMM_LAUNCH_KERNEL -> KernelLauncher::checkThreadCountNotExceeded to TORCH_CHECK-fail on ROCm. @@ -1055,7 +1058,7 @@ def test_lru_cache_insert_large_grid(self) -> None: """ N = (1 << 27) + 1 D = 4 - device = torch.accelerator.current_accelerator("cuda") + device = torch.accelerator.current_accelerator() # T=1 table with E = N entries so all unique linear cache indices # are valid hash keys. cache_hash_size_cumsum has shape (T+1,). @@ -1071,16 +1074,22 @@ def test_lru_cache_insert_large_grid(self) -> None: # kernel grid size. linear_cache_indices = torch.arange(N, dtype=torch.int64, device=device) # Small cache -> tiny lxu_cache_state / weights / lru_state. + # Cache associativity == warp size (DEFAULT_ASSOC: 32 on NVIDIA, + # 64 on AMD). The insert kernel strides cache rows by kWarpSize, so + # the slot dimension MUST match or the kernel indexes out of bounds + # on wavefront-64 hardware. num_cache_sets = 1024 lxu_cache_state = torch.full( - (num_cache_sets, 32), -1, dtype=torch.int64, device=device + (num_cache_sets, DEFAULT_ASSOC), -1, dtype=torch.int64, device=device ) # Flat weights tensor for the single table. weights = torch.zeros(N * D, dtype=torch.float32, device=device) lxu_cache_weights = torch.zeros( - (num_cache_sets * 32, D), dtype=torch.float32, device=device + (num_cache_sets * DEFAULT_ASSOC, D), dtype=torch.float32, device=device + ) + lru_state = torch.zeros( + (num_cache_sets, DEFAULT_ASSOC), dtype=torch.int64, device=device ) - lru_state = torch.zeros((num_cache_sets, 32), dtype=torch.int64, device=device) torch.ops.fbgemm.lru_cache_populate( weights, @@ -1102,14 +1111,14 @@ def test_lru_cache_insert_large_grid(self) -> None: ) # Tier C structural invariants on the populated cache: # 1. shape preserved. - self.assertEqual(lxu_cache_state.shape, (num_cache_sets, 32)) - # 2. Cache is fully populated. With N >> num_cache_sets * 32 = 32K, + self.assertEqual(lxu_cache_state.shape, (num_cache_sets, DEFAULT_ASSOC)) + # 2. Cache is fully populated. With N >> num_cache_sets * DEFAULT_ASSOC, # every cache slot should hold a valid key (not the -1 sentinel) # after the insert kernel runs. Pre-fix the kernel never runs; # post-fix it grid-strides over all N indices and populates # every slot. num_filled = int((lxu_cache_state != -1).sum().item()) - self.assertEqual(num_filled, num_cache_sets * 32) + self.assertEqual(num_filled, num_cache_sets * DEFAULT_ASSOC) # 3. Every populated slot holds a valid linear cache index in # [0, N). This catches "kernel wrote wrong key" bugs. populated = lxu_cache_state[lxu_cache_state != -1]