Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ void _bounds_check_indices_cuda_v1(
int64_t B,
int64_t total_B,
bool vbe,
bool prefetch_pipeline);
bool prefetch_pipeline,
bool disable_offsets_adjustment);

void _bounds_check_indices_cuda_v2(
Tensor& rows_per_table,
Expand All @@ -56,7 +57,8 @@ void _bounds_check_indices_cuda_v2(
int64_t B,
int64_t total_B,
bool vbe,
bool prefetch_pipeline);
bool prefetch_pipeline,
bool disable_offsets_adjustment);

///@ingroup embedding-cuda
void bounds_check_indices_cuda(
Expand All @@ -76,6 +78,9 @@ void bounds_check_indices_cuda(
TORCH_CHECK(bounds_check_version == 1 || bounds_check_version == 2);
const static bool use_v2_jk = fbgemm_gpu::config::is_feature_enabled(
fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2);
const static bool disable_offsets_adjustment =
fbgemm_gpu::config::is_feature_enabled(
fbgemm_gpu::config::FeatureGateName::DISABLE_OFFSETS_ADJUSTMENT);
const auto bounds_check_indices_fn = (use_v2_jk || bounds_check_version == 2)
? _bounds_check_indices_cuda_v2
: _bounds_check_indices_cuda_v1;
Expand Down Expand Up @@ -140,7 +145,8 @@ void bounds_check_indices_cuda(
B,
total_B,
vbe,
prefetch_pipeline);
prefetch_pipeline,
disable_offsets_adjustment);
}
// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
Expand Down
120 changes: 64 additions & 56 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1(
BoundsCheckMode bounds_check_mode,
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> warning,
FixedDivisor fd,
const bool disable_offsets_adjustment,
TORCH_DSA_KERNEL_ARGS) {
int32_t T = rows_per_table.size(0);
auto b_t = blockIdx.x * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -55,40 +56,42 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1(
auto indices_end = offsets[b_t + 1];
const index_t num_indices = indices.size(0);

// Condition first, then branch on mode.
if (indices_start < 0 || indices_start > indices_end ||
if (disable_offsets_adjustment ||
bounds_check_mode == BoundsCheckMode::FATAL) {
CUDA_KERNEL_ASSERT(
indices_start >= 0 && "indices_start must be non-negative");
CUDA_KERNEL_ASSERT(
indices_start <= indices_end &&
"indices_start must not exceed indices_end");
CUDA_KERNEL_ASSERT(
indices_end <= num_indices &&
"indices_end must not exceed num_indices");
} else if (
indices_start < 0 || indices_start > indices_end ||
indices_end > num_indices) {
if (bounds_check_mode == BoundsCheckMode::FATAL) {
CUDA_KERNEL_ASSERT(
indices_start >= 0 && "indices_start must be non-negative");
CUDA_KERNEL_ASSERT(
indices_start <= indices_end &&
"indices_start must not exceed indices_end");
CUDA_KERNEL_ASSERT(
indices_end <= num_indices &&
"indices_end must not exceed num_indices");
} else {
if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (gpuAtomicIncrement(&warning[0]) == 0) {
printf(
"EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for "
"batch: %d, table: %d, indices_start: %lld, indices_end: %lld,"
" num_indices: %lld. Setting indices_start and indices_end within "
"the range.\n",
vbe ? "true" : "false",
b,
t,
static_cast<int64_t>(indices_start),
static_cast<int64_t>(indices_end),
static_cast<int64_t>(num_indices));
}
if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (threadIdx.x == 0 && gpuAtomicIncrement(&warning[0]) == 0) {
printf(
"EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for "
"batch: %d, table: %d, indices_start: %lld, indices_end: %lld,"
" num_indices: %lld. Setting indices_start and indices_end within "
"the range.\n",
vbe ? "true" : "false",
b,
t,
static_cast<int64_t>(indices_start),
static_cast<int64_t>(indices_end),
static_cast<int64_t>(num_indices));
}
adjust_offset_kernel(
indices_start,
indices_end,
num_indices,
&offsets[b_t],
&offsets[b_t + 1]);
}
indices_start =
std::max(static_cast<index_t>(0), std::min(indices_start, num_indices));
indices_end = std::max(indices_start, std::min(indices_end, num_indices));
// Only thread 0 writes back the adjusted offsets to avoid the intra-warp
// race; no sync needed since offsets are not re-read in this kernel.
if (threadIdx.x == 0) {
offsets[b_t] = indices_start;
offsets[b_t + 1] = indices_end;
}
}

Expand Down Expand Up @@ -130,29 +133,32 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1(
}
}

if (bounds_check_mode == BoundsCheckMode::FATAL) {
CUDA_KERNEL_ASSERT(
num_indices == offsets[total_B] &&
"num_indices must match the last element in offsets");
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (num_indices != offsets[total_B]) {
if (gpuAtomicIncrement(&warning[0]) == 0) {
printf(
"EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for "
"total batch size %s: %d, total table num T: %d, "
" last element in offsets: %lld, indices size: %lld. "
" Setting the last element in offsets to be indices size.\n",
vbe ? "true" : "false",
vbe ? "total_B" : "B",
vbe ? total_B : B,
T,
static_cast<int64_t>(offsets[total_B]),
static_cast<int64_t>(num_indices));
}
offsets[total_B] = num_indices;
if (disable_offsets_adjustment ||
bounds_check_mode == BoundsCheckMode::FATAL) {
if (b_t == 0 && threadIdx.x == 0) {
CUDA_KERNEL_ASSERT(
num_indices == offsets[total_B] &&
"num_indices must match the last element in offsets");
}
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
if (num_indices != offsets[total_B]) {
} else if (num_indices != offsets[total_B]) {
// The last-element check is a single global condition; one thread handles
// the warning and the correction (for both WARNING and IGNORE).
if (b_t == 0 && threadIdx.x == 0) {
if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (gpuAtomicIncrement(&warning[0]) == 0) {
printf(
"EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for "
"total batch size %s: %d, total table num T: %d, "
" last element in offsets: %lld, indices size: %lld. "
" Setting the last element in offsets to be indices size.\n",
vbe ? "true" : "false",
vbe ? "total_B" : "B",
vbe ? total_B : B,
T,
static_cast<int64_t>(offsets[total_B]),
static_cast<int64_t>(num_indices));
}
}
offsets[total_B] = num_indices;
}
}
Expand All @@ -174,7 +180,8 @@ void _bounds_check_indices_cuda_v1(
int64_t B,
int64_t /*total_B*/,
bool vbe,
bool prefetch_pipeline) {
bool prefetch_pipeline,
bool disable_offsets_adjustment) {
TORCH_CHECK(
!prefetch_pipeline,
"bounds_check_indices_v1 does not support prefetch_pipeline=true")
Expand Down Expand Up @@ -205,6 +212,7 @@ void _bounds_check_indices_cuda_v1(
vbe ? B_offsets.value().data_ptr<int32_t>() : nullptr,
bounds_check_mode,
PTA_B(warning, int64_t, 1, 32),
FixedDivisor(max_B_));
FixedDivisor(max_B_),
disable_offsets_adjustment);
});
}
88 changes: 45 additions & 43 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
const int32_t* const b_t_map,
const int32_t info_B_num_bits,
const int32_t info_B_mask,
const bool disable_offsets_adjustment,
TORCH_DSA_KERNEL_ARGS) {
int32_t T = rows_per_table.size(0);
int32_t total_B = offsets.size(0) - 1;
Expand All @@ -39,14 +40,15 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
const uint32_t active_threads = blockDim.x * blockDim.y * blockDim.z;
#endif

// Check the last element
// Last-element check; one thread only.
if (b_t_start == 0 && threadIdx.x == 0) {
if (bounds_check_mode == BoundsCheckMode::FATAL) {
if (disable_offsets_adjustment ||
bounds_check_mode == BoundsCheckMode::FATAL) {
CUDA_KERNEL_ASSERT(
num_indices == offsets[total_B] &&
"num_indices must match the last element in offsets");
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (num_indices != offsets[total_B]) {
} else if (num_indices != offsets[total_B]) {
if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (gpuAtomicIncrement(&warning[0]) == 0) {
printf(
"EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for "
Expand All @@ -60,12 +62,8 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
static_cast<int64_t>(offsets[total_B]),
static_cast<int64_t>(num_indices));
}
offsets[total_B] = num_indices;
}
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
if (num_indices != offsets[total_B]) {
offsets[total_B] = num_indices;
}
offsets[total_B] = num_indices;
}
}

Expand All @@ -86,40 +84,42 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
auto indices_start = offsets[b_t];
auto indices_end = offsets[b_t + 1];

// Condition first, then branch on mode.
if (indices_start < 0 || indices_start > indices_end ||
if (disable_offsets_adjustment ||
bounds_check_mode == BoundsCheckMode::FATAL) {
CUDA_KERNEL_ASSERT(
indices_start >= 0 && "indices_start must be non-negative");
CUDA_KERNEL_ASSERT(
indices_start <= indices_end &&
"indices_start must not exceed indices_end");
CUDA_KERNEL_ASSERT(
indices_end <= num_indices &&
"indices_end must not exceed num_indices");
} else if (
indices_start < 0 || indices_start > indices_end ||
indices_end > num_indices) {
if (bounds_check_mode == BoundsCheckMode::FATAL) {
CUDA_KERNEL_ASSERT(
indices_start >= 0 && "indices_start must be non-negative");
CUDA_KERNEL_ASSERT(
indices_start <= indices_end &&
"indices_start must not exceed indices_end");
CUDA_KERNEL_ASSERT(
indices_end <= num_indices &&
"indices_end must not exceed num_indices");
} else {
if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (threadIdx.x == 0 && gpuAtomicIncrement(&warning[0]) == 0) {
printf(
"EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for "
"batch: %d, table: %d, indices_start: %lld, indices_end: %lld,"
" num_indices: %lld. Setting indices_start and indices_end within "
"the range.\n",
vbe ? "true" : "false",
b,
t,
static_cast<int64_t>(indices_start),
static_cast<int64_t>(indices_end),
static_cast<int64_t>(num_indices));
}
if (bounds_check_mode == BoundsCheckMode::WARNING) {
if (threadIdx.x == 0 && gpuAtomicIncrement(&warning[0]) == 0) {
printf(
"EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for "
"batch: %d, table: %d, indices_start: %lld, indices_end: %lld,"
" num_indices: %lld. Setting indices_start and indices_end within "
"the range.\n",
vbe ? "true" : "false",
b,
t,
static_cast<int64_t>(indices_start),
static_cast<int64_t>(indices_end),
static_cast<int64_t>(num_indices));
}
adjust_offset_kernel(
indices_start,
indices_end,
num_indices,
&offsets[b_t],
&offsets[b_t + 1]);
}
indices_start = std::max(
static_cast<index_t>(0), std::min(indices_start, num_indices));
indices_end = std::max(indices_start, std::min(indices_end, num_indices));
// Only thread 0 writes back the adjusted offsets to avoid the intra-warp
// race; no sync needed since offsets are not re-read in this kernel.
if (threadIdx.x == 0) {
offsets[b_t] = indices_start;
offsets[b_t + 1] = indices_end;
}
}

Expand Down Expand Up @@ -225,7 +225,8 @@ void _bounds_check_indices_cuda_v2(
int64_t B,
int64_t total_B,
bool vbe,
bool prefetch_pipeline) {
bool prefetch_pipeline,
bool disable_offsets_adjustment) {
if (vbe) {
TORCH_CHECK(b_t_map.has_value());
TENSOR_NDIM_EQUALS(b_t_map.value(), 1);
Expand Down Expand Up @@ -274,7 +275,8 @@ void _bounds_check_indices_cuda_v2(
FixedDivisor(B), \
vbe ? b_t_map.value().data_ptr<int32_t>() : nullptr, \
info_B_num_bits, \
info_B_mask); \
info_B_mask, \
disable_offsets_adjustment); \
}); \
}

Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/config/feature_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def foo():
# Enable warp-parallel kernel for populate_bucketized_permute
BUCKETIZED_PERMUTE_WARP_KERNEL = auto()

# Gate the bounds_check_indices offsets-adjustment assertions
DISABLE_OFFSETS_ADJUSTMENT = auto()

def is_enabled(self) -> bool:
return FeatureGate.is_enabled(self)

Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ namespace fbgemm_gpu::config {
X(TBE_REPORT_INPUT_PARAMS) \
X(TBE_CPU_OUTPUT_DISABLE_PINNED_MEMORY) \
X(TBE_USE_TUNED_SEGMENT_LENGTHS_CTA_B200) \
X(BUCKETIZED_PERMUTE_WARP_KERNEL)
X(BUCKETIZED_PERMUTE_WARP_KERNEL) \
X(DISABLE_OFFSETS_ADJUSTMENT)
// X(EXAMPLE_FEATURE_FLAG)

/// @ingroup fbgemm-gpu-config
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_bounds_check( # noqa C901
offsets[0] = -100
if offsets.numel() > 1:
offsets[-1] += 100
if bounds_check_mode != BoundsCheckMode.FATAL:
if bounds_check_mode != BoundsCheckMode.FATAL and use_cpu:
torch.ops.fbgemm.bounds_check_indices(
rows_per_table,
indices,
Expand Down
Loading