diff --git a/3rdparty/rocshmem b/3rdparty/rocshmem index f5940f6b9..5211e8c5c 160000 --- a/3rdparty/rocshmem +++ b/3rdparty/rocshmem @@ -1 +1 @@ -Subproject commit f5940f6b9a5412a58297bef309a97ae69ae7dd7a +Subproject commit 5211e8c5cac246bfea0aa2f7b0a082205537e1fd diff --git a/python/triton_dist/jit.py b/python/triton_dist/jit.py index 30b24349f..d0fa49ded 100644 --- a/python/triton_dist/jit.py +++ b/python/triton_dist/jit.py @@ -64,29 +64,12 @@ def shmem_kernel_module_init_hook(*args, **kwargs) -> None: backend = get_shmem_backend() if backend == 'rocshmem': - import pyrocshmem - res = hip.hipModuleGetGlobal(kernel_module, b"ROCSHMEM_CTX_DEFAULT") - # dptr, bytes = res[1], res[2] - if res[0] == hip.hipError_t.hipSuccess: - """ - typedef struct rocshmem_ctx{ - void *ctx_opaque; - void *team_opaque; - } rocshmem_ctx_t; - pyrocshmem.rocshmem_get_device_ctx only return the `ctx_opaque`. - `ROCSHMEM_CTX_DEFAULT` is a `rocshmem_ctx_t` struct, but only the `ctx_opaque` field needs to be updated on the device side. - (equal to `libshmem_device.set_rocshmem_ctx(ctx)` in the kernel) - """ - ctx_opaque_bytes = 8 # assuming 64-bit pointer - # get the host address of the `ctx_opaque` pointer. - ctx = pyrocshmem.rocshmem_get_device_ctx() - ctx_tensor = torch.tensor([ctx], dtype=torch.int64) - # update the device `ROCSHMEM_CTX_DEFAULT` struct's `ctx_opaque` field in the kernel module. - cp_res = hip.hipMemcpy(res[1], ctx_tensor.data_ptr(), ctx_opaque_bytes, - hip.hipMemcpyKind.hipMemcpyHostToDevice) - HIP_CHECK(cp_res) - else: - hip.hipGetLastError() # Discard the last error + has_rocshmem = "rocshmem" in kernel.asm.get('llir', '') + if has_rocshmem: + import pyrocshmem + ret = pyrocshmem.rocshmem_hipmodule_init(kernel_module, 0) + if ret != 0: + warnings.warn(f"rocshmem_hipmodule_init failed with code {ret}") elif backend == 'mori_shmem': # Initialize mori_shmem device symbols in this kernel module import mori.shmem as mori_shmem diff --git a/python/triton_dist/kernels/amd/allgather.py b/python/triton_dist/kernels/amd/allgather.py index 903bfd45a..5aed57c01 100644 --- a/python/triton_dist/kernels/amd/allgather.py +++ b/python/triton_dist/kernels/amd/allgather.py @@ -30,8 +30,7 @@ import triton import triton.language as tl import triton_dist.language as dl -from triton.language.extra.hip.librocshmem_device import set_rocshmem_ctx -from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_with_ctx_on_stream, +from triton_dist.kernels.amd.common_ops import (barrier_all_kernel, barrier_all_on_stream, barrier_on_this_grid) from triton_dist.kernels.amd.memcpy import memcpy_async_kernel from triton_dist.utils import launch_cooperative_grid_options @@ -39,7 +38,6 @@ @triton.jit(do_not_specialize=["rank"]) def allgather_no_barrier_kernel( - ctx, symm_ptr, local_ptr, N_per_rank, @@ -47,7 +45,6 @@ def allgather_no_barrier_kernel( num_ranks: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) npid_per_rank = npid // num_ranks @@ -65,7 +62,6 @@ def allgather_no_barrier_kernel( @triton.jit(do_not_specialize=["rank"]) def allgather_kernel( - ctx, symm_ptr, local_ptr, N_per_rank, @@ -74,7 +70,6 @@ def allgather_kernel( group_barrier_ptr, BLOCK_SIZE: tl.constexpr, ): - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) npid_per_rank = npid // num_ranks @@ -97,7 +92,7 @@ def allgather_kernel( @triton.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_kernel( - ctx, + symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -114,7 +109,6 @@ def allgather_strided_chunked_kernel( """ with the local_ptr """ - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) @@ -400,7 +394,6 @@ def allgather_strided_chunked_pull_packed_kernel( @triton.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_pull_ctx_wrapper_kernel( - ctx, symm_ptr, # (M, N), M = M_per_rank * num_ranks M_per_rank, N, @@ -413,7 +406,6 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) npid = tl.num_programs(0) allgather_strided_chunked_pull_kernel( @@ -436,7 +428,7 @@ def allgather_strided_chunked_pull_ctx_wrapper_kernel( @triton.jit(do_not_specialize=["rank"]) def allgather_strided_chunked_pull_fused_kernel( - ctx, + symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -451,7 +443,6 @@ def allgather_strided_chunked_pull_fused_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) npid = tl.num_programs(0) @@ -481,7 +472,7 @@ def allgather_strided_chunked_pull_fused_kernel( @triton.jit(do_not_specialize=["rank"]) def allgather_chunked_pull_fused_packed_kernel( - ctx, + symm_ptr, # (M, N), M = M_per_rank * num_ranks local_ptr, # (M_per_rank, N) M_per_rank, @@ -496,7 +487,6 @@ def allgather_chunked_pull_fused_packed_kernel( BLOCK_SIZE_N: tl.constexpr, # should be as large as possible SPLIT_N: tl.constexpr, ): - set_rocshmem_ctx(ctx) pid = tl.program_id(0) memcpy_async_kernel(local_ptr, symm_ptr + M_per_rank * stride_m * rank, M_per_rank * N, 8 * BLOCK_SIZE_N) @@ -523,7 +513,7 @@ def allgather_chunked_pull_fused_packed_kernel( @triton.jit(do_not_specialize=["rank"]) def allgather_opt_kernel( - ctx, + symm_ptr, local_ptr, N_per_rank, @@ -533,7 +523,6 @@ def allgather_opt_kernel( BLOCK_SIZE: tl.constexpr, ): """ compared to allagther_kernel, this use less CTAs to achieve to the same level """ - set_rocshmem_ctx(ctx) npid = tl.num_programs(0) pid = tl.program_id(0) total_tiles = tl.cdiv(N_per_rank, BLOCK_SIZE) @@ -618,21 +607,16 @@ def allgather_ipc_kernel( @functools.lru_cache() -def get_pyrocshmem_device_ctx(): - return pyrocshmem.rocshmem_get_device_ctx() - - def allgather_no_fuse(full_symm: torch.Tensor, shard: torch.Tensor, barrier: torch.Tensor): rank = pyrocshmem.rocshmem_my_pe() nranks = pyrocshmem.rocshmem_n_pes() assert shard.numel() * nranks == full_symm.numel() assert shard.dtype == full_symm.dtype - ctx = get_pyrocshmem_device_ctx() current_stream = torch.cuda.current_stream() - barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream) - allgather_no_barrier_kernel[(nranks * 4, )](ctx, full_symm, shard, shard.numel(), rank, nranks, + barrier_all_on_stream(current_stream) + allgather_no_barrier_kernel[(nranks * 4, )](full_symm, shard, shard.numel(), rank, nranks, BLOCK_SIZE=32 * 1024, **launch_cooperative_grid_options()) - barrier_all_with_ctx_on_stream(ctx, rank, nranks, barrier, current_stream) + barrier_all_on_stream(current_stream) def allgather( @@ -645,11 +629,10 @@ def allgather( nranks = pyrocshmem.rocshmem_n_pes() assert shard.numel() * nranks == full_symm.numel() assert shard.dtype == full_symm.dtype - ctx = get_pyrocshmem_device_ctx() workgroups_per_rank = 1 # use only 8 ranks can achieve similay bandwidth - allgather_opt_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, shard.numel(), rank, nranks, barrier, + allgather_opt_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, shard.numel(), rank, nranks, barrier, BLOCK_SIZE=32 * 1024, num_warps=16, **launch_cooperative_grid_options()) @@ -662,7 +645,6 @@ def allgather_chunked_pull( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -678,16 +660,16 @@ def allgather_chunked_pull( full_symm[M_start:M_end, :].copy_(shard) current_stream = torch.cuda.current_stream() - barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream) + barrier_all_on_stream(current_stream) workgroups_per_rank = 2 BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth allgather_strided_chunked_pull_ctx_wrapper_kernel[(nranks * workgroups_per_rank, )]( - ctx, full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, + full_symm, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16) - barrier_all_with_ctx_on_stream(ctx, rank, nranks, group_barrier, current_stream) + barrier_all_on_stream(current_stream) def allgather_chunked_pull_fused( @@ -698,7 +680,6 @@ def allgather_chunked_pull_fused( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -713,7 +694,7 @@ def allgather_chunked_pull_fused( BLOCK_SIZE_M = max(1, 16 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth allgather_strided_chunked_pull_fused_kernel[(nranks * workgroups_per_rank, )]( - ctx, full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, + full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, num_warps=16, **launch_cooperative_grid_options()) @@ -726,7 +707,6 @@ def allgather_chunked_pull_packed_fused( SPLIT_N: int, workgroups_per_rank: int = 2, ): - ctx = get_pyrocshmem_device_ctx() assert full_symm.dim() == 2 and shard.dim() == 2 M_per_rank, N = shard.shape rank = pyrocshmem.rocshmem_my_pe() @@ -740,7 +720,7 @@ def allgather_chunked_pull_packed_fused( BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 2 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similar bandwidth - allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N, + allgather_chunked_pull_fused_packed_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, grid_barrier, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, @@ -766,13 +746,12 @@ def allgather_chunked( assert shard.is_cuda and full_symm.is_cuda assert shard.stride() == full_symm.stride() # just for simple implementation - ctx = get_pyrocshmem_device_ctx() workgroups_per_rank = 1 BLOCK_SIZE_N = triton.next_power_of_2(N // SPLIT_N) BLOCK_SIZE_M = max(1, 32 * 1024 // BLOCK_SIZE_N) # use only 8 ranks can achieve similay bandwidth - allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](ctx, full_symm, shard, M_per_rank, N, + allgather_strided_chunked_kernel[(nranks * workgroups_per_rank, )](full_symm, shard, M_per_rank, N, shard.stride(0), shard.stride(1), rank, nranks, group_barrier, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, SPLIT_N=SPLIT_N, diff --git a/python/triton_dist/language/extra/hip/librocshmem_device.py b/python/triton_dist/language/extra/hip/librocshmem_device.py index aa294da67..abf4f175f 100644 --- a/python/triton_dist/language/extra/hip/librocshmem_device.py +++ b/python/triton_dist/language/extra/hip/librocshmem_device.py @@ -42,22 +42,6 @@ ROCSHMEM_SIGNAL_ADD = 1 -@core.extern -def set_rocshmem_ctx(ctx, _semantic=None): - return extern_call( - "librocshmem_device", - "", - [ - tl.cast(ctx, tl.pointer_type(tl.void), _semantic=_semantic), - ], - { - (tl.pointer_type(tl.void), ): ("rocshmem_set_ctx", ()), - }, - is_pure=False, - _semantic=_semantic, - ) - - void_ptr = core.pointer_type(core.void) diff --git a/python/triton_dist/language/extra/libshmem_device.py b/python/triton_dist/language/extra/libshmem_device.py index 7113188ec..15d8d7e91 100644 --- a/python/triton_dist/language/extra/libshmem_device.py +++ b/python/triton_dist/language/extra/libshmem_device.py @@ -39,12 +39,6 @@ ]) -@_shmem_module.dispatch -def set_rocshmem_ctx(ctx): - """ROCSHMEM only""" - ... - - @_shmem_module.dispatch def my_pe(): """Both NVSHMEM and ROCSHMEM""" diff --git a/python/triton_dist/test/amd/test_put_signal.py b/python/triton_dist/test/amd/test_put_signal.py index 6cbcbf372..144a8f0ac 100644 --- a/python/triton_dist/test/amd/test_put_signal.py +++ b/python/triton_dist/test/amd/test_put_signal.py @@ -11,8 +11,7 @@ from triton_dist.utils import finalize_distributed, initialize_distributed, NVSHMEM_SIGNAL_DTYPE @triton_dist.jit(do_not_specialize=["my_pe", "dst_pe"]) -def simple_put_signal_test(data, message, nelem, sig_addr, my_pe, dst_pe, ctx): - libshmem_device.set_rocshmem_ctx(ctx) +def simple_put_signal_test(data, message, nelem, sig_addr, my_pe, dst_pe): pid = tl.program_id(0) thread_id = tid(0) @@ -49,9 +48,8 @@ def simple_put_signal_test(data, message, nelem, sig_addr, my_pe, dst_pe, ctx): torch.cuda.synchronize() print(RANK, '->', dst_pe) -ctx = pyrocshmem.rocshmem_get_device_ctx() -simple_put_signal_test[(1, )](data, message, elems, sig_addr, RANK, dst_pe, ctx) +simple_put_signal_test[(1, )](data, message, elems, sig_addr, RANK, dst_pe) pyrocshmem.rocshmem_barrier_all() torch.cuda.synchronize() diff --git a/python/triton_dist/test/amd/test_rocshmem_api.py b/python/triton_dist/test/amd/test_rocshmem_api.py index 20c7c7d95..b74244431 100644 --- a/python/triton_dist/test/amd/test_rocshmem_api.py +++ b/python/triton_dist/test/amd/test_rocshmem_api.py @@ -46,8 +46,7 @@ def test_rocshmem_device(): @triton_dist.jit - def _rocshmem_device(comm_buf, ctx, ptr): - libshmem_device.set_rocshmem_ctx(ctx) + def _rocshmem_device(comm_buf, ptr): mype = dl.rank() npes = dl.num_ranks() @@ -58,9 +57,7 @@ def _rocshmem_device(comm_buf, ctx, ptr): tl.store(comm_buf, npes) @triton_dist.jit - def _rocshmem_put(ptr, ctx): - libshmem_device.set_rocshmem_ctx(ctx) - + def _rocshmem_put(ptr): mype = libshmem_device.my_pe() npes = libshmem_device.n_pes() peer = (mype + 1) % npes @@ -68,9 +65,7 @@ def _rocshmem_put(ptr, ctx): libshmem_device.int_p(ptr, mype, peer) @triton_dist.jit - def _rocshmem_get_put_symm_at(local_ptr, ctx): - libshmem_device.set_rocshmem_ctx(ctx) - + def _rocshmem_get_put_symm_at(local_ptr): mype = libshmem_device.my_pe() npes = libshmem_device.n_pes() pid = tl.program_id(axis=0) @@ -87,11 +82,10 @@ def _rocshmem_get_put_symm_at(local_ptr, ctx): mype = pyrocshmem.rocshmem_my_pe() npes = pyrocshmem.rocshmem_n_pes() - ctx = pyrocshmem.rocshmem_get_device_ctx() comm_buf = pyrocshmem.rocshmem_create_tensor((2, ), torch.int32) torch.distributed.barrier() - _rocshmem_device[(1, )](comm_buf, ctx, comm_buf.data_ptr()) + _rocshmem_device[(1, )](comm_buf, comm_buf.data_ptr()) torch.distributed.barrier() torch.cuda.synchronize() @@ -106,7 +100,7 @@ def _rocshmem_get_put_symm_at(local_ptr, ctx): put_buf = pyrocshmem.rocshmem_create_tensor((1, ), torch.int32) torch.distributed.barrier() - _rocshmem_put[(1, )](put_buf, ctx) + _rocshmem_put[(1, )](put_buf) torch.distributed.barrier() torch.cuda.synchronize() @@ -121,7 +115,7 @@ def _rocshmem_get_put_symm_at(local_ptr, ctx): mype:nelems_per_rank * (mype + 1)]) torch.distributed.barrier() - _rocshmem_get_put_symm_at[(1, )](put_bufs, ctx) + _rocshmem_get_put_symm_at[(1, )](put_bufs) torch.distributed.barrier() torch.cuda.synchronize() @@ -139,7 +133,7 @@ def _rocshmem_get_put_symm_at(local_ptr, ctx): def test_rocshmem_basic(): @triton.jit - def _rocshmem_basic(comm_buf, ctx, mype, npes): + def _rocshmem_basic(comm_buf, mype, npes): tl.store(comm_buf, mype) comm_buf += 1 tl.store(comm_buf, npes) @@ -149,11 +143,10 @@ def _rocshmem_basic(comm_buf, ctx, mype, npes): mype = pyrocshmem.rocshmem_my_pe() npes = pyrocshmem.rocshmem_n_pes() - ctx = pyrocshmem.rocshmem_get_device_ctx() comm_buf = pyrocshmem.rocshmem_create_tensor((2, ), torch.int32) # torch.distributed.barrier() - _rocshmem_basic[(1, )](comm_buf, ctx, mype, npes) + _rocshmem_basic[(1, )](comm_buf, mype, npes) # torch.distributed.barrier() pyrocshmem.rocshmem_barrier_all_on_stream(torch.cuda.current_stream().cuda_stream) torch.cuda.synchronize() @@ -169,6 +162,102 @@ def _rocshmem_basic(comm_buf, ctx, mype, npes): print(f"✅ _rocshmem_basic #{mype} pass") +def test_rocshmem_graph_capture(): + """ + Test CUDA/HIP graph capture compatibility with rocshmem_hipmodule_init. + + This test validates that rocshmem_hipmodule_init() works correctly within + torch.cuda.graph() capture, which requires device-to-device copy operations + instead of legacy stream operations (Issue #2741). + """ + print("**test_rocshmem_graph_capture start!") + + # IMPORTANT: Define the kernel INSIDE this function to ensure it's NOT cached + # This forces rocshmem_hipmodule_init to be called during graph capture + @triton_dist.jit + def _graph_test_kernel(result_buf): + """Kernel for graph capture test - must be unique to avoid caching""" + mype = libshmem_device.my_pe() + npes = libshmem_device.n_pes() + # Write rank-specific value to verify execution + tl.store(result_buf, mype + 42) + + mype = pyrocshmem.rocshmem_my_pe() + npes = pyrocshmem.rocshmem_n_pes() + + # Create result buffer + result_buf = pyrocshmem.rocshmem_create_tensor((1,), torch.int32) + result_buf.fill_(0) # Initialize to 0 + + print(f"[Rank {mype}] Creating CUDA graph for rocshmem_hipmodule_init test...") + + # Create a CUDA graph + graph = torch.cuda.CUDAGraph() + + # Use a dedicated stream for graph capture + stream = torch.cuda.Stream() + + with torch.cuda.stream(stream): + # DO NOT warm up - we want rocshmem_hipmodule_init to be called + # DURING graph capture to test if it's graph compatible + + # Begin graph capture IMMEDIATELY + print(f"[Rank {mype}] Beginning graph capture (first kernel launch will call rocshmem_hipmodule_init)...") + graph.capture_begin() + + try: + # Launch kernel within graph capture + # This is the FIRST launch, so rocshmem_hipmodule_init will be called NOW + # With OLD implementation: hipMemcpyFromSymbol uses legacy stream -> SHOULD FAIL + # With NEW implementation: device-to-device copy -> SHOULD SUCCEED + _graph_test_kernel[(1,)](result_buf) + + # End graph capture + graph.capture_end() + print(f"[Rank {mype}] ✅ Graph capture completed successfully") + print(f"[Rank {mype}] ✅ rocshmem_hipmodule_init worked within graph capture!") + + except Exception as e: + print(f"[Rank {mype}] ❌ Graph capture FAILED: {e}") + print(f"[Rank {mype}] ❌ This proves rocshmem_hipmodule_init is NOT graph compatible") + try: + graph.capture_end() + except: + pass + raise + + # Synchronize before replay + torch.cuda.synchronize() + + # Reset buffer before graph replay + result_buf.fill_(0) + + # Replay the captured graph + print(f"[Rank {mype}] Replaying captured graph...") + graph.replay() + torch.cuda.synchronize() + + # Verify result + expected_value = mype + 42 + actual_value = result_buf.item() + + print(f"[Rank {mype}] Graph replay result: expected={expected_value}, actual={actual_value}") + + try: + torch.testing.assert_close( + result_buf, + torch.tensor([expected_value], dtype=torch.int32, device="cuda"), + atol=0, + rtol=0 + ) + except Exception as e: + print(f"[Rank {mype}] ❌ Graph capture test FAILED - result mismatch") + raise e + else: + print(f"[Rank {mype}] ✅ Graph replay verification PASSED") + print(f"[Rank {mype}] ✅ CONFIRMED: rocshmem_hipmodule_init is CUDA graph compatible!") + + def test_rocshmem_memcpy(): print("**rocshmem memcpy start!") @@ -242,6 +331,8 @@ def parse_args(): test_rocshmem_device() + test_rocshmem_graph_capture() + test_rocshmem_memcpy() ctx = get_torch_prof_ctx(args.profile) diff --git a/shmem/rocshmem_bind/pyrocshmem/python/_pyrocshmem/__init__.pyi b/shmem/rocshmem_bind/pyrocshmem/python/_pyrocshmem/__init__.pyi index cdef1f341..62950a4cb 100755 --- a/shmem/rocshmem_bind/pyrocshmem/python/_pyrocshmem/__init__.pyi +++ b/shmem/rocshmem_bind/pyrocshmem/python/_pyrocshmem/__init__.pyi @@ -99,10 +99,6 @@ def rocshmem_ptr(dest: np.intp, pe: np.int32) -> np.intp: ## for device side api -def rocshmem_get_device_ctx() -> np.intp: - ... - - # torch related diff --git a/shmem/rocshmem_bind/pyrocshmem/src/pyrocshmem.cc b/shmem/rocshmem_bind/pyrocshmem/src/pyrocshmem.cc index aa554cf7c..1cfd4e06d 100755 --- a/shmem/rocshmem_bind/pyrocshmem/src/pyrocshmem.cc +++ b/shmem/rocshmem_bind/pyrocshmem/src/pyrocshmem.cc @@ -109,8 +109,17 @@ PYBIND11_MODULE(_pyrocshmem, m) { m.def("rocshmem_barrier_all_on_stream", [](intptr_t stream) { rocshmem_barrier_all_on_stream((hipStream_t)stream); }); - m.def("rocshmem_get_device_ctx", - []() -> int64_t { return (int64_t)rocshmem_get_device_ctx(); }); + m.def("rocshmem_hipmodule_init", + [](intptr_t module, intptr_t stream) -> int { + return rocshmem_hipmodule_init( + (hipModule_t)module, + stream == 0 ? nullptr : (hipStream_t)stream + ); + }, + py::arg("module"), + py::arg("stream") = 0, + "Initialize rocSHMEM device context for a specific HIP module. " + "This is required for CUDA graph compatibility."); m.def("rocshmem_get_uniqueid", []() { rocshmem_uniqueid_t uid; CHECK_ROCSHMEM(rocshmem_get_uniqueid(&uid)); diff --git a/shmem/rocshmem_bind/runtime/rocshmem_wrapper.cc b/shmem/rocshmem_bind/runtime/rocshmem_wrapper.cc index 9b7a4c01e..efc8a4fc2 100644 --- a/shmem/rocshmem_bind/runtime/rocshmem_wrapper.cc +++ b/shmem/rocshmem_bind/runtime/rocshmem_wrapper.cc @@ -44,11 +44,6 @@ rocshmem_ptr_wrapper(void *dest, int pe) { return rocshmem_ptr(dest, pe); } -__device__ void __attribute__((visibility("default"))) -rocshmem_set_ctx(void *ctx) { - ROCSHMEM_CTX_DEFAULT.ctx_opaque = ctx; -} - __device__ void __attribute__((visibility("default"))) rocshmem_putmem_signal_wrapper(void *dest, const void *source, size_t nbytes, uint64_t *sig_addr, uint64_t signal, int sig_op, diff --git a/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh b/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh index d1797c564..7854fc5d8 100755 --- a/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh +++ b/shmem/rocshmem_bind/scripts/build_rocshmem_device_bc.sh @@ -39,12 +39,12 @@ SOURCE_MAP=( ["${ROCSHMEM_SRC}/src/ipc/context_ipc_device_coll.cpp"]="rocshmem_context_ipc_device_coll.bc" ["${ROCSHMEM_SRC}/src/ipc_policy.cpp"]="rocshmem_ipc_policy.bc" ["${ROCSHMEM_SRC}/src/gda/context_gda_device.cpp"]="rocshmem_context_gda_device.bc" + ["${ROCSHMEM_SRC}/src/gda/context_gda_device_coll.cpp"]="rocshmem_context_gda_device_coll.bc" ["${ROCSHMEM_SRC}/src/gda/backend_gda.cpp"]="rocshmem_backend_gda.bc" ["${ROCSHMEM_SRC}/src/gda/queue_pair.cpp"]="rocshmem_queue_pair.bc" ["${ROCSHMEM_SRC}/src/gda/ionic/queue_pair_ionic.cpp"]="rocshmem_queue_pair_ionic.bc" ["${ROCSHMEM_SRC}/src/gda/mlx5/queue_pair_mlx5.cpp"]="rocshmem_queue_pair_mlx5.bc" ["${ROCSHMEM_SRC}/src/gda/mlx5/segment_builder.cpp"]="rocshmem_segment_builder.bc" - ["${ROCSHMEM_SRC}/src/gda/endian.cpp"]="rocshmem_endian.bc" ["${ROCSHMEM_SRC}/src/team.cpp"]="rocshmem_team.bc" ["${ROCSHMEM_SRC}/src/sync/abql_block_mutex.cpp"]="rocshmem_abql_block_mutex.bc" ["${ROCSHMEM_SRC}/src/util.cpp"]="rocshmem_util.bc" diff --git a/tutorials/03a-inter-node-allgather.py b/tutorials/03a-inter-node-allgather.py index d23b9c677..f8d4b82ed 100644 --- a/tutorials/03a-inter-node-allgather.py +++ b/tutorials/03a-inter-node-allgather.py @@ -61,8 +61,9 @@ class AllGatherContext: @triton_dist.jit(do_not_specialize=["rank", "signal_value"]) def all_gather_push_1d_kernel(symm_ptr, bytes_per_rank, symm_flag, - WORLD_SIZE: tl.constexpr, rank, signal_value, ctx): - libshmem_device.set_rocshmem_ctx(ctx) + WORLD_SIZE: tl.constexpr, rank, signal_value): + # Context is now automatically initialized via rocshmem_hipmodule_init() in JIT compiler + # No need to manually set context pid = tl.program_id(0) thread_idx = tid(0) # there are WORLD_SIZE programs processing different data. @@ -99,15 +100,13 @@ def all_gather_push_1d_kernel(symm_ptr, bytes_per_rank, symm_flag, def all_gather_push_1d(ctx: AllGatherContext, symm_buffer: torch.Tensor): ctx.signal_value += 1 - rctx = pyrocshmem.rocshmem_get_device_ctx() all_gather_push_1d_kernel[(ctx.num_ranks, )]( symm_buffer, symm_buffer.nbytes // ctx.num_ranks, ctx.symm_signals[ctx.signal_value % 2], ctx.num_ranks, ctx.rank, - ctx.signal_value, - rctx + ctx.signal_value ) return symm_buffer