From 98c94e27334fa00037b869c3055c8bd26b5c180b Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Fri, 6 Feb 2026 17:03:55 -0600 Subject: [PATCH 1/6] Update rocshmem Signed-off-by: Eric Eaton --- 3rdparty/rocshmem | 2 +- shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/rocshmem b/3rdparty/rocshmem index 7b5765ec0..45b4f5610 160000 --- a/3rdparty/rocshmem +++ b/3rdparty/rocshmem @@ -1 +1 @@ -Subproject commit 7b5765ec0e137db04b50d657ed83ea239be8dd58 +Subproject commit 45b4f5610cc1899a48ba3ac18073a0f60f999987 diff --git a/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh b/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh index d1797c564..7854fc5d8 100755 --- a/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh +++ b/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh @@ -39,12 +39,12 @@ SOURCE_MAP=( ["${ROCSHMEM_SRC}/src/ipc/context_ipc_device_coll.cpp"]="rocshmem_context_ipc_device_coll.bc" ["${ROCSHMEM_SRC}/src/ipc_policy.cpp"]="rocshmem_ipc_policy.bc" ["${ROCSHMEM_SRC}/src/gda/context_gda_device.cpp"]="rocshmem_context_gda_device.bc" + ["${ROCSHMEM_SRC}/src/gda/context_gda_device_coll.cpp"]="rocshmem_context_gda_device_coll.bc" ["${ROCSHMEM_SRC}/src/gda/backend_gda.cpp"]="rocshmem_backend_gda.bc" ["${ROCSHMEM_SRC}/src/gda/queue_pair.cpp"]="rocshmem_queue_pair.bc" ["${ROCSHMEM_SRC}/src/gda/ionic/queue_pair_ionic.cpp"]="rocshmem_queue_pair_ionic.bc" ["${ROCSHMEM_SRC}/src/gda/mlx5/queue_pair_mlx5.cpp"]="rocshmem_queue_pair_mlx5.bc" ["${ROCSHMEM_SRC}/src/gda/mlx5/segment_builder.cpp"]="rocshmem_segment_builder.bc" - ["${ROCSHMEM_SRC}/src/gda/endian.cpp"]="rocshmem_endian.bc" ["${ROCSHMEM_SRC}/src/team.cpp"]="rocshmem_team.bc" ["${ROCSHMEM_SRC}/src/sync/abql_block_mutex.cpp"]="rocshmem_abql_block_mutex.bc" ["${ROCSHMEM_SRC}/src/util.cpp"]="rocshmem_util.bc" From 335c95ed27f888c5f12d4291e8478c289acbadee Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Tue, 10 Feb 2026 15:50:47 -0600 Subject: [PATCH 2/6] Update CI to ROCm 6.4.4 Signed-off-by: Eric Eaton --- .github/workflows/amd-ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/amd-ci.yml b/.github/workflows/amd-ci.yml index 643172d7a..753dfb4fb 100644 --- a/.github/workflows/amd-ci.yml +++ b/.github/workflows/amd-ci.yml @@ -12,7 +12,7 @@ jobs: runs-on: "amd-gfx942-mi325" timeout-minutes: 45 container: - image: rocm/pytorch:rocm6.3_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview + image: rocm/pytorch:rocm6.4.4_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview options: >- --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root --volume /home/runner/.triton:/github/home/.triton @@ -27,7 +27,7 @@ jobs: run: | rm -f /usr/local/bin/cmake apt-get update -y && apt install -y libopenmpi-dev git cython3 ibverbs-utils openmpi-bin libopenmpi-dev libpci-dev libdw1 locales cmake - pip3 install -i https://test.pypi.org/simple hip-python>=6.3.0 # (or whatever Rocm version you have) + pip3 install -i https://test.pypi.org/simple hip-python>=6.4.4 # (or whatever Rocm version you have) pip3 install pybind11 - name: Checkout uses: actions/checkout@v4 @@ -58,7 +58,7 @@ jobs: runs-on: "amd-gfx942-mi325" timeout-minutes: 60 container: - image: rocm/pytorch:rocm6.3_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview + image: rocm/pytorch:rocm6.4.4_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview options: >- --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root --volume /home/runner/.triton:/github/home/.triton @@ -73,7 +73,7 @@ jobs: - name: Install dependencies run: | apt-get update -y && apt install -y libopenmpi-dev - pip3 install -i https://test.pypi.org/simple hip-python>=6.3.0 # (or whatever Rocm version you have) + pip3 install -i https://test.pypi.org/simple hip-python>=6.4.4 # (or whatever Rocm version you have) pip3 install pybind11 - name: Checkout uses: actions/checkout@v4 From 2a02b90098a29b24b06a189dfd5dae72a81db3e1 Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Thu, 12 Feb 2026 18:03:55 -0600 Subject: [PATCH 3/6] Fix allgather.py This file was not included in the CI and wasn't updated with some recent changes. Signed-off-by: Eric Eaton --- .github/workflows/amd-ci.yml | 2 + python/triton_dist/kernels/amd/allgather.py | 71 +++++++-------------- 2 files changed, 24 insertions(+), 49 deletions(-) diff --git a/.github/workflows/amd-ci.yml b/.github/workflows/amd-ci.yml index 753dfb4fb..bf9fa99e8 100644 --- a/.github/workflows/amd-ci.yml +++ b/.github/workflows/amd-ci.yml @@ -45,6 +45,8 @@ jobs: bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_rocshmem_api.py # distributed ops bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_distributed-notify-wait.py + # all-gather + bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_allgather.py # ag gemm bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ag_gemm_intra_node.py 8192 11008 4096 # gemm rs diff --git a/python/triton_dist/kernels/amd/allgather.py b/python/triton_dist/kernels/amd/allgather.py index 903bfd45a..9b172143c 100644 --- a/python/triton_dist/kernels/amd/allgather.py +++ b/python/triton_dist/kernels/amd/allgather.py @@ -29,17 +29,16 @@ import triton import triton.language as tl +import triton_dist import triton_dist.language as dl -from triton.language.extra.hip.librocshmem_device import set_rocshmem_ctx -from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_with_ctx_on_stream, +from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_on_stream, barrier_on_this_grid) from triton_dist.kernels.amd.memcpy import memcpy_async_kernel from triton_dist.utils import launch_cooperative_grid_options -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_no_barrier_kernel( - ctx, symm_ptr, local_ptr, N_per_rank, @@ -47,7 +46,6 @@ def allgather_no_barrier_kernel( num_ranks: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) npid_per_rank = npid // num_ranks @@ -63,9 +61,8 @@ def allgather_no_barrier_kernel( tl.store(dst_ptr + offs, val, mask) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_kernel( - ctx, symm_ptr, local_ptr, N_per_rank, @@ -74,7 +71,6 @@ def allgather_kernel( group_barrier_ptr, BLOCK_SIZE: tl.constexpr, ): - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) npid_per_rank = npid // num_ranks @@ -95,9 +91,8 @@ def allgather_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -114,7 +109,6 @@ def allgather_strided_chunked_kernel( """ with the local_ptr """ - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) @@ -181,7 +175,7 @@ def allgather_strided_chunked_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["pid", "npid", "rank"]) +@triton_dist.jit(do_not_specialize=["pid", "npid", "rank"]) def allgather_strided_chunked_pull_kernel( pid, npid, @@ -312,7 +306,7 @@ def make_2x_ptrs(val0, val1): return vals -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_pull_packed_kernel( symm_ptr, # (M, N), M = M_per_rank * num_ranks M_per_rank, @@ -398,9 +392,8 @@ def allgather_strided_chunked_pull_packed_kernel( tl.store(dst_ptrs[None, :, None] + offs[:, None, :], val, mask=mask[:, None, :]) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_pull_ctx_wrapper_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks M_per_rank, N, @@ -413,7 +406,6 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) npid = tl.num_programs(0) allgather_strided_chunked_pull_kernel( @@ -434,9 +426,8 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_pull_fused_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -451,7 +442,6 @@ def allgather_strided_chunked_pull_fused_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) npid = tl.num_programs(0) @@ -479,9 +469,8 @@ def allgather_strided_chunked_pull_fused_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_chunked_pull_fused_packed_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -496,7 +485,6 @@ def allgather_chunked_pull_fused_packed_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) memcpy_async_kernel(local_ptr, symm_ptr + M_per_rank * stride_m * rank, M_per_rank * N, 8 * BLOCK_SIZE_N) @@ -521,9 +509,8 @@ def allgather_chunked_pull_fused_packed_kernel( barrier_all_kernel(rank, num_ranks, group_barrier_ptr) -@triton.jit(do_not_specialize=["rank"]) +@triton_dist.jit(do_not_specialize=["rank"]) def allgather_opt_kernel( - ctx, symm_ptr, local_ptr, N_per_rank, @@ -533,7 +520,6 @@ def allgather_opt_kernel( BLOCK_SIZE: tl.constexpr, ): """ compared to allagther_kernel, this use less CTAs to achieve to the same level """ - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) total_tiles = tl.cdiv(N_per_rank, BLOCK_SIZE) @@ -617,22 +603,16 @@ def allgather_ipc_kernel( tl.store(dst_ptr + offs, val, mask) -@functools.lru_cache() -def get_pyrocshmem_device_ctx(): - return pyrocshmem.rocshmem_get_device_ctx() - - -def allgather_no_fuse(full_symm: torch.Tensor, shard: torch.Tensor, barrier: torch.Tensor): +def allgather_no_fuse(full_symm: torch.Tensor, shard: torch.Tensor): rank = pyrocshmem.rocshmem_my_pe() nranks = pyrocshmem.rocshmem_n_pes() assert shard.numel() * nranks == full_symm.numel() assert shard.dtype == full_symm.dtype - ctx = get_pyrocshmem_device_ctx() current_stream = torch.cuda.current_stream() - barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream) - allgather_no_barrier_kernel[(nranks * 4, )](ctx, full_symm, shard, shard.numel(), rank, nranks, + barrier_all_on_stream(current_stream) + allgather_no_barrier_kernel[(nranks * 4, )](full_symm, shard, shard.numel(), rank, nranks, BLOCK_SIZE=32 * 1024, **launch_cooperative_grid_options()) - barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream) + barrier_all_on_stream(current_stream) def allgather( @@ -645,11 +625,9 @@ def allgather( nranks = pyrocshmem.rocshmem_n_pes() assert shard.numel() * nranks == full_symm.numel() assert shard.dtype == full_symm.dtype - ctx = get_pyrocshmem_device_ctx() - workgroups_per_rank = 1 # use only 8 ranks can achieve similay bandwidth - allgather_opt_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, shard.numel(), rank, nranks, barrier, + allgather_opt_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, shard.numel(), rank, nranks, barrier, BLOCK_SIZE=32 * 1024, num_warps=16, **launch_cooperative_grid_options()) @@ -662,7 +640,6 @@ def allgather_chunked_pull( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -678,16 +655,16 @@ def allgather_chunked_pull( full_symm[M_start:M_end, :].copy_(shard) current_stream = torch.cuda.current_stream() - barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream) + barrier_all_on_stream(current_stream) workgroups_per_rank = 2 BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth allgather_strided_chunked_pull_ctx_wrapper_kernel[(nranks * workgroups_per_rank, )]( - ctx, full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, + full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16) - barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream) + barrier_all_on_stream(current_stream) def allgather_chunked_pull_fused( @@ -698,7 +675,6 @@ def allgather_chunked_pull_fused( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -713,7 +689,7 @@ def allgather_chunked_pull_fused( BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth allgather_strided_chunked_pull_fused_kernel[(nranks * workgroups_per_rank, )]( - ctx, full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, + full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16, **launch_cooperative_grid_options()) @@ -726,7 +702,6 @@ def allgather_chunked_pull_packed_fused( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -740,7 +715,7 @@ def allgather_chunked_pull_packed_fused( BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 2 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similar bandwidth - allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N, + allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, @@ -766,13 +741,11 @@ def allgather_chunked( assert shard.is_cuda and full_symm.is_cuda assert shard.stride() == full_symm.stride() # just for simple implementation - ctx = get_pyrocshmem_device_ctx() - workgroups_per_rank = 1 BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 32 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth - allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N, + allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, From cc5e5b43768ec89afbc23cc64b60d65fcf776216 Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Wed, 18 Feb 2026 15:34:28 -0600 Subject: [PATCH 4/6] Fix test_allgather.py Signed-off-by: Eric Eaton --- python/triton_dist/test/amd/test_allgather.py | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/python/triton_dist/test/amd/test_allgather.py b/python/triton_dist/test/amd/test_allgather.py index c05948974..0b6491c46 100644 --- a/python/triton_dist/test/amd/test_allgather.py +++ b/python/triton_dist/test/amd/test_allgather.py @@ -59,6 +59,12 @@ def parse_args(): default=False, help="Enable profiling (default False)", ) + parser.add_argument( + "--noperf", + action="store_true", + default=False, + help="Disable performance measurement" + ) parser.add_argument("--stream_priority", "--stream_priority", default=0, type=int, help="stream priority, 0 or -1 supported") parser.add_argument("--N", "-N", type=int, default=1376) @@ -137,25 +143,30 @@ def parse_nbytes(nbytes: str): for name, fn_triton in triton_fns.items(): torch.distributed.barrier(TP_GROUP) fn_triton() + torch.cuda.synchronize() assert_bitwise_equal(A_full.view(8, -1, args.N), A_full_torch.view(8, -1, args.N)) torch.distributed.barrier(TP_GROUP) A_full.random_() - _run_id = os.environ.get("TORCHELASTIC_RUN_ID") - exp_name = f"allgather_{_run_id}" - with group_profile(exp_name, do_prof=args.profile, merge_group=TP_GROUP): - sleep_async(10) - _, duration_ms_torch = perf_func(fn_torch, args.iters, args.warmup_iters) - triton_durations_ms = {} - for name, fn_triton in triton_fns.items(): + if not args.noperf: + _run_id = os.environ.get("TORCHELASTIC_RUN_ID") + exp_name = f"allgather_{_run_id}" + with group_profile(exp_name, do_prof=args.profile, merge_group=TP_GROUP): sleep_async(10) - _, triton_durations_ms[name] = perf_func(fn_triton, args.iters, args.warmup_iters) + _, duration_ms_torch = perf_func(fn_torch, args.iters, args.warmup_iters) + triton_durations_ms = {} + for name, fn_triton in triton_fns.items(): + sleep_async(10) + _, triton_durations_ms[name] = perf_func(fn_triton, args.iters, args.warmup_iters) + + bw_gpbs_torch = nbytes_per_rank * (TP_GROUP.size() - 1) / 2**30 / duration_ms_torch * 1e3 - bw_gpbs_torch = nbytes_per_rank * (TP_GROUP.size() - 1) / 2**30 / duration_ms_torch * 1e3 + print(f"RANK #{TP_GROUP.rank()} AllGather torch: {duration_ms_torch:0.3f} ms/iter {bw_gpbs_torch:0.1f} GB/s.") + for name, duration_ms in triton_durations_ms.items(): + bw_gpbs_fn = (nbytes_per_rank * (TP_GROUP.size() - 1) / 2**30 / duration_ms * 1e3) + print(f" triton {duration_ms:0.3f} ms/iter {bw_gpbs_fn:0.1f} GB/s.") - print(f"RANK #{TP_GROUP.rank()} AllGather torch: {duration_ms_torch:0.3f} ms/iter {bw_gpbs_torch:0.1f} GB/s.") - for name, duration_ms in triton_durations_ms.items(): - bw_gpbs_fn = (nbytes_per_rank * (TP_GROUP.size() - 1) / 2**30 / duration_ms * 1e3) - print(f" triton {duration_ms:0.3f} ms/iter {bw_gpbs_fn:0.1f} GB/s.") + del A_full + del group_barrier finalize_distributed() From 90079dafeda3aa0b5f76d1e7bae513dcb0feffab Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Fri, 20 Feb 2026 14:27:03 -0600 Subject: [PATCH 5/6] Fix the UUID matching in amd_utils.py There was a change in ROCm 7+ that makes it harder to match up the torch and amdsmi devices. This change to use KFD makes operations like sleep_async work again. Signed-off-by: Eric Eaton --- python/triton_dist/amd_utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/python/triton_dist/amd_utils.py b/python/triton_dist/amd_utils.py index 4e132a4ba..7388d0f96 100644 --- a/python/triton_dist/amd_utils.py +++ b/python/triton_dist/amd_utils.py @@ -28,6 +28,8 @@ import json import warnings import re +import os +import sys from threading import Lock from hip import hip @@ -206,7 +208,19 @@ def _get_gpu_uuid_by_physical_device_id(device_id: int): _ensure_amdsmi_initialized() devices = amdsmi.amdsmi_get_processor_handles() handle = devices[device_id] - return amdsmi.amdsmi_get_gpu_device_uuid(handle) + # Due to a change in how UUIDs are generated for CPX mode, amdsmi no longer reports any uuid value that + # matches HIP/pytorch. HIP gets the value from sysfs, and we can also get the value there by getting + # the KFD info from amdsmi and then probing the sysfs directly. + kfd_info = amdsmi.amdsmi_get_gpu_kfd_info(handle) + node_id = kfd_info["node_id"] + kfd_path = os.path.join("/sys/devices/virtual/kfd/kfd/topology/nodes", str(node_id), "properties") + key = "unique_id" + with open(kfd_path, "r") as fd: + for line in fd: + if line.startswith(key): + uuid_str = line[len(key)+1:] + uuid_str = hex(int(uuid_str)) + return uuid_str def torch_uuid_to_unique_id(torch_uuid: str) -> str: @@ -232,7 +246,8 @@ def get_uuid_by_physical_device_id(device_id: int | None = None): try: if has_amdsmi(): return _get_gpu_uuid_by_physical_device_id(device_id) - except Exception: + except Exception as e: + print(e, file=sys.stderr) warnings.warn("get_uuid_by_physical_device_id failed with amdsmi, try using rocm-smi") return _get_physical_gpu_uuid_rocm(device_id) From 9c6473ac641cb054ed2f500f643918cc12f379f3 Mon Sep 17 00:00:00 2001 From: Eric Eaton Date: Thu, 26 Feb 2026 17:21:11 -0600 Subject: [PATCH 6/6] Only use KFD with rocm 7 Signed-off-by: Eric Eaton --- python/triton_dist/amd_utils.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/python/triton_dist/amd_utils.py b/python/triton_dist/amd_utils.py index 7388d0f96..aa61219c1 100644 --- a/python/triton_dist/amd_utils.py +++ b/python/triton_dist/amd_utils.py @@ -208,19 +208,23 @@ def _get_gpu_uuid_by_physical_device_id(device_id: int): _ensure_amdsmi_initialized() devices = amdsmi.amdsmi_get_processor_handles() handle = devices[device_id] - # Due to a change in how UUIDs are generated for CPX mode, amdsmi no longer reports any uuid value that - # matches HIP/pytorch. HIP gets the value from sysfs, and we can also get the value there by getting - # the KFD info from amdsmi and then probing the sysfs directly. - kfd_info = amdsmi.amdsmi_get_gpu_kfd_info(handle) - node_id = kfd_info["node_id"] - kfd_path = os.path.join("/sys/devices/virtual/kfd/kfd/topology/nodes", str(node_id), "properties") - key = "unique_id" - with open(kfd_path, "r") as fd: - for line in fd: - if line.startswith(key): - uuid_str = line[len(key)+1:] - uuid_str = hex(int(uuid_str)) - return uuid_str + major_version = int(torch.version.hip.split('.')[0]) + if major_version >= 7: + # Due to a change in how UUIDs are generated for CPX mode, amdsmi no longer reports any uuid value that + # matches HIP/pytorch. HIP gets the value from sysfs, and we can also get the value there by getting + # the KFD info from amdsmi and then probing the sysfs directly. + kfd_info = amdsmi.amdsmi_get_gpu_kfd_info(handle) + node_id = kfd_info["node_id"] + kfd_path = os.path.join("/sys/devices/virtual/kfd/kfd/topology/nodes", str(node_id), "properties") + key = "unique_id" + with open(kfd_path, "r") as fd: + for line in fd: + if line.startswith(key): + uuid_str = line[len(key)+1:] + uuid_str = hex(int(uuid_str)) + return uuid_str + else: + return amdsmi.amdsmi_get_gpu_device_uuid(handle) def torch_uuid_to_unique_id(torch_uuid: str) -> str: