Skip to content

[Spec Decode][CUDA Graphs] Enables Eagle drafter support for FULL CUDA Graph mode#34880

Open
yiz-liu wants to merge 2 commits into
vllm-project:mainfrom
yiz-liu:full-spec
Open

[Spec Decode][CUDA Graphs] Enables Eagle drafter support for FULL CUDA Graph mode#34880
yiz-liu wants to merge 2 commits into
vllm-project:mainfrom
yiz-liu:full-spec

Conversation

@yiz-liu

@yiz-liu yiz-liu commented Feb 19, 2026

Copy link
Copy Markdown
Contributor

Purpose

As mentioned in vllm-project/vllm-ascend#5459 and #33341 , this PR enables Full CUDA Graph mode for the Eagle drafter model to improve performance.

The main changes include:

  1. CUDA Graph Integration: Wraps the drafter model with CUDAGraphWrapper during load_model and initializes the necessary keys for the dispatcher to manage graph-based execution.
  2. Graph Capture Support: Builds dummy attention metadata during the dummy_run phase, which is required for successful graph capture.
  3. Dispatch: For the first step, the draft model shares the same uniform_decode with target model and basically has the same batch_desc and cudagraph_mode, while for the following steps, the uniform_decode_query_len will be set to 1 and uniform_decode to True, making it possible to have separate cudagraph_keys.
  4. Metadata Correction: Corrects the memory address handling for query_start_loc and slot_mapping within the attention metadata.
  5. Bug Fix: Adjusts CUDA graph capture sizes to resolve a runtime error that occurred when num_speculative_tokens was set to 2. Also fix prepare_inputs_padded and prepare_next_token_ids_padded for padding issues.

Collectively, these changes allow the Eagle drafter to leverage the performance benefits of Full CUDA Graph mode, enhancing throughput for speculative decoding.

Test Plan

The feature was tested by running the model with the following configuration:

  • num_speculative_tokens=2 (and also validated with 3/4/5)
  • cudagraph_mode="FULL" (and also validated with FULL_DECODE_ONLY and FULL_AND_PIECEWISE)

Test Result

The model's acceptance rate in Full CUDA Graph mode is consistent with the results from eager mode.

For FULL_AND_PIECEWISE:

...
[cuda_graph.py:123] | Unpadded Tokens | Padded Tokens | Num Paddings | Runtime Mode | Count |
[cuda_graph.py:123] |-----------------|---------------|--------------|--------------|-------|
[cuda_graph.py:123] | 40              | 40            | 0            | FULL         | 28    |
[cuda_graph.py:123] | 42              | 50            | 8            | PIECEWISE    | 1     |
[cuda_graph.py:123] | 35              | 35            | 0            | FULL         | 1     |
...
--------------------------------------------------
total_num_output_tokens: 768
num_drafts: 402
num_draft_tokens: 1608
num_accepted_tokens: 370
mean acceptance length: 1.92
--------------------------------------------------
acceptance at token 0: 0.46
acceptance at token 1: 0.22
acceptance at token 2: 0.14
acceptance at token 3: 0.10

For FULL:

...
[cuda_graph.py:123] | Unpadded Tokens | Padded Tokens | Num Paddings | Runtime Mode | Count |
[cuda_graph.py:123] |-----------------|---------------|--------------|--------------|-------|
[cuda_graph.py:123] | 40              | 40            | 0            | FULL         | 29    |
[cuda_graph.py:123] | 5               | 5             | 0            | FULL         | 2     |
[cuda_graph.py:123] | 44              | 50            | 6            | FULL         | 1     |
...
--------------------------------------------------
total_num_output_tokens: 768
num_drafts: 402
num_draft_tokens: 1608
num_accepted_tokens: 365
mean acceptance length: 1.91
--------------------------------------------------
acceptance at token 0: 0.48
acceptance at token 1: 0.23
acceptance at token 2: 0.12
acceptance at token 3: 0.08

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables full CUDA graph support for the Eagle drafter model, which is a significant performance enhancement. The changes are well-structured, including necessary modifications for CUDA graph compatibility like in-place tensor updates and proper dummy run setup for graph capturing. I've identified one critical issue: a logging statement with incorrect formatting that will cause a TypeError at runtime. I've provided a suggestion to fix it. Overall, this is a great contribution.

Comment thread vllm/v1/worker/gpu_model_runner.py

@yiz-liu yiz-liu left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are some questions I am not entirely sure about. Any comments?

