Skip to content

feat(qwen3): DFlash speculative decoding#436

Open
xiaguan wants to merge 9 commits into
mainfrom
feat/qwen3-speculative-decoding
Open

feat(qwen3): DFlash speculative decoding#436
xiaguan wants to merge 9 commits into
mainfrom
feat/qwen3-speculative-decoding

Conversation

@xiaguan

@xiaguan xiaguan commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds DFlash speculative decoding to Qwen3-4B behind --dflash-draft-model-path, and optimizes it until every batch size matches or beats vLLM's DFlash on the 5090 while staying greedy-lossless. A clean port of the earlier WIP (#380) onto current main, with the draft/verify split refactored into a method-agnostic core.

The abstraction — speculative decode as an optimistic transaction. Every speculative method is the same transaction; only the propose step differs:

  1. Propose — the DFlash drafter emits K candidate tokens.
  2. Verify — the target runs one prefill-style forward over the K+1 span [current, draft_1..draft_K] and reports its argmax at every position.
  3. Acceptaccept_greedy keeps the longest prefix matching the target argmax, then appends the target's own token at the first mismatch. A verify step always commits 1..=K+1 tokens.
  4. Commit / roll back — accepted tokens' KV is committed; the unused tail of the speculative reservation is LIFO-dropped.

The draft↔verify boundary is a pure token span — hidden states never leave the proposer. That keeps the shared core (speculative.rs) method-agnostic. No proposer trait yet — deliberately concrete until a second method (n-gram / EAGLE) lands.

Key invariant. Speculative-on forces prefix caching off, so every request's target hidden context is captured during its own prefill — removing the need for a per-request drafter-ready handshake. Readiness is derived from the prefill capture-set.

Losslessness

Greedy decode is lossless up to bf16 numerical tie-flips — the same non-determinism that already affects plain greedy decode at genuine bifurcation points (cf. hf_golden_gate's MARGIN_TOL = 0.20). Evidence: multi-token accepts (runs up to 11 tokens) are bit-identical to baseline at every non-tie position; re-runs flip different prompts (non-deterministic ⇒ numerical, not a logic bug); the spec pick is always the #1/#2 token of the prefill kernel's own distribution, within 0.20 nat. The gate passes for the batched draft and the piecewise verify graph (dense ops replay bit-identically; attention is unchanged).

Performance

Goal: match or beat vLLM's DFlash across batch sizes on the 5090 — achieved. All batch sizes now at or above vLLM.

Single-stream decode A/B (bs=1, tests/dflash_speculative_perf.rs): 1.82× on 5070 Ti (93.4 → 170.0 tok/s), 1.56× on 5090 (168.9 → 263.2).

Concurrent greedy throughput, openinfer vs vLLM with the same DFlash-b16 drafter (5090, sharegpt out128, tok/s):

concurrency OI plain OI DFlash batched OI DFlash +graph vLLM DFlash
c1 170 237 274 278
c8 1180 1346 1525 1240
c16 2277 1868 1834 1846

Two fixes got here, both DFlash-internal (behind the unchanged DraftPlan→DraftResult seam, so neither is thrown away by a future EAGLE proposer):

  1. Batched draft (78e2160). The draft ran a per-request serial for loop — launch-bound (a skip-attention A/B showed attention compute <2%, so 24.8 ms/step at batch 16 was almost all kernel-launch overhead), which inverted the single-stream win (c16 −59%). Batching the dense ops into one N×block pass drops draft@batch16 24.8 → 5.6 ms and lifts c8/c16 past vLLM.
  2. Piecewise verify CUDA Graph (84b98f0, 1620c0c). Under greedy dflash the spec path had no CUDA Graph at all (the base-decode graph never fires); nsys showed 1296 ms of GPU-idle launch gap, 84% from dense-op launches (GEMM 48%, RMSNorm 15%), only 8% attention. Capturing the whole forward is infeasible — FlashInfer's paged-prefill attention freezes its KV-iteration count at record time, so a captured attention under-reads the growing verify KV and corrupts tokens past ~60. The fix is piecewise: split forward_layer_batch_paged into pre/attn/post, capture the dense ops (embedding, RMSNorm, every GEMM, SwiGLU, residual adds) into num_layers+1 segments and replay them per batch bucket, keeping attention eager. c1 250.9 → 274.3 (+9.3% from the graph alone), matching vLLM.

Accept is not the gap: measured 9.1% (ours) vs 8.85% (vLLM) with the same drafter — mean 1.29 vs 1.42 draft tokens/step, our pos-0 accept (60%) actually higher.

Tests

  • speculative.rs accept-core units (full accept / prefix+correction / reject / multi-request column split).
  • scheduler::tests resolve truth-table: multi-token emission, stop-token suppression mid-span / at span start, max-tokens truncation, ignore_eos, per-request independence.
  • tests/dflash_speculative_gate.rs — regret-based losslessness gate (passes for serial, batched, and graph builds).
  • tests/dflash_speculative_perf.rs — single-stream A/B harness.
  • hf_golden_gate passes (bs1 / batched / cuda-graph / tp2); lm-eval gsm8k strict-match identical spec on/off.

Notes for review

  • First commit (#433) is a self-contained prep refactor — prunes redundant scheduler fake-loop tests and extracts the test module out of the 1.5k-line scheduler.rs. Can be split into its own PR if preferred.
  • The last three commits (84b98f0 / 1620c0c / 4e03b0e) are the piecewise verify CUDA Graph + its doc; the two before (0fa0a67 GPU-budget, 78e2160 batched draft) are the concurrent-throughput fix. All toxic-reviewed; no correctness/losslessness blocker.
  • 5090 validation — done. Correctness (hf_golden_gate bs1/batched/cuda-graph/tp2), losslessness gate (serial + batched + graph), single-stream 1.56×, gsm8k parity, and the concurrent A/B above.
  • Tracked follow-up (its own PR / issue after this lands): draft-side piecewise CUDA Graph — the draft (5 layers, dflash.rs) is the remaining ~16% of the launch gap; same pre/attn/post split, with its variable-length contiguous KV (DFlashLayerCache) handled at the eager attention boundary.
  • Known smaller follow-ups: (1) the losslessness gate regret-checks only the first divergence position per prompt — extend to re-anchor after a benign tie; (2) executor.rs (~2.8k) and scheduler/tests.rs (~1.2k) exceed the 1k guideline (spec arms in execute_step_on_laneexecutor/spec.rs); (3) tighten the DFlash KV-budget reservation from the per-request upper bound to the now-smaller lane-level footprint.

Design + findings: docs/models/qwen3/dflash-speculative-decoding.md.

🤖 Generated with Claude Code

@xiaguan xiaguan marked this pull request as ready for review June 22, 2026 04:20

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ae71cc21da

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +425 to +428
for _ in 0..self.layers.len() {
layers.push(DFlashLayerCache {
k: HiddenStates::zeros(ctx, kv_dim, max_cache_len)?,
v: HiddenStates::zeros(ctx, kv_dim, max_cache_len)?,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Account for per-request DFlash KV allocations

With --dflash-draft-model-path, every eligible request reaches new_request_state and allocates a separate K/V cache for each draft layer sized to max_cache_len; these buffers are outside the shared KvCacheManager. The scheduler admission and the startup memory profile only budget target-model KV blocks, so concurrent or long requests that pass admission can still allocate hundreds of MiB per request here and OOM during prefill/draft. Please include these DFlash buffers in admission/profiling or pool/cap them explicitly.

Useful? React with 👍 / 👎.

Comment on lines +95 to +96
let max_cache_len =
req.prompt_tokens.len() + req.max_output_tokens + dflash.model.block_size();

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve valid context-limit requests

For DFlash-enabled requests near the advertised context limit, this adds a full draft block to the request capacity before calling new_request_state, which rejects when prompt_tokens + max_output_tokens + block_size exceeds the drafter's max positions. The server still admits and advertises requests up to the target model's prompt + max_tokens limit, so valid requests in the final draft block of context fail at execution instead of generating or cleanly falling back. Clamp the lookahead to remaining capacity or reject/advertise the smaller DFlash limit at admission.

Useful? React with 👍 / 👎.

@xiaguan

xiaguan commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator Author

Addressed both review findings in f446cfc.

P1 — draft GPU footprint was unbudgeted (OOM risk)

The draft model and its per-request KV/scratch live outside the paged KV pool, so they used to load on top of the full pool. DFlashMemoryReservation::from_config now reserves the footprint during memory profiling, before the pool is sized, split by how it scales:

  • Per-token, pool-scaling (draft KV + context/tail scratch + pending = 65 536 B/token) → folded into effective_bytes_per_block, shrinking the target block count. The pool itself stays allocated at the target-only bytes_per_block. Safe upper bound: the scheduler reserves pool blocks for each request's full prompt + max_tokens lifetime, which bounds the draft's attended tokens.
  • Fixed (draft weights ~1.1 GiB + block-sized per-request scratch across the decode batch + one in-fill block of per-request headroom) → added to the KV margin.

Measured on RTX 5070 Ti (16 GB, util 90%):

margin KV budget KV blocks
DFlash OFF 150 MiB 4113 MiB 1828
DFlash ON 2972 MiB 1389 MiB 427

The fixed reservation lands in the margin; the ~1.44× per-token factor shrinks the rest. The TP path reserves nothing (DFlash is single-GPU) and rejects a draft path with a clear error.

P2 — context-limit request crashed mid-prefill

The drafter's fixed-width in-fill block writes block_size positions past the committed length each step, so the DFlash-effective context is max_position_embeddings − block_size. max_context_tokens() returns that when DFlash is on, so a request that fits the target window but lands in the draft's final block is now rejected cleanly at admission instead of panicking.

Verification (local RTX 5070 Ti)

  • Reservation arithmetic unit test (per-token 65536, weights ~1.1 GiB, fixed scaling with decode batch).
  • New admission-rejection gate: a 40952-token request (prompt 16 + max_tokens 40936) is rejected at the 40944 cap — no panic.
  • Losslessness gate unchanged: 4/5 prompts bit-identical + 1 benign bf16 tie.
  • Single-stream A/B unchanged: 1.79× (was 1.82×, run noise — the decode path is untouched).
  • 50/50 lib tests, clippy clean on changed files.

Sharing the per-request scratch (logits-dominated) to reclaim most of the ~2.8 GiB reservation is documented as a tracked follow-up — it touches the draft forward and deserves its own validated PR.

@xiaguan xiaguan force-pushed the feat/qwen3-speculative-decoding branch from ba7d310 to bde4345 Compare June 22, 2026 07:16
xiaguan and others added 4 commits June 22, 2026 17:30
Model speculative decoding as an optimistic transaction: the DFlash drafter
proposes K tokens, one target verify forward over the K+1 span confirms them,
and we commit the longest argmax-matching prefix plus one bonus token, rolling
back the rest of the speculative KV reservation. Only the propose step is
method-specific; the draft/verify boundary is a pure token span so hidden
states never leave the proposer. The shared accept core (speculative.rs) is
method-agnostic; no proposer trait until a second method lands.

Enabled with --dflash-draft-model-path (Qwen3 only, TP=1, primary rank). The
server rejects the flag for other model lines rather than silently ignoring it.
Speculative-on forces prefix caching off so every request's target hidden
context is captured during prefill, removing the need for a per-request
drafter-ready handshake.

Greedy decode is lossless up to bf16 numerical tie-flips (the same
non-determinism that already affects plain greedy decode at genuine
bifurcations): multi-token accepts are bit-identical to baseline at every
non-tie position. Single-stream decode A/B on RTX 5070 Ti (bs=1): 93.4 -> 170.0
tok/s, 1.82x.

Tests: speculative.rs accept-core units, scheduler resolve truth-table
(multi-token emission / stop-token / max-tokens truncation), tests/
dflash_speculative_gate.rs (regret-based losslessness gate) and
dflash_speculative_perf.rs (single-stream A/B). Docs in
docs/models/qwen3/dflash-speculative-decoding.md.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
5090 results: hf_golden_gate passes (bs1/batched/cuda-graph/tp2), losslessness
gate passes, single-stream decode 168.9 -> 263.2 tok/s (1.56x; lower than the
5070 Ti's 1.82x because the 5090's higher bandwidth makes baseline decode less
memory-bound). lm-eval gsm8k (5-shot greedy) strict-match identical spec on/off
(0.86), flexible-extract within one question — task-level confirmation of
greedy losslessness.

Also drop the stale "5070 Ti" label hardcoded in the perf A/B test output, and
note the orthogonal frontend gap (per-request `seed` is rejected, not ignored).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ssion

DFlash's draft model and per-request KV/scratch live outside the paged KV
pool, so they were never budgeted — under concurrency or long contexts the
draft stole from the pool and risked OOM. And a request that fit the target
context window but landed in the draft's fixed-width in-fill block crashed
mid-prefill instead of being rejected.

Budget: DFlashMemoryReservation::from_config (reads the draft config) splits
the footprint into a per-token, pool-scaling term (draft KV + context/tail
scratch + pending) folded into the per-block budget so the target block count
shrinks, and a fixed term (draft weights + block-sized scratch across the
decode batch + one in-fill block of per-request headroom) added to the KV
margin. Reserved during memory profiling, before the pool is sized. Measured
on RTX 5070 Ti: margin 150->2972 MiB, KV pool 1828->427 blocks — the pool
now makes room instead of OOMing.

Context limit: max_context_tokens() returns max_position_embeddings minus
block_size when DFlash is on, so over-limit requests are rejected cleanly at
admission rather than crashing mid-prefill.

Tests: reservation arithmetic unit test (per-token 65536, weights ~1.1GiB);
admission-rejection gate; losslessness gate and 1.79x single-stream A/B
unchanged. Sharing the per-request scratch to reclaim most of the ~2.8GiB
reservation is a tracked follow-up.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
CI's `cargo fmt --all --check` was red since the feature commit. Format the
DFlash files (and the speculative/scheduler/core/kernels code touched by the
branch) to the canonical style. No behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@xiaguan xiaguan force-pushed the feat/qwen3-speculative-decoding branch from bde4345 to 326f716 Compare June 22, 2026 09:33
xiaguan and others added 2 commits June 22, 2026 17:40
Rebasing onto main surfaced two API drifts the dflash tests predated: Qwen3LaunchOptions gained batch_invariant (set false in the dflash perf/gate launch helpers), and start_engine_with_offload gained the dflash draft-path arg (pass None in the main-side batch-invariance tests scheduler_robustness/reject). Test-only, no behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
DFlash's draft ran a per-request serial for-loop — N full forwards = N×
kernel launches, launch-bound (a skip-attention A/B proved attention
compute is <2% of the draft). This inverted the single-stream win under
concurrent load (5090 greedy c16 −59% vs plain decode).

Batch the dense ops (rms_norm/GEMM/silu/MLP/embedding/logits) into one
pass over an N×block buffer — free, since cuBLAS takes any M and the ops
are already row-batched. The varlen ops (rope/KV-copy/attention) stay a
per-request loop slicing the batched buffers at row_offset; the two
DFlash-exclusive ops (dflash_qk_norm_rope_into / single_prefill_nhd_
noncausal_into) gain a row-offset param that advances the device pointer
to the slice (no CUDA-kernel change). A lane-level DFlashBatchScratch
replaces the per-request scratch.

5090 greedy A/B (sharegpt out128, same-session serial vs batched):
  c8  831 → 1346 tok/s  (vLLM 1240)
  c16 1013 → 1868 tok/s (vLLM 1846)
  draft@batch16 24.86 → 5.62 ms (draft_x 15.96 → 3.62)

Both c8/c16 now beat vLLM. Losslessness gate passes (bf16 tie-flips only,
regret ≤ 0.2). c1 (single-stream) is unchanged — batch=1 has no batching
win; its gap to vLLM is launch-bound draft overhead, since openinfer's
spec path (unlike base decode) isn't CUDA-Graph captured. Accept rate is
measured-equal to vLLM (9.1% vs 8.85%, same drafter), so CUDA-Graph draft
is the tracked next step.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@xiaguan

xiaguan commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator Author

Update — 78e2160: batched draft fixes concurrent throughput

The original draft ran a per-request serial loop (execute_dflash_draft called draft_logits once per active request), so at batch N the draft did N full forwards = N× kernel launches. A skip-attention A/B proved it was launch-bound (attention compute <2% of the draft), and this inverted the single-stream win under concurrent load (5090 greedy c16 −59% vs plain decode).

This commit batches the whole draft forward: the dense ops (rms_norm / GEMM / silu / MLP / embedding / logits) run once over an N×block buffer — free, since cuBLAS takes any M and the ops are already row-batched. The varlen ops (rope / KV-copy / attention) stay a per-request loop slicing the batched buffers at row_offset = i·block_size; the two DFlash-exclusive ops gained a row-offset param (no CUDA-kernel change). A lane-level DFlashBatchScratch replaces the per-request scratch.

5090 greedy A/B (sharegpt out128, same-session serial vs batched, same DFlash-b16 drafter):

concurrency DFlash serial DFlash batched vLLM DFlash
c1 245 237 278
c8 831 1346 1240
c16 1013 1868 1846

Both c8 / c16 now beat vLLM. Per-step draft@batch16 drops 24.86 → 5.62 ms (draft_x 15.96 → 3.62). The losslessness gate still passes (bf16 tie-flips only, regret ≤ 0.2).

c1 (single-stream) is unchanged — batch=1 has nothing to batch. Its remaining gap to vLLM (237 vs 278) is not accept: with the same drafter and spec config, measured greedy accept is 9.1% (ours) vs 8.85% (vLLM), and our pos-0 accept is actually higher. The gap is the draft's ~1 ms pure kernel-launch overhead (85 tiny kernels), since openinfer's spec path — unlike base decode — isn't CUDA-Graph captured. CUDA-Graph draft is the tracked next step (its own follow-up PR).

xiaguan and others added 3 commits June 23, 2026 01:31
DFlash verify ran batch_prefill(echo=true) which allocates fresh GPU
scratch every step (PrefillBuffers::new, embedding HiddenStates,
all_logits, PrefillPagedPlan upload) — moving pointers that CUDA Graph
capture cannot tolerate. This is the buffer-reuse refactor that precedes
capturing the verify forward into a graph (no capture yet this phase).

- PrefillPagedPlan: new_preallocated + update_batch_with_cta_tile_q
  refill the device buffers in place (memcpy, capacity-guarded) instead
  of clone_htod, keeping pointers stable. Host plan math extracted into a
  shared BatchPlanHost so new_batch_with_cta_tile_q is byte-identical.
- VerifyGraphBuffers: pre-allocates all verify scratch once at the
  worst-case max_batch*span shape; set_rows only moves seq_len.
- batch_prefill_into: buffer-reusing twin of batch_prefill, issues the
  identical kernel sequence (forward_layer_batch_paged reused verbatim)
  so verify logits/captured-hidden stay bit-equal. Verify never uses LoRA.
- lane: SpeculativeVerify routes through execute_dflash_verify; the old
  allocating verify path is gone. Normal prefill is untouched.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Capture the verify forward's dense ops into per-segment CUDA Graphs and
replay them, keeping attention eager. FlashInfer's paged-prefill attention
freezes its KV-iteration count at capture time, so capturing it corrupts
later tokens as the verify context grows (observed: garbage past ~token 60
once KV crosses a CTA_TILE_KV boundary). Every other op — embedding,
RMSNorm, all GEMMs, SwiGLU, residual adds (~84% of the per-step launch
gap) — is shape-stable in the fixed span-row layout and safe to replay.

Split forward_layer_batch_paged into pre_attn / attn / post_attn so the
verify path can interleave dense graph segments with eager attention; the
normal batch_prefill path calls all three in sequence, behaviour unchanged.

5090, Qwen3-4B + DFlash-b16, greedy c1: 250.9 -> 274.3 tok/s (+9.3% from
the graph alone; +15.7% over the pre-graph 237 baseline), matching vLLM
dflash (278). Losslessness gate passes (only bf16 tie-flips).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… gap

The launch gap is dense-op-dominated (nsys: 84% dense GEMM/norm, 8%
attention), not draft-specific — correcting the earlier "draft launch is
the c1 gap" read. Full-forward capture is infeasible (FlashInfer
paged-prefill attention freezes its KV-iteration count at record time;
captured attention corrupts tokens past ~60 as verify context grows). The
piecewise graph keeps attention eager and captures the dense segments.

5090 greedy: c1 250.9 -> 274.3 (+9.3% graph; ~vLLM 278), c8 1346 -> 1525,
c16 ~flat. All batch sizes now at or above vLLM. Draft-side piecewise graph
tracked as the next step.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant