Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1105,9 +1105,11 @@ Tensor {{ embedding_cuda_op }}(
used_shared_bytes
);

const int32_t cta_per_row_grid_size = std::min(
div_round_up(total_unique_indices, work_group_size),
get_max_thread_blocks_());
const auto cta_per_row_grid_size = utils::cuda::cap_grid_dim_x(
cuda_calc_xblock_count(total_unique_indices, work_group_size),
kThreadGroupSize * num_cta_per_row_groups, // block size
at::cuda::getCurrentCUDAStream(),
utils::cuda::BlockCapPolicy::Always);

FBGEMM_LAUNCH_KERNEL(
backward_cta_per_row_kernel,
Expand Down Expand Up @@ -1243,9 +1245,11 @@ Tensor {{ embedding_cuda_op }}(

auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups);

int32_t warp_per_row_grid_size = std::min(
div_round_up(total_unique_indices, num_warp_per_row_groups),
get_max_thread_blocks_());
auto warp_per_row_grid_size = utils::cuda::cap_grid_dim_x(
cuda_calc_xblock_count(total_unique_indices, num_warp_per_row_groups),
kThreadGroupSize * num_warp_per_row_groups, // block size
at::cuda::getCurrentCUDAStream(),
utils::cuda::BlockCapPolicy::Always);

#ifdef USE_ROCM
{%- if is_optimized_hip_kernel_supported_mode %}
Expand All @@ -1270,7 +1274,16 @@ Tensor {{ embedding_cuda_op }}(
{%- for kWeightDecayMode in [0, 1, 2] %}
if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }})
{
warp_per_row_grid_size = div_round_up(sorted_linear_indices_num_runs[0].item<int32_t>(), segments_per_workgroup);
// HIP kernel: Use OverflowOnly to match original behavior (no cap on CUDA,
// cap only on ROCm when exceeding HIP 2^32 thread limit). The original code
// did not apply get_max_thread_blocks_() cap.
warp_per_row_grid_size = utils::cuda::cap_grid_dim_x(
cuda_calc_xblock_count(
sorted_linear_indices_num_runs[0].item<int32_t>(),
segments_per_workgroup),
256, // blockSize = dim3(256) = 256 threads
at::cuda::getCurrentCUDAStream(),
utils::cuda::BlockCapPolicy::OverflowOnly);
blockSize = dim3(256);
warp_per_row_smem_bytes = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <mutex>

#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/utils/cuda_utilities.cuh"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/find_qparams.cuh"
#include "fbgemm_gpu/utils/fixed_divisor.cuh"
Expand All @@ -50,17 +51,3 @@ DEVICE_INLINE int64_t gpuAtomicIncrement(int64_t* p) {
static_cast<unsigned long long int>(1)));
#endif
}

namespace fbgemm_gpu {
namespace {

// Based on the empirical study, max grid size that is 64x larger than the
// number of SMs gives good performance across the board
constexpr int MAX_THREAD_BLOCKS_FACTOR = 64;

inline int get_max_thread_blocks_() {
return MAX_THREAD_BLOCKS_FACTOR *
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
}
} // namespace
} // namespace fbgemm_gpu
17 changes: 4 additions & 13 deletions fbgemm_gpu/include/fbgemm_gpu/utils/cuda_utilities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,6 @@ namespace fbgemm_gpu::utils::cuda {
/// kernels: `max_blocks = MAX_THREAD_BLOCKS_FACTOR * #SMs`.
constexpr int32_t MAX_THREAD_BLOCKS_FACTOR = 64;

/// Returns `MAX_THREAD_BLOCKS_FACTOR * #SMs` for the device backing
/// `stream`. Legacy helper retained until all direct callers have migrated
/// to `cap_grid_dim_x{,_from_workload}` below.
///
/// @param stream Stream whose device supplies `#SMs`.
/// @return Grid-size cap to apply at kernel launch.
inline auto get_max_thread_blocks(const c10::cuda::CUDAStream& stream) {
const auto device = stream.device_index();
return MAX_THREAD_BLOCKS_FACTOR *
at::cuda::getDeviceProperties(device)->multiProcessorCount;
}

/// Selects how `cap_grid_dim_x{,_from_workload}` clamps the requested grid.
///
/// - `Always`: cap on both CUDA and ROCm at `MAX_THREAD_BLOCKS_FACTOR * #SMs`.
Expand Down Expand Up @@ -82,7 +70,10 @@ inline uint32_t cap_grid_dim_x(
return blocks_uncapped;
}

const auto max_blocks = static_cast<uint32_t>(get_max_thread_blocks(stream));
const auto max_blocks = static_cast<uint32_t>(
MAX_THREAD_BLOCKS_FACTOR *
at::cuda::getDeviceProperties(stream.device_index())
->multiProcessorCount);

if (policy == BlockCapPolicy::Always) {
return std::min<uint32_t>(blocks_uncapped, max_blocks);
Expand Down
35 changes: 29 additions & 6 deletions fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,12 @@ permute_2D_sparse_data_cuda(
permuted_lengths = at::empty({T, B}, lengths.options());

constexpr int32_t threads_1 = 256;
const auto blocks_1 = cuda_calc_block_count(B * T, threads_1);
// HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA,
// which silently wraps). permute_2D_lengths_kernel uses CUDA_KERNEL_LOOP,
// which already grid-strides, so capping is correctness-preserving.
// See: https://github.com/ROCm/hip/issues/2253
const auto blocks_1 = utils::cuda::cap_grid_dim_x_from_workload(
B * T, threads_1, at::cuda::getCurrentCUDAStream());
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_2D_lengths_kernel", [&] {
FBGEMM_LAUNCH_KERNEL(
Expand Down Expand Up @@ -250,7 +255,15 @@ permute_2D_sparse_data_cuda(

constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
const auto blocks_2 = cuda_calc_block_count(B * T, BT_blocks);
// HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA,
// which silently wraps). Both permute_2D_data_kernel and
// permute_2D_data_kernel_vec grid-stride over b_t, so capping is
// correctness-preserving.
// See: https://github.com/ROCm/hip/issues/2253
const auto blocks_2 = utils::cuda::cap_grid_dim_x(
cuda_calc_xblock_count(B * T, BT_blocks),
BT_blocks * 32,
at::cuda::getCurrentCUDAStream());
permuted_indices = at::empty(permuted_indices_size, indices.options());

AT_DISPATCH_INDEX_TYPES(
Expand Down Expand Up @@ -415,8 +428,12 @@ permute_sparse_features_cuda(
permuted_lengths = at::empty({num_output_features, B}, lengths.options());

constexpr int32_t threads_1 = 256;
const auto blocks_1 =
cuda_calc_block_count(B * num_output_features, threads_1);
// HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA,
// which silently wraps). permute_2D_lengths_kernel uses CUDA_KERNEL_LOOP,
// which already grid-strides, so capping is correctness-preserving.
// See: https://github.com/ROCm/hip/issues/2253
const auto blocks_1 = utils::cuda::cap_grid_dim_x_from_workload(
B * num_output_features, threads_1, at::cuda::getCurrentCUDAStream());
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_2D_lengths_kernel", [&] {
FBGEMM_LAUNCH_KERNEL(
Expand Down Expand Up @@ -452,8 +469,14 @@ permute_sparse_features_cuda(

constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
const auto blocks_2 =
cuda_calc_block_count(B * num_output_features, BT_blocks);
// HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA,
// which silently wraps). permute_indices_weights_kernel grid-strides over
// b_t, so capping is correctness-preserving.
// See: https://github.com/ROCm/hip/issues/2253
const auto blocks_2 = utils::cuda::cap_grid_dim_x(
cuda_calc_xblock_count(B * num_output_features, BT_blocks),
BT_blocks * 32,
at::cuda::getCurrentCUDAStream());
permuted_indices = at::empty(permuted_lengths_sum, indices.options());
if (weights.has_value()) {
const Tensor weights_value = weights.value();
Expand Down
150 changes: 150 additions & 0 deletions fbgemm_gpu/test/sparse/permute_indices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,156 @@ def test_permute_1D_sparse_data_large_grid(self) -> None:
torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu)
self.assertIsNone(permuted_weights_gpu)

@unittest.skipIf(*gpu_unavailable)
# Skip on GPUs with insufficient HBM (need ~512 MB for the int32
# lengths tensor at the chosen B).
@unittest.skipIf(*gpu_memory_lt_gb(4))
def test_permute_2D_sparse_data_large_grid(self) -> None:
"""
Reproduces the HIP grid-overflow bug in permute_2D_sparse_data_cuda
and verifies output correctness at the same scale.

With BT_blocks=32 and dim3(32, 32) (block size 1024), the
permute_2D_data_kernel_vec launch grid is
cuda_calc_block_count(B*T, 32). For B*T > 2**27, total threads
exceed the HIP 2**32 limit, causing FBGEMM_LAUNCH_KERNEL ->
KernelLauncher::checkThreadCountNotExceeded to TORCH_CHECK-fail on
ROCm pre-fix. With the production fix in place, this test
additionally validates output correctness against the CPU dispatch
of the same op — the GPU output must match the CPU reference
element-for-element.

Uses ``T=2, B=2**26+1`` so ``B*T = 2**27 + 2`` strictly trips the
threshold. ``lengths`` is sparse: all zero except for four known
non-zero positions (one per row, plus one mid-row), so HBM usage
stays bounded (~537 MB int32) while the permutation logic is
still exercised. ``permute = [1, 0]`` is a deterministic row swap
on the T axis with ``perm[i] != i`` for every i, so any
"kernel computed identity instead of permutation" or wrong-``b_t``
decoding bug surfaces in the assertion below.
"""

# Choose B*T so that total threads strictly exceeds 2**32:
# cuda_calc_block_count(B*T, 32) * 1024 ~= B*T * 32; need B*T > 2**27.
T = 2
B = (1 << 26) + 1

device = torch.device(torch.accelerator.current_accelerator() or "cuda")

# Deterministic non-identity permute: row swap on the T axis.
# perm[0] == 1 and perm[1] == 0, so perm[i] != i for every i.
perm_cpu = torch.tensor([1, 0], dtype=torch.int32)
permute = perm_cpu.to(device)

# Sparse non-zero lengths at four known positions. Total = 11.
lengths_cpu = torch.zeros((T, B), dtype=torch.int32)
lengths_cpu[0, 0] = 3
lengths_cpu[0, B // 2] = 5
lengths_cpu[1, 0] = 2
lengths_cpu[1, B - 1] = 1
lengths = lengths_cpu.to(device)

# Distinct indices per segment so the permutation is fully observable.
indices_cpu = torch.arange(11, dtype=torch.int32)
indices = indices_cpu.to(device)

# CPU reference oracle — same op, different dispatch.
(
permuted_lengths_cpu,
permuted_indices_cpu,
_permuted_weights_cpu,
) = torch.ops.fbgemm.permute_2D_sparse_data(
perm_cpu, lengths_cpu, indices_cpu, None, None
)

# GPU op under test. Pre-fix, this launch trips
# KernelLauncher::checkThreadCountNotExceeded on ROCm.
(
permuted_lengths_gpu,
permuted_indices_gpu,
permuted_weights_gpu,
) = torch.ops.fbgemm.permute_2D_sparse_data(
permute, lengths, indices, None, None
)

torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu)
torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu)
self.assertIsNone(permuted_weights_gpu)

@unittest.skipIf(*gpu_unavailable)
# Skip on GPUs with insufficient HBM (need ~512 MB for the int32
# lengths tensor at the chosen B).
@unittest.skipIf(*gpu_memory_lt_gb(4))
def test_permute_sparse_features_large_grid(self) -> None:
"""
Reproduces the HIP grid-overflow bug in permute_sparse_features_cuda
and verifies output correctness at the same scale.

With BT_blocks=32 and dim3(32, 32) (block size 1024), the
permute_indices_weights_kernel launch grid is
cuda_calc_block_count(B*T, 32). For B*T > 2**27, total threads
exceed the HIP 2**32 limit, causing FBGEMM_LAUNCH_KERNEL ->
KernelLauncher::checkThreadCountNotExceeded to TORCH_CHECK-fail on
ROCm pre-fix. With the production fix in place, this test
additionally validates output correctness against the CPU dispatch
of the same op — the GPU output must match the CPU reference
element-for-element.

Uses ``T=2, B=2**26+1`` so ``B*T = 2**27 + 2`` strictly trips the
threshold. ``lengths`` is sparse: all zero except for four known
non-zero positions (one per row, plus one mid-row), so HBM usage
stays bounded (~537 MB int32) while the permutation logic is
still exercised. ``permute = [1, 0]`` is a deterministic row swap
on the T axis with ``perm[i] != i`` for every i, so any
"kernel computed identity instead of permutation" or wrong-``b_t``
decoding bug surfaces in the assertion below.
"""

# Choose B*T so that total threads strictly exceeds 2**32:
# cuda_calc_block_count(B*T, 32) * 1024 ~= B*T * 32; need B*T > 2**27.
T = 2
B = (1 << 26) + 1

device = torch.device(torch.accelerator.current_accelerator() or "cuda")

# Deterministic non-identity permute: row swap on the T axis.
# perm[0] == 1 and perm[1] == 0, so perm[i] != i for every i.
perm_cpu = torch.tensor([1, 0], dtype=torch.int32)
permute = perm_cpu.to(device)

# Sparse non-zero lengths at four known positions. Total = 11.
lengths_cpu = torch.zeros((T, B), dtype=torch.int32)
lengths_cpu[0, 0] = 3
lengths_cpu[0, B // 2] = 5
lengths_cpu[1, 0] = 2
lengths_cpu[1, B - 1] = 1
lengths = lengths_cpu.to(device)

# Distinct indices per segment so the permutation is fully observable.
indices_cpu = torch.arange(11, dtype=torch.int32)
indices = indices_cpu.to(device)

# CPU reference oracle — same op, different dispatch.
(
permuted_lengths_cpu,
permuted_indices_cpu,
_permuted_weights_cpu,
) = torch.ops.fbgemm.permute_sparse_features(
perm_cpu, lengths_cpu, indices_cpu, None
)

# GPU op under test. Pre-fix, this launch trips
# KernelLauncher::checkThreadCountNotExceeded on ROCm.
(
permuted_lengths_gpu,
permuted_indices_gpu,
permuted_weights_gpu,
) = torch.ops.fbgemm.permute_sparse_features(permute, lengths, indices, None)

torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu)
torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu)
self.assertIsNone(permuted_weights_gpu)


extend_test_class(PermuteIndicesTest)

Expand Down
Loading