Comment thread vllm/v1/spec_decode/eagle.py Outdated
Comment on lines +373 to +381
if not self.speculative_config.enforce_eager:
# This is a temprary mapping open to discussions
# FULL_AND_PIECEWISE -> PIECEWISE, FULL_DECODE_ONLY -> FULL
# PIECEWISE -> PIECEWISE, FULL -> FULL
eagle_cudagraph_mode = (
CUDAGraphMode.PIECEWISE
if cudagraph_mode.has_piecewise_cudagraphs()
else cudagraph_mode.decode_mode()
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any other thoughts on this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is updated by c79e3ac

Comment on lines +571 to +574
common_attn_metadata.query_start_loc[: batch_size + 1] = self.arange[
: batch_size + 1
]
common_attn_metadata.query_start_loc_cpu[: batch_size + 1] = torch.from_numpy(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we set query_start_loc or slot_mapping to a different buffer in the first place, but I assume it's always safe to use the original buffer.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was not considering consistent addressing since we didn't have full graphs yet.

Comment thread vllm/v1/spec_decode/eagle.py Outdated
Comment on lines +909 to +910
# NOTE: For CUDA Graph, we need the `num_reqs_padded` here
batch_size = common_attn_metadata.num_reqs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another core change, as the input_batch.num_reqs != common_attn_metadata.num_reqs after padding, and I wonder if there is a better way to deal with this?

@yiz-liu

yiz-liu commented Feb 19, 2026

Copy link
Copy Markdown
Contributor Author

@tomasruizt @LucasWilkinson Could you please take a look at this? Thanks!

Comment on lines 5588 to 5597
@@ -5575,7 +5593,6 @@ def _check_and_update_cudagraph_mode(
# sizes for decode and mixed prefill-decode.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
and self.uniform_decode_query_len > 1
):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I check the comments here, since we already have separate capture sizes now, we can remove this condition right? This is to solve the num_speculative_tokens=2 issue.

@tomasruizt

Copy link
Copy Markdown
Contributor

@benchislett

@tomasruizt

tomasruizt commented Feb 19, 2026

Copy link
Copy Markdown
Contributor

@yiz-liu Thanks a lot for this PR!

Edit: Perhaps the observation below is just a matter of enabling full CG also for method=draft_model. Let me know! :)

I profiled your branch using the PyTorch profiler and found:

  • That the target model is running with full cuda graphs (CG) and dispatching the forward using cudaGraphLaunch.
  • However, the draft model is not using full cg: instead its launching many separate pytorch ops during the forward (see below, left circle is the target forward, and right circle is the draft forward).
sd-full-cg-profile

I attach the command I used to generate the trace as well as the PyTorch trace, which you can open in https://ui.perfetto.dev/.

Script: profile-4b-sd-0.6b.sh

Trace: rank0.1771496331250686317.pt.trace.json.gz

I assume that you are seeing no changes in performance whatsoever compared to main (TPOT, ITL). If correct, it means that some wiring up is still missing to enable full CG for the drafter. Let me know if I'm wrong or missing something.

@yiz-liu

yiz-liu commented Feb 19, 2026

Copy link
Copy Markdown
Contributor Author

@tomasruizt Weird, I'll look into this later, in the meanwhile scripts and profiling are attached below:
image
FULL.zip
data_parallel.py

python3 data_parallel.py \
--model="/home/weight/Qwen3-30B-A3B-FP8" \
-dp=1 \
-tp=2

@tomasruizt

Copy link
Copy Markdown
Contributor

For EAGLE3, I'm observing the same phenomenon. I used gpt-oss-20b + eagle3.

  • target model forward dispatches to cudaGraphLaunch, while
  • draft model forward dispatches to a bunch of small ops
sd-eagle3-full-cg-profile

Profiling script: profile-gpt-oss-20b-eagle3.sh
PyTorch trace: dp0_pp0_tp0_dcp0_ep0_rank0.1771505307158794037.pt.trace.json.gz

@yiz-liu

yiz-liu commented Feb 19, 2026

Copy link
Copy Markdown
Contributor Author

For EAGLE3, I'm observing the same phenomenon. I used gpt-oss-20b + eagle3.

  • target model forward dispatches to cudaGraphLaunch, while
  • draft model forward dispatches to a bunch of small ops

@tomasruizt Oh yeah I checked this scripts and profiling, I believe the behavior you're observing is due to the default CUDA graph mode, which resolves to FULL_AND_PIECEWISE. As I mentioned in this comment, this configuration correctly results in a PIECEWISE graph for the speculative decoding step. The resulting piecewise graph contains very few ops, which can make it appear as though no graph is active, but this is the expected outcome for that mode:

image

Could you please try explicitly setting the CUDA graph mode to FULL and re-running the profile?

This brings up a design question, I'll elaborate on my comment before: do you think we should change the default strategy to be more aggressive (i.e., prefer FULL over PIECEWISE)? My take is that with async scheduling now available, the host launch overhead for drafter model is likely masked by the target model's computation time, making PIECEWISE still a safe and robust default, which is exactly the reason I keep it. What are your thoughts?

@tomasruizt

tomasruizt commented Feb 19, 2026

Copy link
Copy Markdown
Contributor

If the target model runs in full cg, then the draft model should run in full cg if possible, right? The higher performance setting should be the default. What is the problem with setting full cg as a default?

@yiz-liu

yiz-liu commented Feb 19, 2026

Copy link
Copy Markdown
Contributor Author

If the target model runs in full cg, then the draft model should run in full cg if possible, right? The higher performance setting should be the default. What is the problem with setting full cg as a default?

Yeah that's a good point. No problem at all, my initial design was just trying to honor the existing default behavior for consistency. However, I agree that prioritizing performance is the right way. I'll go ahead and update the PR. Of course, for others who might have concerns, this is still open for discussion.

Thanks for the valuable feedback!

@benchislett

Copy link
Copy Markdown
Member

+1, behaviour should match the base model for consistency whenever possible. If base model uses full graphs for a certain shape, so should the drafter.

@Neo9061

Neo9061 commented Feb 19, 2026

Copy link
Copy Markdown

Thanks for the great work! Will the full CUDA can be applied to parallel-EAGLE as well? CC @benchislett

@yiz-liu

yiz-liu commented Feb 20, 2026

Copy link
Copy Markdown
Contributor Author

@tomasruizt @benchislett Hi, please see the latest commit for the unified CUDA Graph mode, the target model and drafter model should share the same behavior now.

@mergify

mergify Bot commented Feb 21, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yiz-liu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify

mergify Bot commented Feb 21, 2026

Copy link
Copy Markdown
Contributor

Hi @yiz-liu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify

mergify Bot commented Feb 21, 2026

Copy link
Copy Markdown
Contributor

Hi @yiz-liu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@tomasruizt

Copy link
Copy Markdown
Contributor

@yiz-liu Are you able to generate the PyTorch profile? You can attach it once done as proof that the drafter uses cudaGraphLaunch. If for some reason you cannot, let me know, I probably can do it on Monday.

@tomasruizt

tomasruizt commented Mar 26, 2026

Copy link
Copy Markdown
Contributor

@yiz-liu Have you measured improvements in inference metrics? An acceptance length of ~1.9 might be a blocker to get measurable speedups. Perhaps method=draft_model can give you high acceptance lengths (5 and more) to measure speedups (Link)

@mergify

mergify Bot commented Mar 30, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yiz-liu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 30, 2026
yiz-liu added 2 commits April 1, 2026 15:33
…ificantly improving inference performance by reducing CPU overhead during the draft speculative steps.

1. CudagraphDispatcher
* Added a for_draft_model flag to allow specialized graph capture logic for speculative decoding.
* Updated initialize_cudagraph_keys to capture graphs up to max_num_tokens specifically for steps > 0.
* Set uniform_decode_query_len as a independent parameter as when steps > 0, it should be 1.

2. EAGLE Proposer Updates
* Model Wrapping: The draft model is now wrapped in CUDAGraphWrapper when FULL mode is enabled and padding is not disabled.
* Metadata Padding: Fixed a potential crash by padding spec_decode_metadata.cu_num_draft_tokens to match the padded batch size.
* Refined Dispatching: Updated _determine_batch_execution_and_padding to return and pass BatchDescriptor objects, ensuring the runtime uses the correct graph key.
* Capture Logic: Enhanced dummy_run to simulate the actual speculative decoding steps during the graph capture phase.

3. GPUModelRunner
* Introduced supports_sd_full_graph to identify proposers (like EAGLE) that are compatible with FULL graph mode.
* Modified ExecuteModelState to track batch_desc, ensuring consistency between the target model and draft model.
* Ensured CommonAttentionMetadata is correctly passed to the drafter's warmup/dummy runs to facilitate accurate metadata building during capture.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Store `num_reqs_padded` on the GPU input batch and source it from
the runner so `prepare_next_token_ids_padded` can keep a shared 5-argument
API across proposer variants.

Use padded batch size only for output allocation and kernel launch shape,
while guarding Triton work with the actual request count so padded tail
rows remain inert and do not read stale fallback data.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
@oreo-wjx

Copy link
Copy Markdown

will this pr be merged?

uniform_decode_query_len = self.uniform_decode_query_len
if uniform_decode_query_len is None:
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should build a seperate _bs_to_padded_graph_size for uniform_decode_query_len==1, it'll save some computation. Take num_speculative_tokens=3 as an example, bs=1 currently will be padded to 4, however, no pad is needed actually.

yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
…ttern (Step 1)

The MUSA-0090/0109 EagleFullLoopRunner captured the full N-step Eagle3
draft loop as ONE giant CUDAGraph. On the yeahdongcn70 (mcc 5.1.0)
toolchain the giant-chain capture corrupts the draft model's MCCL
barrier state at replay: accept rate collapses to 3.49% with positions
1-4 at exactly 0%, GPUs 4+6 firmware-reboot when the FP8 MoE override
is removed (full bisect at `generated/musa0203/`).

Upstream PR vllm-project#34880 ("Eagle drafter support for FULL
CUDA Graph mode") follows a structurally different design: wrap the
draft model with the standard `CUDAGraphWrapper` and dispatch per-step
via a `CudagraphDispatcher(for_draft_model=True)`. Each draft forward
step gets its own captured graph keyed by `BatchDescriptor` — the same
shape as the target model's PIECEWISE captures that already work on
MUSA. The shared `current_platform.get_global_graph_pool()` isn't the
issue; the giant-chain capture's baked-in inter-step state was.

This commit deletes the runner and its monkey-patch (Step 1 of the
PR vllm-project#34880 backport):

- vllm_musa/v1/spec_decode/eagle_full_loop_runner.py (912 lines)
- vllm_musa/v1/spec_decode/spec_info.py (259 lines)
- vllm_musa/v1/spec_decode/attn_backend_array.py (265 lines)
- vllm_musa/patches/vllm__v1__spec_decode__llm_base_proposer.patch.py
  (244 lines, dead POC default-off)

And reduces these to shims:

- vllm_musa/patches/vllm__v1__spec_decode__eagle.patch.py
  (431 -> 33 lines; only preserves the MUSA-Triton kernel prime that
  must run before upstream's llm_base_proposer binds the broken kernel)
- vllm_musa/v1/spec_decode/__init__.py (19 -> 16 lines)

Net diff: -2121 +41 = -2080. The runner-default-OFF state shipped in
e932fdc already gives the validated 70.524 tok/s baseline on
yeahdongcn70. With the runner removed, env var
VLLM_MUSA_EAGLE_RUNNER is now a no-op (kept undocumented for
backward-compat tooling; future PR vllm-project#34880-style runner will use its
own gating).

Subsequent commits will wire in PR vllm-project#34880's CUDAGraphWrapper +
per-step dispatch pattern; this one establishes the clean baseline.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
…tch additive patches (Step 2)

Two additive patches from upstream vllm-project#34880:

1. cudagraph_dispatcher.py: add `for_draft_model` flag to
   `CudagraphDispatcher.__init__`. When set, `initialize_cudagraph_keys`
   registers an extra set of FULL-mode capture keys with
   `uniform_decode_query_len=1` (one token per request, matching each
   Eagle decode step). Threads `uniform_decode_query_len` through
   `dispatch` and `_create_padded_batch_descriptor` so a caller can
   override per-call.

2. gpu_input_batch.py: add `num_reqs_padded` property to `InputBatch`
   so the spec-decode draft path can request a padded batch_size for
   FULL cudagraph capture keys without changing the scheduler request
   count. Backed by `_num_reqs_padded: int | None` which is reset by
   `condense()` and `refresh_metadata()`.

Both patches are additive — no caller in v0.20.0-dev passes the new
flags, so behavior is unchanged at runtime. They establish the
infrastructure for the follow-up llm_base_proposer + gpu_model_runner
wiring that activates FULL-mode draft capture (Step 3).

When upstream PR vllm-project#34880 merges, these patches become no-ops (the
strings won't match anymore) and can be deleted.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
…(Step 3)

Wires the CUDAGraphWrapper(FULL) draft-model capture path. Three new
patch files target the v0.20.0-dev EagleProposer implementation (which
lives in llm_base_proposer.py, not eagle.py as in PR vllm-project#34880's base).

vllm__v1__spec_decode__llm_base_proposer.patch.py (11 hunks):

  1. Import CUDAGraphWrapper + BatchDescriptor.
  2. CudagraphDispatcher(self.vllm_config) -> (..., for_draft_model=True).
  3. initialize_cudagraph_keys: pass target's cudagraph_mode through
     directly (was restricted to PIECEWISE). With for_draft_model=True
     the dispatcher now also registers FULL-mode keys at qdl=1 when
     target has FULL captures.
  4. load_model: wrap self.model with CUDAGraphWrapper(runtime_mode=FULL)
     when target uses FULL captures. The wrapper is a no-op when the
     runtime mode at call time doesn't match FULL — safe with PIECEWISE
     target (our current SOTA config) until target switches to
     FULL_AND_PIECEWISE.
  5. _determine_batch_execution_and_padding returns 4-tuple now
     (adds batch_desc slot); threads uniform_decode + uniform_decode_query_len
     through.
  6. propose()'s first set_forward_context call passes batch_descriptor=batch_desc.
  7. propose()'s decode-loop set_forward_context call passes
     batch_descriptor=batch_desc with uniform_decode_query_len=1.
  8. dummy_run() 4-tuple unpack.

vllm__v1__spec_decode__dflash.patch.py (1 hunk):
  DFlashProposer.dummy_run 4-tuple unpack.

vllm__v1__spec_decode__utils.patch.py (2 hunks):
  eagle_prepare_inputs_padded_kernel: early-exit guard for query_len==0
  rows + reuse q_end_idx. (Independent kernel safety fix from same PR.)

Skipped from PR vllm-project#34880 (lower-priority for MUSA's PIECEWISE workload,
can be added later if FULL_AND_PIECEWISE perf tuning needs them):

  - target_model_batch_desc parameter to propose() — requires
    gpu_model_runner caller change which adds ~12 more hunks.
  - dummy_run two-pass capture variant (matters only at boot capture
    when target uses FULL_AND_PIECEWISE).
  - prepare_next_token_ids_padded / prepare_inputs_padded padding
    edge cases (matters for padded-drafter batches > 1).

With these patches MUSA's PIECEWISE-mode SOTA (current production
config) is unchanged at runtime — the new code paths only activate
when target uses FULL_AND_PIECEWISE or FULL_DECODE_ONLY. The
infrastructure is in place for follow-up FULL-mode validation
(MUSA-0205 candidate).

Total Step 3 diff: +178 lines across 3 patch files (5 patches total
for Steps 2+3, 23 hunks).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
Step 3 commit b0f6560 wrapped self.model with CUDAGraphWrapper(FULL)
in load_model when target uses FULL captures, but missed the matching
PR vllm-project#34880 hunk that unwraps the wrapper before the Eagle3* isinstance
check in propose(). Without the unwrap, propose() AssertionErrors at
first inference because `isinstance(self.model, Eagle3LlamaForCausalLM)`
returns False (self.model is now a CUDAGraphWrapper).

Boot succeeds (wrap happens during model load), but the first chat
completion fails with `AssertionError` at llm_base_proposer.py:438
and kills the engine. Repro: M2.5+Eagle3 SOTA cookbook → France smoke
FAIL, all 5 runs return 0 tok/s, engine dies with WorkerProc shutdown.

Fix: 12th hunk in vllm__v1__spec_decode__llm_base_proposer.patch.py
adds the unwrap-before-isinstance pattern from PR vllm-project#34880's eagle.py
hunk @@ -423,8 +415,12 @@.

Confirmed via dry-run that all 12 hunks have unique anchors and the
patched llm_base_proposer.py compile-checks clean.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
…(Step 3)

Wires the CUDAGraphWrapper(FULL) draft-model capture path. Three new
patch files target the v0.20.0-dev EagleProposer implementation (which
lives in llm_base_proposer.py, not eagle.py as in PR vllm-project#34880's base).

vllm__v1__spec_decode__llm_base_proposer.patch.py (11 hunks):

  1. Import CUDAGraphWrapper + BatchDescriptor.
  2. CudagraphDispatcher(self.vllm_config) -> (..., for_draft_model=True).
  3. initialize_cudagraph_keys: pass target's cudagraph_mode through
     directly (was restricted to PIECEWISE). With for_draft_model=True
     the dispatcher now also registers FULL-mode keys at qdl=1 when
     target has FULL captures.
  4. load_model: wrap self.model with CUDAGraphWrapper(runtime_mode=FULL)
     when target uses FULL captures. The wrapper is a no-op when the
     runtime mode at call time doesn't match FULL — safe with PIECEWISE
     target (our current SOTA config) until target switches to
     FULL_AND_PIECEWISE.
  5. _determine_batch_execution_and_padding returns 4-tuple now
     (adds batch_desc slot); threads uniform_decode + uniform_decode_query_len
     through.
  6. propose()'s first set_forward_context call passes batch_descriptor=batch_desc.
  7. propose()'s decode-loop set_forward_context call passes
     batch_descriptor=batch_desc with uniform_decode_query_len=1.
  8. dummy_run() 4-tuple unpack.

vllm__v1__spec_decode__dflash.patch.py (1 hunk):
  DFlashProposer.dummy_run 4-tuple unpack.

vllm__v1__spec_decode__utils.patch.py (2 hunks):
  eagle_prepare_inputs_padded_kernel: early-exit guard for query_len==0
  rows + reuse q_end_idx. (Independent kernel safety fix from same PR.)

Skipped from PR vllm-project#34880 (lower-priority for MUSA's PIECEWISE workload,
can be added later if FULL_AND_PIECEWISE perf tuning needs them):

  - target_model_batch_desc parameter to propose() — requires
    gpu_model_runner caller change which adds ~12 more hunks.
  - dummy_run two-pass capture variant (matters only at boot capture
    when target uses FULL_AND_PIECEWISE).
  - prepare_next_token_ids_padded / prepare_inputs_padded padding
    edge cases (matters for padded-drafter batches > 1).

With these patches MUSA's PIECEWISE-mode SOTA (current production
config) is unchanged at runtime — the new code paths only activate
when target uses FULL_AND_PIECEWISE or FULL_DECODE_ONLY. The
infrastructure is in place for follow-up FULL-mode validation
(MUSA-0205 candidate).

Total Step 3 diff: +178 lines across 3 patch files (5 patches total
for Steps 2+3, 23 hunks).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
Step 3 commit b0f6560 wrapped self.model with CUDAGraphWrapper(FULL)
in load_model when target uses FULL captures, but missed the matching
PR vllm-project#34880 hunk that unwraps the wrapper before the Eagle3* isinstance
check in propose(). Without the unwrap, propose() AssertionErrors at
first inference because `isinstance(self.model, Eagle3LlamaForCausalLM)`
returns False (self.model is now a CUDAGraphWrapper).

Boot succeeds (wrap happens during model load), but the first chat
completion fails with `AssertionError` at llm_base_proposer.py:438
and kills the engine. Repro: M2.5+Eagle3 SOTA cookbook → France smoke
FAIL, all 5 runs return 0 tok/s, engine dies with WorkerProc shutdown.

Fix: 12th hunk in vllm__v1__spec_decode__llm_base_proposer.patch.py
adds the unwrap-before-isinstance pattern from PR vllm-project#34880's eagle.py
hunk @@ -423,8 +415,12 @@.

Confirmed via dry-run that all 12 hunks have unique anchors and the
patched llm_base_proposer.py compile-checks clean.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
… draft

Completes the PR vllm-project#34880 backport on top of Step 3. Without this, the
load_model CUDAGraphWrapper(FULL) wrap engages but tries to capture at
request time, tripping validate_cudagraph_capturing_enabled.

Two changes:

llm_base_proposer.patch.py: replace the small dummy_run unpack hunk
with a full dummy_run rewrite that accepts an optional
common_attn_metadata, runs 2 forward passes when is_graph_capturing
(matching propose's first-forward + decode-step shape), builds
per-layer attn metadata via the existing build_for_drafting helper,
and threads batch_descriptor=batch_desc to set_forward_context.

gpu_model_runner.patch.py: wires the caller side (5 hunks):
  1. __init__: supports_sd_full_graph flag (False default).
  2. EagleProposer init: set supports_sd_full_graph based on
     `not disable_padded_drafter_batch`.
  3. _dummy_run: name the second tuple slot from
     `self._build_attention_metadata` as `spec_decode_cm`.
  4. _dummy_run: extend the `use_cudagraphs` predicate so FULL
     captures are enabled when `supports_sd_full_graph` (was
     PIECEWISE-only).
  5. _dummy_run: pass `spec_decode_cm` as `common_attn_metadata`
     to `drafter.dummy_run(...)`.

Plus a small change in llm_base_proposer.patch.py Hunk 7: derive
`uniform_decode = common_attn_metadata.max_query_len == 1` locally
in propose() instead of taking it as an external arg. This avoids
the additional ~5-7 hunks in gpu_model_runner.py that would otherwise
be needed to plumb `target_model_batch_desc` through `execute_model`
-> `propose_draft_token_ids` -> `drafter.propose(...)`.

This commit also obsoletes the prior gpu_model_runner.patch.py
(MUSA-0109 `VLLM_MUSA_DRAFT_COPY_DEFAULT_STREAM` workaround that
targeted the now-deleted EagleFullLoopRunner). The filename is
reused for the new content.

Total backport surface: 29 hunks across 6 patch files. When upstream
PR vllm-project#34880 merges (or v0.20.0-dev rebases onto it), every "old" anchor
will stop matching and the entire backport becomes a no-op deletable
in one PR.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
Step 4 (289c0b7) renamed `attn_metadata, _ = self._build_attention_metadata(...)`
to `attn_metadata, spec_decode_cm = ...`, but the assignment is inside an
`if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:` block.
When that branch doesn't fire (PIECEWISE without force_attention), the
spec_decode_cm variable was never assigned but the drafter.dummy_run call
still references it -> UnboundLocalError at profile_run time.

Add the missing PR vllm-project#34880 hunk that declares `spec_decode_cm = None` at
the top of _dummy_run alongside the existing `attn_metadata = None`
declaration (PR vllm-project#34880 line 5379, sits adjacent to the existing decl).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
Add PR vllm-project#34880's matching in-place hunk for the captured replay path.
Without it the propose() loop reassigns:
  common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
which creates a fresh tensor each call. The captured FA kernel baked
in the pre-capture pointer at boot time and reads stale data at replay
-> draft tokens past position 0 collapse (Pos 1-4 fell to 7/3/1/0.6%
on yeahdongcn70 cookbook).

PR vllm-project#34880 changes to:
  common_attn_metadata.query_start_loc[: batch_size + 1] = self.arange[...]
which is an in-place write into the persistent buffer the captured
graph already references. Same fix for query_start_loc_cpu.

Test: reboot with VLLM_MUSA_DRAFT_FULL_WRAP=1 and rerun cookbook bench.
Expected: accept rate recovers, Pos 1-4 distribution matches eager.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
The captured-replay accept-rate regression that caused us to default the
wrap to OFF was caused by the propose() loop reassigning
common_attn_metadata.query_start_loc instead of writing in-place. The
matching PR vllm-project#34880 hunk (8b4d41c) fixed that, so the wrap is now
correctness-equivalent to eager and faster.

Validation on yeahdongcn70 cookbook (4k/1k BS=1 greedy, [6] capture):

  config                       tok/s   accept   Pos 0/1/2/3/4
  wrap=0 baseline              70.4    47.95%   68/53/47/38/34
  wrap=1 (this commit + qsl)   77.6    47.95%   68/53/47/38/34

Position decay matches eager exactly (bit-equivalent draft outputs).
+10.2% over baseline, captured FULL graphs for both target and draft.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 28, 2026
Cherry-pick vllm_musa/_dispatcher_override.py from origin/xd/op_perf
plus the __init__.py hook. Gated by VLLM_MUSA_OP_PERF_OVERRIDES=1
(default OFF).

Prior A/B on yeahdongcn60 showed wash for BS=1 (35.42 vs 35.14) and
BS>=16 regressions. Worth retesting on yeahdongcn70 mcc 5.1.0 +
PR vllm-project#34880 backport.

The dispatcher uses torchada.replace_op_impl to swap upstream
_C/_moe_C impls behind the dispatcher table, preserving fusion-pass
matchers that key on upstream op names. Covered ops:
  - rotary_embedding (MUSA-0154)
  - fused_qk_norm_rope (MUSA-0157)
  - silu_and_mul_per_block_quant (MUSA-0158)
  - rms_norm_dynamic_per_token_quant (MUSA-0159)
  - merge_attn_states (MUSA-0161)
  - topk_softmax (MUSA-0162)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 29, 2026
…ph runner for M2.5

Drive the Eagle3 draft loop under a captured CUDAGraph on MUSA by
backporting upstream vLLM PR vllm-project#34880's per-step CudagraphDispatcher +
CUDAGraphWrapper(FULL) pattern, replacing the previous (deleted)
EagleFullLoopRunner approach.

Changes:
- Delete EagleFullLoopRunner and its attn_backend_array / spec_info
  scaffolding; the eagle.patch.py override is reduced to the additive
  hooks the new pattern needs.
- Additive patches (vllm_musa/patches/): cudagraph_dispatcher (per-step
  dispatch keyed on the draft's runtime mode), gpu_input_batch,
  spec_decode/{dflash,utils}, and llm_base_proposer wiring.
- Boot-time CUDAGraphWrapper(FULL) capture for the draft model, wired
  through gpu_model_runner; gated behind VLLM_MUSA_DRAFT_FULL_WRAP
  (default ON).
- In-place query_start_loc update so the captured draft loop replays
  with the correct per-step offsets.

Net effect on MiniMax-M2.5 + Eagle3 BS=1 decode (yeahdongcn70, mcc 5.1.0):
runner-ON measured +10.2% over runner-OFF. Supersedes the
EagleFullLoopRunner path from PR #43, which did not boot on this toolchain.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 29, 2026
Cherry-pick vllm_musa/_dispatcher_override.py from origin/xd/op_perf
plus the __init__.py hook. Gated by VLLM_MUSA_OP_PERF_OVERRIDES=1
(default OFF).

Prior A/B on yeahdongcn60 showed wash for BS=1 (35.42 vs 35.14) and
BS>=16 regressions. Worth retesting on yeahdongcn70 mcc 5.1.0 +
PR vllm-project#34880 backport.

The dispatcher uses torchada.replace_op_impl to swap upstream
_C/_moe_C impls behind the dispatcher table, preserving fusion-pass
matchers that key on upstream op names. Covered ops:
  - rotary_embedding (MUSA-0154)
  - fused_qk_norm_rope (MUSA-0157)
  - silu_and_mul_per_block_quant (MUSA-0158)
  - rms_norm_dynamic_per_token_quant (MUSA-0159)
  - merge_attn_states (MUSA-0161)
  - topk_softmax (MUSA-0162)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 29, 2026
…ph runner for M2.5

Drive the Eagle3 draft loop under a captured CUDAGraph on MUSA by
backporting upstream vLLM PR vllm-project#34880's per-step CudagraphDispatcher +
CUDAGraphWrapper(FULL) pattern, replacing the previous (deleted)
EagleFullLoopRunner approach.

Changes:
- Delete EagleFullLoopRunner and its attn_backend_array / spec_info
  scaffolding; the eagle.patch.py override is reduced to the additive
  hooks the new pattern needs.
- Additive patches (vllm_musa/patches/): cudagraph_dispatcher (per-step
  dispatch keyed on the draft's runtime mode), gpu_input_batch,
  spec_decode/{dflash,utils}, and llm_base_proposer wiring.
- Boot-time CUDAGraphWrapper(FULL) capture for the draft model, wired
  through gpu_model_runner; gated behind VLLM_MUSA_DRAFT_FULL_WRAP
  (default ON).
- In-place query_start_loc update so the captured draft loop replays
  with the correct per-step offsets.

Net effect on MiniMax-M2.5 + Eagle3 BS=1 decode (yeahdongcn70, mcc 5.1.0):
runner-ON measured +10.2% over runner-OFF. Supersedes the
EagleFullLoopRunner path from PR #43, which did not boot on this toolchain.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 29, 2026
test_rope.py (standalone harness):
- Rename test_*->check_* so pytest does not try to execute these
  shape-parametrized functions as tests (fixture-resolution failure / broken
  CI collection — high).
- Enter set_current_vllm_config() inside main() instead of at import time, so
  importing the module (e.g. pytest collection) doesn't leak a global vLLM
  config into other tests in the same process (medium).
- Fix the multi-replay docstring to match behaviour (replays the same captured
  inputs, NaN/Inf stability check) rather than the inaccurate "rotating
  positions" claim (low).

setup.py:
- Gate the mate -mllvm load-cluster flags behind mcc > 5.0.0 (same version gate
  as SLP) with a VLLM_MUSA_DISABLE_LOAD_CLUSTER opt-out, so an older/unsupported
  mcc/LLVM that doesn't recognise these -mllvm options doesn't hard-fail the
  build (medium). No change on mcc 5.1.0 (the validated build) — the flags stay
  on, so the 103.5 tok/s SOTA is unaffected.

Dropped: the query_len==0 token_indices_to_sample guard in
vllm__v1__spec_decode__utils.patch.py is a faithful backport of upstream vLLM
PR vllm-project#34880's padded-inputs kernel; query_len==0 only occurs for padded rows that
are masked downstream, and the M2.5+Eagle3 SOTA validated correct output, so
forking upstream's kernel here isn't warranted.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
yeahdongcn added a commit to MooreThreads/vllm-musa that referenced this pull request May 29, 2026
* MUSA-0203: PR vllm-project#34880 backport — Eagle3 draft-loop CUDAGraph runner for M2.5

Drive the Eagle3 draft loop under a captured CUDAGraph on MUSA by
backporting upstream vLLM PR vllm-project#34880's per-step CudagraphDispatcher +
CUDAGraphWrapper(FULL) pattern, replacing the previous (deleted)
EagleFullLoopRunner approach.

Changes:
- Delete EagleFullLoopRunner and its attn_backend_array / spec_info
  scaffolding; the eagle.patch.py override is reduced to the additive
  hooks the new pattern needs.
- Additive patches (vllm_musa/patches/): cudagraph_dispatcher (per-step
  dispatch keyed on the draft's runtime mode), gpu_input_batch,
  spec_decode/{dflash,utils}, and llm_base_proposer wiring.
- Boot-time CUDAGraphWrapper(FULL) capture for the draft model, wired
  through gpu_model_runner; gated behind VLLM_MUSA_DRAFT_FULL_WRAP
  (default ON).
- In-place query_start_loc update so the captured draft loop replays
  with the correct per-step offsets.

Net effect on MiniMax-M2.5 + Eagle3 BS=1 decode (yeahdongcn70, mcc 5.1.0):
runner-ON measured +10.2% over runner-OFF. Supersedes the
EagleFullLoopRunner path from PR #43, which did not boot on this toolchain.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0203: mcc 5.1.0 perf flags (re-enable SLP + fast-math + load-cluster) + drop unused paged_attention

On mcc 5.1.0 the vllm-musa base setup.py disables SLP vectorization
(`-mllvm -vectorize-slp=false`, a workaround) and omits the fast-math /
load-clustering hints mate's JIT compiles with. Restoring them is a large,
measured perf lever for M2.5 + Eagle3 — NOT perf-neutral as first assumed.

A/B (TP8+EP8, FULL cudagraph + cudagraph_capture_sizes=[6], same env,
yeahdongcn70, mcc 5.1.0, random 4k/1k BS=1 greedy, warm median):

  base flags (SLP off):   80.0 tok/s   accept 49.5%   accept_len 3.47
  these flags (SLP on):  104.4 tok/s   accept 70.1%   accept_len 4.51
  => +30.5%

The gain is a numerics effect: re-enabling SLP + `-fno-signed-zeros` changes
FP accumulation in the .mu kernels, lifting spec-decode acceptance uniformly
across every draft-tree position — including position 0 (the draft's
single-token accuracy: 0.73 -> 0.84). It is NOT a launch-config or raw-compute
effect (mean_tpot is ~unchanged; acceptance is the mover).

Flags added (matching mate/mate/jit/gemm_ops.py CUDA_FLAGS):
  -fno-signed-zeros
  -mllvm -mtgpu-load-cluster-mutation=1
  -mllvm --num-dwords-of-load-in-mutation=64
  re-enable SLP on mcc 5.1.0+ (gated behind VLLM_MUSA_DISABLE_SLP=1 to opt out)

paged_attention_v1/v2 are dropped from the build: CUDA-only and unused on MUSA
(vllm uses FlashAttention via mate), AND removing them relieves the compile
pressure that triggers an mcc clang-14 frontend segfault on
fused_layernorm_dynamic_per_token_quant.mu once the load-cluster flags are on.
Their torch_bindings ops.impl refs are stripped so vllm import doesn't
reference unbuilt symbols (ops.def schema left intact).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0203: rope JIT unit + captured-replay test

Single-rank parity test against vLLM's pure-PyTorch reference
(`RotaryEmbedding.forward_static`) across 9 shapes x 3 modes:

  EAGER          single call, JIT vs reference
  CAPTURED       JIT inside one CUDAGraph capture+replay vs reference
  MULTI_REPLAY   32x captured-graph replay, NaN/Inf check

Shapes cover the per-rank configurations in our acceptance matrix:
Eagle3 draft TP=8 (q=3), M2.5 target TP=8 (q=6), Qwen3-8B-FP8 TP=8
(q=4), Qwen3-30B-A3B-FP8 TP=2 (q=16), boundary cases (q=1, q=2).

Run on the authorized MUSA container with single GPU 0 visible:

  MUSA_VISIBLE_DEVICES=0 python3 tests/jit_kernel/test_rope.py

Result on yeahdongcn70 + PR #49 + the override-restored fix: all 27
(9x3) PASS with bit-exact parity (q_max_diff=0.0000e+00). Used during
MUSA-0203 investigation to definitively rule out the rope kernel as
the source of the M2.5+Eagle3 runner-ON hang.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0203: address PR #50 review comments

test_rope.py (standalone harness):
- Rename test_*->check_* so pytest does not try to execute these
  shape-parametrized functions as tests (fixture-resolution failure / broken
  CI collection — high).
- Enter set_current_vllm_config() inside main() instead of at import time, so
  importing the module (e.g. pytest collection) doesn't leak a global vLLM
  config into other tests in the same process (medium).
- Fix the multi-replay docstring to match behaviour (replays the same captured
  inputs, NaN/Inf stability check) rather than the inaccurate "rotating
  positions" claim (low).

setup.py:
- Gate the mate -mllvm load-cluster flags behind mcc > 5.0.0 (same version gate
  as SLP) with a VLLM_MUSA_DISABLE_LOAD_CLUSTER opt-out, so an older/unsupported
  mcc/LLVM that doesn't recognise these -mllvm options doesn't hard-fail the
  build (medium). No change on mcc 5.1.0 (the validated build) — the flags stay
  on, so the 103.5 tok/s SOTA is unaffected.

Dropped: the query_len==0 token_indices_to_sample guard in
vllm__v1__spec_decode__utils.patch.py is a faithful backport of upstream vLLM
PR vllm-project#34880's padded-inputs kernel; query_len==0 only occurs for padded rows that
are masked downstream, and the M2.5+Eagle3 SOTA validated correct output, so
forking upstream's kernel here isn't warranted.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* add rope benchmark

Signed-off-by: Xiaodong Ye <yexiaodong60@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <yexiaodong60@mthreads.com>
Co-authored-by: Xiaodong Ye <yexiaodong60@mthreads.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
yeahdongcn pushed a commit to MooreThreads/vllm-musa that referenced this pull request May 29, 2026
…_DFLASH_FULL_WRAP)

The dflash perf gain plateaus at the eager floor (+47% on gemma-4) because the
draft transformer forward runs uncaptured: dflash hits use_dflash() (before
use_eagle()), so supports_sd_full_graph was never set and the draft's per-step
kernel-launch overhead (~50% of a target forward, vs ~8% of compute for a
5-layer draft) is not amortized. Set supports_sd_full_graph=True for dflash too,
gated behind VLLM_MUSA_DFLASH_FULL_WRAP, to engage the existing PR vllm-project#34880
CUDAGraphWrapper(FULL) path on the dflash draft. Buffer-stability of
set_inputs_first_pass (query_start_loc/seq_lens are fresh tensors) is the next
step if capture corrupts acceptance.

Part of MUSA-0400 / MUSA-0403.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
yeahdongcn added a commit to MooreThreads/vllm-musa that referenced this pull request May 30, 2026
…aft-loop CUDAGraph capture) (#52)

* MUSA-0402: accept common_attn_metadata in DFlashProposer.dummy_run

The MUSA gpu_model_runner patch (Hunk 5) passes
common_attn_metadata=spec_decode_cm to drafter.dummy_run for every
proposer, but the upstream DFlashProposer.dummy_run override does not
declare it (parallel drafting uses a single pass and does not consume
it), so enabling dflash spec-decode boots with
`TypeError: dummy_run() got an unexpected keyword argument
'common_attn_metadata'`. Add the (ignored) kwarg to the dflash
dummy_run signature patch; CommonAttentionMetadata is already imported
in dflash.py.

Part of MUSA-0400 (dflash + gemma-4-31B-it on vllm-musa).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0401/0402: mate FLASH_ATTN for non-causal head-512 + durable dflash patch

dflash spec-decode requires non-causal attention on the target model; on MUSA
V1 the only non-causal-capable decoder backend is mate FLASH_ATTN (TORCH_SDPA
is ViT-only, TRITON_ATTN rejects non-causal). gemma-4's full-attention layers
use head_dim 512, which MUSAFlashAttentionBackend rejected at
supports_head_size <= 256. Relax the gate to <= 512 so mate's FA3 wrapper
handles those layers (select via VLLM_ATTENTION_BACKEND=FLASH_ATTN).

Also fix a latent patch-application gap: vllm.v1.spec_decode.dflash is not
import-resolvable during the plugin-load apply_patches() pass, so the dflash
dummy_run signature patch (MUSA-0402) was silently skipped. apply_patches now
takes force=; check_and_update_config re-applies (force=True) when dflash is
active — main process, before spawn workers fork — so the patch lands on the
installed dflash.py the workers import.

Part of MUSA-0400.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0401/0402: declare non-causal + mm-prefix support on MUSA FLASH_ATTN

MUSAFlashAttentionBackend extends the base AttentionBackend, which defaults
supports_non_causal() and supports_mm_prefix() to False, so
validate_configuration rejected FLASH_ATTN for (a) dflash's non-causal verify
attention and (b) gemma-4's mm_prefix_lm multimodal layers — even though mate's
FA3 wrapper handles non-causal (causal= is plumbed through flash_attn_varlen_func)
and the partial-mm path is dormant for text-only serving. Override both to True
so mate FLASH_ATTN can be selected for gemma-4 + dflash.

mm-prefix correctness for actual image/audio input on mate FA is unverified;
this targets the text-only dflash workload.

Part of MUSA-0400.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0405: align MUSA FLASH_ATTN capability declarations with verified mate behavior

Probe-verified (generated/musa0400/probe_fa_caps.py) the MUSA FLASH_ATTN
capability classmethods against mate flash_attn_varlen_func + torch references:
- supports_per_head_quant_scales True -> False: mate rejects fp8 Q/K inputs
  ("inputs must be float16 or bfloat16"); per-head FP8 quant-scale attention is
  not supported (consistent with flash_attn_supports_fp8()=False and
  get_fp8_dtype_for_flashattn raising NotImplementedError).
- supports_batch_invariance (inherited base False) -> True: mate FA
  deterministic=True is bit-invariant (solo-vs-batched max_abs_err=0.0),
  matching upstream FlashAttentionBackend.
- supports_mm_prefix: documented as an intentional TEXT-ONLY over-claim (mate
  exposes no partial/arbitrary mask, so partial-mm bidirectional attention is not
  genuinely supported; only safe because the dflash workload is text-only).
- supports_non_causal=True, supports_sink=True: probe-confirmed correct (kept).

Root cause: MUSA backends extend the base AttentionBackend (not the upstream
concrete class), so they override only a subset of supports_*/is_* and inherit
base defaults for the rest -> capability drift in both directions.

Part of MUSA-0400 / MUSA-0405.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0403: opt-in dflash draft-loop FULL CUDAGraph capture (VLLM_MUSA_DFLASH_FULL_WRAP)

The dflash perf gain plateaus at the eager floor (+47% on gemma-4) because the
draft transformer forward runs uncaptured: dflash hits use_dflash() (before
use_eagle()), so supports_sd_full_graph was never set and the draft's per-step
kernel-launch overhead (~50% of a target forward, vs ~8% of compute for a
5-layer draft) is not amortized. Set supports_sd_full_graph=True for dflash too,
gated behind VLLM_MUSA_DFLASH_FULL_WRAP, to engage the existing PR vllm-project#34880
CUDAGraphWrapper(FULL) path on the dflash draft. Buffer-stability of
set_inputs_first_pass (query_start_loc/seq_lens are fresh tensors) is the next
step if capture corrupts acceptance.

Part of MUSA-0400 / MUSA-0403.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0403: dflash draft-loop FULL capture — dummy_run boot-capture + query_len align

Phase 4b pieces 1+2, so dflash's draft forward becomes FULL-capturable (the +47%
eager floor is draft-launch-bound; capture is the lever toward ~2x):
- dflash dummy_run rewrite: dispatch at uniform_decode_query_len=1+num_spec
  (dflash is a parallel block proposer, not Eagle's 1-token chain step), build
  dflash's non-causal per-layer attn metadata, and pass batch_descriptor=batch_desc
  to set_forward_context so the FULL graph registers under the inference key
  (fixes the capture-at-inference "capturing at an inappropriate time" error).
- initialize_cudagraph_keys: register + dispatch draft FULL keys at the
  parallel-drafting query_len (1+num_spec), not 1.

Next (piece 3) if captured replay collapses acceptance: make
set_inputs_first_pass query_start_loc/seq_lens buffer-stable.

Part of MUSA-0400 / MUSA-0403.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0403: dflash dummy_run uniform_decode=False to match inference key

The patched inference propose() dispatches with uniform_decode =
common_attn_metadata.max_query_len == 1; for dflash common_attn_metadata is the
non-causal new_cad (max_query_len = 1 + num_speculative_tokens > 1) so inference
uniform_decode is always False. The dummy_run captured under uniform=True ->
batch_descriptor key mismatch -> capture-at-inference. Match it.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* MUSA-0403: dflash draft capture default-on + block-aligned cudagraph coercion

Flip VLLM_MUSA_DFLASH_FULL_WRAP to default-on (opt out with =0): the draft-loop
FULL CUDAGraph capture lifts the fair compile ratio 1.47x->1.83x (prose) and to
4.09x (predictable). Acceptance stays healthy (no per-position collapse), so the
set_inputs_first_pass buffer-stability concern is resolved.

platform.check_and_update_config now coerces dflash cudagraph to pure FULL +
block-aligned capture sizes (default the single BS=1 block of 1+num_spec), so
vLLM default non-block-aligned sizes ([1,2,4,8,16,...]) no longer crash the
draft capture with an illegal memory access. Multi-block (BS>1) capture: MUSA-0406.

* MUSA-0403: dflash capture coercion forces single BS=1 block [9]

The block-aligned filter kept whatever multiples of the block leaked through
vLLM's default capture-size set, which are all large (72=BS8, 144, ...) — a BS=1
9-token decode then padded to 72 and ran ~35 tok/s instead of 48. Force exactly
[1+num_spec] (the proven, no-padding size). Multi-block (BS>1): MUSA-0406.

* MUSA-0406: extend dflash capture to block-aligned BS=1,2,4,8 [9,18,36,72]

Multi-block (BS>1) draft capture: the [72,144,...] default-on run proved pure
FULL captures multi-block sizes without crashing, so coerce to the small
block-aligned set [9,18,36,72] (BS 1,2,4,8) instead of forcing [9]. Each BS=N
decode uses the exact size-N*block graph (no padding). To be validated by a
concurrency sweep.

* Address PR #52 review: mm_prefix comment accuracy + stable use_dflash() detection

- flash_attn.py (augment HIGH): correct the misleading supports_mm_prefix comment.
  mm_prefix=True is REQUIRED for multimodal-model backend selection (gemma-4) and was
  VALIDATED on Qwen2.5-VL (image input correct, MUSA-0404 regression); the prior
  'WOULD be incorrect' was unverified speculation and is disproven for the patterns
  these models use. Residual risk (arbitrary partial 2D mask) documented as untested.
- platform.py (augment MEDIUM): detect dflash via the stable use_dflash() API instead
  of the 'method' string (which can change representation across vLLM versions), with
  a string fallback, so the dflash patch re-apply is never silently skipped.

---------

Co-authored-by: Xiaodong Ye <yexiaodong60@mthreads.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants