diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c6009a0af4..781c159d8a 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1105,9 +1105,11 @@ Tensor {{ embedding_cuda_op }}( used_shared_bytes ); - const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, work_group_size), - get_max_thread_blocks_()); + const auto cta_per_row_grid_size = utils::cuda::cap_grid_dim_x( + cuda_calc_xblock_count(total_unique_indices, work_group_size), + kThreadGroupSize * num_cta_per_row_groups, // block size + at::cuda::getCurrentCUDAStream(), + utils::cuda::BlockCapPolicy::Always); FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, @@ -1243,9 +1245,11 @@ Tensor {{ embedding_cuda_op }}( auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - int32_t warp_per_row_grid_size = std::min( - div_round_up(total_unique_indices, num_warp_per_row_groups), - get_max_thread_blocks_()); + auto warp_per_row_grid_size = utils::cuda::cap_grid_dim_x( + cuda_calc_xblock_count(total_unique_indices, num_warp_per_row_groups), + kThreadGroupSize * num_warp_per_row_groups, // block size + at::cuda::getCurrentCUDAStream(), + utils::cuda::BlockCapPolicy::Always); #ifdef USE_ROCM {%- if is_optimized_hip_kernel_supported_mode %} @@ -1270,7 +1274,16 @@ Tensor {{ embedding_cuda_op }}( {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { - warp_per_row_grid_size = div_round_up(sorted_linear_indices_num_runs[0].item(), segments_per_workgroup); + // HIP kernel: Use OverflowOnly to match original behavior (no cap on CUDA, + // cap only on ROCm when exceeding HIP 2^32 thread limit). The original code + // did not apply get_max_thread_blocks_() cap. + warp_per_row_grid_size = utils::cuda::cap_grid_dim_x( + cuda_calc_xblock_count( + sorted_linear_indices_num_runs[0].item(), + segments_per_workgroup), + 256, // blockSize = dim3(256) = 256 threads + at::cuda::getCurrentCUDAStream(), + utils::cuda::BlockCapPolicy::OverflowOnly); blockSize = dim3(256); warp_per_row_smem_bytes = 0; diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh index 2784adadbc..f41d7e1d7d 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh @@ -24,6 +24,7 @@ #include #include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/utils/cuda_utilities.cuh" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/find_qparams.cuh" #include "fbgemm_gpu/utils/fixed_divisor.cuh" @@ -50,17 +51,3 @@ DEVICE_INLINE int64_t gpuAtomicIncrement(int64_t* p) { static_cast(1))); #endif } - -namespace fbgemm_gpu { -namespace { - -// Based on the empirical study, max grid size that is 64x larger than the -// number of SMs gives good performance across the board -constexpr int MAX_THREAD_BLOCKS_FACTOR = 64; - -inline int get_max_thread_blocks_() { - return MAX_THREAD_BLOCKS_FACTOR * - at::cuda::getCurrentDeviceProperties()->multiProcessorCount; -} -} // namespace -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_utilities.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_utilities.cuh index 39c85de766..c763f2a50a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_utilities.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_utilities.cuh @@ -25,18 +25,6 @@ namespace fbgemm_gpu::utils::cuda { /// kernels: `max_blocks = MAX_THREAD_BLOCKS_FACTOR * #SMs`. constexpr int32_t MAX_THREAD_BLOCKS_FACTOR = 64; -/// Returns `MAX_THREAD_BLOCKS_FACTOR * #SMs` for the device backing -/// `stream`. Legacy helper retained until all direct callers have migrated -/// to `cap_grid_dim_x{,_from_workload}` below. -/// -/// @param stream Stream whose device supplies `#SMs`. -/// @return Grid-size cap to apply at kernel launch. -inline auto get_max_thread_blocks(const c10::cuda::CUDAStream& stream) { - const auto device = stream.device_index(); - return MAX_THREAD_BLOCKS_FACTOR * - at::cuda::getDeviceProperties(device)->multiProcessorCount; -} - /// Selects how `cap_grid_dim_x{,_from_workload}` clamps the requested grid. /// /// - `Always`: cap on both CUDA and ROCm at `MAX_THREAD_BLOCKS_FACTOR * #SMs`. @@ -82,7 +70,10 @@ inline uint32_t cap_grid_dim_x( return blocks_uncapped; } - const auto max_blocks = static_cast(get_max_thread_blocks(stream)); + const auto max_blocks = static_cast( + MAX_THREAD_BLOCKS_FACTOR * + at::cuda::getDeviceProperties(stream.device_index()) + ->multiProcessorCount); if (policy == BlockCapPolicy::Always) { return std::min(blocks_uncapped, max_blocks); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu index 0cc622339f..ca0c43ed13 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu @@ -221,7 +221,12 @@ permute_2D_sparse_data_cuda( permuted_lengths = at::empty({T, B}, lengths.options()); constexpr int32_t threads_1 = 256; - const auto blocks_1 = cuda_calc_block_count(B * T, threads_1); + // HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA, + // which silently wraps). permute_2D_lengths_kernel uses CUDA_KERNEL_LOOP, + // which already grid-strides, so capping is correctness-preserving. + // See: https://github.com/ROCm/hip/issues/2253 + const auto blocks_1 = utils::cuda::cap_grid_dim_x_from_workload( + B * T, threads_1, at::cuda::getCurrentCUDAStream()); AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "permute_2D_lengths_kernel", [&] { FBGEMM_LAUNCH_KERNEL( @@ -250,7 +255,15 @@ permute_2D_sparse_data_cuda( constexpr int32_t BT_blocks = 32; dim3 threads_2(32, BT_blocks); - const auto blocks_2 = cuda_calc_block_count(B * T, BT_blocks); + // HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA, + // which silently wraps). Both permute_2D_data_kernel and + // permute_2D_data_kernel_vec grid-stride over b_t, so capping is + // correctness-preserving. + // See: https://github.com/ROCm/hip/issues/2253 + const auto blocks_2 = utils::cuda::cap_grid_dim_x( + cuda_calc_xblock_count(B * T, BT_blocks), + BT_blocks * 32, + at::cuda::getCurrentCUDAStream()); permuted_indices = at::empty(permuted_indices_size, indices.options()); AT_DISPATCH_INDEX_TYPES( @@ -415,8 +428,12 @@ permute_sparse_features_cuda( permuted_lengths = at::empty({num_output_features, B}, lengths.options()); constexpr int32_t threads_1 = 256; - const auto blocks_1 = - cuda_calc_block_count(B * num_output_features, threads_1); + // HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA, + // which silently wraps). permute_2D_lengths_kernel uses CUDA_KERNEL_LOOP, + // which already grid-strides, so capping is correctness-preserving. + // See: https://github.com/ROCm/hip/issues/2253 + const auto blocks_1 = utils::cuda::cap_grid_dim_x_from_workload( + B * num_output_features, threads_1, at::cuda::getCurrentCUDAStream()); AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "permute_2D_lengths_kernel", [&] { FBGEMM_LAUNCH_KERNEL( @@ -452,8 +469,14 @@ permute_sparse_features_cuda( constexpr int32_t BT_blocks = 32; dim3 threads_2(32, BT_blocks); - const auto blocks_2 = - cuda_calc_block_count(B * num_output_features, BT_blocks); + // HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA, + // which silently wraps). permute_indices_weights_kernel grid-strides over + // b_t, so capping is correctness-preserving. + // See: https://github.com/ROCm/hip/issues/2253 + const auto blocks_2 = utils::cuda::cap_grid_dim_x( + cuda_calc_xblock_count(B * num_output_features, BT_blocks), + BT_blocks * 32, + at::cuda::getCurrentCUDAStream()); permuted_indices = at::empty(permuted_lengths_sum, indices.options()); if (weights.has_value()) { const Tensor weights_value = weights.value(); diff --git a/fbgemm_gpu/test/sparse/permute_indices_test.py b/fbgemm_gpu/test/sparse/permute_indices_test.py index 1296d160c8..02fc9ebc05 100644 --- a/fbgemm_gpu/test/sparse/permute_indices_test.py +++ b/fbgemm_gpu/test/sparse/permute_indices_test.py @@ -866,6 +866,156 @@ def test_permute_1D_sparse_data_large_grid(self) -> None: torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) self.assertIsNone(permuted_weights_gpu) + @unittest.skipIf(*gpu_unavailable) + # Skip on GPUs with insufficient HBM (need ~512 MB for the int32 + # lengths tensor at the chosen B). + @unittest.skipIf(*gpu_memory_lt_gb(4)) + def test_permute_2D_sparse_data_large_grid(self) -> None: + """ + Reproduces the HIP grid-overflow bug in permute_2D_sparse_data_cuda + and verifies output correctness at the same scale. + + With BT_blocks=32 and dim3(32, 32) (block size 1024), the + permute_2D_data_kernel_vec launch grid is + cuda_calc_block_count(B*T, 32). For B*T > 2**27, total threads + exceed the HIP 2**32 limit, causing FBGEMM_LAUNCH_KERNEL -> + KernelLauncher::checkThreadCountNotExceeded to TORCH_CHECK-fail on + ROCm pre-fix. With the production fix in place, this test + additionally validates output correctness against the CPU dispatch + of the same op — the GPU output must match the CPU reference + element-for-element. + + Uses ``T=2, B=2**26+1`` so ``B*T = 2**27 + 2`` strictly trips the + threshold. ``lengths`` is sparse: all zero except for four known + non-zero positions (one per row, plus one mid-row), so HBM usage + stays bounded (~537 MB int32) while the permutation logic is + still exercised. ``permute = [1, 0]`` is a deterministic row swap + on the T axis with ``perm[i] != i`` for every i, so any + "kernel computed identity instead of permutation" or wrong-``b_t`` + decoding bug surfaces in the assertion below. + """ + + # Choose B*T so that total threads strictly exceeds 2**32: + # cuda_calc_block_count(B*T, 32) * 1024 ~= B*T * 32; need B*T > 2**27. + T = 2 + B = (1 << 26) + 1 + + device = torch.device(torch.accelerator.current_accelerator() or "cuda") + + # Deterministic non-identity permute: row swap on the T axis. + # perm[0] == 1 and perm[1] == 0, so perm[i] != i for every i. + perm_cpu = torch.tensor([1, 0], dtype=torch.int32) + permute = perm_cpu.to(device) + + # Sparse non-zero lengths at four known positions. Total = 11. + lengths_cpu = torch.zeros((T, B), dtype=torch.int32) + lengths_cpu[0, 0] = 3 + lengths_cpu[0, B // 2] = 5 + lengths_cpu[1, 0] = 2 + lengths_cpu[1, B - 1] = 1 + lengths = lengths_cpu.to(device) + + # Distinct indices per segment so the permutation is fully observable. + indices_cpu = torch.arange(11, dtype=torch.int32) + indices = indices_cpu.to(device) + + # CPU reference oracle — same op, different dispatch. + ( + permuted_lengths_cpu, + permuted_indices_cpu, + _permuted_weights_cpu, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + perm_cpu, lengths_cpu, indices_cpu, None, None + ) + + # GPU op under test. Pre-fix, this launch trips + # KernelLauncher::checkThreadCountNotExceeded on ROCm. + ( + permuted_lengths_gpu, + permuted_indices_gpu, + permuted_weights_gpu, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + permute, lengths, indices, None, None + ) + + torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu) + torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) + self.assertIsNone(permuted_weights_gpu) + + @unittest.skipIf(*gpu_unavailable) + # Skip on GPUs with insufficient HBM (need ~512 MB for the int32 + # lengths tensor at the chosen B). + @unittest.skipIf(*gpu_memory_lt_gb(4)) + def test_permute_sparse_features_large_grid(self) -> None: + """ + Reproduces the HIP grid-overflow bug in permute_sparse_features_cuda + and verifies output correctness at the same scale. + + With BT_blocks=32 and dim3(32, 32) (block size 1024), the + permute_indices_weights_kernel launch grid is + cuda_calc_block_count(B*T, 32). For B*T > 2**27, total threads + exceed the HIP 2**32 limit, causing FBGEMM_LAUNCH_KERNEL -> + KernelLauncher::checkThreadCountNotExceeded to TORCH_CHECK-fail on + ROCm pre-fix. With the production fix in place, this test + additionally validates output correctness against the CPU dispatch + of the same op — the GPU output must match the CPU reference + element-for-element. + + Uses ``T=2, B=2**26+1`` so ``B*T = 2**27 + 2`` strictly trips the + threshold. ``lengths`` is sparse: all zero except for four known + non-zero positions (one per row, plus one mid-row), so HBM usage + stays bounded (~537 MB int32) while the permutation logic is + still exercised. ``permute = [1, 0]`` is a deterministic row swap + on the T axis with ``perm[i] != i`` for every i, so any + "kernel computed identity instead of permutation" or wrong-``b_t`` + decoding bug surfaces in the assertion below. + """ + + # Choose B*T so that total threads strictly exceeds 2**32: + # cuda_calc_block_count(B*T, 32) * 1024 ~= B*T * 32; need B*T > 2**27. + T = 2 + B = (1 << 26) + 1 + + device = torch.device(torch.accelerator.current_accelerator() or "cuda") + + # Deterministic non-identity permute: row swap on the T axis. + # perm[0] == 1 and perm[1] == 0, so perm[i] != i for every i. + perm_cpu = torch.tensor([1, 0], dtype=torch.int32) + permute = perm_cpu.to(device) + + # Sparse non-zero lengths at four known positions. Total = 11. + lengths_cpu = torch.zeros((T, B), dtype=torch.int32) + lengths_cpu[0, 0] = 3 + lengths_cpu[0, B // 2] = 5 + lengths_cpu[1, 0] = 2 + lengths_cpu[1, B - 1] = 1 + lengths = lengths_cpu.to(device) + + # Distinct indices per segment so the permutation is fully observable. + indices_cpu = torch.arange(11, dtype=torch.int32) + indices = indices_cpu.to(device) + + # CPU reference oracle — same op, different dispatch. + ( + permuted_lengths_cpu, + permuted_indices_cpu, + _permuted_weights_cpu, + ) = torch.ops.fbgemm.permute_sparse_features( + perm_cpu, lengths_cpu, indices_cpu, None + ) + + # GPU op under test. Pre-fix, this launch trips + # KernelLauncher::checkThreadCountNotExceeded on ROCm. + ( + permuted_lengths_gpu, + permuted_indices_gpu, + permuted_weights_gpu, + ) = torch.ops.fbgemm.permute_sparse_features(permute, lengths, indices, None) + + torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu) + torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) + self.assertIsNone(permuted_weights_gpu) + extend_test_class(PermuteIndicesTest)