diff --git a/fbgemm_gpu/src/jagged_tensor_ops/common.cuh b/fbgemm_gpu/src/jagged_tensor_ops/common.cuh index 33938bec62..bb9742bfbd 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/common.cuh +++ b/fbgemm_gpu/src/jagged_tensor_ops/common.cuh @@ -224,7 +224,13 @@ inline std::tuple> check_shape_and_partition_( const int threads_x = inner_dense_size >= kWarpSize / 2 ? kWarpSize : inner_dense_size; +#ifndef USE_ROCM const int threads_y = kMaxThreads / kWarpSize; +#else + // AMD: ~256-thread blocks improve wavefront packing for common D=1 shapes. + constexpr int kTargetBlockThreads = 256; + const int threads_y = kTargetBlockThreads / threads_x; +#endif const dim3 blocks( div_round_up(outer_dense_size * jagged_folded_size, threads_y));