diff --git a/.github/workflows/amd-ci.yml b/.github/workflows/amd-ci.yml index 57f301e3b..eadbc3887 100644 --- a/.github/workflows/amd-ci.yml +++ b/.github/workflows/amd-ci.yml @@ -47,6 +47,8 @@ jobs: bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_distributed-notify-wait.py # ag gemm bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ag_gemm_intra_node.py 8192 11008 4096 + # ar gemm + bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_gemm_ar_intra_node.py 8192 11008 4096 # gemm rs bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_gemm_rs_intra_node.py 8192 4096 12288 # mori shmem api tests diff --git a/python/triton_dist/kernels/amd/gemm_allreduce.py b/python/triton_dist/kernels/amd/gemm_allreduce.py index e3b9980ca..4d4c58a9c 100644 --- a/python/triton_dist/kernels/amd/gemm_allreduce.py +++ b/python/triton_dist/kernels/amd/gemm_allreduce.py @@ -174,30 +174,40 @@ def consumer_all_reduce_kernel(symm_buf_ptr, tile_signal_ptr, M, N, stride_cm, s num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_in_group = GROUP_SIZE_M * tl.cdiv(N, BLOCK_SIZE_N) - for tile_id in range(pid, num_tiles, NUM_COMM_SMS): + + start_tile = rank + stride = world_size + remaining = num_tiles - start_tile + remaining = tl.maximum(remaining, 0) + max_tile_offset = tl.cdiv(remaining, stride) + + for tile_offset in range(pid, max_tile_offset, NUM_COMM_SMS): + tile_id = start_tile + tile_offset * stride pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M) - owner_rank = tile_id % world_size - if rank == owner_rank: - signal_base = tile_id * world_size - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - - final_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(world_size): - target_rank = (i + rank) % world_size - remote_c_ptr = dl.symm_at(symm_buf_ptr, target_rank) - token = dl.wait(tile_signal_ptr + signal_base + target_rank, 1, "sys", "acquire", waitValue=1) - remote_c_ptr = dl.consume_token(remote_c_ptr, token) - remote_c_ptrs = remote_c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn - remote_data = tl.load(remote_c_ptrs, mask=c_mask, other=0.0) - final_acc += remote_data - - c = final_acc.to(symm_buf_ptr.dtype.element_ty) - for remote_rank in range(world_size): - remote_buf_ptr = dl.symm_at(symm_buf_ptr, remote_rank) - remote_buf_ptrs = remote_buf_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn - tl.store(remote_buf_ptrs, c, mask=c_mask) + + signal_base = tile_id * world_size + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + remote_c_ptr = dl.symm_at(symm_buf_ptr, rank) + token = dl.wait(tile_signal_ptr + signal_base + rank, 1, "sys", "acquire", waitValue=1) + remote_c_ptr = dl.consume_token(remote_c_ptr, token) + remote_c_ptrs = remote_c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + final_acc = tl.load(remote_c_ptrs, mask=c_mask, other=0.0).to(tl.float32) + for i in range(1, world_size): + target_rank = (i + rank) % world_size + remote_c_ptr = dl.symm_at(symm_buf_ptr, target_rank) + token = dl.wait(tile_signal_ptr + signal_base + target_rank, 1, "sys", "acquire", waitValue=1) + remote_c_ptr = dl.consume_token(remote_c_ptr, token) + remote_c_ptrs = remote_c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + final_acc += tl.load(remote_c_ptrs, mask=c_mask, other=0.0) + + c = final_acc.to(symm_buf_ptr.dtype.element_ty) + for remote_rank in range(world_size): + remote_buf_ptr = dl.symm_at(symm_buf_ptr, remote_rank) + remote_buf_ptrs = remote_buf_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + tl.store(remote_buf_ptrs, c, mask=c_mask) DEFAULT_GEMM_CONFIG = triton.Config(