From 835ce65c4eb5bbb594ad1d32faf910d56a3d2fa0 Mon Sep 17 00:00:00 2001 From: yxs Date: Sun, 12 Apr 2026 21:17:03 -0600 Subject: [PATCH 1/2] Add cuBLAS+NCCL fast path for small GEMM in GEMM+ReduceScatter For small GEMM shapes (FLOPs < 100 GFLOPS per GPU), the persistent kernel's fixed overhead (TMA descriptors, barriers, signals) exceeds the overlap benefit. On H100, this caused 4-GPU to be 40% slower and 8-GPU to be 93% slower than PyTorch for M=4096 N=4096 K=8192. Fall back to cuBLAS matmul + NCCL reduce_scatter when the GEMM is too small for overlap to be profitable. Benchmarked on 8xH100 (NV18): - 4-GPU small shape: 0.286ms -> 0.231ms (+24%) - 8-GPU small shape: 0.323ms -> 0.167ms (+94%) - Large shapes (LLaMA-70B): no regression --- .../kernels/nvidia/gemm_reduce_scatter.py | 36 ++++++++++++++++--- python/triton_dist/layers/nvidia/tp_mlp.py | 1 + .../triton_dist/test/nvidia/test_gemm_rs.py | 2 +- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py b/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py index 23f4927d6..800eda1c6 100644 --- a/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py +++ b/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py @@ -42,8 +42,12 @@ from triton_dist.kernels.nvidia.gemm_perf_model import estimate_gemm_sol_time_ms from triton_dist.nv_utils import get_intranode_max_speed_gbps - ################### context ################### + +# fall back to cuBLAS + NCCL when overlap overhead > benefit +_SMALL_GEMM_FLOPS_THRESHOLD = 1e11 # 100 GFLOPS + + @dataclasses.dataclass class GEMMReduceScatterTensorParallelContext: rs_ctx: ReduceScatter2DContext @@ -58,6 +62,9 @@ class GEMMReduceScatterTensorParallelContext: # gemm kernel config num_gemm_sms: int + # process group for cuBLAS + NCCL fast path + tp_group: Optional[torch.distributed.ProcessGroup] = None + def finalize(self): self.rs_ctx.finalize() nvshmem_free_tensor_sync(self.gemm_out_bufs[self.rs_ctx.local_rank]) @@ -68,9 +75,17 @@ def get_gemm_out_buf(self, input): return self.gemm_out_bufs[local_rank][:M] -def create_gemm_rs_context(max_M, N, rank, world_size, local_world_size, output_dtype: torch.dtype, - rs_stream: torch.cuda.Stream, - reduce_st: bool = False) -> GEMMReduceScatterTensorParallelContext: +def create_gemm_rs_context( + max_M, + N, + rank, + world_size, + local_world_size, + output_dtype: torch.dtype, + rs_stream: torch.cuda.Stream, + reduce_st: bool = False, + tp_group: Optional[torch.distributed.ProcessGroup] = None, +) -> GEMMReduceScatterTensorParallelContext: rs_ctx = create_reduce_scater_2d_ctx(max_M, N, rank, world_size, local_world_size, output_dtype) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count @@ -79,7 +94,7 @@ def create_gemm_rs_context(max_M, N, rank, world_size, local_world_size, output_ gemm_out_bufs = nvshmem_create_tensors((max_M // world_size if reduce_st else max_M, N), output_dtype, rank, local_world_size) ctx = GEMMReduceScatterTensorParallelContext(rs_ctx=rs_ctx, output_dtype=output_dtype, gemm_out_bufs=gemm_out_bufs, - rs_stream=rs_stream, num_gemm_sms=num_gemm_sms) + rs_stream=rs_stream, num_gemm_sms=num_gemm_sms, tp_group=tp_group) nvshmem_barrier_all_on_stream(torch.cuda.current_stream()) return ctx @@ -619,6 +634,17 @@ def gemm_rs_op(A: torch.Tensor, B: torch.Tensor, ctx: GEMMReduceScatterTensorPar assert M % world_size == 0 M_per_rank = M // world_size + + # fast path: cuBLAS + NCCL for small GEMMs where overlap overhead > benefit + gemm_flops = 2 * M * N * local_K + if gemm_flops < _SMALL_GEMM_FLOPS_THRESHOLD and ctx.tp_group is not None and not reduce_st: + output = torch.empty((M_per_rank, N), dtype=output_dtype, device=A.device) + gemm_out = torch.matmul(A, B) + if gemm_out.dtype != output_dtype: + gemm_out = gemm_out.to(output_dtype) + torch.distributed.reduce_scatter_tensor(output, gemm_out, group=ctx.tp_group) + return output + current_stream = torch.cuda.current_stream() rs_stream.wait_stream(current_stream) diff --git a/python/triton_dist/layers/nvidia/tp_mlp.py b/python/triton_dist/layers/nvidia/tp_mlp.py index 55b90eb69..f299b05da 100644 --- a/python/triton_dist/layers/nvidia/tp_mlp.py +++ b/python/triton_dist/layers/nvidia/tp_mlp.py @@ -110,6 +110,7 @@ def _init_ctx(self, max_M, ag_intranode_stream: torch.cuda.Stream | None = None, local_world_size=self.world_size, output_dtype=self.dtype, rs_stream=ag_intranode_stream, + tp_group=self.group, ) nvshmem_barrier_all_on_stream(torch.cuda.current_stream()) torch.cuda.synchronize() diff --git a/python/triton_dist/test/nvidia/test_gemm_rs.py b/python/triton_dist/test/nvidia/test_gemm_rs.py index e3e510358..3c454307b 100644 --- a/python/triton_dist/test/nvidia/test_gemm_rs.py +++ b/python/triton_dist/test/nvidia/test_gemm_rs.py @@ -109,7 +109,7 @@ def __init__( self.rs_stream: torch.cuda.Stream = torch.cuda.Stream(priority=-1) self.ctx = create_gemm_rs_context(max_M, N, self.rank, self.world_size, self.local_world_size, output_dtype, - self.rs_stream, reduce_st) + self.rs_stream, reduce_st, tp_group=tp_group) self.reduce_st = reduce_st self.fuse_scatter = fuse_scatter self.persistent = persistent From 6aa8256c5f59e40d4265af926bc7ab789868ab67 Mon Sep 17 00:00:00 2001 From: yxs Date: Sun, 12 Apr 2026 23:02:56 -0600 Subject: [PATCH 2/2] Skip fast path for integer dtypes to avoid matmul overflow torch.matmul with int8 accumulates in int8 and wraps on overflow, while the Triton kernel accumulates in fp32. Restrict the fast path to floating-point dtypes only. --- python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py b/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py index 800eda1c6..b3b70cd9f 100644 --- a/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py +++ b/python/triton_dist/kernels/nvidia/gemm_reduce_scatter.py @@ -637,7 +637,8 @@ def gemm_rs_op(A: torch.Tensor, B: torch.Tensor, ctx: GEMMReduceScatterTensorPar # fast path: cuBLAS + NCCL for small GEMMs where overlap overhead > benefit gemm_flops = 2 * M * N * local_K - if gemm_flops < _SMALL_GEMM_FLOPS_THRESHOLD and ctx.tp_group is not None and not reduce_st: + if (gemm_flops < _SMALL_GEMM_FLOPS_THRESHOLD and ctx.tp_group is not None and not reduce_st + and A.dtype.is_floating_point): output = torch.empty((M_per_rank, N), dtype=output_dtype, device=A.device) gemm_out = torch.matmul(A, B) if gemm_out.dtype != output_dtype: