[ROCm] AITER fused allreduce+rmsnorm: merge PRs #37646 + #81 with reviewer fixes#40773
[ROCm] AITER fused allreduce+rmsnorm: merge PRs #37646 + #81 with reviewer fixes#40773rbrugaro-amd wants to merge 3 commits into
Conversation
…llm-project#81 with reviewer fixes Merge vllm-project#37646 (AITER fused allreduce + RMSNorm pass) and EmbeddedLLM#81 (hidden_dim=7168 support + threshold tuning), then address maintainer feedback: - Add AITER compile range splitting in _set_compile_ranges() so the backend generates separate compiled graphs for fusion-eligible vs non-eligible token counts, matching the existing flashinfer pattern. - Convert RocmAiterAllReduceFusionPass to VllmFusionPatternMatcherPass and rewrite AITER patterns as VllmPatternReplacement subclasses. Add extra_check support to VllmPatternReplacement/register(). - Extract _compute_allreduce_max_token_num() shared helper used by both AllReduceFusionPass (flashinfer) and RocmAiterAllReduceFusionPass. - Simplify parallel_state.py aiter capture context per review suggestion. Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
There was a problem hiding this comment.
Code Review
This pull request implements fused All-Reduce and RMSNorm operations for ROCm platforms using the AITER library, including a new compilation pass (RocmAiterAllReduceFusionPass) and integration with the distributed graph capture. Feedback identifies a bug where the residual buffer in the fused pattern must be initialized to zero to avoid incorrect results from garbage data. Additionally, the lifecycle management of the AITER communicator needs correction, as destroying it in the pass's destructor causes runtime crashes during execution, and its initialization should be made idempotent.
| residual = torch.empty_like(input) | ||
| allreduce = op( | ||
| input_=input, | ||
| residual=residual, | ||
| weight=weight, | ||
| epsilon=epsilon, | ||
| ) | ||
| return allreduce[0], allreduce[1] |
There was a problem hiding this comment.
The AiterAllreduceFusedRMSNormPattern replacement uses torch.empty_like(input) for the residual buffer. Since the underlying AITER kernel performs an additive operation (allreduce(input) + residual), using an uninitialized buffer will lead to incorrect results containing garbage data. This buffer should be initialized to zeros for the non-additive case to ensure correctness.
| residual = torch.empty_like(input) | |
| allreduce = op( | |
| input_=input, | |
| residual=residual, | |
| weight=weight, | |
| epsilon=epsilon, | |
| ) | |
| return allreduce[0], allreduce[1] | |
| residual = torch.zeros_like(input) | |
| allreduce = op( | |
| input_=input, | |
| residual=residual, | |
| weight=weight, | |
| epsilon=epsilon, | |
| ) | |
| return allreduce[0], allreduce[1] |
| def __del__(self) -> None: | ||
| if getattr(self, "disabled", True): | ||
| return | ||
| with contextlib.suppress(Exception): | ||
| rocm_aiter_ops.destroy_aiter_allreduce() |
There was a problem hiding this comment.
The __del__ method destroys the global AITER allreduce communicator. This is highly problematic because the pass instance is typically garbage collected after the compilation phase, but the global communicator must remain alive for the execution of the compiled graph. Destroying it here will cause runtime crashes when the model is executed. Furthermore, in environments with multiple engine instances or models, deleting one pass will break the communicator for all others. The lifecycle of the communicator should be managed by a persistent entity like parallel_state or GroupCoordinator, not by a fusion pass instance.
| try: | ||
| from aiter.dist.device_communicators.custom_all_reduce import ( | ||
| CustomAllreduce as AiterCustomAllreduce, | ||
| ) | ||
|
|
||
| cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) | ||
| except Exception: | ||
| cls._CUSTOM_ALL_REDUCE = None |
There was a problem hiding this comment.
The initialize_aiter_allreduce method should check if the communicator is already initialized before creating a new one. This prevents redundant resource allocation and potential race conditions when multiple fusion passes are instantiated (e.g., during multiple compilation cycles or for different models in the same process).
| try: | |
| from aiter.dist.device_communicators.custom_all_reduce import ( | |
| CustomAllreduce as AiterCustomAllreduce, | |
| ) | |
| cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) | |
| except Exception: | |
| cls._CUSTOM_ALL_REDUCE = None | |
| if cls._CUSTOM_ALL_REDUCE is not None: | |
| return | |
| try: | |
| from aiter.dist.device_communicators.custom_all_reduce import ( | |
| CustomAllreduce as AiterCustomAllreduce, | |
| ) | |
| cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) | |
| except Exception: | |
| cls._CUSTOM_ALL_REDUCE = None |
Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
Add three fixes to RocmAiterAllReduceFusionPass that resolve a hang during CUDA graph capture with fuse_allreduce_rms + full_and_piecewise: 1. Register first_return_only pattern variant so the last transformer layer (which discards the residual) also gets fused, preventing mixed AITER/vllm allreduce during capture. 2. Override _trace_fn to use pm.fwd_only, matching the working patch and avoiding unnecessary view_to_reshape transforms during pattern tracing. 3. Add __del__ cleanup to destroy AITER allreduce resources and prevent IPC buffer leaks across compilations. Validated on: - Kimi-K2-Thinking-MXFP4 TP=4: 123 patterns/worker, GSM8K=0.948 - Qwen3-30B-A3B-FP8 TP=2: 97 patterns/worker, GSM8K=0.808 Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…iewer fixes
Merge #37646 (AITER fused allreduce + RMSNorm pass) and
EmbeddedLLM#81 (hidden_dim=7168 support + threshold tuning), then
address maintainer feedback: