Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .github/workflows/amd-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: "amd-gfx942-mi325"
timeout-minutes: 45
container:
image: rocm/pytorch:rocm6.3_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview
image: rocm/pytorch:rocm6.4.4_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview
options: >-
--device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
--volume /home/runner/.triton:/github/home/.triton
Expand All @@ -27,7 +27,7 @@ jobs:
run: |
rm -f /usr/local/bin/cmake
apt-get update -y && apt install -y libopenmpi-dev git cython3 ibverbs-utils openmpi-bin libopenmpi-dev libpci-dev libdw1 locales cmake
pip3 install -i https://test.pypi.org/simple hip-python>=6.3.0 # (or whatever Rocm version you have)
pip3 install -i https://test.pypi.org/simple hip-python>=6.4.4 # (or whatever Rocm version you have)
pip3 install pybind11
- name: Checkout
uses: actions/checkout@v4
Expand All @@ -45,6 +45,8 @@ jobs:
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_rocshmem_api.py
# distributed ops
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_distributed-notify-wait.py
# all-gather
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_allgather.py
# ag gemm
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ag_gemm_intra_node.py 8192 11008 4096
# gemm rs
Expand All @@ -58,7 +60,7 @@ jobs:
runs-on: "amd-gfx942-mi325"
timeout-minutes: 60
container:
image: rocm/pytorch:rocm6.3_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview
image: rocm/pytorch:rocm6.4.4_ubuntu22.04_py3.10_pytorch_release_2.5.1_preview
options: >-
--device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
--volume /home/runner/.triton:/github/home/.triton
Expand All @@ -73,7 +75,7 @@ jobs:
- name: Install dependencies
run: |
apt-get update -y && apt install -y libopenmpi-dev
pip3 install -i https://test.pypi.org/simple hip-python>=6.3.0 # (or whatever Rocm version you have)
pip3 install -i https://test.pypi.org/simple hip-python>=6.4.4 # (or whatever Rocm version you have)
pip3 install pybind11
- name: Checkout
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/rocshmem
Submodule rocshmem updated 149 files
23 changes: 21 additions & 2 deletions python/triton_dist/amd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import json
import warnings
import re
import os
import sys
from threading import Lock
from hip import hip

Expand Down Expand Up @@ -206,7 +208,23 @@ def _get_gpu_uuid_by_physical_device_id(device_id: int):
_ensure_amdsmi_initialized()
devices = amdsmi.amdsmi_get_processor_handles()
handle = devices[device_id]
return amdsmi.amdsmi_get_gpu_device_uuid(handle)
major_version = int(torch.version.hip.split('.')[0])
if major_version >= 7:
# Due to a change in how UUIDs are generated for CPX mode, amdsmi no longer reports any uuid value that
# matches HIP/pytorch. HIP gets the value from sysfs, and we can also get the value there by getting
# the KFD info from amdsmi and then probing the sysfs directly.
kfd_info = amdsmi.amdsmi_get_gpu_kfd_info(handle)
node_id = kfd_info["node_id"]
kfd_path = os.path.join("/sys/devices/virtual/kfd/kfd/topology/nodes", str(node_id), "properties")
key = "unique_id"
with open(kfd_path, "r") as fd:
for line in fd:
if line.startswith(key):
uuid_str = line[len(key)+1:]
uuid_str = hex(int(uuid_str))
return uuid_str
else:
return amdsmi.amdsmi_get_gpu_device_uuid(handle)


def torch_uuid_to_unique_id(torch_uuid: str) -> str:
Expand All @@ -232,7 +250,8 @@ def get_uuid_by_physical_device_id(device_id: int | None = None):
try:
if has_amdsmi():
return _get_gpu_uuid_by_physical_device_id(device_id)
except Exception:
except Exception as e:
print(e, file=sys.stderr)
warnings.warn("get_uuid_by_physical_device_id failed with amdsmi, try using rocm-smi")

return _get_physical_gpu_uuid_rocm(device_id)
Expand Down
71 changes: 22 additions & 49 deletions python/triton_dist/kernels/amd/allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,23 @@

import triton
import triton.language as tl
import triton_dist
import triton_dist.language as dl
from triton.language.extra.hip.librocshmem_device import set_rocshmem_ctx
from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_with_ctx_on_stream,
from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_on_stream,
barrier_on_this_grid)
from triton_dist.kernels.amd.memcpy import memcpy_async_kernel
from triton_dist.utils import launch_cooperative_grid_options


