From 98c94e27334fa00037b869c3055c8bd26b5c180b Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Fri, 6 Feb 2026 17:03:55 -0600 Subject: [PATCH 1/7] Update rocshmem Signed-off-by: Eric Eaton --- 3rdparty/rocshmem | 2 +- shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/rocshmem b/3rdparty/rocshmem index 7b5765ec0..45b4f5610 160000 --- a/3rdparty/rocshmem +++ b/3rdparty/rocshmem @@ -1 +1 @@ -Subproject commit 7b5765ec0e137db04b50d657ed83ea239be8dd58 +Subproject commit 45b4f5610cc1899a48ba3ac18073a0f60f999987 diff --git a/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh b/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh index d1797c564..7854fc5d8 100755 --- a/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh +++ b/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh @@ -39,12 +39,12 @@ SOURCE_MAP=( ["${ROCSHMEM_SRC}/src/ipc/context_ipc_device_coll.cpp"]="rocshmem_context_ipc_device_coll.bc" ["${ROCSHMEM_SRC}/src/ipc_policy.cpp"]="rocshmem_ipc_policy.bc" ["${ROCSHMEM_SRC}/src/gda/context_gda_device.cpp"]="rocshmem_context_gda_device.bc" + ["${ROCSHMEM_SRC}/src/gda/context_gda_device_coll.cpp"]="rocshmem_context_gda_device_coll.bc" ["${ROCSHMEM_SRC}/src/gda/backend_gda.cpp"]="rocshmem_backend_gda.bc" ["${ROCSHMEM_SRC}/src/gda/queue_pair.cpp"]="rocshmem_queue_pair.bc" ["${ROCSHMEM_SRC}/src/gda/ionic/queue_pair_ionic.cpp"]="rocshmem_queue_pair_ionic.bc" ["${ROCSHMEM_SRC}/src/gda/mlx5/queue_pair_mlx5.cpp"]="rocshmem_queue_pair_mlx5.bc" ["${ROCSHMEM_SRC}/src/gda/mlx5/segment_builder.cpp"]="rocshmem_segment_builder.bc" - ["${ROCSHMEM_SRC}/src/gda/endian.cpp"]="rocshmem_endian.bc" ["${ROCSHMEM_SRC}/src/team.cpp"]="rocshmem_team.bc" ["${ROCSHMEM_SRC}/src/sync/abql_block_mutex.cpp"]="rocshmem_abql_block_mutex.bc" ["${ROCSHMEM_SRC}/src/util.cpp"]="rocshmem_util.bc" From 335c95ed27f888c5f12d4291e8478c289acbadee Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Tue, 10 Feb 2026 15:50:47 -0600 Subject: [PATCH 2/7] Update CI to ROCm 6.4.4 Signed-off-by: Eric Eaton --- .github/workflows/amd-ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/amd-ci.yml b/.github/workflows/amd-ci.yml index 643172d7a..753dfb4fb 100644 --- a/.github/workflows/amd-ci.yml +++ b/.github/workflows/amd-ci.yml @@ -12,7 +12,7 @@ jobs: runs-on: "amd-gfx942-mi325" timeout-minutes: 45 container: - image: rocm/pytorch:rocm6.3_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview + image: rocm/pytorch:rocm6.4.4_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview options: >- --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root --volume /home/runner/.triton:/github/home/.triton @@ -27,7 +27,7 @@ jobs: run: | rm -f /usr/local/bin/cmake apt-get update -y && apt install -y libopenmpi-dev git cython3 ibverbs-utils openmpi-bin libopenmpi-dev libpci-dev libdw1 locales cmake - pip3 install -i https://test.pypi.org/simple hip-python>=6.3.0 # (or whatever Rocm version you have) + pip3 install -i https://test.pypi.org/simple hip-python>=6.4.4 # (or whatever Rocm version you have) pip3 install pybind11 - name: Checkout uses: actions/checkout@v4 @@ -58,7 +58,7 @@ jobs: runs-on: "amd-gfx942-mi325" timeout-minutes: 60 container: - image: rocm/pytorch:rocm6.3_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview + image: rocm/pytorch:rocm6.4.4_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview options: >- --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root --volume /home/runner/.triton:/github/home/.triton @@ -73,7 +73,7 @@ jobs: - name: Install dependencies run: | apt-get update -y && apt install -y libopenmpi-dev - pip3 install -i https://test.pypi.org/simple hip-python>=6.3.0 # (or whatever Rocm version you have) + pip3 install -i https://test.pypi.org/simple hip-python>=6.4.4 # (or whatever Rocm version you have) pip3 install pybind11 - name: Checkout uses: actions/checkout@v4 From 2a02b90098a29b24b06a189dfd5dae72a81db3e1 Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Thu, 12 Feb 2026 18:03:55 -0600 Subject: [PATCH 3/7] Fix allgather.py This file was not included in the CI and wasn't updated with some recent changes. Signed-off-by: Eric Eaton --- .github/workflows/amd-ci.yml | 2 + python/triton_dist/kernels/amd/allgather.py | 71 +++++++-------------- 2 files changed, 24 insertions(+), 49 deletions(-) diff --git a/.github/workflows/amd-ci.yml b/.github/workflows/amd-ci.yml index 753dfb4fb..bf9fa99e8 100644 --- a/.github/workflows/amd-ci.yml +++ b/.github/workflows/amd-ci.yml @@ -45,6 +45,8 @@ jobs: bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_rocshmem_api.py # distributed ops bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_distributed-notify-wait.py + # all-gather + bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_allgather.py # ag gemm bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ag_gemm_intra_node.py 8192 11008 4096 # gemm rs diff --git a/python/triton_dist/kernels/amd/allgather.py b/python/triton_dist/kernels/amd/allgather.py index 903bfd45a..9b172143c 100644 --- a/python/triton_dist/kernels/amd/allgather.py +++ b/python/triton_dist/kernels/amd/allgather.py @@ -29,17 +29,16 @@ import triton import triton.language as tl +import triton_dist import triton_dist.language as dl -from triton.language.extra.hip.librocshmem_device import set_rocshmem_ctx -from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_with_ctx_on_stream, +from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_on_stream, barrier_on_this_grid) from triton_dist.kernels.amd.memcpy import memcpy_async_kernel from triton_dist.utils import launch_cooperative_grid_options -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_no_barrier_kernel( - ctx, symm_ptr, local_ptr, N_per_rank, @@ -47,7 +46,6 @@ def allgather_no_barrier_kernel( num_ranks: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) npid_per_rank = npid // num_ranks @@ -63,9 +61,8 @@ def allgather_no_barrier_kernel( tl.store(dst_ptr + offs, val, mask) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_kernel( - ctx, symm_ptr, local_ptr, N_per_rank, @@ -74,7 +71,6 @@ def allgather_kernel( group_barrier_ptr, BLOCK_SIZE: tl.constexpr, ): - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) npid_per_rank = npid // num_ranks @@ -95,9 +91,8 @@ def allgather_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -114,7 +109,6 @@ def allgather_strided_chunked_kernel( """ with the local_ptr """ - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) @@ -181,7 +175,7 @@ def allgather_strided_chunked_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["pid", "npid", "rank"]) +@triton_dist.jit(do_not_specialize=["pid", "npid", "rank"]) def allgather_strided_chunked_pull_kernel( pid, npid, @@ -312,7 +306,7 @@ def make_2x_ptrs(val0, val1): return vals -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_pull_packed_kernel( symm_ptr, # (M, N), M = M_per_rank * num_ranks M_per_rank, @@ -398,9 +392,8 @@ def allgather_strided_chunked_pull_packed_kernel( tl.store(dst_ptrs[None, :, None] + offs[:, None, :], val, mask=mask[:, None, :]) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_pull_ctx_wrapper_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks M_per_rank, N, @@ -413,7 +406,6 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) npid = tl.num_programs(0) allgather_strided_chunked_pull_kernel( @@ -434,9 +426,8 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_pull_fused_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -451,7 +442,6 @@ def allgather_strided_chunked_pull_fused_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) npid = tl.num_programs(0) @@ -479,9 +469,8 @@ def allgather_strided_chunked_pull_fused_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_chunked_pull_fused_packed_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -496,7 +485,6 @@ def allgather_chunked_pull_fused_packed_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) memcpy_async_kernel(local_ptr, symm_ptr + M_per_rank * stride_m * rank, M_per_rank * N, 8 * BLOCK_SIZE_N) @@ -521,9 +509,8 @@ def allgather_chunked_pull_fused_packed_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_opt_kernel( - ctx, symm_ptr, local_ptr, N_per_rank, @@ -533,7 +520,6 @@ def allgather_opt_kernel( BLOCK_SIZE: tl.constexpr, ): """ compared to allagther_kernel, this use less CTAs to achieve to the same level """ - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) total_tiles = tl.cdiv(N_per_rank, BLOCK_SIZE) @@ -617,22 +603,16 @@ def allgather_ipc_kernel( tl.store(dst_ptr + offs, val, mask) -@functools.lru_cache() -def get_pyrocshmem_device_ctx(): - return pyrocshmem.rocshmem_get_device_ctx() - - -def allgather_no_fuse(full_symm: torch.Tensor, shard: torch.Tensor, barrier: torch.Tensor): +def allgather_no_fuse(full_symm: torch.Tensor, shard: torch.Tensor): rank = pyrocshmem.rocshmem_my_pe() nranks = pyrocshmem.rocshmem_n_pes() assert shard.numel() * nranks == full_symm.numel() assert shard.dtype == full_symm.dtype - ctx = get_pyrocshmem_device_ctx() current_stream = torch.cuda.current_stream() - barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream) - allgather_no_barrier_kernel[(nranks * 4, )](ctx, full_symm, shard, shard.numel(), rank, nranks, + barrier_all_on_stream(current_stream) + allgather_no_barrier_kernel[(nranks * 4, )](full_symm, shard, shard.numel(), rank, nranks, BLOCK_SIZE=32 * 1024, **launch_cooperative_grid_options()) - barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream) + barrier_all_on_stream(current_stream) def allgather( @@ -645,11 +625,9 @@ def allgather( nranks = pyrocshmem.rocshmem_n_pes() assert shard.numel() * nranks == full_symm.numel() assert shard.dtype == full_symm.dtype - ctx = get_pyrocshmem_device_ctx() - workgroups_per_rank = 1 # use only 8 ranks can achieve similay bandwidth - allgather_opt_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, shard.numel(), rank, nranks, barrier, + allgather_opt_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, shard.numel(), rank, nranks, barrier, BLOCK_SIZE=32 * 1024, num_warps=16, **launch_cooperative_grid_options()) @@ -662,7 +640,6 @@ def allgather_chunked_pull( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -678,16 +655,16 @@ def allgather_chunked_pull( full_symm[M_start:M_end, :].copy_(shard) current_stream = torch.cuda.current_stream() - barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream) + barrier_all_on_stream(current_stream) workgroups_per_rank = 2 BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth allgather_strided_chunked_pull_ctx_wrapper_kernel[(nranks * workgroups_per_rank, )]( - ctx, full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, + full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16) - barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream) + barrier_all_on_stream(current_stream) def allgather_chunked_pull_fused( @@ -698,7 +675,6 @@ def allgather_chunked_pull_fused( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -713,7 +689,7 @@ def allgather_chunked_pull_fused( BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth allgather_strided_chunked_pull_fused_kernel[(nranks * workgroups_per_rank, )]( - ctx, full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, + full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16, **launch_cooperative_grid_options()) @@ -726,7 +702,6 @@ def allgather_chunked_pull_packed_fused( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -740,7 +715,7 @@ def allgather_chunked_pull_packed_fused( BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 2 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similar bandwidth - allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N, + allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, @@ -766,13 +741,11 @@ def allgather_chunked( assert shard.is_cuda and full_symm.is_cuda assert shard.stride() == full_symm.stride() # just for simple implementation - ctx = get_pyrocshmem_device_ctx() - workgroups_per_rank = 1 BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 32 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth - allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N, + allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, From cc5e5b43768ec89afbc23cc64b60d65fcf776216 Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Wed, 18 Feb 2026 15:34:28 -0600 Subject: [PATCH 4/7] Fix test_allgather.py Signed-off-by: Eric Eaton --- python/triton_dist/test/amd/test_allgather.py | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/python/triton_dist/test/amd/test_allgather.py b/python/triton_dist/test/amd/test_allgather.py index c05948974..0b6491c46 100644 --- a/python/triton_dist/test/amd/test_allgather.py +++ b/python/triton_dist/test/amd/test_allgather.py @@ -59,6 +59,12 @@ def parse_args(): default=False, help="Enable profiling (default False)", ) + parser.add_argument( + "--noperf", + action="store_true", + default=False, + help="Disable performance measurement" + ) parser.add_argument("--stream_priority", "--stream_priority", default=0, type=int, help="stream priority, 0 or -1 supported") parser.add_argument("--N", "-N", type=int, default=1376) @@ -137,25 +143,30 @@ def parse_nbytes(nbytes: str): for name, fn_triton in triton_fns.items(): torch.distributed.barrier(TP_GROUP) fn_triton() + torch.cuda.synchronize() assert_bitwise_equal(A_full.view(8, -1, args.N), A_full_torch.view(8, -1, args.N)) torch.distributed.barrier(TP_GROUP) A_full.random_() - _run_id = os.environ.get("TORCHELASTIC_RUN_ID") - exp_name = f"allgather_{_run_id}" - with group_profile(exp_name, do_prof=args.profile, merge_group=TP_GROUP): - sleep_async(10) - _, duration_ms_torch = perf_func(fn_torch, args.iters, args.warmup_iters) - triton_durations_ms = {} - for name, fn_triton in triton_fns.items(): + if not args.noperf: + _run_id = os.environ.get("TORCHELASTIC_RUN_ID") + exp_name = f"allgather_{_run_id}" + with group_profile(exp_name, do_prof=args.profile, merge_group=TP_GROUP): sleep_async(10) - _, triton_durations_ms[name] = perf_func(fn_triton, args.iters, args.warmup_iters) + _, duration_ms_torch = perf_func(fn_torch, args.iters, args.warmup_iters) + triton_durations_ms = {} + for name, fn_triton in triton_fns.items(): + sleep_async(10) + _, triton_durations_ms[name] = perf_func(fn_triton, args.iters, args.warmup_iters) + + bw_gpbs_torch = nbytes_per_rank * (TP_GROUP.size() - 1) / 2**30 / duration_ms_torch * 1e3 - bw_gpbs_torch = nbytes_per_rank * (TP_GROUP.size() - 1) / 2**30 / duration_ms_torch * 1e3 + print(f"RANK #{TP_GROUP.rank()} AllGather torch: {duration_ms_torch:0.3f} ms/iter {bw_gpbs_torch:0.1f} GB/s.") + for name, duration_ms in triton_durations_ms.items(): + bw_gpbs_fn = (nbytes_per_rank * (TP_GROUP.size() - 1) / 2**30 / duration_ms * 1e3) + print(f" triton {duration_ms:0.3f} ms/iter {bw_gpbs_fn:0.1f} GB/s.") - print(f"RANK #{TP_GROUP.rank()} AllGather torch: {duration_ms_torch:0.3f} ms/iter {bw_gpbs_torch:0.1f} GB/s.") - for name, duration_ms in triton_durations_ms.items(): - bw_gpbs_fn = (nbytes_per_rank * (TP_GROUP.size() - 1) / 2**30 / duration_ms * 1e3) - print(f" triton {duration_ms:0.3f} ms/iter {bw_gpbs_fn:0.1f} GB/s.") + del A_full + del group_barrier finalize_distributed() From 90079dafeda3aa0b5f76d1e7bae513dcb0feffab Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Fri, 20 Feb 2026 14:27:03 -0600 Subject: [PATCH 5/7] Fix the UUID matching in amd_utils.py There was a change in ROCm 7+ that makes it harder to match up the torch and amdsmi devices. This change to use KFD makes operations like sleep_async work again. Signed-off-by: Eric Eaton --- python/triton_dist/amd_utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/python/triton_dist/amd_utils.py b/python/triton_dist/amd_utils.py index 4e132a4ba..7388d0f96 100644 --- a/python/triton_dist/amd_utils.py +++ b/python/triton_dist/amd_utils.py @@ -28,6 +28,8 @@ import json import warnings import re +import os +import sys from threading import Lock from hip import hip @@ -206,7 +208,19 @@ def _get_gpu_uuid_by_physical_device_id(device_id: int): _ensure_amdsmi_initialized() devices = amdsmi.amdsmi_get_processor_handles() handle = devices[device_id] - return amdsmi.amdsmi_get_gpu_device_uuid(handle) + # Due to a change in how UUIDs are generated for CPX mode, amdsmi no longer reports any uuid value that + # matches HIP/pytorch. HIP gets the value from sysfs, and we can also get the value there by getting + # the KFD info from amdsmi and then probing the sysfs directly. + kfd_info = amdsmi.amdsmi_get_gpu_kfd_info(handle) + node_id = kfd_info["node_id"] + kfd_path = os.path.join("/sys/devices/virtual/kfd/kfd/topology/nodes", str(node_id), "properties") + key = "unique_id" + with open(kfd_path, "r") as fd: + for line in fd: + if line.startswith(key): + uuid_str = line[len(key)+1:] + uuid_str = hex(int(uuid_str)) + return uuid_str def torch_uuid_to_unique_id(torch_uuid: str) -> str: @@ -232,7 +246,8 @@ def get_uuid_by_physical_device_id(device_id: int | None = None): try: if has_amdsmi(): return _get_gpu_uuid_by_physical_device_id(device_id) - except Exception: + except Exception as e: + print(e, file=sys.stderr) warnings.warn("get_uuid_by_physical_device_id failed with amdsmi, try using rocm-smi") return _get_physical_gpu_uuid_rocm(device_id) From aa9418253d0b2db4e75463244fe8bc7823ab232c Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Fri, 20 Feb 2026 15:05:50 -0600 Subject: [PATCH 6/7] Copy the nvidia low_latency_allgather Signed-off-by: Eric Eaton --- python/triton_dist/kernels/amd/__init__.py | 14 + .../kernels/amd/low_latency_allgather.py | 988 ++++++++++++++++++ python/triton_dist/layers/amd/__init__.py | 26 + .../layers/amd/low_latency_allgather_layer.py | 198 ++++ .../test/amd/test_fast_allgather.py | 158 +++ 5 files changed, 1384 insertions(+) create mode 100644 python/triton_dist/kernels/amd/low_latency_allgather.py create mode 100644 python/triton_dist/layers/amd/__init__.py create mode 100644 python/triton_dist/layers/amd/low_latency_allgather_layer.py create mode 100644 python/triton_dist/test/amd/test_fast_allgather.py diff --git a/python/triton_dist/kernels/amd/__init__.py b/python/triton_dist/kernels/amd/__init__.py index 227725346..001f5d913 100644 --- a/python/triton_dist/kernels/amd/__init__.py +++ b/python/triton_dist/kernels/amd/__init__.py @@ -24,7 +24,21 @@ ################################################################################ from .allgather_gemm import ag_gemm_intra_node, create_ag_gemm_intra_node_context from .gemm_reduce_scatter import gemm_rs_intra_node, create_gemm_rs_intra_node_context +from .low_latency_allgather import (fast_allgather, create_fast_allgather_context, _forward_pull_kernel, + _forward_push_2d_kernel, _forward_push_3d_kernel, _forward_push_2d_ll_kernel, + _forward_push_2d_ll_multimem_kernel, _forward_push_numa_2d_ll_kernel, + _forward_push_numa_2d_kernel, _forward_push_numa_2d_ll_multinode_kernel) __all__ = [ + "_forward_pull_kernel", + "_forward_push_2d_kernel", + "_forward_push_3d_kernel", + "_forward_push_2d_ll_kernel", + "_forward_push_2d_ll_multimem_kernel", + "_forward_push_numa_2d_kernel", + "_forward_push_numa_2d_ll_kernel", + "_forward_push_numa_2d_ll_multinode_kernel", + "create_fast_allgather_context", + "fast_allgather", "ag_gemm_intra_node", "create_ag_gemm_intra_node_context", "gemm_rs_intra_node", "create_gemm_rs_intra_node_context" ] diff --git a/python/triton_dist/kernels/amd/low_latency_allgather.py b/python/triton_dist/kernels/amd/low_latency_allgather.py new file mode 100644 index 000000000..ccdaec477 --- /dev/null +++ b/python/triton_dist/kernels/amd/low_latency_allgather.py @@ -0,0 +1,988 @@ +################################################################################ +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ +import torch +from dataclasses import dataclass +from typing import List + +import triton.language as tl +import triton_dist +from triton_dist.kernels.common_ops import pack_b32_v2 + +from triton_dist.language.extra import libshmem_device +from triton_dist.language.extra.cuda.language_extra import ( + __syncthreads, + tid, + ntid, + load_v4_u32, + load_v2_b64, + st_v2_u32, + st, + multimem_st_b64, +) +from pyrocshmem import rocshmem_create_tensor +from triton_dist.utils import NVSHMEM_SIGNAL_DTYPE + + +@triton_dist.jit(do_not_specialize=["rank", "signal_target"]) +def _forward_pull_kernel(symm_ptr, bytes_per_rank, symm_flag, world_size, rank, signal_target): + pid = tl.program_id(0) + thread_idx = tid(0) + if pid == rank: + if thread_idx != rank and thread_idx < world_size: + libshmem_device.signal_op( + symm_flag + rank, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + thread_idx, + ) + __syncthreads() + else: + peer = pid + if thread_idx == 0: + libshmem_device.signal_wait_until(symm_flag + peer, libshmem_device.ROCSHMEM_CMP_EQ, signal_target) + __syncthreads() + libshmem_device.getmem_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + peer * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + peer * bytes_per_rank, + bytes_per_rank, + peer, + ) + + +@triton_dist.jit(do_not_specialize=["rank", "signal_target"]) +def _forward_push_numa_2d_kernel( + symm_ptr, + bytes_per_rank, + symm_flag, + n_numa_nodes: tl.constexpr, + world_size: tl.constexpr, + rank, + signal_target, +): + tl.static_assert(n_numa_nodes == 2, "only support NUMA node == 2") + numa_world_size = world_size // n_numa_nodes + local_rank = rank % numa_world_size + nid = rank // numa_world_size + + pid = tl.program_id(0) + peer_nid = pid // numa_world_size + peer_local_rank = pid % numa_world_size + thread_idx = tid(0) + + symm_ptr = tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + + if peer_local_rank == local_rank: # remote push + if peer_nid != nid: # pnid: peer node id. each block recv from pnid + peer_to = peer_nid * numa_world_size + local_rank + libshmem_device.putmem_signal_wg( + symm_ptr + rank * bytes_per_rank, + symm_ptr + rank * bytes_per_rank, + bytes_per_rank, + symm_flag + rank, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer_to, + ) # write and tell peer remote that remote copy is done + else: # pack ll data + # wait for all write done + if thread_idx < world_size and thread_idx != rank: + libshmem_device.signal_wait_until( + symm_flag + thread_idx, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + else: # local push + peer = nid * numa_world_size + peer_local_rank + segment = peer_nid * numa_world_size + local_rank + if peer_nid != nid: # wait for recv_ll done + if thread_idx == 0: + libshmem_device.signal_wait_until(symm_flag + segment, libshmem_device.ROCSHMEM_CMP_EQ, signal_target) + __syncthreads() + libshmem_device.putmem_signal_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + + +@triton_dist.jit(do_not_specialize=["rank", "signal_target"]) +def _forward_push_numa_2d_ll_kernel( + symm_ptr, + bytes_per_rank, + symm_flag, + symm_ll_buffer, + n_numa_nodes: tl.constexpr, + world_size: tl.constexpr, + rank, + signal_target, +): + tl.static_assert(n_numa_nodes == 2, "only support NUMA node == 2") + numa_world_size = world_size // n_numa_nodes + local_rank = rank % numa_world_size + nid = rank // numa_world_size + + pid = tl.program_id(0) + peer_nid = pid // numa_world_size + peer_local_rank = pid % numa_world_size + thread_idx = tid(0) + num_ints = bytes_per_rank // 4 + + symm_ll_buffer = tl.cast(symm_ll_buffer, tl.pointer_type(tl.int8)) + symm_ptr = tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + + if peer_local_rank == local_rank: # remote push + if peer_nid != nid: # pnid: peer node id. each block recv from pnid + segment = peer_nid * numa_world_size + local_rank + _recv_ll_block( + symm_ptr + segment * bytes_per_rank, + symm_ll_buffer + segment * bytes_per_rank * 2, + num_ints, + signal_target, + ) # magic number here + __syncthreads() + if thread_idx == 0: + st(symm_flag + segment, signal_target, scope="gpu", semantic="release") + else: # pack ll data + _pack_ll_block( + symm_ll_buffer + rank * bytes_per_rank * 2, + symm_ptr + rank * bytes_per_rank, + num_ints, + signal_target, + 2048, + ) # magic number here + __syncthreads() + + peer_to_nid = 1 - nid # only for n_numa_nodes == 2 + peer_to = peer_to_nid * numa_world_size + local_rank + libshmem_device.putmem_wg( + symm_ll_buffer + rank * bytes_per_rank * 2, + symm_ll_buffer + rank * bytes_per_rank * 2, + bytes_per_rank * 2, + peer_to, + ) # write and tell peer remote that remote copy is done + # wait for all write done + if thread_idx < world_size and thread_idx != rank: + libshmem_device.signal_wait_until( + symm_flag + thread_idx, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + + else: # local push + peer = nid * numa_world_size + peer_local_rank + segment = peer_nid * numa_world_size + local_rank + if peer_nid != nid: # wait for recv_ll done + if thread_idx == 0: + libshmem_device.signal_wait_until(symm_flag + segment, libshmem_device.ROCSHMEM_CMP_EQ, signal_target) + __syncthreads() + libshmem_device.putmem_signal_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + + +@triton_dist.jit(do_not_specialize=["rank", "signal_target"]) +def _forward_push_numa_2d_ll_multinode_kernel( + symm_ptr, + bytes_per_rank, + symm_flag, + symm_ll_buffer, + nnodes: tl.constexpr, + n_numa_nodes: tl.constexpr, + world_size: tl.constexpr, + rank, + signal_target, +): + # the communication pattern + # for rank (node_id, numa_id, numa_rank) with (i0, j0, k0) + # (i0, j0, k0) => (i0, j0, k_x) with intra NUMA communication with nvshmem_putmem_signal_block + # (i0, j1, k0) => (i0, j_x, k0) with inter NUMA communication with nvshmem_putmem_signal_block + # (i0, j1, k0) => (i_x, j0, k0) with intra NODE communication with nvshmem_putmem_nbi_warp with LL protocol + # BUT the difference: NIC can be done with nbi, but NUMA can't (or better not) + tl.static_assert(n_numa_nodes == 2, "only support NUMA node == 2") + local_world_size = world_size // nnodes + node_id = rank // local_world_size + local_rank = rank % local_world_size + numa_world_size = local_world_size // n_numa_nodes + numa_rank = local_rank % numa_world_size + local_numa_id = local_rank // numa_world_size + global_numa_id = rank // numa_world_size + + pid = tl.program_id(0) + peer_node_id = pid // local_world_size + peer_local_rank = pid % local_world_size + peer_numa_rank = peer_local_rank % numa_world_size + peer_local_numa_id = peer_local_rank // numa_world_size + peer_global_numa_id = pid // numa_world_size + thread_idx = tid(0) + num_ints = bytes_per_rank // 4 + + symm_ll_buffer = tl.cast(symm_ll_buffer, tl.pointer_type(tl.int8)) + symm_ptr = tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + + # (i0, j0, k0) 2 conditions to put internode => (node_id, local_numa_id, numa_rank) + # 1. send intra NUMA (i0, j0, k_x) x!=0 + # 2. send inter NUMA (i0, j_x, k0) x!=0 + is_intra_numa = numa_rank != peer_numa_rank + is_inter_numa = node_id == peer_node_id and (local_numa_id != peer_local_numa_id and numa_rank == peer_numa_rank) + + if (is_intra_numa and global_numa_id == peer_global_numa_id): # no need to wait, just send + peer = global_numa_id * numa_world_size + peer_numa_rank + segment = rank + libshmem_device.putmem_signal_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + elif is_intra_numa and global_numa_id != peer_global_numa_id: + peer = global_numa_id * numa_world_size + peer_numa_rank + segment = (peer_node_id * local_world_size + peer_local_numa_id * numa_world_size + numa_rank) + # wait for segment ready + if thread_idx == 0: + libshmem_device.signal_wait_until(symm_flag + segment, libshmem_device.ROCSHMEM_CMP_EQ, signal_target) + __syncthreads() + libshmem_device.putmem_signal_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + elif is_inter_numa: + peer = (node_id * local_world_size + peer_local_numa_id * numa_world_size + peer_numa_rank) + segment = rank + libshmem_device.putmem_signal_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + else: # remote push + if peer_node_id != node_id: # pnid: peer node id. each block recv from pnid + segment = peer_global_numa_id * numa_world_size + peer_numa_rank + _recv_ll_block( + symm_ptr + segment * bytes_per_rank, + symm_ll_buffer + segment * bytes_per_rank * 2, + num_ints, + signal_target, + ) # magic number here + __syncthreads() + + if thread_idx == 0: + st(symm_flag + segment, signal_target, scope="gpu", semantic="release") + else: # pack ll data and send to peer + _pack_ll_block( + symm_ll_buffer + rank * bytes_per_rank * 2, + symm_ptr + rank * bytes_per_rank, + num_ints, + signal_target, + 2048, + ) # magic number here + __syncthreads() + + for i in range(world_size // numa_world_size): + if i // n_numa_nodes != node_id: # only send to other nodes + peer_to = numa_rank + i * numa_world_size + libshmem_device.putmem_nbi_warp( + symm_ll_buffer + rank * bytes_per_rank * 2, + symm_ll_buffer + rank * bytes_per_rank * 2, + bytes_per_rank * 2, + peer_to, + ) # write and tell peer remote that remote copy is done + + # wait for all write done + if thread_idx < world_size and thread_idx != rank: + libshmem_device.signal_wait_until( + symm_flag + thread_idx, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + + +@triton_dist.jit(do_not_specialize=["rank", "signal_target"]) +def _forward_push_2d_kernel(symm_ptr, bytes_per_rank, symm_flag, NNODES, WORLD_SIZE, rank, signal_target): + LOCAL_WORLD_SIZE = WORLD_SIZE // NNODES + local_rank = rank % LOCAL_WORLD_SIZE + node_id = rank // LOCAL_WORLD_SIZE + rank_base = node_id * LOCAL_WORLD_SIZE + + pid = tl.program_id(0) + peer_rank = pid + peer_node_id = peer_rank // LOCAL_WORLD_SIZE + peer_local_rank = peer_rank % LOCAL_WORLD_SIZE + thread_idx = tid(0) + if peer_local_rank == local_rank: # remote push + if peer_rank != rank: + peer = peer_node_id * LOCAL_WORLD_SIZE + local_rank + segment = rank + libshmem_device.putmem_signal_nbi_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + else: + if thread_idx < WORLD_SIZE and thread_idx != rank: + libshmem_device.signal_wait_until( + symm_flag + thread_idx, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + else: # local push + peer = rank_base + peer_local_rank + segment = peer_node_id * LOCAL_WORLD_SIZE + local_rank + if peer_node_id != node_id: # wait for data from other nodes + if thread_idx == 0: + libshmem_device.signal_wait_until( + symm_flag + segment, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + libshmem_device.putmem_signal_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + + +@triton_dist.jit(do_not_specialize=["rank", "signal_target"]) +def _forward_push_3d_kernel( + symm_ptr, + bytes_per_rank, + symm_ll_buffer, + symm_flag, + NNODES, + N_NUMA_NODES, + WORLD_SIZE, + rank, + signal_target, + INTER_NODE_WITH_LL: tl.constexpr = False, +): + """inter-node / inter-NUMA / intra-NUMA""" + LOCAL_WORLD_SIZE = WORLD_SIZE // NNODES + NUMA_WORLD_SIZE = LOCAL_WORLD_SIZE // N_NUMA_NODES + local_rank = rank % LOCAL_WORLD_SIZE + node_id = rank // LOCAL_WORLD_SIZE + numa_rank = local_rank % NUMA_WORLD_SIZE + local_numa_id = local_rank // NUMA_WORLD_SIZE + + pid = tl.program_id(0) + peer_rank = pid + peer_node_id = peer_rank // LOCAL_WORLD_SIZE + peer_local_rank = peer_rank % LOCAL_WORLD_SIZE + peer_numa_rank = peer_local_rank % NUMA_WORLD_SIZE + peer_local_numa_id = peer_local_rank // NUMA_WORLD_SIZE + + thread_idx = tid(0) + num_ints = bytes_per_rank // 4 + symm_ptr = tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + symm_ll_buffer = tl.cast(symm_ll_buffer, tl.pointer_type(tl.int8)) + if peer_local_rank == local_rank: + if peer_node_id != node_id: + if INTER_NODE_WITH_LL: + segment = peer_node_id * LOCAL_WORLD_SIZE + local_rank + _recv_ll_block( + symm_ptr + segment * bytes_per_rank, + symm_ll_buffer + segment * bytes_per_rank * 2, + num_ints, + signal_target, + ) + __syncthreads() + if thread_idx == 0: + st(symm_flag + segment, signal_target, scope="gpu", semantic="release") + else: + wid = thread_idx // 32 + if INTER_NODE_WITH_LL: + segment = rank + _pack_ll_block(symm_ll_buffer + rank * bytes_per_rank * 2, symm_ptr + rank * bytes_per_rank, num_ints, + signal_target, 2048) + __syncthreads() + + if wid < NNODES and wid != node_id: + peer = wid * LOCAL_WORLD_SIZE + local_rank + libshmem_device.putmem_nbi_warp( + symm_ll_buffer + segment * bytes_per_rank * 2, + symm_ll_buffer + segment * bytes_per_rank * 2, + bytes_per_rank * 2, + peer, + ) # write and tell peer remote that remote copy is done + else: + if wid < NNODES and wid != node_id: + peer = wid * LOCAL_WORLD_SIZE + local_rank + segment = rank + libshmem_device.putmem_signal_nbi_warp( + symm_ptr + segment * bytes_per_rank, + symm_ptr + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + + __syncthreads() + if thread_idx < WORLD_SIZE and thread_idx != rank: + libshmem_device.signal_wait_until( + symm_flag + thread_idx, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + else: # local push with NUMA opt + # NIC consume all the PCI-e bandwidth. don't overlap with inter-NODE communication + # inter/intra NODE communication overlap pattern is too complex. + if NNODES > 1: # no if for single node. + if thread_idx < WORLD_SIZE and (thread_idx % LOCAL_WORLD_SIZE == local_rank and thread_idx != rank): + libshmem_device.signal_wait_until( + symm_flag + thread_idx, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + + if peer_numa_rank == numa_rank: # NUMA write + peer = (node_id * LOCAL_WORLD_SIZE + peer_local_numa_id * NUMA_WORLD_SIZE + numa_rank) + segment = peer_node_id * LOCAL_WORLD_SIZE + local_rank + libshmem_device.putmem_signal_wg( + symm_ptr + segment * bytes_per_rank, + symm_ptr + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) + else: + peer = (node_id * LOCAL_WORLD_SIZE + local_numa_id * NUMA_WORLD_SIZE + peer_numa_rank) + segment = (peer_node_id * LOCAL_WORLD_SIZE + peer_local_numa_id * NUMA_WORLD_SIZE + numa_rank) + + if peer_local_numa_id != local_numa_id: # wait for data from other NUMA + if thread_idx == 0: + libshmem_device.signal_wait_until( + symm_flag + segment, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + + libshmem_device.putmem_signal_wg( + symm_ptr + segment * bytes_per_rank, + symm_ptr + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) + + +@triton_dist.jit +def _recv_ll_block(dest_ptr, src_ptr, num_ints, ll_flag): + """split src/dest outside of _recv_ll. this function is designed for a threadblock + + num_ints: of the pre-LL-packed num_ints. + """ + thread_idx = tid(0) + block_size = ntid(0) + src_ptr = tl.cast(src_ptr, tl.pointer_type(tl.int32)) + dest_ptr = tl.cast(dest_ptr, tl.pointer_type(tl.int32)) + # manual load per vec + for n in range(thread_idx, num_ints // 2, block_size): + data1, flag1, data2, flag2 = load_v4_u32(src_ptr + n * 4) + while flag1 != ll_flag or flag2 != ll_flag: + data1, flag1, data2, flag2 = load_v4_u32(src_ptr + n * 4) + st_v2_u32(dest_ptr + n * 2, data1, data2) + + +@triton_dist.jit(do_not_specialize=["ll_flag"]) +def _pack_ll_block(dest_ptr, src_ptr, num_ints, ll_flag, BLOCK_SIZE: tl.constexpr): + """split src/dest outside of _recv_ll. this function is designed for a threadblock + + nbytes: of the pre-LL-packed bytes. + BLOCK_SIZE: count by ints, not bytes. + """ + iters = tl.cdiv(num_ints, BLOCK_SIZE) + src_ptr = tl.cast(src_ptr, dtype=tl.pi32_t) + dest_ptr = tl.cast(dest_ptr, dtype=tl.pi32_t) + for n in range(iters): + src_offsets = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + src_mask = src_offsets < num_ints + src = tl.load(src_ptr + src_offsets, mask=src_mask) + flags = tl.full((BLOCK_SIZE, ), ll_flag, tl.int32) + dst = tl.interleave(src, flags) + dest_offset = n * BLOCK_SIZE * 2 + tl.arange(0, BLOCK_SIZE * 2) + dest_mask = dest_offset < num_ints * 2 + tl.store(dest_ptr + dest_offset, dst, mask=dest_mask) + + +@triton_dist.jit +def _recv_ll_and_multimem_st_block(dest_ptr, src_ptr, num_ints, ll_flag): + """split src/dest outside of _recv_ll. this function is designed for a threadblock + + num_ints: of the pre-LL-packed num_ints. + """ + thread_idx = tid(0) + block_size = ntid(0) + src_ptr = tl.cast(src_ptr, tl.pointer_type(tl.int32)) + dest_ptr = tl.cast(dest_ptr, tl.pointer_type(tl.int32)) + dest_mc_ptr = libshmem_device.remote_mc_ptr(libshmem_device.NVSHMEMX_TEAM_NODE, dest_ptr) + # manual load per vec + for n in range(thread_idx, num_ints // 2, block_size): + data1, flag1, data2, flag2 = load_v4_u32(src_ptr + n * 4) + while flag1 != ll_flag or flag2 != ll_flag: + data1, flag1, data2, flag2 = load_v4_u32(src_ptr + n * 4) + multimem_st_b64(dest_mc_ptr + n * 2, pack_b32_v2(data1, data2)) + + +@triton_dist.jit(do_not_specialize=["ll_flag"]) +def _recv_ll_and_multimem_st_ll_block(dest_ptr, src_ptr, num_ints, ll_flag): + """split src/dest outside of _recv_ll. this function is designed for a threadblock + + num_ints: of the pre-LL-packed num_ints. + """ + thread_idx = tid(0) + block_size = ntid(0) + src_ptr = tl.cast(src_ptr, tl.pointer_type(tl.int32)) + dest_ptr = tl.cast(dest_ptr, tl.pointer_type(tl.int32)) + dest_mc_ptr = libshmem_device.remote_mc_ptr(libshmem_device.NVSHMEMX_TEAM_NODE, dest_ptr) + # manual load per vec + for n in range(thread_idx, num_ints // 2, block_size): + data1, flag1, data2, flag2 = load_v4_u32(src_ptr + n * 4) + while flag1 != ll_flag or flag2 != ll_flag: + data1, flag1, data2, flag2 = load_v4_u32(src_ptr + n * 4) + multimem_st_b64(dest_mc_ptr + n * 4, pack_b32_v2(data1, flag1)) + multimem_st_b64(dest_mc_ptr + n * 4 + 2, pack_b32_v2(data2, flag2)) + + +@triton_dist.jit +def broadcast_naive_block(dst_ptr, src_ptr, nbytes): + thread_idx = tid(axis=0) + block_dim = ntid(axis=0) + src_ptr = tl.cast(src_ptr, tl.pointer_type(tl.int8)) + dst_ptr = tl.cast(dst_ptr, tl.pointer_type(tl.int8)) + dst_mc_ptr = libshmem_device.remote_mc_ptr(libshmem_device.NVSHMEMX_TEAM_NODE, dst_ptr) + num_int4 = nbytes // 16 + for n in range(thread_idx, num_int4, block_dim): + val0, val1 = load_v2_b64(src_ptr + 16 * n) + multimem_st_b64(dst_mc_ptr + n * 16, val0) + multimem_st_b64(dst_mc_ptr + n * 16 + 8, val1) + + +@triton_dist.jit(do_not_specialize=["rank", "signal_target"]) +def _forward_push_2d_ll_multimem_kernel( + symm_ptr, + bytes_per_rank, + symm_ll_buffer, + nnodes: tl.constexpr, + world_size: tl.constexpr, + rank, + signal_target, +): + """ + pack_ll and nvshmem_putmem_nbi, then recv_ll and multimem.st + """ + local_world_size = world_size // nnodes + local_rank = rank % local_world_size + nid = rank // local_world_size + + pid = tl.program_id(0) + peer_nid = pid // local_world_size + peer_local_rank = pid % local_world_size + num_ints = bytes_per_rank // 4 + thread_idx = tid(axis=0) + + ll_buffer_int8 = tl.cast(symm_ll_buffer, tl.pointer_type(tl.int8)) + symm_ptr = tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + + if peer_local_rank == local_rank: + if nid != peer_nid: + segment = peer_nid * local_world_size + local_rank + _recv_ll_and_multimem_st_ll_block( + ll_buffer_int8 + segment * bytes_per_rank * 2, + ll_buffer_int8 + segment * bytes_per_rank * 2, + num_ints, + signal_target, + ) # magic number here + _recv_ll_block( + symm_ptr + segment * bytes_per_rank, + ll_buffer_int8 + segment * bytes_per_rank * 2, + num_ints, + signal_target, + ) # magic number here + else: # already has data. pack only + _pack_ll_block( + ll_buffer_int8 + rank * bytes_per_rank * 2, + symm_ptr + rank * bytes_per_rank, + num_ints, + signal_target, + 2048, + ) # magic number here + __syncthreads() + wid = thread_idx // 32 + # send + if wid < nnodes and wid != nid: + peer_to = wid * local_world_size + local_rank + libshmem_device.putmem_nbi_warp( + ll_buffer_int8 + rank * bytes_per_rank * 2, + ll_buffer_int8 + rank * bytes_per_rank * 2, + bytes_per_rank * 2, + peer_to, + ) # write and tell peer remote that remote copy is done + + segment = peer_nid * local_world_size + local_rank + broadcast_naive_block( + ll_buffer_int8 + segment * bytes_per_rank * 2, + ll_buffer_int8 + segment * bytes_per_rank * 2, + bytes_per_rank * 2, + ) + else: + segment_recv_local = peer_nid * local_world_size + peer_local_rank + _recv_ll_block( + symm_ptr + segment_recv_local * bytes_per_rank, + ll_buffer_int8 + segment_recv_local * bytes_per_rank * 2, + num_ints, + signal_target, + ) # magic number here + + +@triton_dist.jit(do_not_specialize=["rank", "signal_target"]) +def _forward_push_2d_ll_kernel( + symm_ptr, + bytes_per_rank, + symm_flag, + symm_ll_buffer, + nnodes: tl.constexpr, + world_size: tl.constexpr, + rank, + signal_target, +): + local_world_size = world_size // nnodes + local_rank = rank % local_world_size + nid = rank // local_world_size + + pid = tl.program_id(0) + peer_nid = pid // local_world_size + peer_local_rank = pid % local_world_size + thread_idx = tid(0) + num_ints = bytes_per_rank // 4 + + ll_buffer_int8 = tl.cast(symm_ll_buffer, tl.pointer_type(tl.int8)) + symm_ptr = tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + + if peer_local_rank == local_rank: # remote push + if peer_nid != nid: # pnid: peer node id. each block recv from pnid + segment = peer_nid * local_world_size + local_rank + _recv_ll_block( + symm_ptr + segment * bytes_per_rank, + ll_buffer_int8 + segment * bytes_per_rank * 2, + num_ints, + signal_target, + ) # magic number here + __syncthreads() + if thread_idx == 0: + st(symm_flag + segment, signal_target, scope="gpu", semantic="release") + else: # pack ll data + _pack_ll_block( + ll_buffer_int8 + rank * bytes_per_rank * 2, + symm_ptr + rank * bytes_per_rank, + num_ints, + signal_target, + 2048, + ) # magic number here + __syncthreads() + wid = thread_idx // 32 + if wid < nnodes and wid != nid: # wid -> peer node id + peer_to = wid * local_world_size + local_rank + libshmem_device.putmem_nbi_warp( + ll_buffer_int8 + rank * bytes_per_rank * 2, + ll_buffer_int8 + rank * bytes_per_rank * 2, + bytes_per_rank * 2, + peer_to, + ) # write and tell peer remote that remote copy is done + # wait for all write done + if thread_idx < world_size and thread_idx != rank: + libshmem_device.signal_wait_until( + symm_flag + thread_idx, + libshmem_device.ROCSHMEM_CMP_EQ, + signal_target, + ) + __syncthreads() + + else: # local push + peer = nid * local_world_size + peer_local_rank + segment = peer_nid * local_world_size + local_rank + if peer_nid != nid: # wait for recv_ll done + if thread_idx == 0: + libshmem_device.signal_wait_until(symm_flag + segment, libshmem_device.ROCSHMEM_CMP_EQ, signal_target) + __syncthreads() + libshmem_device.putmem_signal_wg( + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank, + bytes_per_rank, + symm_flag + segment, + signal_target, + libshmem_device.ROCSHMEM_SIGNAL_SET, + peer, + ) # write and tell peer remote that remote copy is done + + +@dataclass +class FastAllGatherContext: + rank: int + node: int + num_ranks: int + num_nodes: int + signal_tensor: torch.Tensor + ll_buffers: List[torch.Tensor] # double buffer + grid_barrier: torch.Tensor + max_buffer_size: int = 2 * 32 * 1024 * 1024 + signal_target: int = 1 + + def finalize(self): + del self.signal_tensor + for ll_buffer in self.ll_buffers: + del ll_buffer + + +def create_fast_allgather_context(rank, node, num_ranks, num_nodes, max_buffer_size: int = 2 * 32 * 1024 * 1024): + signal_tensor = rocshmem_create_tensor((num_ranks, ), NVSHMEM_SIGNAL_DTYPE) + signal_tensor.zero_() + ll_buffers = [rocshmem_create_tensor((max_buffer_size, ), torch.int8) for _ in range(2)] + grid_barrier = torch.zeros((1, ), dtype=torch.uint32, device="cuda") + + ctx = FastAllGatherContext( + rank=rank, + node=node, + num_ranks=num_ranks, + num_nodes=num_nodes, + signal_tensor=signal_tensor, + ll_buffers=ll_buffers, + grid_barrier=grid_barrier, + max_buffer_size=max_buffer_size, + signal_target=15, + ) + + return ctx + + +def fast_allgather_pull(ctx: FastAllGatherContext, symm_buffer: torch.Tensor): + ctx.signal_target += 1 + return _forward_pull_kernel[(ctx.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // ctx.num_ranks, + ctx.signal_tensor, + ctx.num_ranks, + ctx.rank, + ctx.signal_target, + num_warps=32, + ) + + +def fast_allgather_push_2d(ctx: FastAllGatherContext, symm_buffer: torch.Tensor): + ctx.signal_target += 1 + _forward_push_2d_kernel[(ctx.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // ctx.num_ranks, + ctx.signal_tensor, + ctx.num_nodes, + ctx.num_ranks, + ctx.rank, + ctx.signal_target, + num_warps=32, + ) + return symm_buffer + + +def fast_allgather_push_3d(ctx: FastAllGatherContext, symm_buffer: torch.Tensor): + ctx.signal_target += 1 + _forward_push_3d_kernel[(ctx.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // ctx.num_ranks, + ctx.ll_buffers[ctx.signal_target % 2], + ctx.signal_tensor, + ctx.num_nodes, + 2, # TODO(houqi.1993) + ctx.num_ranks, + ctx.rank, + ctx.signal_target, + INTER_NODE_WITH_LL=False, + num_warps=32, + ) + return symm_buffer + + +def fast_allgather_push_2d_ll(ctx: FastAllGatherContext, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < ctx.max_buffer_size + ctx.signal_target += 1 + ll_buffer = ctx.ll_buffers[ctx.signal_target % 2] + _forward_push_2d_ll_kernel[(ctx.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // ctx.num_ranks, + ctx.signal_tensor, + ll_buffer, + ctx.num_nodes, + ctx.num_ranks, + ctx.rank, + ctx.signal_target, + num_warps=32, + ) + + return symm_buffer + + +def fast_allgather_push_2d_ll_multimem(ctx: FastAllGatherContext, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < ctx.max_buffer_size + ctx.signal_target += 1 + ll_buffer = ctx.ll_buffers[ctx.signal_target % 2] + _forward_push_2d_ll_multimem_kernel[(ctx.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // ctx.num_ranks, + ll_buffer, + ctx.num_nodes, + ctx.num_ranks, + ctx.rank, + ctx.signal_target, + num_warps=32, + ) + + return symm_buffer + + +def fast_allgather_push_numa_2d(ctx: FastAllGatherContext, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < ctx.max_buffer_size + signal = ctx.signal[ctx.signal_target % 2] + _forward_push_numa_2d_kernel[(ctx.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // ctx.num_ranks, + signal, + 2, # TODO(houqi.1993) 2 NUMA nodes supported + ctx.num_ranks, + ctx.rank, + ctx.signal_target, + num_warps=32, + ) + ctx.signal_target += 1 + return symm_buffer + + +def fast_allgather_push_numa_2d_ll(ctx: FastAllGatherContext, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < ctx.max_buffer_size + assert ctx.num_nodes == 1 + signal = ctx.signal[ctx.signal_target % 2] + _forward_push_numa_2d_ll_kernel[(ctx.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // ctx.num_ranks, + signal, + ctx.ll_buffers[ctx.signal_target % 2], + 2, # TODO(houqi.1993) 2 NUMA nodes supported + ctx.num_ranks, + ctx.rank, + ctx.signal_target, + num_warps=32, + ) + ctx.signal_target += 1 + return symm_buffer + + +def fast_allgather_push_numa_2d_ll_multinode(ctx: FastAllGatherContext, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < ctx.max_buffer_size + signal = ctx.signal[ctx.signal_target % 2] + _forward_push_numa_2d_ll_multinode_kernel[(ctx.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // ctx.num_ranks, + signal, + ctx.ll_buffers[ctx.signal_target % 2], + ctx.num_nodes, + 2, # TODO(houqi.1993) 2 NUMA nodes supported + ctx.num_ranks, + ctx.rank, + ctx.signal_target, + num_warps=32, + ) + ctx.signal_target += 1 + return symm_buffer + + +FAST_ALLGATHER_FUNC_DISPATCH = { + "pull": fast_allgather_pull, + "push2d": fast_allgather_push_2d, + "push2d_ll": fast_allgather_push_2d_ll, + "push2d_ll_multimem": fast_allgather_push_2d_ll_multimem, + "push_numa_2d": fast_allgather_push_numa_2d, + "push_numa_2d_ll": fast_allgather_push_numa_2d_ll, + "push_numa_2d_ll_multinode": fast_allgather_push_numa_2d_ll_multinode, +} + + +def fast_allgather( + symm_buffer: torch.Tensor, + ctx: FastAllGatherContext = None, + rank=None, + node=None, + num_ranks=None, + num_nodes=None, + mode="pull", +): + assert mode in FAST_ALLGATHER_FUNC_DISPATCH + if ctx is None: + assert rank is not None and node is not None + assert num_ranks is not None and num_nodes is not None + ctx = create_fast_allgather_context( + rank, + node, + num_ranks, + num_nodes, + ) + return FAST_ALLGATHER_FUNC_DISPATCH[mode](ctx, symm_buffer) diff --git a/python/triton_dist/layers/amd/__init__.py b/python/triton_dist/layers/amd/__init__.py new file mode 100644 index 000000000..b6cb0b004 --- /dev/null +++ b/python/triton_dist/layers/amd/__init__.py @@ -0,0 +1,26 @@ +################################################################################ +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ +from .low_latency_allgather_layer import AllGatherLayer +__all__ = ["AllGatherLayer"] \ No newline at end of file diff --git a/python/triton_dist/layers/amd/low_latency_allgather_layer.py b/python/triton_dist/layers/amd/low_latency_allgather_layer.py new file mode 100644 index 000000000..b56d0ee4b --- /dev/null +++ b/python/triton_dist/layers/amd/low_latency_allgather_layer.py @@ -0,0 +1,198 @@ +################################################################################ +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ +import warnings +import torch +from triton_dist.kernels.amd import _forward_push_2d_ll_kernel, _forward_push_2d_kernel, _forward_push_3d_kernel, _forward_pull_kernel, _forward_push_2d_ll_multimem_kernel, _forward_push_numa_2d_ll_kernel, _forward_push_numa_2d_kernel, _forward_push_numa_2d_ll_multinode_kernel +from pyrocshmem import rocshmem_create_tensor +from triton_dist.utils import NVSHMEM_SIGNAL_DTYPE, get_numa_node_count_in_group, get_triton_dist_world, rocshmem_barrier_all_on_stream + + +class AllGatherLayer: + + def __init__(self, nnodes, world_size, rank, max_buffer_size: int = 2 * 32 * 128 * 128, stages=2): + self.rank = rank + self.num_ranks = world_size + self.symm_signal = rocshmem_create_tensor((stages, self.num_ranks), NVSHMEM_SIGNAL_DTYPE) + self.max_buffer_size = max_buffer_size + self.symm_ll_buffers = rocshmem_create_tensor((stages, self.max_buffer_size), torch.int8) + self.signal_target = 1 + self.symm_signal.zero_() + self.nnodes = nnodes + self.stages = stages + rocshmem_barrier_all_on_stream(torch.cuda.current_stream()) + + def finalize(self): + if self.symm_signal is not None: + del (self.symm_signal) + if self.symm_ll_buffers is not None: + del (self.symm_ll_buffers) + self.symm_signal = None + self.symm_ll_buffers = None + + def forward_pull(self, symm_buffer: torch.Tensor): + _forward_pull_kernel[(self.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // self.num_ranks, + self.symm_signal, + self.num_ranks, + self.rank, + self.signal_target, + num_warps=32, + ) + self.signal_target += 1 + return symm_buffer + + def forward_push_2d(self, symm_buffer: torch.Tensor): + _forward_push_2d_kernel[(self.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // self.num_ranks, + self.symm_signal, + self.nnodes, + self.num_ranks, + self.rank, + self.signal_target, + num_warps=32, + ) + self.signal_target += 1 + return symm_buffer + + def _forward_push_3d(self, symm_buffer: torch.Tensor, use_ll_protocol: bool = False): + symm_ll_buffer = self.symm_ll_buffers[self.signal_target % self.stages] + _forward_push_3d_kernel[(self.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // self.num_ranks, + symm_ll_buffer, + self.symm_signal, + self.nnodes, + 2, # TODO(houqi.1993) + self.num_ranks, + self.rank, + self.signal_target, + INTER_NODE_WITH_LL=use_ll_protocol, + num_warps=32, + ) + self.signal_target += 1 + + return symm_buffer + + def forward_push_3d(self, symm_buffer: torch.Tensor): + return self._forward_push_3d(symm_buffer, False) # no LL protocol + + def forward_push_3d_ll(self, symm_buffer: torch.Tensor): + return self._forward_push_3d(symm_buffer, True) # with LL protocol + + def forward_push_2d_ll(self, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < self.max_buffer_size + symm_signal = self.symm_signal[self.signal_target % self.stages] + symm_ll_buffer = self.symm_ll_buffers[self.signal_target % self.stages] + _forward_push_2d_ll_kernel[(self.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // self.num_ranks, + symm_signal, + symm_ll_buffer, + self.nnodes, + self.num_ranks, + self.rank, + self.signal_target, + num_warps=32, + ) + self.signal_target += 1 + return symm_buffer + + def forward_push_numa_2d(self, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < self.max_buffer_size + symm_signal = self.symm_signal[self.signal_target % self.stages] + n_numa_nodes = get_numa_node_count_in_group(get_triton_dist_world()) + if n_numa_nodes != 2: + warnings.warn("Only 2 NUMA nodes supported for now. use 1d algo instead") + return self.forward_push_2d(symm_buffer) + + _forward_push_numa_2d_kernel[(self.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // self.num_ranks, + symm_signal, + 2, + self.num_ranks, + self.rank, + self.signal_target, + num_warps=32, + ) + self.signal_target += 1 + return symm_buffer + + def forward_push_numa_2d_ll_multinode(self, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < self.max_buffer_size + symm_signal = self.symm_signal[self.signal_target % self.stages] + ll_buffer = self.symm_ll_buffers[self.signal_target % self.stages] + _forward_push_numa_2d_ll_multinode_kernel[(self.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // self.num_ranks, + symm_signal, + ll_buffer, + self.nnodes, + 2, # TODO(houqi.1993) 2 NUMA nodes supported + self.num_ranks, + self.rank, + self.signal_target, + num_warps=32, + ) + + self.signal_target += 1 + return symm_buffer + + def forward_push_numa_2d_ll(self, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < self.max_buffer_size + symm_signal = self.symm_signal[self.signal_target % self.stages] + symm_ll_buffer = self.symm_ll_buffers[self.signal_target % self.stages] + _forward_push_numa_2d_ll_kernel[(self.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // self.num_ranks, + symm_signal, + symm_ll_buffer, + 2, # TODO(houqi.1993) 2 NUMA nodes supported + self.num_ranks, + self.rank, + self.signal_target, + num_warps=32, + ) + self.signal_target += 1 + return symm_buffer + + def forward_push_2d_ll_multimem(self, symm_buffer: torch.Tensor): + assert symm_buffer.nbytes * 2 < self.max_buffer_size + symm_ll_buffer = self.symm_ll_buffers[self.signal_target % self.stages] + _forward_push_2d_ll_multimem_kernel[(self.num_ranks, )]( + symm_buffer, + symm_buffer.nbytes // self.num_ranks, + symm_ll_buffer, + self.nnodes, + self.num_ranks, + self.rank, + self.signal_target, + num_warps=32, + ) + self.signal_target += 1 + + return symm_buffer diff --git a/python/triton_dist/test/amd/test_fast_allgather.py b/python/triton_dist/test/amd/test_fast_allgather.py new file mode 100644 index 000000000..16d70c122 --- /dev/null +++ b/python/triton_dist/test/amd/test_fast_allgather.py @@ -0,0 +1,158 @@ +################################################################################ +# +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ +import argparse +import os + +import torch + +from pyrocshmem import rocshmem_create_tensor +from triton_dist.layers.amd import AllGatherLayer +from triton_dist.profiler_utils import group_profile +from triton_dist.utils import (initialize_distributed, rocshmem_barrier_all_on_stream, finalize_distributed, sleep_async) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-f", "--stepfactor", default=2, type=int) + parser.add_argument("-b", "--minbytes", type=int, default=1024 * 4) + parser.add_argument("-e", "--maxbytes", type=int, default=1024 * 1024 * 32) + parser.add_argument("--dtype", type=str, default="int32") + parser.add_argument("--warmup_iters", type=int, default=30) + parser.add_argument("--iters", type=int, default=500) + parser.add_argument("--profile", action="store_true", default=False) + parser.add_argument("--verify", default=True, action=argparse.BooleanOptionalAction) + parser.add_argument( + "--mode", + default="pull_1d", + choices=[ + "push_numa_2d_ll", "push_numa_2d", "push_2d_ll_multimem", "push_2d_ll", "push_2d", "pull_1d", + "push_numa_2d_ll_multinode", "push_3d", "push_3d_ll" + ], + ) + args = parser.parse_args() + return args + + +def perf_ag(ag_op: AllGatherLayer, ag_buffer: torch.Tensor, nbytes: int, do_verify: bool): + nbytes_per_rank = nbytes // WORLD_SIZE + ref_tensor = torch.arange(nbytes, dtype=dtype).cuda() + + # local copy + index_start, index_end = nbytes_per_rank * RANK, nbytes_per_rank * (RANK + 1) + ag_buffer[ag_op.signal_target % ag_op.stages][index_start:index_end].copy_(ref_tensor[index_start:index_end]) + + # ag_buffer = ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes] # only keeps the needed part + + def _run_with_ag_op(): + if args.mode == "push_2d": + return ag_op.forward_push_2d(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + if args.mode == "push_3d": + return ag_op.forward_push_3d(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + if args.mode == "push_3d_ll": + return ag_op.forward_push_3d_ll(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + elif args.mode == "push_2d_ll": + return ag_op.forward_push_2d_ll(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + elif args.mode == "push_numa_2d": + return ag_op.forward_push_numa_2d(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + elif args.mode == "push_numa_2d_ll": + return ag_op.forward_push_numa_2d_ll(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + elif args.mode == "push_numa_2d_ll_multinode": + return ag_op.forward_push_numa_2d_ll_multinode(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + elif args.mode == "push_2d_ll_multimem": + return ag_op.forward_push_2d_ll_multimem(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + elif args.mode == "pull_1d": + return ag_op.forward_pull(ag_buffer[ag_op.signal_target % ag_op.stages][:nbytes]) + else: + raise ValueError(f"Unknown mode {args.mode}") + + def _verify(): + for i in range(100): + ref_tensor = torch.randint(0, 9999999, [nbytes // 4], dtype=torch.int32).view(dtype).cuda() + torch.distributed.broadcast(ref_tensor, src=0) + ag_buffer[ag_op.signal_target % ag_op.stages][index_start:index_end].copy_( + ref_tensor[index_start:index_end]) + result = _run_with_ag_op() + + try: + torch.testing.assert_close(result[:nbytes], ref_tensor, atol=0, rtol=0) + except Exception as e: + print(result[:nbytes].view(WORLD_SIZE, -1)) + print(ref_tensor.view(WORLD_SIZE, -1)) + print(f"❌ RANK[{RANK}] check failed") + raise e + print(f"✅ RANK[{RANK}] check passed") + + if do_verify: + _verify() + rocshmem_barrier_all_on_stream(torch.cuda.current_stream()) + from triton_dist.profiler_utils import perf_func + + with group_profile(f"all_gather_op_{nbytes//1024}KB", do_prof=args.profile, group=TP_GROUP): + sleep_async(1000) # in case CPU bound + _, ag_time_ms = perf_func( + _run_with_ag_op, + warmup_iters=warmup_iters, + iters=iters, + ) + + gbps = (lambda ms: nbytes * 1e-9 / (ms * 1e-3) * (WORLD_SIZE - 1) / WORLD_SIZE) + print( + f"RANK = {RANK}, {nbytes // 1024} KB, Latency = {ag_time_ms * 1000:0.2f} us, Bandwith = {gbps(ag_time_ms):0.2f} GB/s" + ) + + +def align_to(value, alignment): + return (value + alignment - 1) // alignment * alignment + + +if __name__ == "__main__": + args = parse_args() + dtype = torch.int8 + warmup_iters = args.warmup_iters + iters = args.iters + + RANK = int(os.environ.get("RANK", 0)) + LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) + WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) + LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + TP_GROUP = initialize_distributed() + + stages = 2 + + ag_buffer = rocshmem_create_tensor((stages, args.maxbytes), dtype) + + nnodes = WORLD_SIZE // LOCAL_WORLD_SIZE + ag_op = AllGatherLayer(nnodes, WORLD_SIZE, RANK, max_buffer_size=args.maxbytes * 2, stages=stages) + + minbytes = align_to(args.minbytes, 16) + maxbytes = align_to(args.maxbytes, 16) + nbytes = minbytes + while nbytes < maxbytes: + perf_ag(ag_op, ag_buffer, nbytes, args.verify) + nbytes = args.stepfactor * nbytes + + ag_op.finalize() + del (ag_buffer) + finalize_distributed() \ No newline at end of file From 7d3448359662c330eddc501e8bceafbabb29c9e0 Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Wed, 25 Feb 2026 15:58:54 -0600 Subject: [PATCH 7/7] Fix pull_1d, push_2d and push_3d Signed-off-by: Eric Eaton --- .../kernels/amd/low_latency_allgather.py | 31 +++++++++++-------- .../language/extra/hip/librocshmem_device.py | 16 ++++++++++ .../layers/amd/low_latency_allgather_layer.py | 16 +++++----- .../rocshmem_bind/runtime/rocshmem_wrapper.cc | 16 ++++++++++ 4 files changed, 58 insertions(+), 21 deletions(-) diff --git a/python/triton_dist/kernels/amd/low_latency_allgather.py b/python/triton_dist/kernels/amd/low_latency_allgather.py index ccdaec477..d6a3aec6f 100644 --- a/python/triton_dist/kernels/amd/low_latency_allgather.py +++ b/python/triton_dist/kernels/amd/low_latency_allgather.py @@ -31,19 +31,24 @@ from triton_dist.kernels.common_ops import pack_b32_v2 from triton_dist.language.extra import libshmem_device -from triton_dist.language.extra.cuda.language_extra import ( +from triton_dist.language.extra.hip.language_extra import ( __syncthreads, tid, ntid, - load_v4_u32, - load_v2_b64, - st_v2_u32, + #load_v4_u32, + #load_v2_b64, + #st_v2_u32, st, - multimem_st_b64, + #multimem_st_b64, + load, ) from pyrocshmem import rocshmem_create_tensor from triton_dist.utils import NVSHMEM_SIGNAL_DTYPE +load_v4_u32 = load +load_v2_b64 = load +st_v2_u32 = st +multimem_st_b64 = st @triton_dist.jit(do_not_specialize=["rank", "signal_target"]) def _forward_pull_kernel(symm_ptr, bytes_per_rank, symm_flag, world_size, rank, signal_target): @@ -61,7 +66,7 @@ def _forward_pull_kernel(symm_ptr, bytes_per_rank, symm_flag, world_size, rank, else: peer = pid if thread_idx == 0: - libshmem_device.signal_wait_until(symm_flag + peer, libshmem_device.ROCSHMEM_CMP_EQ, signal_target) + libshmem_device.signal_wait_until(symm_flag + peer, libshmem_device.ROCSHMEM_CMP_GE, signal_target) __syncthreads() libshmem_device.getmem_wg( tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + peer * bytes_per_rank, @@ -371,7 +376,7 @@ def _forward_push_2d_kernel(symm_ptr, bytes_per_rank, symm_flag, NNODES, WORLD_S if thread_idx < WORLD_SIZE and thread_idx != rank: libshmem_device.signal_wait_until( symm_flag + thread_idx, - libshmem_device.ROCSHMEM_CMP_EQ, + libshmem_device.ROCSHMEM_CMP_GE, signal_target, ) __syncthreads() @@ -382,7 +387,7 @@ def _forward_push_2d_kernel(symm_ptr, bytes_per_rank, symm_flag, NNODES, WORLD_S if thread_idx == 0: libshmem_device.signal_wait_until( symm_flag + segment, - libshmem_device.ROCSHMEM_CMP_EQ, + libshmem_device.ROCSHMEM_CMP_GE, signal_target, ) __syncthreads() @@ -452,7 +457,7 @@ def _forward_push_3d_kernel( if wid < NNODES and wid != node_id: peer = wid * LOCAL_WORLD_SIZE + local_rank - libshmem_device.putmem_nbi_warp( + libshmem_device.putmem_nbi_wave( symm_ll_buffer + segment * bytes_per_rank * 2, symm_ll_buffer + segment * bytes_per_rank * 2, bytes_per_rank * 2, @@ -462,7 +467,7 @@ def _forward_push_3d_kernel( if wid < NNODES and wid != node_id: peer = wid * LOCAL_WORLD_SIZE + local_rank segment = rank - libshmem_device.putmem_signal_nbi_warp( + libshmem_device.putmem_signal_nbi_wave( symm_ptr + segment * bytes_per_rank, symm_ptr + segment * bytes_per_rank, bytes_per_rank, @@ -476,7 +481,7 @@ def _forward_push_3d_kernel( if thread_idx < WORLD_SIZE and thread_idx != rank: libshmem_device.signal_wait_until( symm_flag + thread_idx, - libshmem_device.ROCSHMEM_CMP_EQ, + libshmem_device.ROCSHMEM_CMP_GE, signal_target, ) __syncthreads() @@ -487,7 +492,7 @@ def _forward_push_3d_kernel( if thread_idx < WORLD_SIZE and (thread_idx % LOCAL_WORLD_SIZE == local_rank and thread_idx != rank): libshmem_device.signal_wait_until( symm_flag + thread_idx, - libshmem_device.ROCSHMEM_CMP_EQ, + libshmem_device.ROCSHMEM_CMP_GE, signal_target, ) __syncthreads() @@ -512,7 +517,7 @@ def _forward_push_3d_kernel( if thread_idx == 0: libshmem_device.signal_wait_until( symm_flag + segment, - libshmem_device.ROCSHMEM_CMP_EQ, + libshmem_device.ROCSHMEM_CMP_GE, signal_target, ) __syncthreads() diff --git a/python/triton_dist/language/extra/hip/librocshmem_device.py b/python/triton_dist/language/extra/hip/librocshmem_device.py index c3ed15781..539223133 100644 --- a/python/triton_dist/language/extra/hip/librocshmem_device.py +++ b/python/triton_dist/language/extra/hip/librocshmem_device.py @@ -358,6 +358,22 @@ def putmem_signal_nbi_wave(dest, source, nbytes, sig_addr, signal, sig_op, pe, _ ) +@core.extern +def signal_op(sig_addr, signal, sig_op, pe, _semantic=None): + return extern_call( + "librocshmem_device", "", [ + tl.cast(sig_addr, pi_u64_t, _semantic=_semantic), + tl.cast(signal, tl.uint64, _semantic=_semantic), + tl.cast(sig_op, tl.int32, _semantic=_semantic), + tl.cast(pe, tl.int32, _semantic=_semantic), + ], { + (pi_u64_t, tl.uint64, tl.int32, tl.int32): ( + "rocshmem_signal_op", + (), + ), + }, is_pure=False, _semantic=_semantic) + + @core.extern def signal_wait_until(sig_addr, cmp_, cmp_val, _semantic=None): tl.static_assert(sig_addr.dtype == pi_u64_t or sig_addr.dtype == pi_i64_t, diff --git a/python/triton_dist/layers/amd/low_latency_allgather_layer.py b/python/triton_dist/layers/amd/low_latency_allgather_layer.py index b56d0ee4b..bdfc69cf6 100644 --- a/python/triton_dist/layers/amd/low_latency_allgather_layer.py +++ b/python/triton_dist/layers/amd/low_latency_allgather_layer.py @@ -59,7 +59,7 @@ def forward_pull(self, symm_buffer: torch.Tensor): self.num_ranks, self.rank, self.signal_target, - num_warps=32, + num_warps=16, ) self.signal_target += 1 return symm_buffer @@ -73,7 +73,7 @@ def forward_push_2d(self, symm_buffer: torch.Tensor): self.num_ranks, self.rank, self.signal_target, - num_warps=32, + num_warps=16, ) self.signal_target += 1 return symm_buffer @@ -91,7 +91,7 @@ def _forward_push_3d(self, symm_buffer: torch.Tensor, use_ll_protocol: bool = Fa self.rank, self.signal_target, INTER_NODE_WITH_LL=use_ll_protocol, - num_warps=32, + num_warps=16, ) self.signal_target += 1 @@ -116,7 +116,7 @@ def forward_push_2d_ll(self, symm_buffer: torch.Tensor): self.num_ranks, self.rank, self.signal_target, - num_warps=32, + num_warps=16, ) self.signal_target += 1 return symm_buffer @@ -137,7 +137,7 @@ def forward_push_numa_2d(self, symm_buffer: torch.Tensor): self.num_ranks, self.rank, self.signal_target, - num_warps=32, + num_warps=16, ) self.signal_target += 1 return symm_buffer @@ -156,7 +156,7 @@ def forward_push_numa_2d_ll_multinode(self, symm_buffer: torch.Tensor): self.num_ranks, self.rank, self.signal_target, - num_warps=32, + num_warps=16, ) self.signal_target += 1 @@ -175,7 +175,7 @@ def forward_push_numa_2d_ll(self, symm_buffer: torch.Tensor): self.num_ranks, self.rank, self.signal_target, - num_warps=32, + num_warps=16, ) self.signal_target += 1 return symm_buffer @@ -191,7 +191,7 @@ def forward_push_2d_ll_multimem(self, symm_buffer: torch.Tensor): self.num_ranks, self.rank, self.signal_target, - num_warps=32, + num_warps=16, ) self.signal_target += 1 diff --git a/shmem/rocshmem_bind/runtime/rocshmem_wrapper.cc b/shmem/rocshmem_bind/runtime/rocshmem_wrapper.cc index 0e2515915..4501f7913 100644 --- a/shmem/rocshmem_bind/runtime/rocshmem_wrapper.cc +++ b/shmem/rocshmem_bind/runtime/rocshmem_wrapper.cc @@ -22,6 +22,7 @@ */ #include #include +#include using namespace rocshmem; extern "C" { @@ -165,6 +166,21 @@ rocshmem_putmem_nbi_wg_wrapper(void *dest, const void *source, size_t nbytes, rocshmem_putmem_nbi_wg(dest, source, nbytes, pe); } +__device__ void __attribute__((visibility("default"))) +rocshmem_signal_op(uint64_t *sig_addr, uint64_t signal, int sig_op, int pe) { + switch (sig_op) { + case ROCSHMEM_SIGNAL_SET: + rocshmem_ulong_atomic_set(sig_addr, signal, pe); + break; + case ROCSHMEM_SIGNAL_ADD: + rocshmem_ulong_atomic_add(sig_addr, signal, pe); + break; + default: + DPRINTF("[%s] Invalid sig_op value (%d)\n", __func__, sig_op); + break; + } +} + __device__ void __attribute__((visibility("default"))) rocshmem_ulong_wait_until_wrapper(void *sig_addr, int cmp, uint64_t cmp_val) { rocshmem_ulong_wait_until(static_cast(sig_addr), cmp, cmp_val);