Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/rocshmem
Submodule rocshmem updated 50 files
+9 −0 CMakeLists.txt
+3 −3 docs/api/env_variables.rst
+1 −1 docs/sphinx/requirements.in
+4 −4 docs/sphinx/requirements.txt
+2 −0 examples/CMakeLists.txt
+34 −0 examples/LL_MoE/CMakeLists.txt
+563 −0 examples/LL_MoE/LL_MoE.hpp
+166 −0 examples/LL_MoE/LL_MoE_Buffers.hpp
+207 −0 examples/LL_MoE/LL_MoE_Data.hpp
+666 −0 examples/LL_MoE/LL_MoE_Kernels.hpp
+139 −0 examples/LL_MoE/MoE_example.cpp
+23 −6 examples/util.h
+27 −2 include/rocshmem/rocshmem.hpp
+34 −7 scripts/functional_tests/driver.sh
+1 −1 src/atomic_return.cpp
+0 −14 src/bootstrap/bootstrap.cpp
+0 −1 src/gda/CMakeLists.txt
+10 −7 src/gda/backend_gda.cpp
+1 −0 src/gda/backend_gda.hpp
+2 −2 src/gda/bnxt/backend_gda_bnxt.cpp
+2 −1 src/gda/context_gda_tmpl_device.hpp
+0 −81 src/gda/endian.cpp
+56 −19 src/gda/endian.hpp
+43 −2 src/gda/ionic/backend_gda_ionic.cpp
+35 −0 src/gda/ionic/ionic_dv.h
+64 −47 src/gda/ionic/ionic_fw.h
+3 −0 src/gda/ionic/provider_gda_ionic.hpp
+80 −45 src/gda/ionic/queue_pair_ionic.cpp
+1 −1 src/gda/mlx5/backend_gda_mlx5.cpp
+3 −4 src/gda/mlx5/queue_pair_mlx5.cpp
+8 −8 src/gda/mlx5/segment_builder.cpp
+12 −2 src/gda/queue_pair.cpp
+6 −0 src/gda/queue_pair.hpp
+25 −6 src/ipc/backend_ipc.cpp
+1 −1 src/ipc/backend_ipc.hpp
+2 −0 src/ipc/context_ipc_device.cpp
+8 −18 src/ipc/context_ipc_tmpl_device.hpp
+1 −1 src/mpi_instance.cpp
+16 −0 src/rocshmem.cpp
+66 −5 src/rocshmem_gpu.cpp
+10 −0 src/util.cpp
+11 −0 src/util.hpp
+5 −0 tests/functional_tests/CMakeLists.txt
+217 −0 tests/functional_tests/flood_tester.cpp
+56 −0 tests/functional_tests/flood_tester.hpp
+1 −1 tests/functional_tests/team_reduction_tester.cpp
+31 −0 tests/functional_tests/tester.cpp
+6 −0 tests/functional_tests/tester.hpp
+15 −0 tests/functional_tests/tester_arguments.cpp
+4 −0 tests/unit_tests/CMakeLists.txt
29 changes: 6 additions & 23 deletions python/triton_dist/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,29 +64,12 @@ def shmem_kernel_module_init_hook(*args, **kwargs) -> None:
backend = get_shmem_backend()

if backend == 'rocshmem':
import pyrocshmem
res = hip.hipModuleGetGlobal(kernel_module, b"ROCSHMEM_CTX_DEFAULT")
# dptr, bytes = res[1], res[2]
if res[0] == hip.hipError_t.hipSuccess:
"""
typedef struct rocshmem_ctx{
void *ctx_opaque;
void *team_opaque;
} rocshmem_ctx_t;
pyrocshmem.rocshmem_get_device_ctx only return the `ctx_opaque`.
`ROCSHMEM_CTX_DEFAULT` is a `rocshmem_ctx_t` struct, but only the `ctx_opaque` field needs to be updated on the device side.
(equal to `libshmem_device.set_rocshmem_ctx(ctx)` in the kernel)
"""
ctx_opaque_bytes = 8 # assuming 64-bit pointer
# get the host address of the `ctx_opaque` pointer.
ctx = pyrocshmem.rocshmem_get_device_ctx()
ctx_tensor = torch.tensor([ctx], dtype=torch.int64)
# update the device `ROCSHMEM_CTX_DEFAULT` struct's `ctx_opaque` field in the kernel module.
cp_res = hip.hipMemcpy(res[1], ctx_tensor.data_ptr(), ctx_opaque_bytes,
hip.hipMemcpyKind.hipMemcpyHostToDevice)
HIP_CHECK(cp_res)
else:
hip.hipGetLastError() # Discard the last error
has_rocshmem = "rocshmem" in kernel.asm.get('llir', '')
if has_rocshmem:
import pyrocshmem
ret = pyrocshmem.rocshmem_hipmodule_init(kernel_module, 0)
if ret != 0:
warnings.warn(f"rocshmem_hipmodule_init failed with code {ret}")
elif backend == 'mori_shmem':
# Initialize mori_shmem device symbols in this kernel module
import mori.shmem as mori_shmem
Expand Down
51 changes: 15 additions & 36 deletions python/triton_dist/kernels/amd/allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,21 @@
import triton
import triton.language as tl
import triton_dist.language as dl
from triton.language.extra.hip.librocshmem_device import set_rocshmem_ctx
from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_with_ctx_on_stream,
from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_on_stream,
barrier_on_this_grid)
from triton_dist.kernels.amd.memcpy import memcpy_async_kernel
from triton_dist.utils import launch_cooperative_grid_options


