Skip to content

[ROCm] AITER fused allreduce+rmsnorm: merge PRs #37646 + #81 with reviewer fixes#40773

Closed
rbrugaro-amd wants to merge 3 commits into
vllm-project:mainfrom
rbrugaro-amd:allreduce_rms_comb_37646_81
Closed

[ROCm] AITER fused allreduce+rmsnorm: merge PRs #37646 + #81 with reviewer fixes#40773
rbrugaro-amd wants to merge 3 commits into
vllm-project:mainfrom
rbrugaro-amd:allreduce_rms_comb_37646_81

Conversation

@rbrugaro-amd

Copy link
Copy Markdown
Contributor

…iewer fixes

Merge #37646 (AITER fused allreduce + RMSNorm pass) and
EmbeddedLLM#81 (hidden_dim=7168 support + threshold tuning), then
address maintainer feedback:

  • Add AITER compile range splitting in _set_compile_ranges() so the backend generates separate compiled graphs for fusion-eligible vs non-eligible token counts, matching the existing flashinfer pattern.
  • Convert RocmAiterAllReduceFusionPass to VllmFusionPatternMatcherPass and rewrite AITER patterns as VllmPatternReplacement subclasses. Add extra_check support to VllmPatternReplacement/register().
  • Extract _compute_allreduce_max_token_num() shared helper used by both AllReduceFusionPass (flashinfer) and RocmAiterAllReduceFusionPass.
  • Simplify parallel_state.py aiter capture context per review suggestion.

…llm-project#81 with reviewer fixes

  Merge vllm-project#37646 (AITER fused allreduce + RMSNorm pass) and
  EmbeddedLLM#81 (hidden_dim=7168 support + threshold tuning), then
  address maintainer feedback:

  - Add AITER compile range splitting in _set_compile_ranges() so the
    backend generates separate compiled graphs for fusion-eligible vs
    non-eligible token counts, matching the existing flashinfer pattern.
  - Convert RocmAiterAllReduceFusionPass to VllmFusionPatternMatcherPass
    and rewrite AITER patterns as VllmPatternReplacement subclasses.
    Add extra_check support to VllmPatternReplacement/register().
  - Extract _compute_allreduce_max_token_num() shared helper used by
    both AllReduceFusionPass (flashinfer) and RocmAiterAllReduceFusionPass.
  - Simplify parallel_state.py aiter capture context per review suggestion.

Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
@mergify mergify Bot added ci/build rocm Related to AMD ROCm labels Apr 24, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 24, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements fused All-Reduce and RMSNorm operations for ROCm platforms using the AITER library, including a new compilation pass (RocmAiterAllReduceFusionPass) and integration with the distributed graph capture. Feedback identifies a bug where the residual buffer in the fused pattern must be initialized to zero to avoid incorrect results from garbage data. Additionally, the lifecycle management of the AITER communicator needs correction, as destroying it in the pass's destructor causes runtime crashes during execution, and its initialization should be made idempotent.

Comment on lines +937 to +944
residual = torch.empty_like(input)
allreduce = op(
input_=input,
residual=residual,
weight=weight,
epsilon=epsilon,
)
return allreduce[0], allreduce[1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The AiterAllreduceFusedRMSNormPattern replacement uses torch.empty_like(input) for the residual buffer. Since the underlying AITER kernel performs an additive operation (allreduce(input) + residual), using an uninitialized buffer will lead to incorrect results containing garbage data. This buffer should be initialized to zeros for the non-additive case to ensure correctness.

Suggested change
residual = torch.empty_like(input)
allreduce = op(
input_=input,
residual=residual,
weight=weight,
epsilon=epsilon,
)
return allreduce[0], allreduce[1]
residual = torch.zeros_like(input)
allreduce = op(
input_=input,
residual=residual,
weight=weight,
epsilon=epsilon,
)
return allreduce[0], allreduce[1]

Comment on lines +1098 to +1102
def __del__(self) -> None:
if getattr(self, "disabled", True):
return
with contextlib.suppress(Exception):
rocm_aiter_ops.destroy_aiter_allreduce()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The __del__ method destroys the global AITER allreduce communicator. This is highly problematic because the pass instance is typically garbage collected after the compilation phase, but the global communicator must remain alive for the execution of the compiled graph. Destroying it here will cause runtime crashes when the model is executed. Furthermore, in environments with multiple engine instances or models, deleting one pass will break the communicator for all others. The lifecycle of the communicator should be managed by a persistent entity like parallel_state or GroupCoordinator, not by a fusion pass instance.

Comment thread vllm/_aiter_ops.py
Comment on lines +1416 to +1423
try:
from aiter.dist.device_communicators.custom_all_reduce import (
CustomAllreduce as AiterCustomAllreduce,
)

cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device)
except Exception:
cls._CUSTOM_ALL_REDUCE = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initialize_aiter_allreduce method should check if the communicator is already initialized before creating a new one. This prevents redundant resource allocation and potential race conditions when multiple fusion passes are instantiated (e.g., during multiple compilation cycles or for different models in the same process).

Suggested change
try:
from aiter.dist.device_communicators.custom_all_reduce import (
CustomAllreduce as AiterCustomAllreduce,
)
cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device)
except Exception:
cls._CUSTOM_ALL_REDUCE = None
if cls._CUSTOM_ALL_REDUCE is not None:
return
try:
from aiter.dist.device_communicators.custom_all_reduce import (
CustomAllreduce as AiterCustomAllreduce,
)
cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device)
except Exception:
cls._CUSTOM_ALL_REDUCE = None

@rbrugaro-amd rbrugaro-amd changed the title [ROCm] AITER fused allreduce+rmsnorm: merge PRs #37646 + #81 with reviewer fixes [ROCm] [WIP not ready]AITER fused allreduce+rmsnorm: merge PRs #37646 + #81 with reviewer fixes Apr 24, 2026
attila-dusnoki-htec and others added 2 commits April 24, 2026 11:14
Signed-off-by: Attila Dusnoki <attila.dusnoki@htecgroup.com>
Add three fixes to RocmAiterAllReduceFusionPass that resolve a hang
during CUDA graph capture with fuse_allreduce_rms + full_and_piecewise:

1. Register first_return_only pattern variant so the last transformer
   layer (which discards the residual) also gets fused, preventing
   mixed AITER/vllm allreduce during capture.

2. Override _trace_fn to use pm.fwd_only, matching the working patch
   and avoiding unnecessary view_to_reshape transforms during pattern
   tracing.

3. Add __del__ cleanup to destroy AITER allreduce resources and prevent
   IPC buffer leaks across compilations.

Validated on:
- Kimi-K2-Thinking-MXFP4 TP=4: 123 patterns/worker, GSM8K=0.948
- Qwen3-30B-A3B-FP8 TP=2: 97 patterns/worker, GSM8K=0.808

Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@rbrugaro-amd rbrugaro-amd changed the title [ROCm] [WIP not ready]AITER fused allreduce+rmsnorm: merge PRs #37646 + #81 with reviewer fixes [ROCm] AITER fused allreduce+rmsnorm: merge PRs #37646 + #81 with reviewer fixes Apr 24, 2026
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants