From 1bf3b4f0dacaf1f71d12167d103a76f2c7e5e3f7 Mon Sep 17 00:00:00 2001 From: Gantaphon Chalumporn Date: Mon, 15 Jun 2026 11:16:18 -0700 Subject: [PATCH] Assert bad offsets in bounds_check_indices (Part 2) (#5895) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2814 Alternate version of D101843208 with a JustKnobs killswitch (`DISABLE_OFFSETS_ADJUSTMENT`, default `true`) gating the assert path, so we can recover production jobs without a code revert if any turn out to depend on the legacy silent-correction behavior. Default behavior matches D101843208: malformed offsets (`indices_start < 0`, `indices_start > indices_end`, or `indices_end > num_indices`) trigger CUDA_KERNEL_ASSERT in all bounds-check modes. Per-b_t asserts are lane-0-guarded (32x redundant asserts are wasteful). See D101843208 for the full motivation — the legacy path has known intra-warp races on the offsets buffer and incorrect assumptions about bag size. Plumbed via a `const static bool` JK lookup in the host wrapper, passed through both v1 and v2 launchers/kernels as a runtime arg. The kernel predicate is `disable_offsets_adjustment || mode == FATAL`; the else branch preserves the legacy printf + `adjust_offset_kernel` path. Differential Revision: D106829384 --- .../utils/embedding_bounds_check_host.cpp | 12 +- .../utils/embedding_bounds_check_v1.cu | 120 ++++++++++-------- .../utils/embedding_bounds_check_v2.cu | 88 ++++++------- fbgemm_gpu/fbgemm_gpu/config/feature_list.py | 3 + .../include/fbgemm_gpu/config/feature_gates.h | 3 +- .../tbe/utils/split_embeddings_utils_test.py | 2 +- 6 files changed, 124 insertions(+), 104 deletions(-) diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp index c76a288f03..a7fb3a0b82 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp @@ -38,7 +38,8 @@ void _bounds_check_indices_cuda_v1( int64_t B, int64_t total_B, bool vbe, - bool prefetch_pipeline); + bool prefetch_pipeline, + bool disable_offsets_adjustment); void _bounds_check_indices_cuda_v2( Tensor& rows_per_table, @@ -56,7 +57,8 @@ void _bounds_check_indices_cuda_v2( int64_t B, int64_t total_B, bool vbe, - bool prefetch_pipeline); + bool prefetch_pipeline, + bool disable_offsets_adjustment); ///@ingroup embedding-cuda void bounds_check_indices_cuda( @@ -76,6 +78,9 @@ void bounds_check_indices_cuda( TORCH_CHECK(bounds_check_version == 1 || bounds_check_version == 2); const static bool use_v2_jk = fbgemm_gpu::config::is_feature_enabled( fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2); + const static bool disable_offsets_adjustment = + fbgemm_gpu::config::is_feature_enabled( + fbgemm_gpu::config::FeatureGateName::DISABLE_OFFSETS_ADJUSTMENT); const auto bounds_check_indices_fn = (use_v2_jk || bounds_check_version == 2) ? _bounds_check_indices_cuda_v2 : _bounds_check_indices_cuda_v1; @@ -140,7 +145,8 @@ void bounds_check_indices_cuda( B, total_B, vbe, - prefetch_pipeline); + prefetch_pipeline, + disable_offsets_adjustment); } // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu index 4bbc079937..c7ceacfec5 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu @@ -19,6 +19,7 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1( BoundsCheckMode bounds_check_mode, pta::PackedTensorAccessor32 warning, FixedDivisor fd, + const bool disable_offsets_adjustment, TORCH_DSA_KERNEL_ARGS) { int32_t T = rows_per_table.size(0); auto b_t = blockIdx.x * blockDim.y + threadIdx.y; @@ -55,40 +56,42 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1( auto indices_end = offsets[b_t + 1]; const index_t num_indices = indices.size(0); - // Condition first, then branch on mode. - if (indices_start < 0 || indices_start > indices_end || + if (disable_offsets_adjustment || + bounds_check_mode == BoundsCheckMode::FATAL) { + CUDA_KERNEL_ASSERT( + indices_start >= 0 && "indices_start must be non-negative"); + CUDA_KERNEL_ASSERT( + indices_start <= indices_end && + "indices_start must not exceed indices_end"); + CUDA_KERNEL_ASSERT( + indices_end <= num_indices && + "indices_end must not exceed num_indices"); + } else if ( + indices_start < 0 || indices_start > indices_end || indices_end > num_indices) { - if (bounds_check_mode == BoundsCheckMode::FATAL) { - CUDA_KERNEL_ASSERT( - indices_start >= 0 && "indices_start must be non-negative"); - CUDA_KERNEL_ASSERT( - indices_start <= indices_end && - "indices_start must not exceed indices_end"); - CUDA_KERNEL_ASSERT( - indices_end <= num_indices && - "indices_end must not exceed num_indices"); - } else { - if (bounds_check_mode == BoundsCheckMode::WARNING) { - if (gpuAtomicIncrement(&warning[0]) == 0) { - printf( - "EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for " - "batch: %d, table: %d, indices_start: %lld, indices_end: %lld," - " num_indices: %lld. Setting indices_start and indices_end within " - "the range.\n", - vbe ? "true" : "false", - b, - t, - static_cast(indices_start), - static_cast(indices_end), - static_cast(num_indices)); - } + if (bounds_check_mode == BoundsCheckMode::WARNING) { + if (threadIdx.x == 0 && gpuAtomicIncrement(&warning[0]) == 0) { + printf( + "EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for " + "batch: %d, table: %d, indices_start: %lld, indices_end: %lld," + " num_indices: %lld. Setting indices_start and indices_end within " + "the range.\n", + vbe ? "true" : "false", + b, + t, + static_cast(indices_start), + static_cast(indices_end), + static_cast(num_indices)); } - adjust_offset_kernel( - indices_start, - indices_end, - num_indices, - &offsets[b_t], - &offsets[b_t + 1]); + } + indices_start = + std::max(static_cast(0), std::min(indices_start, num_indices)); + indices_end = std::max(indices_start, std::min(indices_end, num_indices)); + // Only thread 0 writes back the adjusted offsets to avoid the intra-warp + // race; no sync needed since offsets are not re-read in this kernel. + if (threadIdx.x == 0) { + offsets[b_t] = indices_start; + offsets[b_t + 1] = indices_end; } } @@ -130,29 +133,32 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1( } } - if (bounds_check_mode == BoundsCheckMode::FATAL) { - CUDA_KERNEL_ASSERT( - num_indices == offsets[total_B] && - "num_indices must match the last element in offsets"); - } else if (bounds_check_mode == BoundsCheckMode::WARNING) { - if (num_indices != offsets[total_B]) { - if (gpuAtomicIncrement(&warning[0]) == 0) { - printf( - "EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for " - "total batch size %s: %d, total table num T: %d, " - " last element in offsets: %lld, indices size: %lld. " - " Setting the last element in offsets to be indices size.\n", - vbe ? "true" : "false", - vbe ? "total_B" : "B", - vbe ? total_B : B, - T, - static_cast(offsets[total_B]), - static_cast(num_indices)); - } - offsets[total_B] = num_indices; + if (disable_offsets_adjustment || + bounds_check_mode == BoundsCheckMode::FATAL) { + if (b_t == 0 && threadIdx.x == 0) { + CUDA_KERNEL_ASSERT( + num_indices == offsets[total_B] && + "num_indices must match the last element in offsets"); } - } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { - if (num_indices != offsets[total_B]) { + } else if (num_indices != offsets[total_B]) { + // The last-element check is a single global condition; one thread handles + // the warning and the correction (for both WARNING and IGNORE). + if (b_t == 0 && threadIdx.x == 0) { + if (bounds_check_mode == BoundsCheckMode::WARNING) { + if (gpuAtomicIncrement(&warning[0]) == 0) { + printf( + "EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for " + "total batch size %s: %d, total table num T: %d, " + " last element in offsets: %lld, indices size: %lld. " + " Setting the last element in offsets to be indices size.\n", + vbe ? "true" : "false", + vbe ? "total_B" : "B", + vbe ? total_B : B, + T, + static_cast(offsets[total_B]), + static_cast(num_indices)); + } + } offsets[total_B] = num_indices; } } @@ -174,7 +180,8 @@ void _bounds_check_indices_cuda_v1( int64_t B, int64_t /*total_B*/, bool vbe, - bool prefetch_pipeline) { + bool prefetch_pipeline, + bool disable_offsets_adjustment) { TORCH_CHECK( !prefetch_pipeline, "bounds_check_indices_v1 does not support prefetch_pipeline=true") @@ -205,6 +212,7 @@ void _bounds_check_indices_cuda_v1( vbe ? B_offsets.value().data_ptr() : nullptr, bounds_check_mode, PTA_B(warning, int64_t, 1, 32), - FixedDivisor(max_B_)); + FixedDivisor(max_B_), + disable_offsets_adjustment); }); } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index 13aa83e468..53e31ea3df 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -22,6 +22,7 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( const int32_t* const b_t_map, const int32_t info_B_num_bits, const int32_t info_B_mask, + const bool disable_offsets_adjustment, TORCH_DSA_KERNEL_ARGS) { int32_t T = rows_per_table.size(0); int32_t total_B = offsets.size(0) - 1; @@ -39,14 +40,15 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( const uint32_t active_threads = blockDim.x * blockDim.y * blockDim.z; #endif - // Check the last element + // Last-element check; one thread only. if (b_t_start == 0 && threadIdx.x == 0) { - if (bounds_check_mode == BoundsCheckMode::FATAL) { + if (disable_offsets_adjustment || + bounds_check_mode == BoundsCheckMode::FATAL) { CUDA_KERNEL_ASSERT( num_indices == offsets[total_B] && "num_indices must match the last element in offsets"); - } else if (bounds_check_mode == BoundsCheckMode::WARNING) { - if (num_indices != offsets[total_B]) { + } else if (num_indices != offsets[total_B]) { + if (bounds_check_mode == BoundsCheckMode::WARNING) { if (gpuAtomicIncrement(&warning[0]) == 0) { printf( "EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for " @@ -60,12 +62,8 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( static_cast(offsets[total_B]), static_cast(num_indices)); } - offsets[total_B] = num_indices; - } - } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { - if (num_indices != offsets[total_B]) { - offsets[total_B] = num_indices; } + offsets[total_B] = num_indices; } } @@ -86,40 +84,42 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( auto indices_start = offsets[b_t]; auto indices_end = offsets[b_t + 1]; - // Condition first, then branch on mode. - if (indices_start < 0 || indices_start > indices_end || + if (disable_offsets_adjustment || + bounds_check_mode == BoundsCheckMode::FATAL) { + CUDA_KERNEL_ASSERT( + indices_start >= 0 && "indices_start must be non-negative"); + CUDA_KERNEL_ASSERT( + indices_start <= indices_end && + "indices_start must not exceed indices_end"); + CUDA_KERNEL_ASSERT( + indices_end <= num_indices && + "indices_end must not exceed num_indices"); + } else if ( + indices_start < 0 || indices_start > indices_end || indices_end > num_indices) { - if (bounds_check_mode == BoundsCheckMode::FATAL) { - CUDA_KERNEL_ASSERT( - indices_start >= 0 && "indices_start must be non-negative"); - CUDA_KERNEL_ASSERT( - indices_start <= indices_end && - "indices_start must not exceed indices_end"); - CUDA_KERNEL_ASSERT( - indices_end <= num_indices && - "indices_end must not exceed num_indices"); - } else { - if (bounds_check_mode == BoundsCheckMode::WARNING) { - if (threadIdx.x == 0 && gpuAtomicIncrement(&warning[0]) == 0) { - printf( - "EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for " - "batch: %d, table: %d, indices_start: %lld, indices_end: %lld," - " num_indices: %lld. Setting indices_start and indices_end within " - "the range.\n", - vbe ? "true" : "false", - b, - t, - static_cast(indices_start), - static_cast(indices_end), - static_cast(num_indices)); - } + if (bounds_check_mode == BoundsCheckMode::WARNING) { + if (threadIdx.x == 0 && gpuAtomicIncrement(&warning[0]) == 0) { + printf( + "EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for " + "batch: %d, table: %d, indices_start: %lld, indices_end: %lld," + " num_indices: %lld. Setting indices_start and indices_end within " + "the range.\n", + vbe ? "true" : "false", + b, + t, + static_cast(indices_start), + static_cast(indices_end), + static_cast(num_indices)); } - adjust_offset_kernel( - indices_start, - indices_end, - num_indices, - &offsets[b_t], - &offsets[b_t + 1]); + } + indices_start = std::max( + static_cast(0), std::min(indices_start, num_indices)); + indices_end = std::max(indices_start, std::min(indices_end, num_indices)); + // Only thread 0 writes back the adjusted offsets to avoid the intra-warp + // race; no sync needed since offsets are not re-read in this kernel. + if (threadIdx.x == 0) { + offsets[b_t] = indices_start; + offsets[b_t + 1] = indices_end; } } @@ -225,7 +225,8 @@ void _bounds_check_indices_cuda_v2( int64_t B, int64_t total_B, bool vbe, - bool prefetch_pipeline) { + bool prefetch_pipeline, + bool disable_offsets_adjustment) { if (vbe) { TORCH_CHECK(b_t_map.has_value()); TENSOR_NDIM_EQUALS(b_t_map.value(), 1); @@ -274,7 +275,8 @@ void _bounds_check_indices_cuda_v2( FixedDivisor(B), \ vbe ? b_t_map.value().data_ptr() : nullptr, \ info_B_num_bits, \ - info_B_mask); \ + info_B_mask, \ + disable_offsets_adjustment); \ }); \ } diff --git a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py index 84780aef4d..50f47affcb 100644 --- a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py +++ b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py @@ -69,6 +69,9 @@ def foo(): # Enable warp-parallel kernel for populate_bucketized_permute BUCKETIZED_PERMUTE_WARP_KERNEL = auto() + # Gate the bounds_check_indices offsets-adjustment assertions + DISABLE_OFFSETS_ADJUSTMENT = auto() + def is_enabled(self) -> bool: return FeatureGate.is_enabled(self) diff --git a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h index cb7c80a9aa..3746dfa632 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h +++ b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h @@ -65,7 +65,8 @@ namespace fbgemm_gpu::config { X(TBE_REPORT_INPUT_PARAMS) \ X(TBE_CPU_OUTPUT_DISABLE_PINNED_MEMORY) \ X(TBE_USE_TUNED_SEGMENT_LENGTHS_CTA_B200) \ - X(BUCKETIZED_PERMUTE_WARP_KERNEL) + X(BUCKETIZED_PERMUTE_WARP_KERNEL) \ + X(DISABLE_OFFSETS_ADJUSTMENT) // X(EXAMPLE_FEATURE_FLAG) /// @ingroup fbgemm-gpu-config diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py index b34a58aeda..d5eb529dc7 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py @@ -391,7 +391,7 @@ def test_bounds_check( # noqa C901 offsets[0] = -100 if offsets.numel() > 1: offsets[-1] += 100 - if bounds_check_mode != BoundsCheckMode.FATAL: + if bounds_check_mode != BoundsCheckMode.FATAL and use_cpu: torch.ops.fbgemm.bounds_check_indices( rows_per_table, indices,