@triton.jit(do_not_specialize=["rank"])
def allgather_no_barrier_kernel(
ctx,
symm_ptr,
local_ptr,
N_per_rank,
rank,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
set_rocshmem_ctx(ctx)
npid = tl.num_programs(0)
pid = tl.program_id(0)
npid_per_rank = npid // num_ranks
Expand All @@ -65,7 +62,6 @@ def allgather_no_barrier_kernel(

@triton.jit(do_not_specialize=["rank"])
def allgather_kernel(
ctx,
symm_ptr,
local_ptr,
N_per_rank,
Expand All @@ -74,7 +70,6 @@ def allgather_kernel(
group_barrier_ptr,
BLOCK_SIZE: tl.constexpr,
):
set_rocshmem_ctx(ctx)
npid = tl.num_programs(0)
pid = tl.program_id(0)
npid_per_rank = npid // num_ranks
Expand All @@ -97,7 +92,7 @@ def allgather_kernel(

@triton.jit(do_not_specialize=["rank"])
def allgather_strided_chunked_kernel(
ctx,

symm_ptr, # (M, N), M = M_per_rank * num_ranks
local_ptr, # (M_per_rank, N)
M_per_rank,
Expand All @@ -114,7 +109,6 @@ def allgather_strided_chunked_kernel(
"""
with the local_ptr
"""
set_rocshmem_ctx(ctx)
npid = tl.num_programs(0)
pid = tl.program_id(0)

Expand Down Expand Up @@ -400,7 +394,6 @@ def allgather_strided_chunked_pull_packed_kernel(

@triton.jit(do_not_specialize=["rank"])
def allgather_strided_chunked_pull_ctx_wrapper_kernel(
ctx,
symm_ptr, # (M, N), M = M_per_rank * num_ranks
M_per_rank,
N,
Expand All @@ -413,7 +406,6 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel(
BLOCK_SIZE_N: tl.constexpr, # should be as large as possible
SPLIT_N: tl.constexpr,
):
set_rocshmem_ctx(ctx)
pid = tl.program_id(0)
npid = tl.num_programs(0)
allgather_strided_chunked_pull_kernel(
Expand All @@ -436,7 +428,7 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel(

@triton.jit(do_not_specialize=["rank"])
def allgather_strided_chunked_pull_fused_kernel(
ctx,

symm_ptr, # (M, N), M = M_per_rank * num_ranks
local_ptr, # (M_per_rank, N)
M_per_rank,
Expand All @@ -451,7 +443,6 @@ def allgather_strided_chunked_pull_fused_kernel(
BLOCK_SIZE_N: tl.constexpr, # should be as large as possible
SPLIT_N: tl.constexpr,
):
set_rocshmem_ctx(ctx)
pid = tl.program_id(0)
npid = tl.num_programs(0)

Expand Down Expand Up @@ -481,7 +472,7 @@ def allgather_strided_chunked_pull_fused_kernel(

@triton.jit(do_not_specialize=["rank"])
def allgather_chunked_pull_fused_packed_kernel(
ctx,

symm_ptr, # (M, N), M = M_per_rank * num_ranks
local_ptr, # (M_per_rank, N)
M_per_rank,
Expand All @@ -496,7 +487,6 @@ def allgather_chunked_pull_fused_packed_kernel(
BLOCK_SIZE_N: tl.constexpr, # should be as large as possible
SPLIT_N: tl.constexpr,
):
set_rocshmem_ctx(ctx)
pid = tl.program_id(0)

memcpy_async_kernel(local_ptr, symm_ptr + M_per_rank * stride_m * rank, M_per_rank * N, 8 * BLOCK_SIZE_N)
Expand All @@ -523,7 +513,7 @@ def allgather_chunked_pull_fused_packed_kernel(

@triton.jit(do_not_specialize=["rank"])
def allgather_opt_kernel(
ctx,

symm_ptr,
local_ptr,
N_per_rank,
Expand All @@ -533,7 +523,6 @@ def allgather_opt_kernel(
BLOCK_SIZE: tl.constexpr,
):
""" compared to allagther_kernel, this use less CTAs to achieve to the same level """
set_rocshmem_ctx(ctx)
npid = tl.num_programs(0)
pid = tl.program_id(0)
total_tiles = tl.cdiv(N_per_rank, BLOCK_SIZE)
Expand Down Expand Up @@ -618,21 +607,16 @@ def allgather_ipc_kernel(


@functools.lru_cache()
def get_pyrocshmem_device_ctx():
return pyrocshmem.rocshmem_get_device_ctx()


def allgather_no_fuse(full_symm: torch.Tensor, shard: torch.Tensor, barrier: torch.Tensor):
rank = pyrocshmem.rocshmem_my_pe()
nranks = pyrocshmem.rocshmem_n_pes()
assert shard.numel() * nranks == full_symm.numel()
assert shard.dtype == full_symm.dtype
ctx = get_pyrocshmem_device_ctx()
current_stream = torch.cuda.current_stream()
barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream)
allgather_no_barrier_kernel[(nranks * 4, )](ctx, full_symm, shard, shard.numel(), rank, nranks,
barrier_all_on_stream(current_stream)
allgather_no_barrier_kernel[(nranks * 4, )](full_symm, shard, shard.numel(), rank, nranks,
BLOCK_SIZE=32 * 1024, **launch_cooperative_grid_options())
barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream)
barrier_all_on_stream(current_stream)


def allgather(
Expand All @@ -645,11 +629,10 @@ def allgather(
nranks = pyrocshmem.rocshmem_n_pes()
assert shard.numel() * nranks == full_symm.numel()
assert shard.dtype == full_symm.dtype
ctx = get_pyrocshmem_device_ctx()

workgroups_per_rank = 1
# use only 8 ranks can achieve similay bandwidth
allgather_opt_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, shard.numel(), rank, nranks, barrier,
allgather_opt_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, shard.numel(), rank, nranks, barrier,
BLOCK_SIZE=32 * 1024, num_warps=16,
**launch_cooperative_grid_options())

Expand All @@ -662,7 +645,6 @@ def allgather_chunked_pull(
SPLIT_N: int,
workgroups_per_rank: int = 2,
):
ctx = get_pyrocshmem_device_ctx()
assert full_symm.dim() == 2 and shard.dim() == 2
M_per_rank, N = shard.shape
rank = pyrocshmem.rocshmem_my_pe()
Expand All @@ -678,16 +660,16 @@ def allgather_chunked_pull(
full_symm[M_start:M_end, :].copy_(shard)

current_stream = torch.cuda.current_stream()
barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream)
barrier_all_on_stream(current_stream)

workgroups_per_rank = 2
BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N)
BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N)
# use only 8 ranks can achieve similay bandwidth
allgather_strided_chunked_pull_ctx_wrapper_kernel[(nranks * workgroups_per_rank, )](
ctx, full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier,
full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier,
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16)
barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream)
barrier_all_on_stream(current_stream)


def allgather_chunked_pull_fused(
Expand All @@ -698,7 +680,6 @@ def allgather_chunked_pull_fused(
SPLIT_N: int,
workgroups_per_rank: int = 2,
):
ctx = get_pyrocshmem_device_ctx()
assert full_symm.dim() == 2 and shard.dim() == 2
M_per_rank, N = shard.shape
rank = pyrocshmem.rocshmem_my_pe()
Expand All @@ -713,7 +694,7 @@ def allgather_chunked_pull_fused(
BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N)
# use only 8 ranks can achieve similay bandwidth
allgather_strided_chunked_pull_fused_kernel[(nranks * workgroups_per_rank, )](
ctx, full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier,
full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier,
group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16,
**launch_cooperative_grid_options())

Expand All @@ -726,7 +707,6 @@ def allgather_chunked_pull_packed_fused(
SPLIT_N: int,
workgroups_per_rank: int = 2,
):
ctx = get_pyrocshmem_device_ctx()
assert full_symm.dim() == 2 and shard.dim() == 2
M_per_rank, N = shard.shape
rank = pyrocshmem.rocshmem_my_pe()
Expand All @@ -740,7 +720,7 @@ def allgather_chunked_pull_packed_fused(
BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N)
BLOCK_SIZE_M = max(1, 2 * 1024 // BLOCK_SIZE_N)
# use only 8 ranks can achieve similar bandwidth
allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N,
allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N,
shard.stride(0), shard.stride(1), rank,
nranks, grid_barrier, group_barrier,
BLOCK_SIZE_M=BLOCK_SIZE_M,
Expand All @@ -766,13 +746,12 @@ def allgather_chunked(
assert shard.is_cuda and full_symm.is_cuda
assert shard.stride() == full_symm.stride() # just for simple implementation

ctx = get_pyrocshmem_device_ctx()

workgroups_per_rank = 1
BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N)
BLOCK_SIZE_M = max(1, 32 * 1024 // BLOCK_SIZE_N)
# use only 8 ranks can achieve similay bandwidth
allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N,
allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N,
shard.stride(0), shard.stride(1), rank, nranks,
group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N,
Expand Down
16 changes: 0 additions & 16 deletions python/triton_dist/language/extra/hip/librocshmem_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,6 @@
ROCSHMEM_SIGNAL_ADD = 1


@core.extern
def set_rocshmem_ctx(ctx, _semantic=None):
return extern_call(
"librocshmem_device",
"",
[
tl.cast(ctx, tl.pointer_type(tl.void), _semantic=_semantic),
],
{
(tl.pointer_type(tl.void), ): ("rocshmem_set_ctx", ()),
},
is_pure=False,
_semantic=_semantic,
)


void_ptr = core.pointer_type(core.void)


Expand Down
6 changes: 0 additions & 6 deletions python/triton_dist/language/extra/libshmem_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@
])


@_shmem_module.dispatch
def set_rocshmem_ctx(ctx):
"""ROCSHMEM only"""
...


@_shmem_module.dispatch
def my_pe():
"""Both NVSHMEM and ROCSHMEM"""
Expand Down
6 changes: 2 additions & 4 deletions python/triton_dist/test/amd/test_put_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from triton_dist.utils import finalize_distributed, initialize_distributed, NVSHMEM_SIGNAL_DTYPE

@triton_dist.jit(do_not_specialize=["my_pe", "dst_pe"])
def simple_put_signal_test(data, message, nelem, sig_addr, my_pe, dst_pe, ctx):
libshmem_device.set_rocshmem_ctx(ctx)
def simple_put_signal_test(data, message, nelem, sig_addr, my_pe, dst_pe):
pid = tl.program_id(0)
thread_id = tid(0)

Expand Down Expand Up @@ -49,9 +48,8 @@ def simple_put_signal_test(data, message, nelem, sig_addr, my_pe, dst_pe, ctx):
torch.cuda.synchronize()

print(RANK, '->', dst_pe)
ctx = pyrocshmem.rocshmem_get_device_ctx()

simple_put_signal_test[(1, )](data, message, elems, sig_addr, RANK, dst_pe, ctx)
simple_put_signal_test[(1, )](data, message, elems, sig_addr, RANK, dst_pe)
pyrocshmem.rocshmem_barrier_all()
torch.cuda.synchronize()

Expand Down
Loading