@triton.jit(do_not_specialize=["rank"])
@triton_dist.jit(do_not_specialize=["rank"])
def allgather_no_barrier_kernel(
ctx,
symm_ptr,
local_ptr,
N_per_rank,
rank,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
set_rocshmem_ctx(ctx)
npid = tl.num_programs(0)
pid = tl.program_id(0)
npid_per_rank = npid // num_ranks
Expand All @@ -63,9 +61,8 @@ def allgather_no_barrier_kernel(
tl.store(dst_ptr + offs, val, mask)


@triton.jit(do_not_specialize=["rank"])
@triton_dist.jit(do_not_specialize=["rank"])
def allgather_kernel(
ctx,
symm_ptr,
local_ptr,
N_per_rank,
Expand All @@ -74,7 +71,6 @@ def allgather_kernel(
group_barrier_ptr,
BLOCK_SIZE: tl.constexpr,
):
set_rocshmem_ctx(ctx)
npid = tl.num_programs(0)
pid = tl.program_id(0)
npid_per_rank = npid // num_ranks
Expand All @@ -95,9 +91,8 @@ def allgather_kernel(
barrier_all_kernel(rank, num_ranks, group_barrier_ptr)


@triton.jit(do_not_specialize=["rank"])
@triton_dist.jit(do_not_specialize=["rank"])
def allgather_strided_chunked_kernel(
ctx,
symm_ptr, # (M, N), M = M_per_rank * num_ranks
local_ptr, # (M_per_rank, N)
M_per_rank,
Expand All @@ -114,7 +109,6 @@ def allgather_strided_chunked_kernel(
"""
with the local_ptr
"""
set_rocshmem_ctx(ctx)
npid = tl.num_programs(0)
pid = tl.program_id(0)

Expand Down Expand Up @@ -181,7 +175,7 @@ def allgather_strided_chunked_kernel(
barrier_all_kernel(rank, num_ranks, group_barrier_ptr)


@triton.jit(do_not_specialize=["pid", "npid", "rank"])
@triton_dist.jit(do_not_specialize=["pid", "npid", "rank"])
def allgather_strided_chunked_pull_kernel(
pid,
npid,
Expand Down Expand Up @@ -312,7 +306,7 @@ def make_2x_ptrs(val0, val1):
return vals


@triton.jit(do_not_specialize=["rank"])
@triton_dist.jit(do_not_specialize=["rank"])
def allgather_strided_chunked_pull_packed_kernel(
symm_ptr, # (M, N), M = M_per_rank * num_ranks
M_per_rank,
Expand Down Expand Up @@ -398,9 +392,8 @@ def allgather_strided_chunked_pull_packed_kernel(
tl.store(dst_ptrs[None, :, None] + offs[:, None, :], val, mask=mask[:, None, :])


@triton.jit(do_not_specialize=["rank"])
@triton_dist.jit(do_not_specialize=["rank"])
def allgather_strided_chunked_pull_ctx_wrapper_kernel(
ctx,
symm_ptr, # (M, N), M = M_per_rank * num_ranks
M_per_rank,
N,
Expand All @@ -413,7 +406,6 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel(
BLOCK_SIZE_N: tl.constexpr, # should be as large as possible
SPLIT_N: tl.constexpr,
):
set_rocshmem_ctx(ctx)
pid = tl.program_id(0)
npid = tl.num_programs(0)
allgather_strided_chunked_pull_kernel(
Expand All @@ -434,9 +426,8 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel(
barrier_all_kernel(rank, num_ranks, group_barrier_ptr)


@triton.jit(do_not_specialize=["rank"])
@triton_dist.jit(do_not_specialize=["rank"])
def allgather_strided_chunked_pull_fused_kernel(
ctx,
symm_ptr, # (M, N), M = M_per_rank * num_ranks
local_ptr, # (M_per_rank, N)
M_per_rank,
Expand All @@ -451,7 +442,6 @@ def allgather_strided_chunked_pull_fused_kernel(
BLOCK_SIZE_N: tl.constexpr, # should be as large as possible
SPLIT_N: tl.constexpr,
):
set_rocshmem_ctx(ctx)
pid = tl.program_id(0)
npid = tl.num_programs(0)

Expand Down Expand Up @@ -479,9 +469,8 @@ def allgather_strided_chunked_pull_fused_kernel(
barrier_all_kernel(rank, num_ranks, group_barrier_ptr)


@triton.jit(do_not_specialize=["rank"])
@triton_dist.jit(do_not_specialize=["rank"])
def allgather_chunked_pull_fused_packed_kernel(
ctx,
symm_ptr, # (M, N), M = M_per_rank * num_ranks
local_ptr, # (M_per_rank, N)
M_per_rank,
Expand All @@ -496,7 +485,6 @@ def allgather_chunked_pull_fused_packed_kernel(
BLOCK_SIZE_N: tl.constexpr, # should be as large as possible
SPLIT_N: tl.constexpr,
):
set_rocshmem_ctx(ctx)
pid = tl.program_id(0)

memcpy_async_kernel(local_ptr, symm_ptr + M_per_rank * stride_m * rank, M_per_rank * N, 8 * BLOCK_SIZE_N)
Expand All @@ -521,9 +509,8 @@ def allgather_chunked_pull_fused_packed_kernel(
barrier_all_kernel(rank, num_ranks, group_barrier_ptr)


@triton.jit(do_not_specialize=["rank"])
@triton_dist.jit(do_not_specialize=["rank"])
def allgather_opt_kernel(
ctx,
symm_ptr,
local_ptr,
N_per_rank,
Expand All @@ -533,7 +520,6 @@ def allgather_opt_kernel(
BLOCK_SIZE: tl.constexpr,
):
""" compared to allagther_kernel, this use less CTAs to achieve to the same level """
set_rocshmem_ctx(ctx)
npid = tl.num_programs(0)
pid = tl.program_id(0)
total_tiles = tl.cdiv(N_per_rank, BLOCK_SIZE)
Expand Down Expand Up @@ -617,22 +603,16 @@ def allgather_ipc_kernel(
tl.store(dst_ptr + offs, val, mask)


@functools.lru_cache()
def get_pyrocshmem_device_ctx():
return pyrocshmem.rocshmem_get_device_ctx()


def allgather_no_fuse(full_symm: torch.Tensor, shard: torch.Tensor, barrier: torch.Tensor):
def allgather_no_fuse(full_symm: torch.Tensor, shard: torch.Tensor):
rank = pyrocshmem.rocshmem_my_pe()
nranks = pyrocshmem.rocshmem_n_pes()
assert shard.numel() * nranks == full_symm.numel()
assert shard.dtype == full_symm.dtype
ctx = get_pyrocshmem_device_ctx()
current_stream = torch.cuda.current_stream()
barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream)
allgather_no_barrier_kernel[(nranks * 4, )](ctx, full_symm, shard, shard.numel(), rank, nranks,
barrier_all_on_stream(current_stream)
allgather_no_barrier_kernel[(nranks * 4, )](full_symm, shard, shard.numel(), rank, nranks,
BLOCK_SIZE=32 * 1024, **launch_cooperative_grid_options())
barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream)
barrier_all_on_stream(current_stream)


def allgather(
Expand All @@ -645,11 +625,9 @@ def allgather(
nranks = pyrocshmem.rocshmem_n_pes()
assert shard.numel() * nranks == full_symm.numel()
assert shard.dtype == full_symm.dtype
ctx = get_pyrocshmem_device_ctx()

workgroups_per_rank = 1
# use only 8 ranks can achieve similay bandwidth
allgather_opt_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, shard.numel(), rank, nranks, barrier,
allgather_opt_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, shard.numel(), rank, nranks, barrier,
BLOCK_SIZE=32 * 1024, num_warps=16,
**launch_cooperative_grid_options())

Expand All @@ -662,7 +640,6 @@ def allgather_chunked_pull(
SPLIT_N: int,
workgroups_per_rank: int = 2,
):
ctx = get_pyrocshmem_device_ctx()
assert full_symm.dim() == 2 and shard.dim() == 2
M_per_rank, N = shard.shape
rank = pyrocshmem.rocshmem_my_pe()
Expand All @@ -678,16 +655,16 @@ def allgather_chunked_pull(
full_symm[M_start:M_end, :].copy_(shard)

current_stream = torch.cuda.current_stream()
barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream)
barrier_all_on_stream(current_stream)

workgroups_per_rank = 2
BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N)
BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N)
# use only 8 ranks can achieve similay bandwidth
allgather_strided_chunked_pull_ctx_wrapper_kernel[(nranks * workgroups_per_rank, )](
ctx, full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier,
full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier,
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16)
barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream)
barrier_all_on_stream(current_stream)


def allgather_chunked_pull_fused(
Expand All @@ -698,7 +675,6 @@ def allgather_chunked_pull_fused(
SPLIT_N: int,
workgroups_per_rank: int = 2,
):
ctx = get_pyrocshmem_device_ctx()
assert full_symm.dim() == 2 and shard.dim() == 2
M_per_rank, N = shard.shape
rank = pyrocshmem.rocshmem_my_pe()
Expand All @@ -713,7 +689,7 @@ def allgather_chunked_pull_fused(
BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N)
# use only 8 ranks can achieve similay bandwidth
allgather_strided_chunked_pull_fused_kernel[(nranks * workgroups_per_rank, )](
ctx, full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier,
full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier,
group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16,
**launch_cooperative_grid_options())

Expand All @@ -726,7 +702,6 @@ def allgather_chunked_pull_packed_fused(
SPLIT_N: int,
workgroups_per_rank: int = 2,
):
ctx = get_pyrocshmem_device_ctx()
assert full_symm.dim() == 2 and shard.dim() == 2
M_per_rank, N = shard.shape
rank = pyrocshmem.rocshmem_my_pe()
Expand All @@ -740,7 +715,7 @@ def allgather_chunked_pull_packed_fused(
BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N)
BLOCK_SIZE_M = max(1, 2 * 1024 // BLOCK_SIZE_N)
# use only 8 ranks can achieve similar bandwidth
allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N,
allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N,
shard.stride(0), shard.stride(1), rank,
nranks, grid_barrier, group_barrier,
BLOCK_SIZE_M=BLOCK_SIZE_M,
Expand All @@ -766,13 +741,11 @@ def allgather_chunked(
assert shard.is_cuda and full_symm.is_cuda
assert shard.stride() == full_symm.stride() # just for simple implementation

ctx = get_pyrocshmem_device_ctx()

workgroups_per_rank = 1
BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N)
BLOCK_SIZE_M = max(1, 32 * 1024 // BLOCK_SIZE_N)
# use only 8 ranks can achieve similay bandwidth
allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N,
allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N,
shard.stride(0), shard.stride(1), rank, nranks,
group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N,
Expand Down
Loading
Loading