diff --git a/docs/index.md b/docs/index.md index 8f41a7f6..9ffd5baa 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,6 +30,7 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | `models/qwen3/roadmap.md` | Qwen3-4B roadmap (2026-06 review): line is the maturity bar; #220 RoPE OOB, batched greedy sampling (#307), mixed greedy/non-greedy sampling (#284), and pegaflow KV offload (#316) are landed; open set is zero TP coverage, zero-adapter-only LoRA gate, dropped prefix-cache observability, stale docs, and YaRN #8 follow-up. | | `models/qwen3/model-crate.md` | `openinfer-qwen3-4b` owns Qwen3 config/weights/executor/scheduler/tests/kernel plan; root sees generic `EngineHandle`; split-K retuned to `256/64`, with 4k/64 serving TPOT p50 at `6.46ms` on RTX 5090. | | `models/qwen3/prefix-cache.md` | Prefix caching on by default for Qwen3-4B: full-block kvbm radix matching at the executor, suffix-only prefill. Repeated ~1900-token prompt TTFT 141.8 → 16.3ms p50 (8.7×); warm TTFT ≈ TPOT + ~5ms setup. Includes the RoPE scalar-path corruption fix and the drain-the-stream TTFT measurement pitfall. | +| `models/qwen3/dflash-speculative-decoding.md` | DFlash speculative decoding behind `--dflash-draft-model-path`, modelled as an optimistic transaction (propose K → verify K+1 span → accept longest argmax prefix + 1 bonus → commit/roll back KV). Lossless up to bf16 tie-flips (bit-identical multi-token accepts; lm-eval gsm8k strict-match identical spec on/off). Single-stream decode 1.82× on 5070 Ti, 1.56× on 5090. Concurrent throughput fixed by batching the draft forward, then a piecewise verify CUDA Graph (dense ops captured, attention eager) closed single-stream: 5090 greedy c1 274 ≈ vLLM 278, c8 1525 > 1240, c16 1834 ≈ 1846 — all batch sizes now ≥ vLLM. Accept measured equal (9.1% vs 8.85%, same drafter); draft-side piecewise graph tracked next. Proposer trait deferred to EAGLE. | | `models/qwen3/accuracy-gate.md` | Qwen3-4B instance of the logits golden gate (`tests/hf_golden_gate.rs`): 48 teacher-forced sequences / 816 positions vs a stored HF bf16 golden, replayed over bs=1 / batched eager / CUDA-graph. Strict guards: regret check + mean ≤ 0.06 + p99 ≤ 0.20; absolute max printed but not asserted (coverage-unstable). Methodology in `subsystems/correctness/`. | | `models/qwen3/kernels-crate.md` | Phase 1 split implemented and 5090-verified: Qwen3-4B kernel surface lives in `openinfer-kernels`; release build, test-target compile, accuracy gate, and bench snapshot pass. | | `models/qwen3/tp-design.md` | Qwen3 tensor-parallel design: `TP=2` milestone scope plus the controller/worker broadcast execution model, request identity, and coarse-grained step protocol for future TP/MoE work. | diff --git a/docs/models/qwen3/dflash-speculative-decoding.md b/docs/models/qwen3/dflash-speculative-decoding.md new file mode 100644 index 00000000..fd67f33c --- /dev/null +++ b/docs/models/qwen3/dflash-speculative-decoding.md @@ -0,0 +1,143 @@ +# DFlash Speculative Decoding (Qwen3-4B) + +**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points; lm-eval gsm8k strict-match is identical with spec on vs off. Measured single-stream decode A/B: **1.82× on RTX 5070 Ti** (93.4 → 170.0 tok/s), **1.56× on RTX 5090** (168.9 → 263.2 tok/s). The drafter's out-of-pool footprint (weights + per-request KV/scratch) is reserved during memory profiling so the KV pool shrinks to fit instead of OOMing under load, and requests landing in the draft's final in-fill block are rejected cleanly at admission rather than crashing mid-prefill. **Concurrent throughput is now competitive after batching the draft forward.** The draft used to run a per-request **serial** `for` loop — launch-bound (a skip-attention A/B showed attention compute <2%, so 24.8 ms/step at batch 16 was almost all kernel-launch overhead), which inverted the single-stream win (c16 −59%). Batching the dense ops into one N×block pass drops draft@batch16 to **5.6 ms** and lifts 5090 greedy throughput to **c8 1346 / c16 1868 tok/s — both now beating vLLM's 1240 / 1846**. The single-stream gap is now closed by a **piecewise CUDA Graph** over the verify forward: 5090 greedy **c1 237 → 274 tok/s (+16%), matching vLLM dflash (278)**; c8 1346 → 1525 and c16 ≈ flat (no regression, both still ≥ vLLM's 1240 / 1846). Under greedy dflash the spec path ran with no CUDA Graph at all (the base-decode graph never fires), so nsys saw 1296 ms of GPU-idle launch gap — **~84% from dense-op kernel launches** (GEMM alone 48%), only 8% from attention. Capturing the *whole* forward fails: FlashInfer's paged-prefill attention freezes its KV-iteration count at capture time, so a captured attention under-reads the growing verify KV and corrupts tokens past ~60. The fix is **piecewise** — capture the dense ops (embedding, RMSNorm, every GEMM, SwiGLU, residual adds) into per-segment graphs and replay them, keeping **attention eager**. Accept is *not* the gap (9.1% vs vLLM's 8.85%, same drafter). Draft-side piecewise graph is the tracked next step. See Performance § for the A/B tables. + +Last touched: 2026-06 + +## The abstraction: speculative decode = optimistic transaction + +Every speculative method is the same transaction; only *propose* differs. + +1. **Propose** — a method-specific drafter emits K candidate tokens. (DFlash is the only proposer today; an n-gram / EAGLE proposer would slot in here.) +2. **Verify** — the target model runs ONE prefill-style forward over the K+1 span `[current_token, draft_1, …, draft_K]` and reports its argmax at every position (echo=true). +3. **Accept** — `accept_greedy` keeps the longest prefix where each draft matches the target argmax, then appends the target's own token at the first mismatch (the "bonus"). A verify step therefore always commits `1..=K+1` tokens — at least one token of progress even when every draft is rejected. +4. **Commit / roll back** — accepted tokens' KV is committed (`apply_speculative`); the unused tail of the speculative reservation is LIFO-dropped (`revert_schedule`). + +The draft↔verify boundary is a **pure token span**. Hidden states never cross it — they stay inside the proposer (`dflash.rs` / `dflash_lane.rs`). This is what lets the shared core (`speculative.rs`: `accept_greedy`, `build_verify_results`) be method-agnostic, and it is why there is deliberately **no proposer trait yet**: a trait with one impl is premature. Add it when a second method lands. + +Shared core lives in `openinfer-qwen3-4b/src/speculative.rs`; the transaction wiring is in `openinfer-qwen3-4b/src/executor/spec.rs`; the KV transaction primitives (`schedule_speculative` / `apply_speculative` / `speculative_view` / `revert_schedule`) are in `openinfer-kv-cache/src/pool.rs` delegating to kvbm `scheduled.rs`. + +## Key invariant: readiness comes from prefill capture, not a handshake + +Speculative-on **forces prefix caching off**. With no prefix reuse, every eligible request's target hidden context is captured during its own prefill — so there is no per-request "drafter ready" handshake. Readiness is derived from the prefill capture-set plus the completed flag. This is the load-bearing simplification; if prefix caching were allowed, a cache-hit request would skip the prefill that seeds the drafter, and the invariant would break. + +Draft-seed alignment: after a verify accepting `m` drafts, the committed span positions with *known* hidden states are `[current_token, draft_1..draft_m]` = `m+1` rows. The next current_token (`target_argmax[m]`) is freshly predicted and its hidden was never forwarded — exactly like the post-prefill first token. `record_verify_dflash_context` appends `accepted_tokens.len()` (= `m+1`) rows and advances `kv_position` by the same. This is the subtlest part of the feature; it is defended by crash-early asserts (`append_pending_context` overflow/dim/range checks). + +## Losslessness: lossless up to bf16 ties + +The greedy oracle is "spec-on equals spec-off token-for-token". In practice the two diverge only at genuine bifurcation points, and only because of bf16 numerical noise — **not** a logic bug. Evidence: + +- The verify forward builds committed KV incrementally across batched spans (M=16-ish), while a one-shot prefill is a single M=1 forward. cuBLAS picks different GEMM tilings → ULP-scale reduction-order differences → an argmax flip at a near-tie. This is the same class as the `hf_golden_gate` `MARGIN_TOL = 0.20` tolerance. +- Multi-token acceptance (runs observed up to 11 tokens) is **bit-identical** to baseline at every non-tie position — proving the span-position-≥1 path is correct. +- Re-running flips **different** prompts each time (non-deterministic ⇒ bf16, not a deterministic logic error); one observed flip had regret 0.000 (an exact tie). +- The spec pick is always the #1 or #2 token of the prefill kernel's own distribution, within 0.20 nat of #1. + +### The losslessness gate (`tests/dflash_speculative_gate.rs`) + +The gate runs a baseline (spec off, logprobs on) and a spec engine on the same prompts, then at the first token where they diverge it measures the spec pick's **regret** against the prefill kernel's own distribution. Within `MARGIN_TOL = 0.20` nat ⇒ benign tie-flip (pass); clearly worse ⇒ real bug (fail). The realistic bug classes (KV misalignment, mask leak) push the pick far outside 0.20 nat, so the gate has teeth — it is not tautological. + +**Known scope limit:** the gate regret-checks only the *first* divergence position per prompt, then continues. A bug that stays within the tie band at the first bifurcation but corrupts later positions could slip. Acceptable for v1; the next iteration should re-anchor and regret-check the following few positions after a benign-tie classification. + +## Performance + +Single-stream decode is where speculative decoding pays off directly: plain decode is memory-bound (one target forward per token); spec amortizes that forward over the accepted run. A/B harness: `tests/dflash_speculative_perf.rs` (bs=1, 256 tokens, `ignore_eos`, one warm-up discarded). + +| Config | RTX 5070 Ti, bs=1 | RTX 5090, bs=1 | +| --- | --- | --- | +| spec OFF (plain decode) | 93.4 tok/s | 168.9 tok/s | +| spec ON (DFlash) | 170.0 tok/s | 263.2 tok/s | +| **speedup** | **1.82×** | **1.56×** | + +The speedup is smaller on the 5090: its higher memory bandwidth makes the baseline decode less memory-bound, so amortizing the target forward buys less. Both builds use CUDA 13.x (the 5090's default 12.9 has the documented cuBLAS N=1025 cliff). + +### Concurrent throughput: the draft loop is serial and launch-bound + +Under concurrent load the single-stream win inverts. Same greedy harness (`temperature=0`, sharegpt prompts, 128 out tokens), openinfer vs vLLM with the **same** DFlash-b16 drafter, RTX 5090: + +| concurrency | OI plain | OI DFlash serial | OI DFlash batched | **OI DFlash +graph** | vLLM DFlash | +| --- | --- | --- | --- | --- | --- | +| c1 | 170 | 245 | 237 | **274** | 278 | +| c8 | 1180 | 831 | 1346 | **1525** | 1240 | +| c16 | 2277 | 1013 | 1868 | **1834** | 1846 | + +(tok/s, sharegpt out128, RTX 5090, greedy. The serial draft inverted the win — vLLM degraded gracefully while openinfer nearly halved; batching the draft restored c8/c16 past vLLM. The **+graph** column adds the piecewise verify CUDA Graph (this branch): it closes c1 to vLLM and lifts c8, with c16 flat — see "Single-stream gap" below. Caveat: the serial→batched columns are a same-session A/B; the +graph column's c1 is a clean same-session A/B (251 fixed-buffer eager → 274 graph), while c8/c16 are single runs against the prior batched baseline, i.e. a no-regression check rather than a tight A/B.) + +#### Root cause (serial draft) and the fix (batched draft, landed) + +Accept length is **not** the cause — both engines accept ~equally (same drafter + greedy = longest-prefix match). The cause was the **per-request serial draft loop**: `execute_dflash_draft` (`dflash_lane.rs`) called `draft_logits` once per request in a `for` loop, while verify runs ONE batched target forward over all spans. Per-step draft timing by batch size, **before vs after batching** (5090, instrumented): + +| batch | draft serial | draft batched | verify | draft serial scaling | draft batched scaling | +| --- | --- | --- | --- | --- | --- | +| 1 | 1.56 ms | 1.55 ms | 8.1 ms | 1.00× | 1.00× | +| 8 | 12.45 ms | 3.17 ms | 9.6 ms | 7.99× | 2.04× | +| 16 | 24.86 ms | **5.62 ms** | 13.6 ms | **15.96×** | **3.62×** | + +The serial draft scaled **exactly linearly** (`draft_x` 16.00 at batch 16); batched draft scales sub-linearly (3.62×). At batch 16 the draft drops from 65% of the step to ~29%, and step time roughly halves → c16 throughput doubles. + +**Why it was launch-bound, not compute-bound.** A skip-attention A/B (short-circuit `single_prefill` to a cheap copy, keep every other kernel) barely moved the serial draft — batch 16: 24.36 ms vs 24.81 ms, so attention compute is **<2%**. The 24.8 ms was almost entirely per-request serial **kernel-launch overhead**: each request's draft is ~90 tiny 16-token kernels (5 layers × 35 dense GEMMs + varlen copy/rope/KV), compute ≈ 0. + +**The fix that landed: batch the whole draft forward (N requests in one pass), killing the N× launch.** `draft_logits_batched` (`dflash.rs`) runs the dense ops (rms_norm / GEMM / silu / MLP / embedding / logits) **once over an N×block buffer** — free, since cuBLAS takes any M and the ops are already row-batched (N×35 → 35 launches, no CUDA-kernel change). The varlen ops (tail concat / k·v GEMM / rope / KV-copy / attention) stay a per-request loop slicing the batched buffers at `row_offset = i·block_size` (the two DFlash-exclusive ops `dflash_qk_norm_rope_into` / `single_prefill_nhd_noncausal_into` advance the device pointer to the slice). A lane-level `DFlashBatchScratch` (sized for the largest batch bucket) replaces the per-request scratch — folding in the reservation-reclaim follow-up. Losslessness gate still passes (bf16 tie-flips only). + +#### Single-stream gap (c1): closed by a piecewise verify CUDA Graph + +c1 trailed vLLM (237 vs 278), and it is **not accept**: with the *same* drafter and spec config (vLLM `--speculative-config '{"method":"dflash",…,"num_speculative_tokens":16}'`), measured greedy accept is **9.1% (ours) vs 8.85% (vLLM)** — mean 1.29 vs 1.42 draft tokens/step, our pos-0 accept (60%) actually higher. 9% is the drafter's floor on sharegpt free-text (bimodal: structured spans accept up to 15, free text 0), shared by both engines. + +The gap is **launch exposure**. Under greedy dflash the spec path runs with *no* CUDA Graph: the base-decode graph (`execute_decode`) never fires because greedy routes through `SpeculativeDecode`, not `batch_decode`. nsys on c1 (single stream, fully serial): wall 7999 ms, GPU **busy** 6703 ms (memory-bound weight reads — irreducible, the same for vLLM), GPU **idle 1296 ms**, ~79% of it sub-3 µs kernel-launch gaps. Attributing that idle by kernel type (`gap_by_kernel.py`): **dense ops 84%** (GEMM 48%, RMSNorm 15%, embedding / KV-copy / residual-add / SwiGLU the rest), **attention only 8%**. Dense dominates because verify is 36 layers of GEMMs vs the draft's 5 — so the launch gap lives mostly in *verify*, not the draft. (This corrects the earlier read that pinned c1 on draft launch; that predated the batched draft and the per-kernel-type nsys breakdown.) + +**Why the whole forward can't go in one graph.** A first attempt captured the entire verify forward; output was correct up to ~token 60, then garbage. Root cause: FlashInfer's paged-prefill attention derives its KV-iteration count (`num_iterations`, from `kv_len`) and that loop bound is **frozen when the graph is recorded**. The verify context grows every step, so once it crosses the captured `CTA_TILE_KV` (~64) boundary the replayed attention under-reads KV. Base-decode's graph is safe only because its *decode* kernel's KV loop is purely device-driven; *prefill*'s is not. FlashInfer ships no graph-safe prefill variant; vLLM hits the same wall and keeps attention out of its piecewise cudagraph. + +**The fix: piecewise graph.** Keep attention **eager**; capture only the dense ops, whose dims depend on the fixed `span` row count, never on KV length. `forward_layer_batch_paged` is split into `pre_attn` / `attn` / `post_attn`, and the verify forward becomes `num_layers+1` dense graph segments — `[embed+L0.pre] [L0.attn eager] [L0.post+L1.pre] … [L_last.post+lm_head]` — captured once per batch bucket and replayed (`verify_graph.rs`). The ping-pong residual swap sits inside the captured segments, so its pointer alternation is baked into each graph and reproduces on every replay regardless of layer parity: `run_or_capture` re-runs the CPU swap only on the capture step, and the one eager op (attention) touches just `q/k/v_batch` / `attn_output`, never the swapped `hidden`. + +Result (5090, greedy, same-session A/B): fixed-buffer eager **250.9 → +graph 274.3 tok/s (+9.3% from the graph alone)**; 237 → 274 (+16%) over the pre-graph batched baseline, **matching vLLM's 278**. Concurrent (no-regression check): c8 1346 → 1525, c16 1868 → 1834 (both still ≥ vLLM). Losslessness gate passes (bf16 tie-flips only) — the dense ops replay bit-identically, and the eager attention is unchanged. + +**Draft-side piecewise graph is the tracked next step.** The draft (5 layers, `dflash.rs`) is the other ~16% of the launch gap; it needs the same pre/attn/post split, with its variable-length contiguous KV (`DFlashLayerCache`) handled at the eager attention boundary. Tracked as its own PR after this one lands. + +The EAGLE proposer trait is still deferred to when EAGLE actually lands (see "no proposer trait yet" above); both the batched draft and the future CUDA-Graph draft are DFlash-internal changes behind the unchanged `DraftPlan→DraftResult` seam, so neither is thrown away by EAGLE. + +## Task-level accuracy parity (lm-eval gsm8k) + +Token-level losslessness should imply task-level parity. Confirmed on the 5090 with `lm-eval` gsm8k (5-shot, greedy, `local-completions` against the openinfer server, 50 questions): + +| | flexible-extract | strict-match | +| --- | --- | --- | +| spec OFF | 0.86 | 0.86 | +| spec ON (DFlash) | 0.88 | 0.86 | + +`strict-match` is identical; `flexible-extract` differs by one question (within the ±0.05 stderr), the same single bf16 tie-flip the losslessness gate sees. DFlash does not change task accuracy. + +Harness note: openinfer's `/v1/completions` rejects a per-request `seed` field (`"per-request seed is not supported by this engine"` → 400), which the OpenAI/lm-eval client sends by default. For the eval the client was patched to drop `seed`; making the frontend accept-and-ignore `seed` under greedy is a separate, unrelated improvement (not part of this change). + +## GPU memory budget & context limit + +DFlash's draft model and its per-request KV/scratch live **outside** the paged KV pool (`KvCacheManager`), so they must be reserved *before* the pool is sized or they silently steal from it and OOM under concurrency / long contexts. `DFlashMemoryReservation::from_config` (cheap — reads the draft `config.json`) splits the footprint by how it scales and the budget charges each in the matching place: + +- **Per-token, pool-scaling** (draft KV + the prompt-tracking context scratch + the in-fill tail scratch + pending context, **65 536 B/token**) → folded into `effective_bytes_per_block`, so the *target* block count shrinks while the pool itself stays allocated at the target-only `bytes_per_block`. This is a safe upper bound: the scheduler reserves pool blocks for each request's whole lifetime (`prompt + max_tokens`), and the draft attends at most that many tokens, so billing the draft per pool-token over-covers. The draft KV and tail scratch are sized one in-fill block past that lifetime, so the fixed term also reserves `max_decode_batch × block_size` of that per-request headroom. +- **Fixed, pool-independent** (draft weights ~1.1 GiB + the block-sized per-request scratch across the decode batch — logits-dominated, ~6.5 MiB × 256) → added to the KV `margin`. + +Measured on the RTX 5070 Ti (16 GB, util 90%), the pool correctly makes room for the draft: + +| | margin | KV budget | KV blocks | +| --- | --- | --- | --- | +| DFlash OFF | 150 MiB | 4113 MiB | 1828 | +| DFlash ON | 2972 MiB | 1389 MiB | 427 | + +The fixed reservation lands exactly in the margin (+2822 MiB) and the per-token factor (≈1.44×) shrinks the remaining blocks. Before this, those ~2.8 GiB loaded on top of the full 1828-block pool → OOM under load. + +**Context limit.** The drafter's fixed-width in-fill block writes `block_size` positions past the committed length each step, so the DFlash-effective context is `max_position_embeddings − block_size`. `max_context_tokens()` returns that when DFlash is on, so a request that fits the target window but lands in the draft's final block is **rejected cleanly at admission** instead of crashing mid-prefill (`tests/dflash_speculative_gate.rs::dflash_request_in_draft_headroom_is_rejected_not_panicked`). + +## Constraints & open follow-ups + +- **TP=1, primary rank only.** The DFlash lane runs on the worker thread of the primary rank; the launch path gates `!(dflash && lora)` and `tp_size == 1`. The server fails loud (`--dflash-draft-model-path` is rejected for non-Qwen3 model lines rather than silently ignored). +- **Second proposer (EAGLE) → introduce the trait then.** A proposer trait wants two real methods to fix its shape; DFlash stays concrete until EAGLE lands. The batched-draft perf work (Performance §) is orthogonal — it's a DFlash-internal change behind the unchanged `DraftPlan→DraftResult` seam, so it can land before or with the trait without being redone, and EAGLE's autoregressive attention won't reuse DFlash's block attention anyway. +- **File size.** `executor.rs` (~2.8k lines) and `scheduler/tests.rs` (~1.2k lines) exceed the 1k guideline; the spec arms in `execute_step_on_lane` are candidates to move into `executor/spec.rs` in a follow-up. +- **5090 validation — done** for correctness, single-stream, and concurrent throughput: `hf_golden_gate` passes (bs1 / batched / cuda-graph / tp2), the losslessness gate passes (batched draft *and* piecewise verify graph), single-stream A/B is 1.56×, and lm-eval gsm8k parity holds (above). Concurrent throughput is **fixed** (Performance §): batching the draft beat vLLM at c8/c16, then the **piecewise verify CUDA Graph** closed c1 — **c1 274 ≈ vLLM 278, c8 1525 > 1240, c16 1834 ≈ 1846**, all batch sizes now at or above vLLM. Draft-side piecewise graph (the remaining ~16% launch gap) is tracked next. +- **Frontend `seed`.** `/v1/completions` 400s on a per-request `seed` instead of accepting-and-ignoring it under greedy — surfaced while wiring lm-eval; orthogonal to spec decode, worth a small separate fix. +- **Reclaim the per-request reservation.** The fixed term is dominated by the block-sized scratch billed per decode-batch slot (~6.5 MiB × 256 ≈ 1.6 GiB), most of it the transient logits buffer (`vocab × block`). The per-token term carries the context/pending scratch, which today reallocs to prompt length on the first draft and never shrinks (sized for `[0, context_len)` but only `[committed_len, +block]` is live). Both are billed conservatively because they were per-request. **Partly reclaimed by the batched-draft PR:** the per-request scratch is gone — `DFlashBatchScratch` is allocated once at lane level, sharing the transient logits/dense buffers across the whole batch instead of per request. The reservation accounting (`from_config`) still bills the old per-request upper bound (a safe over-estimate, so no OOM); tightening it to the now-smaller lane-level footprint is the remaining follow-up. + +### Review blockers (correctness/usability, independent of the perf work) + +Issues surfaced in PR review, largely independent of the draft batching (#2's DFlash wrappers were fixed alongside it): + +1. **Unified path silently skips DFlash readiness.** `StepCommand::Unified` (`executor.rs`) captures no DFlash hidden state, and only the plain `execute_prefill` post-step marks a request draft-ready (`executor.rs` ~1766). So when active + pending fuse into a Unified step (`scheduler/plan.rs` — the normal mixed-load path), greedy requests routed through Unified never become draft-ready and never recover: DFlash silently no-ops for them. No wrong tokens, but the feature quietly disables itself under mixed load. Crash-early or capture-in-Unified, don't degrade silently. +2. **Stream-override race (partly fixed).** The DFlash `qk_norm_rope` / `single_prefill` wrappers (`attention.rs`) now use `active_cu_stream(ctx)` (fixed in the batched-draft PR). `copy_hidden_rows_into` (`elementwise.rs:209`) still uses `ctx.stream.cu_stream()` instead of the repo convention (`tensor.rs:43`) — under Green-Context / split-stream decode overlap this remains a planted race. +3. **`gemm_lt_pin_tune` is not a real warmup.** It only pins the heuristic (`linear.cu:497`); it never executes a `cublasLtMatmul` the way the old `gemm_lt_tune_cuda` (`linear.cu:431`) did, so the first real matmul can land inside CUDA-graph capture. `batch_invariance_decode_gemm_graph` backstops it, but the warmup should actually run the matmul. diff --git a/openinfer-core/src/ops.rs b/openinfer-core/src/ops.rs index dc522350..389ea51a 100644 --- a/openinfer-core/src/ops.rs +++ b/openinfer-core/src/ops.rs @@ -15,6 +15,7 @@ pub use attention::{ pub use openinfer_kernels::ops::{ GEMM_LT_MAX_N, LoraDecodeGroupedProjection, accumulate_bf16_token_scaled_to_f32_into, add_batch, add_batch_into, argmax, argmax_batch_bf16_into, bf16_hidden_to_f32_into, + copy_hidden_rows_into, copy_hidden_token_range_into, dflash_qk_norm_rope_into, embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, fused_add_rms_norm_into, gather_hidden_tokens_into, gemm, gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked, gemm_into_checked, gemm_lt_tune, @@ -23,7 +24,8 @@ pub use openinfer_kernels::ops::{ qk_norm_partial_rope_batched_decode_hd256_into, rms_norm, rms_norm_batch_offset_into, rms_norm_gated_batch_into, rms_norm_into, rms_norm_offset_into, scale_f32_in_place, scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, - scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, write_vec_into, + scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, + single_prefill_nhd_noncausal_into, write_vec_into, }; #[cfg(not(feature = "kernel-call-trace"))] pub use openinfer_kernels::ops::{ diff --git a/openinfer-core/src/ops/paged_plan.rs b/openinfer-core/src/ops/paged_plan.rs index 1d162951..d2770939 100644 --- a/openinfer-core/src/ops/paged_plan.rs +++ b/openinfer-core/src/ops/paged_plan.rs @@ -147,6 +147,53 @@ impl PrefillPagedPlan { }) } + /// Pre-allocate a worst-case-sized plan to be refilled in place by + /// [`Self::update_batch_with_cta_tile_q`] (graph-stable buffer reuse). + pub fn new_preallocated( + ctx: &DeviceContext, + max_total_tokens: usize, + max_total_pages: usize, + max_batch: usize, + max_tiles: usize, + ) -> Result { + Ok(Self { + inner: openinfer_kernels::ops::PrefillPagedPlan::new_preallocated( + ctx, + max_total_tokens, + max_total_pages, + max_batch, + max_tiles, + )?, + }) + } + + /// Refill a pre-allocated plan in place (no allocation, pointers unchanged). + #[allow(clippy::too_many_arguments)] + pub fn update_batch_with_cta_tile_q( + &mut self, + ctx: &DeviceContext, + page_indices: &[Vec], + last_page_lens: &[usize], + start_positions: &[usize], + seq_lens: &[usize], + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + cta_tile_q_override: i32, + ) -> Result<()> { + self.inner.update_batch_with_cta_tile_q( + ctx, + page_indices, + last_page_lens, + start_positions, + seq_lens, + num_q_heads, + num_kv_heads, + head_dim, + cta_tile_q_override, + ) + } + pub fn page_indices_d(&self) -> &CudaSlice { self.inner.page_indices_d() } diff --git a/openinfer-kernels/csrc/shared/elementwise.cu b/openinfer-kernels/csrc/shared/elementwise.cu index 92de04eb..7193b145 100644 --- a/openinfer-kernels/csrc/shared/elementwise.cu +++ b/openinfer-kernels/csrc/shared/elementwise.cu @@ -63,6 +63,43 @@ __global__ void gather_hidden_tokens_kernel( } } +__global__ void copy_hidden_rows_kernel( + const __nv_bfloat16 *__restrict__ src, + __nv_bfloat16 *__restrict__ dst, + int src_hidden_dim, + int dst_hidden_dim, + int row_offset, + int rows, + int seq_len) { + int total = rows * seq_len; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total; + idx += gridDim.x * blockDim.x) { + int token = idx / rows; + int row = idx % rows; + dst[(size_t)token * dst_hidden_dim + row_offset + row] = + src[(size_t)token * src_hidden_dim + row]; + } +} + +__global__ void copy_hidden_token_range_kernel( + const __nv_bfloat16 *__restrict__ src, + __nv_bfloat16 *__restrict__ dst, + int hidden_dim, + int src_token_offset, + int dst_token_offset, + int token_count) { + int total = hidden_dim * token_count; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total; + idx += gridDim.x * blockDim.x) { + int token = idx / hidden_dim; + int row = idx % hidden_dim; + dst[(size_t)(dst_token_offset + token) * hidden_dim + row] = + src[(size_t)(src_token_offset + token) * hidden_dim + row]; + } +} + __global__ void scaled_add_rows_indexed_kernel( const __nv_bfloat16 *__restrict__ delta, float scale, @@ -305,6 +342,53 @@ CUresult gather_hidden_tokens_cuda( return (CUresult)cudaGetLastError(); } +CUresult copy_hidden_rows_cuda( + const __nv_bfloat16 *src, + __nv_bfloat16 *dst, + int src_hidden_dim, + int dst_hidden_dim, + int row_offset, + int rows, + int seq_len, + cudaStream_t stream) { + if (src == nullptr || dst == nullptr || src_hidden_dim <= 0 || + dst_hidden_dim <= 0 || row_offset < 0 || rows <= 0 || seq_len <= 0 || + rows > src_hidden_dim || row_offset + rows > dst_hidden_dim) { + return CUDA_ERROR_INVALID_VALUE; + } + int total = rows * seq_len; + int block = 256; + int grid = (total + block - 1) / block; + copy_hidden_rows_kernel<<>>( + src, dst, src_hidden_dim, dst_hidden_dim, row_offset, rows, seq_len); + return (CUresult)cudaGetLastError(); +} + +CUresult copy_hidden_token_range_cuda( + const __nv_bfloat16 *src, + __nv_bfloat16 *dst, + int hidden_dim, + int src_token_offset, + int dst_token_offset, + int token_count, + int src_seq_len, + int dst_seq_len, + cudaStream_t stream) { + if (src == nullptr || dst == nullptr || hidden_dim <= 0 || + src_token_offset < 0 || dst_token_offset < 0 || token_count <= 0 || + src_seq_len <= 0 || dst_seq_len <= 0 || + src_token_offset + token_count > src_seq_len || + dst_token_offset + token_count > dst_seq_len) { + return CUDA_ERROR_INVALID_VALUE; + } + int total = hidden_dim * token_count; + int block = 256; + int grid = (total + block - 1) / block; + copy_hidden_token_range_kernel<<>>( + src, dst, hidden_dim, src_token_offset, dst_token_offset, token_count); + return (CUresult)cudaGetLastError(); +} + CUresult scaled_add_rows_indexed_cuda( const __nv_bfloat16 *delta, float scale, diff --git a/openinfer-kernels/csrc/shared/paged_attention.cu b/openinfer-kernels/csrc/shared/paged_attention.cu index 4506a60f..01f4e82c 100644 --- a/openinfer-kernels/csrc/shared/paged_attention.cu +++ b/openinfer-kernels/csrc/shared/paged_attention.cu @@ -607,6 +607,70 @@ int single_prefill_cuda( reinterpret_cast(stream))); } +int single_prefill_nhd_noncausal_cuda( + // Q and output (HiddenStates token-major: [seq_len, q_dim]) + void* q, + void* output, + // Contiguous KV cache (HiddenStates token-major: [max_seq_len, kv_dim]) + void* k_cache, + void* v_cache, + int32_t num_qo_heads, + int32_t num_kv_heads, + int32_t head_dim, + int32_t seq_len, + int32_t kv_len, + int32_t max_seq_len, + float sm_scale, + void* stream) +{ + if (q == nullptr || output == nullptr || k_cache == nullptr || v_cache == nullptr || + num_qo_heads <= 0 || num_kv_heads <= 0 || head_dim != 128 || + seq_len <= 0 || kv_len <= 0 || max_seq_len < kv_len) { + return static_cast(cudaErrorInvalidValue); + } + + uint32_t q_stride_n = num_qo_heads * head_dim; + uint32_t q_stride_h = head_dim; + uint32_t kv_stride_n = num_kv_heads * head_dim; + uint32_t kv_stride_h = head_dim; + + PrefillParamsT params( + reinterpret_cast(q), + reinterpret_cast(k_cache), + reinterpret_cast(v_cache), + /*maybe_custom_mask=*/nullptr, + reinterpret_cast(output), + /*lse=*/nullptr, + /*maybe_alibi_slopes=*/nullptr, + num_qo_heads, + num_kv_heads, + static_cast(seq_len), + static_cast(kv_len), + q_stride_n, + q_stride_h, + kv_stride_n, + kv_stride_h, + static_cast(head_dim), + /*window_left=*/-1, + /*logits_soft_cap=*/0.0f, + sm_scale, + /*rope_scale=*/1.0f, + /*rope_theta=*/1e6f); + + return static_cast( + SinglePrefillWithKVCacheDispatched< + /*HEAD_DIM_QK=*/128, + /*HEAD_DIM_VO=*/128, + PosEncodingMode::kNone, + /*USE_FP16_QK_REDUCTION=*/false, + MaskMode::kNone, + Variant, + PrefillParamsT>( + params, + /*tmp=*/nullptr, + reinterpret_cast(stream))); +} + // --------------------------------------------------------------------------- // Single-request prefill for HEAD_DIM=256 — wraps FlashInfer SinglePrefillWithKVCache. // diff --git a/openinfer-kernels/csrc/shared/prefill_attention.cu b/openinfer-kernels/csrc/shared/prefill_attention.cu index a7b24b66..0c69522b 100644 --- a/openinfer-kernels/csrc/shared/prefill_attention.cu +++ b/openinfer-kernels/csrc/shared/prefill_attention.cu @@ -93,6 +93,90 @@ __global__ void prefill_qk_norm_rope_kernel( data[offset] = result; } +__global__ void dflash_qk_norm_rope_kernel( + __nv_bfloat16* __restrict__ q, + __nv_bfloat16* __restrict__ k, + const __nv_bfloat16* __restrict__ q_norm_weight, + const __nv_bfloat16* __restrict__ k_norm_weight, + const __nv_bfloat16* __restrict__ cos_cache, + const __nv_bfloat16* __restrict__ sin_cache, + int num_q_heads, + int num_kv_heads, + int head_dim, + int q_len, + int k_len, + int q_start_pos, + int k_start_pos, + float eps, + int cos_max_pos +) { + int head_global = blockIdx.x; + int token = blockIdx.y; + int d = threadIdx.x; + + bool is_q = (head_global < num_q_heads); + int local_heads = is_q ? num_q_heads : num_kv_heads; + int seq_len = is_q ? q_len : k_len; + if (token >= seq_len) return; + + int head_local = is_q ? head_global : (head_global - num_q_heads); + if (head_local >= local_heads) return; + + __nv_bfloat16* data = is_q ? q : k; + int dim_stride = local_heads * head_dim; + const __nv_bfloat16* norm_w = is_q ? q_norm_weight : k_norm_weight; + int pos = (is_q ? q_start_pos : k_start_pos) + token; + if (pos < 0 || pos >= cos_max_pos) __trap(); + + int offset = token * dim_stride + head_local * head_dim + d; + float val = __bfloat162float(data[offset]); + + float sq = warp_reduce_sum(val * val); + int warp_id = d / WARP_SIZE; + int lane_id = d % WARP_SIZE; + __shared__ float warp_sums[4]; + if (lane_id == 0) warp_sums[warp_id] = sq; + __syncthreads(); + + __shared__ float s_inv_rms; + if (warp_id == 0) { + float v = (lane_id < 4) ? warp_sums[lane_id] : 0.0f; + float total = warp_reduce_sum(v); + if (lane_id == 0) s_inv_rms = rsqrtf(total / head_dim + eps); + } + __syncthreads(); + + __nv_bfloat16 normed = __float2bfloat16(val * s_inv_rms); + float normed_f = __bfloat162float(normed) * __bfloat162float(norm_w[d]); + + __shared__ __nv_bfloat16 smem[HEAD_DIM]; + smem[d] = __float2bfloat16(normed_f); + __syncthreads(); + + int half = head_dim / 2; + __nv_bfloat16 result; + if (d < half) { + float lo = __bfloat162float(smem[d]); + float hi = __bfloat162float(smem[d + half]); + float c = __bfloat162float(cos_cache[pos * head_dim + d]); + float s = __bfloat162float(sin_cache[pos * head_dim + d]); + float lo_cos = __bfloat162float(__float2bfloat16(lo * c)); + float hi_sin = __bfloat162float(__float2bfloat16(hi * s)); + result = __float2bfloat16(lo_cos - hi_sin); + } else { + int pair_d = d - half; + float lo = __bfloat162float(smem[pair_d]); + float hi = __bfloat162float(smem[d]); + float c = __bfloat162float(cos_cache[pos * head_dim + pair_d]); + float s = __bfloat162float(sin_cache[pos * head_dim + pair_d]); + float lo_sin = __bfloat162float(__float2bfloat16(lo * s)); + float hi_cos = __bfloat162float(__float2bfloat16(hi * c)); + result = __float2bfloat16(lo_sin + hi_cos); + } + + data[offset] = result; +} + extern "C" { // ============================================================================ @@ -136,4 +220,38 @@ void qk_norm_rope_batched_decode_cuda( ); } +int dflash_qk_norm_rope_cuda( + __nv_bfloat16* q, + __nv_bfloat16* k, + const __nv_bfloat16* q_norm_weight, + const __nv_bfloat16* k_norm_weight, + const __nv_bfloat16* cos_cache, + const __nv_bfloat16* sin_cache, + int num_q_heads, + int num_kv_heads, + int head_dim, + int q_len, + int k_len, + int q_start_pos, + int k_start_pos, + float rms_eps, + int cos_max_pos, + cudaStream_t stream +) { + if (q == nullptr || k == nullptr || q_norm_weight == nullptr || + k_norm_weight == nullptr || cos_cache == nullptr || sin_cache == nullptr || + num_q_heads <= 0 || num_kv_heads <= 0 || head_dim != HEAD_DIM || + q_len <= 0 || k_len <= 0 || q_start_pos < 0 || k_start_pos < 0 || + q_start_pos + q_len > cos_max_pos || k_start_pos + k_len > cos_max_pos) { + return static_cast(cudaErrorInvalidValue); + } + + dim3 grid(num_q_heads + num_kv_heads, q_len > k_len ? q_len : k_len); + dflash_qk_norm_rope_kernel<<>>( + q, k, q_norm_weight, k_norm_weight, cos_cache, sin_cache, + num_q_heads, num_kv_heads, head_dim, q_len, k_len, + q_start_pos, k_start_pos, rms_eps, cos_max_pos); + return static_cast(cudaGetLastError()); +} + } // extern "C" diff --git a/openinfer-kernels/src/ffi/shared.rs b/openinfer-kernels/src/ffi/shared.rs index aff4e5e2..c8de7a63 100644 --- a/openinfer-kernels/src/ffi/shared.rs +++ b/openinfer-kernels/src/ffi/shared.rs @@ -30,6 +30,29 @@ unsafe extern "C" { stream: CUstream, ) -> CUresult; + pub fn copy_hidden_rows_cuda( + src: *const Half, + dst: *mut Half, + src_hidden_dim: i32, + dst_hidden_dim: i32, + row_offset: i32, + rows: i32, + seq_len: i32, + stream: CUstream, + ) -> CUresult; + + pub fn copy_hidden_token_range_cuda( + src: *const Half, + dst: *mut Half, + hidden_dim: i32, + src_token_offset: i32, + dst_token_offset: i32, + token_count: i32, + src_seq_len: i32, + dst_seq_len: i32, + stream: CUstream, + ) -> CUresult; + pub fn fused_add_rms_norm_cuda( hidden: *mut Half, residual: *const Half, @@ -249,6 +272,25 @@ unsafe extern "C" { stream: CUstream, ); + pub fn dflash_qk_norm_rope_cuda( + q: *mut Half, + k: *mut Half, + q_norm_weight: *const Half, + k_norm_weight: *const Half, + cos_cache: *const Half, + sin_cache: *const Half, + num_q_heads: i32, + num_kv_heads: i32, + head_dim: i32, + q_len: i32, + k_len: i32, + q_start_pos: i32, + k_start_pos: i32, + rms_eps: f32, + cos_max_pos: i32, + stream: CUstream, + ) -> i32; + // Scatter contiguous KV → paged layout (one layer, FlashInfer prefill append). pub fn paged_kv_scatter_cuda( kv_data: *const Half, @@ -375,6 +417,21 @@ unsafe extern "C" { stream: CUstream, ) -> i32; + pub fn single_prefill_nhd_noncausal_cuda( + q: *const Half, + output: *mut Half, + k_cache: *const Half, + v_cache: *const Half, + num_qo_heads: i32, + num_kv_heads: i32, + head_dim: i32, + seq_len: i32, + kv_len: i32, + max_seq_len: i32, + sm_scale: f32, + stream: CUstream, + ) -> i32; + // Paged attention decode (FlashInfer BatchDecode, no partition-KV). pub fn paged_attention_decode_cuda( q: *const Half, diff --git a/openinfer-kernels/src/ops.rs b/openinfer-kernels/src/ops.rs index fa8c5362..2d78930e 100644 --- a/openinfer-kernels/src/ops.rs +++ b/openinfer-kernels/src/ops.rs @@ -15,9 +15,10 @@ mod norm; mod sampling; pub use attention::{ - PrefillPagedPlan, paged_attention_batch_decode_hd256_into, paged_attention_batch_decode_into, - paged_attention_batch_decode_split_kv_into, prefill_attention_paged_into, - qk_norm_partial_rope_batched_decode_hd256_into, qk_norm_rope_batch_decode_into, + PrefillPagedPlan, dflash_qk_norm_rope_into, paged_attention_batch_decode_hd256_into, + paged_attention_batch_decode_into, paged_attention_batch_decode_split_kv_into, + prefill_attention_paged_into, qk_norm_partial_rope_batched_decode_hd256_into, + qk_norm_rope_batch_decode_into, single_prefill_nhd_noncausal_into, }; #[cfg(feature = "kimi-k2")] pub use deepep::{ @@ -27,11 +28,11 @@ pub use deepep::{ pub use deepseek_v2_lite::*; pub use elementwise::{ accumulate_bf16_token_scaled_to_f32_into, add_batch, add_batch_into, bf16_hidden_to_f32_into, - extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, - gather_hidden_tokens_into, repeat_f32_for_reduce_scatter_into, scale_f32_in_place, - scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, - scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, - silu_mul_fused_batch_into, write_vec_into, + copy_hidden_rows_into, copy_hidden_token_range_into, extract_vec, extract_vec_into, + extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, gather_hidden_tokens_into, + repeat_f32_for_reduce_scatter_into, scale_f32_in_place, scaled_add_batch_into, + scaled_add_rows_indexed_into, scaled_add_rows_into, scaled_add_rows_token_range_into, + silu_mul_batch, silu_mul_batch_into, silu_mul_fused_batch_into, write_vec_into, }; pub use embedding::{embedding_batch, embedding_batch_vocab_shard, embedding_decode_into}; #[cfg(feature = "kimi-k2")] diff --git a/openinfer-kernels/src/ops/attention.rs b/openinfer-kernels/src/ops/attention.rs index 99516a13..55b301d3 100644 --- a/openinfer-kernels/src/ops/attention.rs +++ b/openinfer-kernels/src/ops/attention.rs @@ -215,6 +215,210 @@ impl PrefillPagedPlan { num_kv_heads: usize, head_dim: usize, cta_tile_q_override: i32, + ) -> Result { + let host = BatchPlanHost::compute( + page_indices, + last_page_lens, + start_positions, + seq_lens, + num_q_heads, + num_kv_heads, + head_dim, + cta_tile_q_override, + )?; + + // Upload all to GPU + Ok(Self { + page_indices_d: ctx.stream.clone_htod(&host.all_page_indices)?, + page_indptr_d: ctx.stream.clone_htod(&host.page_indptr)?, + last_page_len_d: ctx.stream.clone_htod(&host.last_page_lens_i32)?, + batch_indices_d: ctx.stream.clone_htod(&host.batch_indices)?, + positions_d: ctx.stream.clone_htod(&host.positions)?, + q_indptr_d: ctx.stream.clone_htod(&host.q_indptr)?, + request_indices_d: ctx.stream.clone_htod(&host.request_indices_v)?, + qo_tile_indices_d: ctx.stream.clone_htod(&host.qo_tile_indices_v)?, + kv_tile_indices_d: ctx.stream.clone_htod(&host.kv_tile_indices_v)?, + kv_chunk_size_d: ctx.stream.clone_htod(&host.kv_chunk_sizes)?, + total_num_rows_d: ctx.stream.clone_htod(&[host.total_tokens as u32])?, + num_tiles: host.num_tiles, + batch_size: host.batch_size as i32, + total_tokens: host.total_tokens, + cta_tile_q: host.cta_tile_q as i32, + }) + } + + /// Allocate a worst-case-sized plan once, to be refilled in place by + /// [`Self::update_batch_with_cta_tile_q`]. Buffer pointers stay fixed across + /// updates so a CUDA Graph captured against them remains valid on replay. + /// Scalar fields start at 0; an unfilled plan must not be used for a forward. + pub fn new_preallocated( + ctx: &DeviceContext, + max_total_tokens: usize, + max_total_pages: usize, + max_batch: usize, + max_tiles: usize, + ) -> Result { + Ok(Self { + page_indices_d: ctx.stream.alloc_zeros(max_total_pages)?, + page_indptr_d: ctx.stream.alloc_zeros(max_batch + 1)?, + last_page_len_d: ctx.stream.alloc_zeros(max_batch)?, + batch_indices_d: ctx.stream.alloc_zeros(max_total_tokens)?, + positions_d: ctx.stream.alloc_zeros(max_total_tokens)?, + q_indptr_d: ctx.stream.alloc_zeros(max_batch + 1)?, + request_indices_d: ctx.stream.alloc_zeros(max_tiles)?, + qo_tile_indices_d: ctx.stream.alloc_zeros(max_tiles)?, + kv_tile_indices_d: ctx.stream.alloc_zeros(max_tiles)?, + kv_chunk_size_d: ctx.stream.alloc_zeros(max_batch)?, + total_num_rows_d: ctx.stream.alloc_zeros(1)?, + num_tiles: 0, + batch_size: 0, + total_tokens: 0, + cta_tile_q: 0, + }) + } + + /// Recompute the host-side batch metadata and `memcpy_htod` it into the + /// pre-allocated device buffers (no allocation, no pointer change). The host + /// computation is identical to [`Self::new_batch_with_cta_tile_q`]; only the + /// upload differs (overwrite in place vs. fresh `clone_htod`). + /// + /// `memcpy_htod` copies `src.len()` elements and tolerates a larger + /// destination, so the worst-case allocation may exceed the actual fill. + #[allow(clippy::too_many_arguments)] + pub fn update_batch_with_cta_tile_q( + &mut self, + ctx: &DeviceContext, + page_indices: &[Vec], + last_page_lens: &[usize], + start_positions: &[usize], + seq_lens: &[usize], + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + cta_tile_q_override: i32, + ) -> Result<()> { + let host = BatchPlanHost::compute( + page_indices, + last_page_lens, + start_positions, + seq_lens, + num_q_heads, + num_kv_heads, + head_dim, + cta_tile_q_override, + )?; + + anyhow::ensure!( + host.all_page_indices.len() <= self.page_indices_d.len(), + "verify plan page_indices ({}) exceeds preallocated capacity ({})", + host.all_page_indices.len(), + self.page_indices_d.len(), + ); + anyhow::ensure!( + host.page_indptr.len() <= self.page_indptr_d.len(), + "verify plan page_indptr ({}) exceeds preallocated capacity ({})", + host.page_indptr.len(), + self.page_indptr_d.len(), + ); + anyhow::ensure!( + host.last_page_lens_i32.len() <= self.last_page_len_d.len(), + "verify plan last_page_lens ({}) exceeds preallocated capacity ({})", + host.last_page_lens_i32.len(), + self.last_page_len_d.len(), + ); + anyhow::ensure!( + host.batch_indices.len() <= self.batch_indices_d.len(), + "verify plan batch_indices ({}) exceeds preallocated capacity ({})", + host.batch_indices.len(), + self.batch_indices_d.len(), + ); + anyhow::ensure!( + host.positions.len() <= self.positions_d.len(), + "verify plan positions ({}) exceeds preallocated capacity ({})", + host.positions.len(), + self.positions_d.len(), + ); + anyhow::ensure!( + host.q_indptr.len() <= self.q_indptr_d.len(), + "verify plan q_indptr ({}) exceeds preallocated capacity ({})", + host.q_indptr.len(), + self.q_indptr_d.len(), + ); + anyhow::ensure!( + host.request_indices_v.len() <= self.request_indices_d.len(), + "verify plan tiles ({}) exceeds preallocated capacity ({})", + host.request_indices_v.len(), + self.request_indices_d.len(), + ); + anyhow::ensure!( + host.kv_chunk_sizes.len() <= self.kv_chunk_size_d.len(), + "verify plan kv_chunk_sizes ({}) exceeds preallocated capacity ({})", + host.kv_chunk_sizes.len(), + self.kv_chunk_size_d.len(), + ); + + ctx.stream + .memcpy_htod(&host.all_page_indices, &mut self.page_indices_d)?; + ctx.stream + .memcpy_htod(&host.page_indptr, &mut self.page_indptr_d)?; + ctx.stream + .memcpy_htod(&host.last_page_lens_i32, &mut self.last_page_len_d)?; + ctx.stream + .memcpy_htod(&host.batch_indices, &mut self.batch_indices_d)?; + ctx.stream + .memcpy_htod(&host.positions, &mut self.positions_d)?; + ctx.stream + .memcpy_htod(&host.q_indptr, &mut self.q_indptr_d)?; + ctx.stream + .memcpy_htod(&host.request_indices_v, &mut self.request_indices_d)?; + ctx.stream + .memcpy_htod(&host.qo_tile_indices_v, &mut self.qo_tile_indices_d)?; + ctx.stream + .memcpy_htod(&host.kv_tile_indices_v, &mut self.kv_tile_indices_d)?; + ctx.stream + .memcpy_htod(&host.kv_chunk_sizes, &mut self.kv_chunk_size_d)?; + ctx.stream + .memcpy_htod(&[host.total_tokens as u32], &mut self.total_num_rows_d)?; + + self.num_tiles = host.num_tiles; + self.batch_size = host.batch_size as i32; + self.total_tokens = host.total_tokens; + self.cta_tile_q = host.cta_tile_q as i32; + Ok(()) + } +} + +/// Host-side batch-prefill metadata, computed identically for the fresh +/// (`new_batch_with_cta_tile_q`) and in-place (`update_batch_with_cta_tile_q`) +/// paths so the two never diverge. +struct BatchPlanHost { + all_page_indices: Vec, + page_indptr: Vec, + last_page_lens_i32: Vec, + batch_indices: Vec, + positions: Vec, + q_indptr: Vec, + request_indices_v: Vec, + qo_tile_indices_v: Vec, + kv_tile_indices_v: Vec, + kv_chunk_sizes: Vec, + num_tiles: i32, + batch_size: usize, + total_tokens: usize, + cta_tile_q: usize, +} + +impl BatchPlanHost { + #[allow(clippy::too_many_arguments)] + fn compute( + page_indices: &[Vec], + last_page_lens: &[usize], + start_positions: &[usize], + seq_lens: &[usize], + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + cta_tile_q_override: i32, ) -> Result { let batch_size = page_indices.len(); assert_eq!(batch_size, last_page_lens.len()); @@ -281,23 +485,21 @@ impl PrefillPagedPlan { } let num_tiles = request_indices_v.len() as i32; - // Upload all to GPU Ok(Self { - page_indices_d: ctx.stream.clone_htod(&all_page_indices)?, - page_indptr_d: ctx.stream.clone_htod(&page_indptr)?, - last_page_len_d: ctx.stream.clone_htod(&last_page_lens_i32)?, - batch_indices_d: ctx.stream.clone_htod(&batch_indices)?, - positions_d: ctx.stream.clone_htod(&positions)?, - q_indptr_d: ctx.stream.clone_htod(&q_indptr)?, - request_indices_d: ctx.stream.clone_htod(&request_indices_v)?, - qo_tile_indices_d: ctx.stream.clone_htod(&qo_tile_indices_v)?, - kv_tile_indices_d: ctx.stream.clone_htod(&kv_tile_indices_v)?, - kv_chunk_size_d: ctx.stream.clone_htod(&kv_chunk_sizes)?, - total_num_rows_d: ctx.stream.clone_htod(&[total_tokens as u32])?, + all_page_indices, + page_indptr, + last_page_lens_i32, + batch_indices, + positions, + q_indptr, + request_indices_v, + qo_tile_indices_v, + kv_tile_indices_v, + kv_chunk_sizes, num_tiles, - batch_size: batch_size as i32, + batch_size, total_tokens, - cta_tile_q: cta_tile_q as i32, + cta_tile_q, }) } } @@ -497,6 +699,143 @@ pub fn qk_norm_rope_batch_decode_into( } } +/// QK RMSNorm + RoPE for one DFlash request's draft block. +/// +/// `q` is a row sub-range of a batched buffer: `q_row_offset` rows precede this +/// request's `q_seq_len` query rows. The kernel still sees a single-request Q +/// shape — we just advance the device pointer to the request's slice. `k` is the +/// request's own varlen tail scratch (whole buffer), so it needs no offset. +#[allow(clippy::too_many_arguments)] +pub fn dflash_qk_norm_rope_into( + ctx: &DeviceContext, + q: &mut HiddenStates, + q_row_offset: usize, + q_seq_len: usize, + k: &mut HiddenStates, + q_norm_weight: &DeviceVec, + k_norm_weight: &DeviceVec, + cos_cache: &DeviceVec, + sin_cache: &DeviceVec, + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + q_start_pos: usize, + k_start_pos: usize, + rms_eps: f32, +) -> Result<()> { + assert_eq!(q.hidden_dim, num_q_heads * head_dim); + assert_eq!(k.hidden_dim, num_kv_heads * head_dim); + assert_eq!(q_norm_weight.len, head_dim); + assert_eq!(k_norm_weight.len, head_dim); + assert!( + q_row_offset + q_seq_len <= q.seq_len, + "dflash_qk_norm_rope q row range [{}..{}) exceeds seq_len {}", + q_row_offset, + q_row_offset + q_seq_len, + q.seq_len + ); + + let (q_ptr, _gq) = q.data.device_ptr_mut(&ctx.stream); + let q_ptr = q_ptr + (q_row_offset * q.hidden_dim * std::mem::size_of::()) as u64; + let (k_ptr, _gk) = k.data.device_ptr_mut(&ctx.stream); + let (qn_ptr, _gqn) = q_norm_weight.data.device_ptr(&ctx.stream); + let (kn_ptr, _gkn) = k_norm_weight.data.device_ptr(&ctx.stream); + let (cos_ptr, _gc) = cos_cache.data.device_ptr(&ctx.stream); + let (sin_ptr, _gs) = sin_cache.data.device_ptr(&ctx.stream); + + let result = unsafe { + ffi::dflash_qk_norm_rope_cuda( + q_ptr as *mut ffi::Half, + k_ptr as *mut ffi::Half, + qn_ptr as *const ffi::Half, + kn_ptr as *const ffi::Half, + cos_ptr as *const ffi::Half, + sin_ptr as *const ffi::Half, + num_q_heads as i32, + num_kv_heads as i32, + head_dim as i32, + q_seq_len as i32, + k.seq_len as i32, + q_start_pos as i32, + k_start_pos as i32, + rms_eps, + (cos_cache.data.len() / head_dim) as i32, + crate::tensor::active_cu_stream(ctx), + ) + }; + if result != 0 { + anyhow::bail!("dflash_qk_norm_rope_cuda failed with error {result}"); + } + Ok(()) +} + +/// Non-causal prefill attention for one DFlash request's draft block. +/// +/// `q` and `output` share the SAME row sub-range of batched buffers: request +/// `i` owns rows `[row_offset, row_offset + q_seq_len)` in both, because the +/// draft writes each request's attention output back into the row slot its +/// queries came from. The k/v caches are the request's own whole buffers. The +/// kernel sees a single-request shape — we advance the Q/output device pointers +/// to the request's slice. +#[allow(clippy::too_many_arguments)] +pub fn single_prefill_nhd_noncausal_into( + ctx: &DeviceContext, + q: &HiddenStates, + row_offset: usize, + q_seq_len: usize, + k_cache: &HiddenStates, + v_cache: &HiddenStates, + output: &mut HiddenStates, + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + kv_len: usize, +) -> Result<()> { + assert_eq!(q.hidden_dim, num_q_heads * head_dim); + assert_eq!(output.hidden_dim, q.hidden_dim); + assert_eq!(output.seq_len, q.seq_len); + assert_eq!(k_cache.hidden_dim, num_kv_heads * head_dim); + assert_eq!(v_cache.hidden_dim, k_cache.hidden_dim); + assert_eq!(v_cache.seq_len, k_cache.seq_len); + assert!(kv_len <= k_cache.seq_len); + assert!( + row_offset + q_seq_len <= q.seq_len, + "single_prefill row range [{}..{}) exceeds seq_len {}", + row_offset, + row_offset + q_seq_len, + q.seq_len + ); + + // q and output share row_offset (asserted same seq_len/hidden_dim above). + let byte_offset = (row_offset * q.hidden_dim * std::mem::size_of::()) as u64; + let (q_ptr, _gq) = q.data.device_ptr(&ctx.stream); + let q_ptr = q_ptr + byte_offset; + let (k_ptr, _gk) = k_cache.data.device_ptr(&ctx.stream); + let (v_ptr, _gv) = v_cache.data.device_ptr(&ctx.stream); + let (out_ptr, _go) = output.data.device_ptr_mut(&ctx.stream); + let out_ptr = out_ptr + byte_offset; + let result = unsafe { + ffi::single_prefill_nhd_noncausal_cuda( + q_ptr as *const ffi::Half, + out_ptr as *mut ffi::Half, + k_ptr as *const ffi::Half, + v_ptr as *const ffi::Half, + num_q_heads as i32, + num_kv_heads as i32, + head_dim as i32, + q_seq_len as i32, + kv_len as i32, + k_cache.seq_len as i32, + 1.0f32 / (head_dim as f32).sqrt(), + crate::tensor::active_cu_stream(ctx), + ) + }; + if result != 0 { + anyhow::bail!("single_prefill_nhd_noncausal_cuda failed with error {result}"); + } + Ok(()) +} + /// Batched QK RMSNorm + partial RoPE for Qwen3.5 HD256 decode. /// /// Reads Q from interleaved `q_full` ([q, gate] per head), writes prepared Q into `q`, diff --git a/openinfer-kernels/src/ops/elementwise.rs b/openinfer-kernels/src/ops/elementwise.rs index 6956ab47..ab37b8bc 100644 --- a/openinfer-kernels/src/ops/elementwise.rs +++ b/openinfer-kernels/src/ops/elementwise.rs @@ -176,6 +176,90 @@ pub fn gather_hidden_tokens_into( Ok(()) } +pub fn copy_hidden_rows_into( + ctx: &DeviceContext, + src: &HiddenStates, + dst: &mut HiddenStates, + row_offset: usize, +) -> Result<()> { + assert!( + row_offset + src.hidden_dim <= dst.hidden_dim, + "row range [{}..{}) exceeds destination hidden_dim {}", + row_offset, + row_offset + src.hidden_dim, + dst.hidden_dim + ); + assert_eq!( + src.seq_len, dst.seq_len, + "copy_hidden_rows_into seq_len mismatch: src {}, dst {}", + src.seq_len, dst.seq_len + ); + + let (src_ptr, _gs) = src.data.device_ptr(&ctx.stream); + let (dst_ptr, _gd) = dst.data.device_ptr_mut(&ctx.stream); + let result = unsafe { + ffi::copy_hidden_rows_cuda( + src_ptr as *const ffi::Half, + dst_ptr as *mut ffi::Half, + src.hidden_dim as i32, + dst.hidden_dim as i32, + row_offset as i32, + src.hidden_dim as i32, + src.seq_len as i32, + ctx.stream.cu_stream(), + ) + }; + result.result()?; + Ok(()) +} + +pub fn copy_hidden_token_range_into( + ctx: &DeviceContext, + src: &HiddenStates, + src_token_offset: usize, + dst: &mut HiddenStates, + dst_token_offset: usize, + token_count: usize, +) -> Result<()> { + assert_eq!( + src.hidden_dim, dst.hidden_dim, + "copy_hidden_token_range_into hidden_dim mismatch: src {}, dst {}", + src.hidden_dim, dst.hidden_dim + ); + assert!( + src_token_offset + token_count <= src.seq_len, + "source token range [{}..{}) exceeds seq_len {}", + src_token_offset, + src_token_offset + token_count, + src.seq_len + ); + assert!( + dst_token_offset + token_count <= dst.seq_len, + "destination token range [{}..{}) exceeds seq_len {}", + dst_token_offset, + dst_token_offset + token_count, + dst.seq_len + ); + + let (src_ptr, _gs) = src.data.device_ptr(&ctx.stream); + let (dst_ptr, _gd) = dst.data.device_ptr_mut(&ctx.stream); + let result = unsafe { + ffi::copy_hidden_token_range_cuda( + src_ptr as *const ffi::Half, + dst_ptr as *mut ffi::Half, + src.hidden_dim as i32, + src_token_offset as i32, + dst_token_offset as i32, + token_count as i32, + src.seq_len as i32, + dst.seq_len as i32, + ctx.stream.cu_stream(), + ) + }; + result.result()?; + Ok(()) +} + pub fn scaled_add_rows_indexed_into( ctx: &DeviceContext, delta: &HiddenStates, diff --git a/openinfer-kv-cache/src/pool.rs b/openinfer-kv-cache/src/pool.rs index 77d08256..f8797a05 100644 --- a/openinfer-kv-cache/src/pool.rs +++ b/openinfer-kv-cache/src/pool.rs @@ -262,7 +262,9 @@ impl LoadReservation { /// Per-request KV state wrapping `SchedulableSequence`. /// /// Lifecycle: `schedule_prefill → prefill_view/pages → forward → apply_prefill`, -/// then `schedule_decode → decode_view/pages → forward → apply_decode` in a loop. +/// then either `schedule_decode → decode_view → forward → apply_decode` or +/// `schedule_speculative → speculative_view → forward → apply_speculative` in a +/// loop (`revert_schedule` undoes a reservation whose step failed). pub struct RequestKv { seq: SchedulableSequence<()>, } @@ -299,6 +301,19 @@ impl RequestKv { self.seq.schedule_decode(&pool.block_manager) } + /// Reserve KV for a speculative verify step covering `num_draft_tokens` + /// consecutive positions (current dangling token + draft candidates). + /// [`Self::apply_speculative`] commits the accepted prefix; on any failure + /// [`Self::revert_schedule`] returns the reservation. + pub fn schedule_speculative( + &mut self, + num_draft_tokens: usize, + pool: &BlockPool, + ) -> Result<(), ScheduleError> { + self.seq + .schedule_speculative(num_draft_tokens, &pool.block_manager) + } + // ── Views (for forward pass) ─────────────────────────────────────── /// Build an immutable `KvView` for prefill. @@ -327,6 +342,20 @@ impl RequestKv { ) } + /// Build an immutable `KvView` for speculative verification: the verifier + /// forwards `num_draft_tokens` consecutive positions (current dangling token + /// followed by draft candidates). The view covers the post-step KV extent; + /// [`Self::apply_speculative`] later commits only the accepted prefix and + /// releases excess draft capacity. Same exact page-row contract as + /// [`Self::prefill_view`]. + pub fn speculative_view(&self, num_draft_tokens: usize) -> KvView { + KvView::new( + self.step_page_indices(num_draft_tokens), + self.seq.kv_position() + num_draft_tokens, + self.seq.block_size(), + ) + } + // ── Apply (register blocks, advance position) ────────────────────── pub fn apply_prefill(&mut self, token: u32, pool: &BlockPool) -> anyhow::Result<()> { @@ -350,6 +379,26 @@ impl RequestKv { .map_err(|e| anyhow::anyhow!("apply_decode: {e}")) } + /// Commit the accepted prefix of a speculative verify step. kvbm keeps the + /// `accepted_tokens` KV and LIFO-releases the rejected draft blocks. + pub fn apply_speculative( + &mut self, + accepted_tokens: &[u32], + pool: &BlockPool, + ) -> anyhow::Result { + self.seq + .apply_speculative(accepted_tokens, &pool.block_manager) + .map_err(|e| anyhow::anyhow!("apply_speculative: {e}")) + } + + /// Undo a scheduled-but-unapplied KV reservation (e.g. a speculative + /// schedule whose forward or apply failed) and return its blocks to the pool. + pub fn revert_schedule(&mut self) -> anyhow::Result<()> { + self.seq + .revert_schedule() + .map_err(|e| anyhow::anyhow!("revert_schedule: {e}")) + } + pub fn release(&mut self) -> anyhow::Result<()> { self.seq .release() diff --git a/openinfer-qwen3-4b/src/config.rs b/openinfer-qwen3-4b/src/config.rs index c02b140e..abffe4d7 100644 --- a/openinfer-qwen3-4b/src/config.rs +++ b/openinfer-qwen3-4b/src/config.rs @@ -38,6 +38,29 @@ pub(crate) struct Config { pub(crate) stop_token_ids: Vec, } +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct DFlashConfig { + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: usize, + pub(crate) num_target_layers: usize, + pub(crate) head_dim: usize, + pub(crate) vocab_size: usize, + pub(crate) rms_norm_eps: f32, + pub(crate) rope_theta: f32, + pub(crate) max_position_embeddings: usize, + pub(crate) block_size: usize, + pub(crate) dflash_config: DFlashInnerConfig, +} + +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct DFlashInnerConfig { + pub(crate) mask_token_id: u32, + pub(crate) target_layer_ids: Vec, +} + fn default_max_position_embeddings() -> usize { 40960 } @@ -117,6 +140,85 @@ impl Config { } } +impl DFlashConfig { + pub(crate) fn from_file(model_path: &str) -> Result { + let config_path = format!("{}/config.json", model_path); + let content = fs::read_to_string(&config_path)?; + Ok(serde_json::from_str(&content)?) + } + + pub(crate) fn validate_for_target(&self, target: &Config) -> Result<()> { + anyhow::ensure!( + self.hidden_size == target.hidden_size, + "DFlash hidden_size {} does not match target {}", + self.hidden_size, + target.hidden_size + ); + anyhow::ensure!( + self.num_target_layers == target.num_hidden_layers, + "DFlash num_target_layers {} does not match target layers {}", + self.num_target_layers, + target.num_hidden_layers + ); + anyhow::ensure!( + self.num_attention_heads == target.num_attention_heads + && self.num_key_value_heads == target.num_key_value_heads + && self.head_dim == target.head_dim, + "DFlash attention geometry does not match target" + ); + anyhow::ensure!( + self.vocab_size == target.vocab_size, + "DFlash vocab_size {} does not match target {}", + self.vocab_size, + target.vocab_size + ); + anyhow::ensure!( + self.rope_theta == target.rope_theta, + "DFlash rope_theta {} does not match target {}", + self.rope_theta, + target.rope_theta + ); + anyhow::ensure!( + self.max_position_embeddings >= target.max_position_embeddings, + "DFlash max_position_embeddings {} is smaller than target {}", + self.max_position_embeddings, + target.max_position_embeddings + ); + anyhow::ensure!( + self.block_size >= 2, + "DFlash block_size must be >= 2, got {}", + self.block_size + ); + anyhow::ensure!( + self.dflash_config.mask_token_id < target.vocab_size as u32, + "DFlash mask_token_id {} is outside target vocab_size {}", + self.dflash_config.mask_token_id, + target.vocab_size + ); + anyhow::ensure!( + self.dflash_config.target_layer_ids.len() == self.num_hidden_layers, + "DFlash target_layer_ids length {} does not match draft layers {}", + self.dflash_config.target_layer_ids.len(), + self.num_hidden_layers + ); + anyhow::ensure!( + self.dflash_config + .target_layer_ids + .iter() + .all(|&layer| layer < target.num_hidden_layers), + "DFlash target_layer_ids must be within target layer count" + ); + anyhow::ensure!( + self.dflash_config + .target_layer_ids + .windows(2) + .all(|pair| pair[0] < pair[1]), + "DFlash target_layer_ids must be strictly increasing" + ); + Ok(()) + } +} + impl TensorParallelConfig { pub(crate) fn validate_for(self, config: &Config) -> Result<()> { if self.world_size == 0 { diff --git a/openinfer-qwen3-4b/src/dflash.rs b/openinfer-qwen3-4b/src/dflash.rs new file mode 100644 index 00000000..96b80d79 --- /dev/null +++ b/openinfer-qwen3-4b/src/dflash.rs @@ -0,0 +1,1046 @@ +use anyhow::{Context, Result}; +use cudarc::driver::CudaSlice; +use log::debug; + +use crate::config::DFlashConfig; +use crate::weights::{Attention, MLP, Qwen3Model, TransformerBlock}; +use openinfer_core::ops; +use openinfer_core::tensor::HiddenStates; +use openinfer_core::tensor::{DeviceContext, DeviceMatrix, DeviceVec}; +use openinfer_core::weight_loader::{ + deserialize_shards, load_shard_info, load_tensor_1d, load_tensor_2d, mmap_shards, + precompute_rope, +}; + +pub(crate) struct DFlashDraftModel { + config: DFlashConfig, + layers: Vec, + norm: DeviceVec, + hidden_norm: DeviceVec, + fc: DeviceMatrix, + cos_cache: DeviceVec, + sin_cache: DeviceVec, +} + +pub(crate) struct DFlashRequestState { + layers: Vec, + pending_context: DFlashPendingContext, + /// Projected target context for the current draft round. Computed once from + /// `pending_context` and read by every layer's tail concat, so it lives with + /// the request (the batched scratch only holds one request's varlen tail). + context: DFlashContextScratch, + committed_len: usize, + max_cache_len: usize, +} + +/// GPU memory DFlash needs on top of the target KV pool, derived from the draft +/// config so the KV budget can reserve it *before* the draft model loads (the +/// draft buffers live outside the paged `KvCacheManager`). Split by how it scales: +/// +/// - `kv_bytes_per_token` scales with the KV pool (billed by shrinking the target +/// block count): the draft's own KV cache plus the per-request context-projection +/// and pending-context buffers, which currently persist at prompt length per +/// request (see `dflash-speculative-decoding.md` — collapsing that persistence +/// is a tracked follow-up that would shrink this term to the draft KV alone). +/// - `fixed_bytes` does not scale with the pool (billed via the memory margin): +/// the draft weights plus the lane-level batched scratch sized for the whole +/// decode batch. +/// +// TODO: the draft scratch is now a single lane-level `DFlashBatchScratch` +// allocated once (dense buffers sized `max_batch * block_size`, plus one shared +// varlen tail), not a per-request buffer. The per-token `tail_scratch` term and +// the per-request `block_headroom` tail term are therefore over-estimates — kept +// as a conservative upper bound until the accounting is retuned against the +// batched allocation. +pub(crate) struct DFlashMemoryReservation { + pub(crate) kv_bytes_per_token: usize, + pub(crate) fixed_bytes: usize, +} + +impl DFlashMemoryReservation { + pub(crate) fn from_path(draft_path: &str, max_decode_batch_size: usize) -> Result { + let config = DFlashConfig::from_file(draft_path)?; + Ok(Self::from_config(&config, max_decode_batch_size)) + } + + fn from_config(config: &DFlashConfig, max_decode_batch_size: usize) -> Self { + const BF16: usize = 2; + let hidden = config.hidden_size; + let kv_dim = config.num_key_value_heads * config.head_dim; + let q_dim = config.num_attention_heads * config.head_dim; + let inter = config.intermediate_size; + let capture_layers = config.dflash_config.target_layer_ids.len(); + + // Per-sequence-token, pool-scaling buffers. + let draft_kv = config.num_hidden_layers * 2 * kv_dim * BF16; // DFlashLayerCache k+v + // Scratch split by what it tracks: `context_*` grows with the committed + // prefix; `tail_*` (tail_input + k_tail + v_tail) grows with the in-fill + // tail, which is one block past the prefix. + let context_scratch = 2 * hidden * BF16; // context_projected + context_hidden + let tail_scratch = (hidden + 2 * kv_dim) * BF16; // tail_input + k_tail + v_tail + let pending = hidden * capture_layers * BF16; // context_feature_dim + let kv_bytes_per_token = draft_kv + context_scratch + tail_scratch + pending; + + // Lane-level batched dense scratch: every dense buffer is sized for the + // whole decode batch (`max_batch * block_size` rows), allocated once. + // Same total magnitude as the old per-request scratch summed over the + // batch, but now one contiguous allocation. + let dense_scratch_per_block_row = + BF16 * (config.vocab_size + 5 * hidden + 2 * q_dim + 3 * inter); + let scratch_total = dense_scratch_per_block_row * config.block_size * max_decode_batch_size; + + // Draft weights (5 transformer layers + the context projection), +10% slack + // for norms, rope caches, and allocator alignment. + let per_layer = BF16 + * (hidden * (q_dim + 2 * kv_dim) // qkv_proj + + q_dim * hidden // o_proj + + hidden * 2 * inter // gate_up_proj + + inter * hidden); // down_proj + let fc = BF16 * hidden * (hidden * capture_layers); // context projection + let weights = per_layer * config.num_hidden_layers + fc; + let weights = weights + weights / 10; + + // The durable draft KV and the tail scratch are sized to `context + + // block_size` — one in-fill block past the lifetime the KV pool reserves + // for the request. The per-token term bills only the pool's tokens, so + // reserve that one-block headroom per concurrently decoding request to + // keep the reservation an upper bound. + let block_headroom = max_decode_batch_size * config.block_size * (draft_kv + tail_scratch); + + Self { + kv_bytes_per_token, + fixed_bytes: weights + scratch_total + block_headroom, + } + } +} + +struct DFlashLayerCache { + k: HiddenStates, + v: HiddenStates, +} + +struct DFlashPendingContext { + buffer: HiddenStates, + len: usize, + capacity: usize, +} + +/// Per-request projected context. The fc projection + hidden_norm turn the +/// captured target hidden context into draft hidden space once per draft round; +/// every layer's tail concat reads `context_hidden`, so it must persist across +/// the layer loop and therefore lives in the request (not the shared scratch). +struct DFlashContextScratch { + max_context_len: usize, + context_projected: HiddenStates, + context_hidden: HiddenStates, +} + +/// Lane-level batched draft scratch, allocated once for the whole decode batch. +/// +/// Dense buffers (`hidden`, `normed`, `q_batch`, `attn_output`, the MLP buffers, +/// and `logits`) hold `max_batch * block_size` rows so the GEMM / rms_norm / +/// silu / add / logits / embedding ops run ONCE over the batched buffer. The +/// varlen tail buffers (`tail_input`, `k_tail`, `v_tail`) stay sized for a single +/// request and are reused inside the per-request loop, because their ops (tail +/// concat, k/v GEMMs, rope, KV copy, attention) still loop per request — Step 2 +/// will batch those via CUDA-kernel changes. +pub(crate) struct DFlashBatchScratch { + max_batch_block_rows: usize, + max_tail_len: usize, + block_token_ids_h: Vec, + token_ids_d: CudaSlice, + hidden: HiddenStates, + hidden_out: HiddenStates, + normed: HiddenStates, + q_batch: HiddenStates, + attn_output: HiddenStates, + o_buf: HiddenStates, + gate_out: HiddenStates, + up_out: HiddenStates, + act_out: HiddenStates, + logits_normed: HiddenStates, + logits: HiddenStates, + // Shared single-request varlen tail scratch (reused inside the per-request loop). + tail_input: HiddenStates, + k_tail: HiddenStates, + v_tail: HiddenStates, +} + +impl DFlashRequestState { + pub(crate) fn pending_context_len(&self) -> Option { + (self.pending_context.len > 0).then_some(self.pending_context.len) + } +} + +impl DFlashPendingContext { + fn new(ctx: &DeviceContext, hidden_dim: usize, capacity: usize) -> Result { + anyhow::ensure!( + capacity > 0, + "DFlash pending context capacity must be non-zero" + ); + let mut buffer = HiddenStates::zeros(ctx, hidden_dim, capacity)?; + buffer.seq_len = 0; + Ok(Self { + buffer, + len: 0, + capacity, + }) + } + + fn append_from( + &mut self, + ctx: &DeviceContext, + src: &HiddenStates, + src_token_offset: usize, + token_count: usize, + max_capacity: usize, + ) -> Result<()> { + let required_len = self + .len + .checked_add(token_count) + .context("DFlash pending context length overflow")?; + anyhow::ensure!( + required_len <= max_capacity, + "DFlash pending context length {} exceeds request capacity {}", + required_len, + max_capacity + ); + self.ensure_capacity(ctx, required_len, max_capacity)?; + self.buffer.seq_len = self.capacity; + ops::copy_hidden_token_range_into( + ctx, + src, + src_token_offset, + &mut self.buffer, + self.len, + token_count, + )?; + self.len = required_len; + self.buffer.seq_len = self.len; + Ok(()) + } + + fn ensure_capacity( + &mut self, + ctx: &DeviceContext, + required_len: usize, + max_capacity: usize, + ) -> Result<()> { + if required_len <= self.capacity { + return Ok(()); + } + let doubled = self + .capacity + .checked_mul(2) + .context("DFlash pending context capacity overflow")?; + let new_capacity = required_len.max(doubled).min(max_capacity); + anyhow::ensure!( + new_capacity >= required_len, + "DFlash pending context capacity {} cannot fit {} tokens", + new_capacity, + required_len + ); + let mut next = HiddenStates::zeros(ctx, self.buffer.hidden_dim, new_capacity)?; + if self.len > 0 { + self.buffer.seq_len = self.capacity; + ops::copy_hidden_token_range_into(ctx, &self.buffer, 0, &mut next, 0, self.len)?; + } + next.seq_len = self.len; + self.buffer = next; + self.capacity = new_capacity; + Ok(()) + } + + fn activate_for_read(&mut self) { + self.buffer.seq_len = self.len; + } + + fn clear(&mut self) { + self.len = 0; + self.buffer.seq_len = 0; + } +} + +impl DFlashContextScratch { + fn new(ctx: &DeviceContext, hidden_size: usize, max_context_len: usize) -> Result { + Ok(Self { + max_context_len, + context_projected: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, + context_hidden: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, + }) + } + + fn ensure_capacity( + &mut self, + ctx: &DeviceContext, + hidden_size: usize, + context_len: usize, + ) -> Result<()> { + if context_len > self.max_context_len { + *self = Self::new(ctx, hidden_size, context_len)?; + } + self.context_projected.seq_len = context_len; + self.context_hidden.seq_len = context_len; + Ok(()) + } +} + +impl DFlashBatchScratch { + fn new( + ctx: &DeviceContext, + config: &DFlashConfig, + max_decode_batch_size: usize, + ) -> Result { + anyhow::ensure!( + max_decode_batch_size > 0, + "DFlash batch scratch needs a non-zero batch size" + ); + let block_size = config.block_size; + let hidden_size = config.hidden_size; + let q_dim = config.num_attention_heads * config.head_dim; + let kv_dim = config.num_key_value_heads * config.head_dim; + let inter_dim = config.intermediate_size; + // Dense buffers span the whole decode batch so the dense ops run once. + let batch_rows = block_size * max_decode_batch_size; + // The shared varlen tail starts at one block (no committed context yet) + // and grows on demand via `ensure_tail_capacity`. + let tail_capacity = block_size; + + Ok(Self { + max_batch_block_rows: batch_rows, + max_tail_len: tail_capacity, + block_token_ids_h: vec![config.dflash_config.mask_token_id; batch_rows], + token_ids_d: ctx.stream.alloc_zeros(batch_rows)?, + hidden: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + hidden_out: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + normed: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + q_batch: HiddenStates::zeros(ctx, q_dim, batch_rows)?, + attn_output: HiddenStates::zeros(ctx, q_dim, batch_rows)?, + o_buf: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + gate_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + up_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + act_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + logits_normed: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + logits: HiddenStates::zeros(ctx, config.vocab_size, batch_rows)?, + tail_input: HiddenStates::zeros(ctx, hidden_size, tail_capacity)?, + k_tail: HiddenStates::zeros(ctx, kv_dim, tail_capacity)?, + v_tail: HiddenStates::zeros(ctx, kv_dim, tail_capacity)?, + }) + } + + /// Point every dense buffer at the active `batch_block_rows = active_batch * + /// block_size` prefix. Allocated for the max decode batch, so this only + /// shrinks `seq_len`; it never reallocates. + fn activate_dense(&mut self, batch_block_rows: usize) { + assert!( + batch_block_rows <= self.max_batch_block_rows, + "DFlash batched draft {} block rows exceeds scratch capacity {}", + batch_block_rows, + self.max_batch_block_rows + ); + self.hidden.seq_len = batch_block_rows; + self.hidden_out.seq_len = batch_block_rows; + self.normed.seq_len = batch_block_rows; + self.q_batch.seq_len = batch_block_rows; + self.attn_output.seq_len = batch_block_rows; + self.o_buf.seq_len = batch_block_rows; + self.gate_out.seq_len = batch_block_rows; + self.up_out.seq_len = batch_block_rows; + self.act_out.seq_len = batch_block_rows; + self.logits_normed.seq_len = batch_block_rows; + self.logits.seq_len = batch_block_rows; + } + + /// Size the shared varlen tail buffers for one request's `tail_len = + /// context_len + block_size`, growing the allocation if needed. + fn ensure_tail_capacity( + &mut self, + ctx: &DeviceContext, + config: &DFlashConfig, + tail_len: usize, + ) -> Result<()> { + if tail_len > self.max_tail_len { + let hidden_size = config.hidden_size; + let kv_dim = config.num_key_value_heads * config.head_dim; + self.tail_input = HiddenStates::zeros(ctx, hidden_size, tail_len)?; + self.k_tail = HiddenStates::zeros(ctx, kv_dim, tail_len)?; + self.v_tail = HiddenStates::zeros(ctx, kv_dim, tail_len)?; + self.max_tail_len = tail_len; + } + self.tail_input.seq_len = tail_len; + self.k_tail.seq_len = tail_len; + self.v_tail.seq_len = tail_len; + Ok(()) + } +} + +impl DFlashDraftModel { + pub(crate) fn from_safetensors_for_target( + ctx: &DeviceContext, + model_path: &str, + target: &Qwen3Model, + ) -> Result { + let config = DFlashConfig::from_file(model_path) + .with_context(|| format!("load DFlash config from {model_path}"))?; + config.validate_for_target(target.config())?; + + let (shard_paths, weight_map) = load_shard_info(model_path)?; + debug!( + "Loading DFlash drafter from {model_path}: {} shard(s)", + shard_paths.len() + ); + let mmaps = mmap_shards(&shard_paths)?; + let shards = deserialize_shards(&mmaps)?; + + let mut layers = Vec::with_capacity(config.num_hidden_layers); + for layer_idx in 0..config.num_hidden_layers { + let prefix = format!("layers.{layer_idx}"); + + let q_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.q_proj.weight"), + )?; + let k_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.k_proj.weight"), + )?; + let v_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.v_proj.weight"), + )?; + let q_dim = q_proj.rows; + let kv_dim = k_proj.rows; + let qkv_proj = DeviceMatrix::vstack(ctx, &[&q_proj, &k_proj, &v_proj])?; + drop(q_proj); + drop(k_proj); + drop(v_proj); + + let gate_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.gate_proj.weight"), + )?; + let up_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.up_proj.weight"), + )?; + let gate_up_proj = DeviceMatrix::vstack(ctx, &[&gate_proj, &up_proj])?; + drop(gate_proj); + drop(up_proj); + + layers.push(TransformerBlock { + input_layernorm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.input_layernorm.weight"), + )?, + attention: Attention { + qkv_proj, + o_proj: load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.o_proj.weight"), + )?, + q_norm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.q_norm.weight"), + )?, + k_norm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.k_norm.weight"), + )?, + q_dim, + kv_dim, + }, + post_attention_layernorm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.post_attention_layernorm.weight"), + )?, + mlp: MLP { + gate_up_proj, + down_proj: load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.down_proj.weight"), + )?, + }, + }); + } + + let norm = load_tensor_1d(ctx, &shards, &weight_map, "norm.weight")?; + let hidden_norm = load_tensor_1d(ctx, &shards, &weight_map, "hidden_norm.weight")?; + let fc = load_tensor_2d(ctx, &shards, &weight_map, "fc.weight")?; + let (cos_cache, sin_cache) = precompute_rope( + ctx, + config.head_dim, + config.max_position_embeddings, + config.rope_theta, + )?; + ctx.sync()?; + + Ok(Self { + config, + layers, + norm, + hidden_norm, + fc, + cos_cache, + sin_cache, + }) + } + + pub(crate) fn block_size(&self) -> usize { + self.config.block_size + } + + /// Largest sequence position the draft can cache. `validate_for_target` + /// guarantees this is `>=` the target's, but the draft's per-step in-fill + /// block writes `block_size` transient positions past the committed length, + /// so the usable context is `max_position_embeddings - block_size`. + pub(crate) fn max_position_embeddings(&self) -> usize { + self.config.max_position_embeddings + } + + pub(crate) fn mask_token_id(&self) -> u32 { + self.config.dflash_config.mask_token_id + } + + pub(crate) fn target_layer_ids(&self) -> &[usize] { + &self.config.dflash_config.target_layer_ids + } + + pub(crate) fn tune_gemm_algos(&self, target: &Qwen3Model) -> Result<()> { + let ctx = target.device_ctx(); + let block_size = self.block_size().min(ops::GEMM_LT_MAX_N); + let hidden = self.config.hidden_size; + let q_dim = self.config.num_attention_heads * self.config.head_dim; + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + let context_dim = self.context_feature_dim(); + + let fc_samples = [(&self.fc, 0)]; + for n in 1..=block_size { + ops::gemm_lt_tune(ctx, &fc_samples, hidden, n)?; + } + + let kv_samples: Vec<_> = self + .layers + .iter() + .flat_map(|layer| { + [ + (&layer.attention.qkv_proj, q_dim), + (&layer.attention.qkv_proj, q_dim + kv_dim), + ] + }) + .collect(); + let min_tail_n = self.block_size() + 1; + let max_tail_n = (self.block_size() * 2).min(ops::GEMM_LT_MAX_N); + for n in min_tail_n..=max_tail_n { + ops::gemm_lt_tune(ctx, &kv_samples, kv_dim, n)?; + } + + log::info!( + "Qwen3 DFlash cublasLt tuned: fc M={} K={} N=1..{}, kv M={} K={} N={}..{}", + hidden, + context_dim, + block_size, + kv_dim, + hidden, + min_tail_n, + max_tail_n, + ); + Ok(()) + } + + /// Allocate the lane-level batched draft scratch once, sized for the whole + /// decode batch. The per-request `DFlashRequestState` no longer owns scratch. + pub(crate) fn new_batch_scratch( + &self, + ctx: &DeviceContext, + max_decode_batch_size: usize, + ) -> Result { + DFlashBatchScratch::new(ctx, &self.config, max_decode_batch_size) + } + + pub(crate) fn new_request_state( + &self, + ctx: &DeviceContext, + max_cache_len: usize, + ) -> Result { + anyhow::ensure!( + max_cache_len <= self.config.max_position_embeddings, + "DFlash request cache length {} exceeds max_position_embeddings {}", + max_cache_len, + self.config.max_position_embeddings + ); + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + let mut layers = Vec::with_capacity(self.layers.len()); + for _ in 0..self.layers.len() { + layers.push(DFlashLayerCache { + k: HiddenStates::zeros(ctx, kv_dim, max_cache_len)?, + v: HiddenStates::zeros(ctx, kv_dim, max_cache_len)?, + }); + } + Ok(DFlashRequestState { + layers, + pending_context: DFlashPendingContext::new( + ctx, + self.context_feature_dim(), + self.config.block_size.min(max_cache_len), + )?, + context: DFlashContextScratch::new( + ctx, + self.config.hidden_size, + self.config.block_size, + )?, + committed_len: 0, + max_cache_len, + }) + } + + pub(crate) fn append_pending_context( + &self, + ctx: &DeviceContext, + state: &mut DFlashRequestState, + captured_hidden: &HiddenStates, + token_offset: usize, + token_count: usize, + ) -> Result<()> { + anyhow::ensure!(token_count > 0, "DFlash context append needs tokens"); + anyhow::ensure!( + captured_hidden.hidden_dim == self.context_feature_dim(), + "DFlash captured hidden dim {} does not match expected {}", + captured_hidden.hidden_dim, + self.context_feature_dim() + ); + anyhow::ensure!( + token_offset + token_count <= captured_hidden.seq_len, + "DFlash captured hidden token range exceeds source" + ); + let required_committed_len = state + .committed_len + .checked_add(state.pending_context.len) + .and_then(|len| len.checked_add(token_count)) + .and_then(|len| len.checked_add(self.block_size())) + .context("DFlash pending context cache length overflow")?; + anyhow::ensure!( + required_committed_len <= state.max_cache_len, + "DFlash pending context would exceed cache: committed={}, pending={}, append={}, block={}, max={}", + state.committed_len, + state.pending_context.len, + token_count, + self.block_size(), + state.max_cache_len + ); + state.pending_context.append_from( + ctx, + captured_hidden, + token_offset, + token_count, + state.max_cache_len, + )?; + Ok(()) + } + + /// Batched draft forward over all active requests at once. + /// + /// The *dense* ops (embedding, rms_norm, q / o / gate_up / down GEMMs, silu, + /// add, fused_add_rms_norm, logits) run ONCE over an `active_batch * + /// block_size` batched buffer. The *varlen* ops (context projection, tail + /// concat, k/v GEMMs, rope, KV copy, attention) still loop per request, + /// slicing each request's `block_size` rows at offset `i * block_size` in the + /// batched buffers — those are Step 2/3's job to batch via CUDA-kernel changes. + /// + /// Returns the batched logits (`active_batch * block_size` rows): request `i` + /// owns rows `[i * block_size, (i + 1) * block_size)`. + pub(crate) fn draft_logits_batched<'a>( + &self, + target: &Qwen3Model, + states: &mut [&mut DFlashRequestState], + current_tokens: &[u32], + scratch: &'a mut DFlashBatchScratch, + ) -> Result<&'a HiddenStates> { + let ctx = target.device_ctx(); + let active_batch = states.len(); + anyhow::ensure!( + active_batch > 0, + "DFlash batched draft needs active requests" + ); + anyhow::ensure!( + states.len() == current_tokens.len(), + "DFlash batched draft: {} states vs {} current tokens", + states.len(), + current_tokens.len() + ); + let block_size = self.block_size(); + let batch_block_rows = active_batch * block_size; + + // Each request's committed context length for this round; advancing + // `committed_len` is deferred until after the layer loop (the rope start + // positions and KV write offsets read the pre-advance value). + let mut context_lens = Vec::with_capacity(active_batch); + for (i, state) in states.iter().enumerate() { + let Some(context_len) = state.pending_context_len() else { + anyhow::bail!( + "DFlash draft requested before target hidden context is available (request slot {i})" + ); + }; + let tail_len = context_len + block_size; + anyhow::ensure!( + state.committed_len + tail_len <= state.max_cache_len, + "DFlash draft cache overflow: committed={}, tail={}, max={}", + state.committed_len, + tail_len, + state.max_cache_len + ); + context_lens.push(context_len); + } + + scratch.activate_dense(batch_block_rows); + + // Build the batched token id buffer: each request's block is + // [current_token, mask, mask, ...]. + scratch.block_token_ids_h[..batch_block_rows].fill(self.mask_token_id()); + for (i, ¤t_token) in current_tokens.iter().enumerate() { + scratch.block_token_ids_h[i * block_size] = current_token; + } + // token_ids_d holds `max_batch * block_size` ids; copy only the active + // prefix. The embedding kernel reads `out.seq_len = batch_block_rows` ids + // from the buffer start, so the active prefix is what it consumes. + let mut token_ids_dst = scratch.token_ids_d.slice_mut(..batch_block_rows); + ctx.stream.memcpy_htod( + &scratch.block_token_ids_h[..batch_block_rows], + &mut token_ids_dst, + )?; + target.get_embeddings_batch_into(&scratch.token_ids_d, &mut scratch.hidden)?; + + // Per-request context projection: varlen (each request's committed + // prefix differs), persisted in the request so every layer can read it. + for (i, state) in states.iter_mut().enumerate() { + let context_len = context_lens[i]; + state + .context + .ensure_capacity(ctx, self.config.hidden_size, context_len)?; + state.pending_context.activate_for_read(); + self.project_context_into(ctx, &state.pending_context.buffer, &mut state.context)?; + state.pending_context.clear(); + } + + let hidden_size = self.config.hidden_size; + let q_dim = self.config.num_attention_heads * self.config.head_dim; + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + let inter_dim = self.config.intermediate_size; + debug_assert_eq!(scratch.hidden.hidden_dim, hidden_size); + debug_assert_eq!(scratch.q_batch.hidden_dim, q_dim); + debug_assert_eq!(scratch.k_tail.hidden_dim, kv_dim); + debug_assert_eq!(scratch.gate_out.hidden_dim, inter_dim); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + // Dense: input layernorm over the whole batch. + ops::rms_norm_batch_into( + ctx, + &scratch.hidden, + &layer.input_layernorm, + self.config.rms_norm_eps, + &mut scratch.normed, + ); + + // Dense: Q projection over the whole batch (per-token, no cross-request + // mixing). Computed before the per-request loop reads `normed`, and + // before the post-attention norm overwrites it. + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + 0, + q_dim, + &scratch.normed, + &mut scratch.q_batch, + ); + + // Per-request varlen attention: tail concat, k/v GEMMs, rope, KV copy, + // single-request prefill. Each request slices its `block_size` rows at + // offset `i * block_size` of the batched `normed`/`q_batch`/`attn_output`. + for (i, state) in states.iter_mut().enumerate() { + let context_len = context_lens[i]; + let tail_len = context_len + block_size; + let row_offset = i * block_size; + scratch.ensure_tail_capacity(ctx, &self.config, tail_len)?; + + // tail_input = [context_hidden(context_len) | normed_block(block_size)]. + ops::copy_hidden_token_range_into( + ctx, + &state.context.context_hidden, + 0, + &mut scratch.tail_input, + 0, + context_len, + )?; + ops::copy_hidden_token_range_into( + ctx, + &scratch.normed, + row_offset, + &mut scratch.tail_input, + context_len, + block_size, + )?; + + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + q_dim, + kv_dim, + &scratch.tail_input, + &mut scratch.k_tail, + ); + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + q_dim + kv_dim, + kv_dim, + &scratch.tail_input, + &mut scratch.v_tail, + ); + + ops::dflash_qk_norm_rope_into( + ctx, + &mut scratch.q_batch, + row_offset, + block_size, + &mut scratch.k_tail, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + state.committed_len + context_len, + state.committed_len, + self.config.rms_norm_eps, + )?; + + let cache = &mut state.layers[layer_idx]; + ops::copy_hidden_token_range_into( + ctx, + &scratch.k_tail, + 0, + &mut cache.k, + state.committed_len, + tail_len, + )?; + ops::copy_hidden_token_range_into( + ctx, + &scratch.v_tail, + 0, + &mut cache.v, + state.committed_len, + tail_len, + )?; + ops::single_prefill_nhd_noncausal_into( + ctx, + &scratch.q_batch, + row_offset, + block_size, + &cache.k, + &cache.v, + &mut scratch.attn_output, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + state.committed_len + tail_len, + )?; + } + + // Dense: o_proj + residual + post-attention norm + MLP over the batch. + ops::gemm_into( + ctx, + &layer.attention.o_proj, + &scratch.attn_output, + &mut scratch.o_buf, + ); + openinfer_kernels::ops::fused_add_rms_norm_round_batch_into( + ctx, + &mut scratch.hidden, + &scratch.o_buf, + &layer.post_attention_layernorm, + self.config.rms_norm_eps, + &mut scratch.normed, + )?; + + ops::gemm_rows_into( + ctx, + &layer.mlp.gate_up_proj, + 0, + inter_dim, + &scratch.normed, + &mut scratch.gate_out, + ); + ops::gemm_rows_into( + ctx, + &layer.mlp.gate_up_proj, + inter_dim, + inter_dim, + &scratch.normed, + &mut scratch.up_out, + ); + ops::silu_mul_batch_into( + ctx, + &scratch.gate_out, + &scratch.up_out, + &mut scratch.act_out, + )?; + ops::gemm_into( + ctx, + &layer.mlp.down_proj, + &scratch.act_out, + &mut scratch.o_buf, + ); + ops::add_batch_into( + ctx, + &scratch.hidden, + &scratch.o_buf, + &mut scratch.hidden_out, + )?; + std::mem::swap(&mut scratch.hidden, &mut scratch.hidden_out); + } + + for (i, state) in states.iter_mut().enumerate() { + state.committed_len += context_lens[i]; + } + self.compute_logits_with_target_head_into(target, scratch)?; + Ok(&scratch.logits) + } + + fn context_feature_dim(&self) -> usize { + self.config.hidden_size * self.target_layer_ids().len() + } + + fn project_context_into( + &self, + ctx: &DeviceContext, + context_features: &HiddenStates, + context: &mut DFlashContextScratch, + ) -> Result<()> { + ops::gemm_into( + ctx, + &self.fc, + context_features, + &mut context.context_projected, + ); + ops::rms_norm_batch_into( + ctx, + &context.context_projected, + &self.hidden_norm, + self.config.rms_norm_eps, + &mut context.context_hidden, + ); + Ok(()) + } + + fn compute_logits_with_target_head_into( + &self, + target: &Qwen3Model, + scratch: &mut DFlashBatchScratch, + ) -> Result<()> { + let ctx = target.device_ctx(); + ops::rms_norm_batch_into( + ctx, + &scratch.hidden, + &self.norm, + self.config.rms_norm_eps, + &mut scratch.logits_normed, + ); + ops::gemm_into( + ctx, + target.output_projection(), + &scratch.logits_normed, + &mut scratch.logits, + ); + Ok(()) + } +} + +#[cfg(test)] +pub(crate) fn validate_dflash_config_for_target( + dflash_path: &str, + target_config: &crate::config::Config, +) -> Result { + let config = DFlashConfig::from_file(dflash_path)?; + config.validate_for_target(target_config)?; + Ok(config) +} + +#[cfg(test)] +mod tests { + use super::validate_dflash_config_for_target; + use crate::config::Config; + use std::path::Path; + + #[test] + fn downloaded_dflash_config_matches_qwen3_4b() { + let target_path = std::env::var("OPENINFER_TEST_MODEL_PATH") + .unwrap_or_else(|_| "/data/models/Qwen3-4B".to_string()); + let dflash_path = std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") + .unwrap_or_else(|_| "/data/models/Qwen3-4B-DFlash-b16".to_string()); + if !Path::new(&target_path).join("config.json").exists() + || !Path::new(&dflash_path).join("config.json").exists() + { + eprintln!( + "skipping DFlash config test; set OPENINFER_TEST_MODEL_PATH and OPENINFER_DFLASH_TEST_MODEL_PATH" + ); + return; + } + + let target = Config::from_file(&target_path).expect("target config"); + let dflash = validate_dflash_config_for_target(&dflash_path, &target) + .expect("DFlash config should match target"); + + assert_eq!(dflash.block_size, 16); + assert_eq!(dflash.dflash_config.mask_token_id, 151669); + assert_eq!( + dflash.dflash_config.target_layer_ids, + vec![1, 9, 17, 25, 33] + ); + + // Pin the memory reservation the KV budget bills against. The per-token + // term (draft KV 5*2*1024*2 + scratch-context (3*2560+2*1024)*2 + pending + // 2560*5*2) drives the ~12% block haircut; a layer-count or geometry + // regression here would silently over/under-reserve and risk OOM. + let reservation = + super::DFlashMemoryReservation::from_config(&dflash, /*max_decode_batch*/ 256); + assert_eq!( + reservation.kv_bytes_per_token, 65_536, + "draft KV(20480) + scratch-ctx(19456) + pending(25600) per token" + ); + // Weights (~1.1 GiB) dominate the fixed term at batch=1; the block-sized + // per-request scratch (~6.5 MiB, logits-heavy) plus the one-block KV/tail + // headroom (~0.5 MiB) add across the decode batch. + let fixed_batch1 = super::DFlashMemoryReservation::from_config(&dflash, 1).fixed_bytes; + assert!( + (1_150_000_000..1_220_000_000).contains(&fixed_batch1), + "draft weights ~1.1GiB, got {fixed_batch1}" + ); + assert!( + (2_900_000_000..3_000_000_000).contains(&reservation.fixed_bytes), + "weights + 256 * (~6.5MiB scratch + ~0.5MiB block headroom), got {}", + reservation.fixed_bytes + ); + } +} diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 95fe3d4b..cf8ed126 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -18,6 +18,20 @@ use openinfer_kv_cache::{ }; use openinfer_kv_offload::{LoadHandle, OffloadConfig, OffloadEngine}; +mod dflash_lane; +mod dflash_prefill; +mod spec; + +use crate::dflash::DFlashDraftModel; +use crate::speculative::{ + DraftPlan, DraftResult, DraftStepItem, VerifyPlan, VerifyResult, VerifyStepItem, + build_verify_results, +}; +use dflash_lane::DFlashLaneState; +use dflash_prefill::{DFlashPrefillAction, dflash_prefill_action}; + +use crate::verify_graph::VerifyGraphBuffers; + #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] pub struct RequestId(pub(crate) u64); @@ -264,8 +278,26 @@ fn execute_step_on_lane( .iter() .map(|req| req.lora_adapter.as_deref()) .collect(); - let (logits, all_position_logits) = - lane.execute_prefill(&prompts, kv_views, &lora_adapters, *echo)?; + // When DFlash is loaded, capture target hidden states for eligible + // requests so they can seed the draft model after prefill finishes. + let capture_requested = lane.should_capture_dflash_prefill_context(requests); + let capture_layer_ids = if capture_requested { + lane.dflash_capture_layer_ids() + } else { + None + }; + let (logits, all_position_logits, captured_hidden) = lane.execute_prefill( + &prompts, + kv_views, + &lora_adapters, + *echo, + capture_layer_ids.as_deref(), + )?; + let dflash_context_captured_requests = lane.record_prefill_dflash_context( + requests, + capture_requested, + captured_hidden.as_ref(), + )?; if collect_result { let params: Vec<&SamplingParams> = requests.iter().map(|r| &r.params).collect(); let tokens = lane.select_step_tokens(&logits, ¶ms, *sample_seed)?; @@ -278,6 +310,7 @@ fn execute_step_on_lane( all_position_logits.as_ref(), *echo, )?, + dflash_context_captured_requests, })) } else { Ok(WorkerStepOutcome::Ack) @@ -396,11 +429,12 @@ fn execute_step_on_lane( // Launch prefill on prefill partition stream. unsafe { set_stream_override(prefill_stream.0) }; - let (prefill_logits, _) = lane.execute_prefill( + let (prefill_logits, _, _) = lane.execute_prefill( &prefill_prompts, prefill_kv_views, &prefill_lora_adapters, false, + None, )?; clear_stream_override(); @@ -458,6 +492,19 @@ fn execute_step_on_lane( Ok(WorkerStepOutcome::Ack) } } + StepCommand::SpeculativeVerify { requests, kv_views } => { + // One target forward over each request's K+1 draft span with a + // speculative KV view. The fixed-buffer verify path computes all- + // position logits (accept_greedy needs the target's posterior at + // each span position) and captures the target hidden states (at the + // DFlash layers) to seed the next draft — all into reused, + // pointer-stable scratch (`VerifyGraphBuffers`). + let result = lane.execute_dflash_verify(requests, kv_views)?; + Ok(WorkerStepOutcome::SpeculativeVerify(result)) + } + StepCommand::SpeculativeDraft { requests } => Ok(WorkerStepOutcome::SpeculativeDraft( + lane.execute_dflash_draft(requests)?, + )), } } @@ -595,6 +642,10 @@ pub struct DecodeRequestResult { pub struct PrefillResult { pub requests: Vec, + /// Requests whose DFlash target context was captured this prefill step. + /// Empty unless speculative decoding is enabled. The executor folds these + /// into its `dflash_ready_requests` set once the prompt is fully prefilled. + pub dflash_context_captured_requests: Vec, } pub struct DecodeResult { @@ -619,6 +670,27 @@ pub(crate) trait ModelExecutor: Send { fn execute_decode(&mut self, plan: DecodePlan<'_>) -> Result; fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> Result; + /// Run one speculative draft round (propose `K` tokens per request). Only + /// meaningful when [`Self::speculative_enabled`] is true. + fn execute_speculative_draft(&mut self, _plan: DraftPlan<'_>) -> Result { + anyhow::bail!("speculative draft is not implemented for this executor") + } + + /// Verify a draft span with one target forward and accept the greedy prefix. + fn execute_speculative_verify(&mut self, _plan: VerifyPlan<'_>) -> Result { + anyhow::bail!("speculative verification is not implemented for this executor") + } + + /// Whether a draft model is loaded and speculative decoding is active. + fn speculative_enabled(&self) -> bool { + false + } + + /// Whether `request_id` has captured draft context and can be drafted. + fn speculative_request_ready(&self, _request_id: RequestId) -> bool { + false + } + fn load_lora_adapter(&mut self, request: &LoadLoraAdapterRequest) -> Result<()> { anyhow::bail!( "Qwen3 LoRA adapter loading is not implemented yet: name={}, path={}", @@ -732,6 +804,13 @@ pub struct Qwen3Executor { /// In-flight async prefill state. Populated by the SplitConcurrent step, /// consumed by `poll_async_prefill`. async_prefill: Option, + /// DFlash draft metadata; `Some` once a draft model is loaded into the + /// primary lane. Speculative decoding is enabled iff this is set. + speculative: Option, + /// Requests whose DFlash context is captured and ready to draft. A request + /// enters this set when its prompt finishes prefilling with captured target + /// context, and leaves on retire or a plain (non-speculative) decode. + dflash_ready_requests: HashSet, } /// State for an in-flight async prefill on the prefill overlap stream. @@ -768,10 +847,15 @@ impl Qwen3Executor { model: Qwen3Model, offload_opts: &Qwen3OffloadOptions, max_prefill_tokens: usize, + dflash_kv_bytes_per_token: usize, memory_options: Qwen3MemoryOptions, ) -> Result { - let (model, budget) = - profile_kv_budget_on_worker(model, max_prefill_tokens, memory_options)?; + let (model, budget) = profile_kv_budget_on_worker( + model, + max_prefill_tokens, + dflash_kv_bytes_per_token, + memory_options, + )?; let kv_mgr = KvCacheManager::new( &model.device_ctx().stream, budget.num_layers, @@ -809,6 +893,8 @@ impl Qwen3Executor { l1_retention_disabled: false, overlap: None, async_prefill: None, + speculative: None, + dflash_ready_requests: HashSet::new(), }) } @@ -824,6 +910,7 @@ impl Qwen3Executor { Qwen3LoraOptions::default(), Qwen3OffloadOptions::disabled(), crate::scheduler::DEFAULT_MAX_PREFILL_TOKENS, + None, Qwen3MemoryOptions::default(), ) } @@ -835,9 +922,10 @@ impl Qwen3Executor { lora_options: Qwen3LoraOptions, offload_options: Qwen3OffloadOptions, max_prefill_tokens: usize, + dflash_draft_path: Option<&str>, memory_options: Qwen3MemoryOptions, ) -> Result { - let memory_options = memory_options.validate()?; + let mut memory_options = memory_options.validate()?; let lora_options = lora_options.validate()?; anyhow::ensure!( !device_ordinals.is_empty(), @@ -860,11 +948,36 @@ impl Qwen3Executor { max_lora_rank: lora_options.max_lora_rank, }, )?; - let mut executor = - Self::single(model, &offload_options, max_prefill_tokens, memory_options)?; + // The DFlash draft model loads after profiling but lives outside the + // paged KV pool, so reserve its footprint up front from the draft + // config: fixed bytes (weights + block scratch) via the margin, and + // pool-scaling per-token bytes folded into the block budget. + let dflash_kv_bytes_per_token = match dflash_draft_path { + Some(path) => { + let reservation = crate::dflash::DFlashMemoryReservation::from_path( + path, + *BATCH_BUCKETS.last().unwrap(), + )?; + memory_options.kv_cache_memory_margin_bytes += reservation.fixed_bytes; + reservation.kv_bytes_per_token + } + None => 0, + }; + let mut executor = Self::single( + model, + &offload_options, + max_prefill_tokens, + dflash_kv_bytes_per_token, + memory_options, + )?; executor.lora_options = lora_options; return Ok(executor); } + anyhow::ensure!( + dflash_draft_path.is_none(), + "speculative decoding requires the single-GPU path (got {} devices)", + device_ordinals.len() + ); let world_size = device_ordinals.len(); let mut models = Vec::with_capacity(world_size); @@ -887,8 +1000,9 @@ impl Qwen3Executor { let mut profiled_models = Vec::with_capacity(world_size); let mut budgets = Vec::with_capacity(world_size); for model in models { + // DFlash is single-GPU only, so the TP path reserves nothing for it. let (model, budget) = - profile_kv_budget_on_worker(model, max_prefill_tokens, memory_options)?; + profile_kv_budget_on_worker(model, max_prefill_tokens, 0, memory_options)?; profiled_models.push(model); budgets.push(budget); } @@ -987,6 +1101,8 @@ impl Qwen3Executor { l1_retention_disabled: false, overlap: None, async_prefill: None, + speculative: None, + dflash_ready_requests: HashSet::new(), }) } @@ -1068,6 +1184,33 @@ impl Qwen3Executor { } } + /// Enable speculative decoding by loading a DFlash draft model into the + /// primary lane. + /// + /// Requires the single-GPU topology (tensor parallel shards KV per rank) and + /// is incompatible with KV offload. Disables the prefix cache: speculative + /// capture needs clean, uncached target hidden states for every prompt + /// token, and a prefix-cache hit skips the forward that would produce them. + pub fn load_dflash_draft_model(&mut self, draft_path: &str) -> Result<()> { + anyhow::ensure!( + self.workers.is_empty(), + "speculative decoding requires the single-GPU path (got {} extra ranks)", + self.workers.len() + ); + anyhow::ensure!( + self.offload.is_none(), + "speculative decoding is not supported together with KV offload" + ); + let meta = self.primary.load_dflash(draft_path.to_string())?; + log::info!( + "Qwen3 DFlash speculative decoding enabled: draft block size {}", + meta.block_size + ); + self.prefix_cache_enabled = false; + self.speculative = Some(meta); + Ok(()) + } + /// Whether KV offload is active on this executor. pub fn offload_enabled(&self) -> bool { self.offload.is_some() @@ -1293,6 +1436,7 @@ impl Qwen3Executor { fn profile_kv_budget_on_worker( model: Qwen3Model, max_prefill_tokens: usize, + dflash_kv_bytes_per_token: usize, memory_options: Qwen3MemoryOptions, ) -> Result<(Qwen3Model, KvBudget)> { let handle = thread::Builder::new() @@ -1309,6 +1453,7 @@ fn profile_kv_budget_on_worker( let budget = model.profiled_kv_budget( max_prefill_tokens, *BATCH_BUCKETS.last().unwrap(), + dflash_kv_bytes_per_token, memory_options, )?; Ok((model, budget)) @@ -1379,7 +1524,15 @@ impl ModelExecutor for Qwen3Executor { } fn max_context_tokens(&self) -> usize { - self.metadata.config.max_position_embeddings + let target = self.metadata.config.max_position_embeddings; + match &self.speculative { + // The draft's fixed-width in-fill block writes `block_size` positions + // past the committed length each step, so a request may use at most + // `draft.max_pos - block_size` tokens before the draft cache would + // overflow. Reject the rest at admission instead of crashing mid-prefill. + Some(meta) => target.min(meta.max_position_embeddings.saturating_sub(meta.block_size)), + None => target, + } } fn max_decode_batch_size(&self) -> usize { @@ -1418,6 +1571,10 @@ impl ModelExecutor for Qwen3Executor { } } self.saved_cursor.remove(&request_id); + if self.speculative.is_some() { + self.dflash_ready_requests.remove(&request_id); + self.primary.drop_dflash_request(request_id)?; + } Ok(()) } @@ -1582,6 +1739,28 @@ impl ModelExecutor for Qwen3Executor { for req_result in &result.requests { self.apply_prefill_result(req_result)?; } + // A request becomes draft-ready once its prompt is fully prefilled with + // captured target context. Partial chunks stay pending; ineligible + // requests drop any stale worker state. + if self.speculative.is_some() { + for req_result in &result.requests { + let captured = result + .dflash_context_captured_requests + .contains(&req_result.request_id); + match dflash_prefill_action(captured, req_result.completed) { + DFlashPrefillAction::MarkReady => { + self.dflash_ready_requests.insert(req_result.request_id); + } + DFlashPrefillAction::KeepPending => { + self.dflash_ready_requests.remove(&req_result.request_id); + } + DFlashPrefillAction::Drop => { + self.dflash_ready_requests.remove(&req_result.request_id); + self.primary.drop_dflash_request(req_result.request_id)?; + } + } + } + } // 5. Offload the blocks this prefill just sealed (post-step-sync). for req_result in &result.requests { self.save_sealed_blocks(req_result.request_id); @@ -1634,6 +1813,15 @@ impl ModelExecutor for Qwen3Executor { .expect("request must exist after decode"); rkv.apply_decode(req_result.token, self.kv_mgr.pool())?; } + // A plain decode advances the sequence outside the speculative path, so + // any captured draft context is now stale — drop it. + if self.speculative.is_some() { + for req_result in &result.requests { + if self.dflash_ready_requests.remove(&req_result.request_id) { + self.primary.drop_dflash_request(req_result.request_id)?; + } + } + } // 5. Offload any block this decode step just sealed (post-step-sync). for req_result in &result.requests { self.save_sealed_blocks(req_result.request_id); @@ -1642,6 +1830,22 @@ impl ModelExecutor for Qwen3Executor { Ok(result) } + fn execute_speculative_draft(&mut self, plan: DraftPlan<'_>) -> Result { + self.execute_speculative_draft_impl(plan) + } + + fn execute_speculative_verify(&mut self, plan: VerifyPlan<'_>) -> Result { + self.execute_speculative_verify_impl(plan) + } + + fn speculative_enabled(&self) -> bool { + self.speculative.is_some() + } + + fn speculative_request_ready(&self, request_id: RequestId) -> bool { + self.dflash_ready_requests.contains(&request_id) + } + fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> Result { // 1. Create RequestKvs for prefill requests (first chunk only), clamp // chunk budgets, schedule KV for this step's tokens @@ -2038,6 +2242,19 @@ impl Drop for Qwen3Executor { } } +/// What the executor learns about the draft model after loading it on the +/// worker: the draft block size (`K` candidates per round) and which target +/// layers feed the draft (the worker captures these; kept for diagnostics). +#[derive(Clone, Debug)] +struct DFlashMeta { + block_size: usize, + /// Draft's max cacheable position; with the `block_size` in-fill headroom + /// this caps the DFlash-effective context to `max_position_embeddings - block_size`. + max_position_embeddings: usize, + #[allow(dead_code)] + target_layer_ids: Vec, +} + struct LocalQwen3Lane { model: Qwen3Model, kv_buffer: KvBuffer, @@ -2046,6 +2263,16 @@ struct LocalQwen3Lane { sample_scratch: openinfer_sample::SampleScratch, /// In-flight prefill from a previous SplitConcurrent step (not yet synced). inflight_prefill: Option, + /// DFlash draft lane (the draft model + per-request draft state). `None` + /// unless speculative decoding is enabled; only the primary rank carries it. + dflash: Option, + /// Fixed, pre-allocated scratch for the DFlash verify forward. Lazily built + /// on the first verify step (its shape depends on the loaded draft model's + /// block size and the target's capture layers). Pointer-stable for the + /// upcoming verify CUDA Graph. + verify_bufs: Option, + /// KV pool block count — the worst-case page-list bound for `verify_bufs`. + total_blocks: usize, } /// Stored state for an async prefill that was launched but not yet synced. @@ -2101,9 +2328,36 @@ impl LocalQwen3Lane { bufs, sample_scratch, inflight_prefill: None, + dflash: None, + verify_bufs: None, + total_blocks, }) } + /// Load the DFlash draft model into this lane (primary rank only). The draft + /// model is built here on the worker thread because it reads the co-located + /// target model's embeddings and head. + fn load_dflash(&mut self, draft_path: &str) -> Result { + let model = DFlashDraftModel::from_safetensors_for_target( + self.model.device_ctx(), + draft_path, + &self.model, + )?; + model.tune_gemm_algos(&self.model)?; + let meta = DFlashMeta { + block_size: model.block_size(), + max_position_embeddings: model.max_position_embeddings(), + target_layer_ids: model.target_layer_ids().to_vec(), + }; + let max_decode_batch_size = *BATCH_BUCKETS.last().unwrap(); + self.dflash = Some(DFlashLaneState::new( + self.model.device_ctx(), + model, + max_decode_batch_size, + )?); + Ok(meta) + } + fn bind(&self) -> Result { bind_model_thread(&self.model)?; tune_decode_gemm_algos(&self.model)?; @@ -2142,7 +2396,12 @@ impl LocalQwen3Lane { false, )?; - Ok(PrefillResult { requests: results }) + // Split-concurrent prefill never runs with DFlash (capture needs the + // synchronous result), so no context is captured here. + Ok(PrefillResult { + requests: results, + dflash_context_captured_requests: Vec::new(), + }) } /// Pick one token per logits column (batched argmax for greedy rows, @@ -2202,7 +2461,8 @@ impl LocalQwen3Lane { kv_views: &[KvView], lora_adapters: &[Option<&str>], echo: bool, - ) -> Result<(HiddenStates, Option)> { + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option, Option)> { self.model.batch_prefill( prompts, kv_views, @@ -2210,9 +2470,73 @@ impl LocalQwen3Lane { self.kv_buffer.buffer(), &self.layout, echo, + capture_layer_ids, ) } + /// DFlash verify forward over each request's `block_size`-token span, using + /// the fixed pre-allocated [`VerifyGraphBuffers`] (no per-step allocation). + /// Numerically equivalent to the `batch_prefill(echo=true)` verify path it + /// replaces; the buffers are lazily built on first use. + fn execute_dflash_verify( + &mut self, + requests: &[VerifyStepItem], + kv_views: &[KvView], + ) -> Result { + let capture_layer_ids = self.dflash_capture_layer_ids().ok_or_else(|| { + anyhow::anyhow!("DFlash verify requested but no draft model is loaded") + })?; + let block_size = self + .dflash + .as_ref() + .expect("DFlash present when capture layers exist") + .model + .block_size(); + + if self.verify_bufs.is_none() { + let max_batch = *BATCH_BUCKETS.last().unwrap(); + self.verify_bufs = Some(VerifyGraphBuffers::new( + &self.model, + max_batch, + block_size, + capture_layer_ids.len(), + self.total_blocks, + )?); + } + + // Take the buffers out of `self` so the forward (borrows `&self.model`, + // `&mut bufs`) and the subsequent sampling (`&mut self.sample_scratch`) + // and context record (`&mut self.dflash`) don't alias a `self` borrow. + let mut bufs = self.verify_bufs.take().expect("verify buffers just set"); + let result = (|| -> Result { + let spans: Vec<&[u32]> = requests.iter().map(VerifyStepItem::as_slice).collect(); + self.model.batch_prefill_into( + &spans, + kv_views, + self.kv_buffer.buffer(), + &self.layout, + &capture_layer_ids, + &mut bufs, + )?; + + let total_tokens: usize = requests.iter().map(|req| req.as_slice().len()).sum(); + let greedy = SamplingParams::default(); + let params: Vec<&SamplingParams> = vec![&greedy; total_tokens]; + let target_tokens = self.select_step_tokens(bufs.all_logits(), ¶ms, 0)?; + let request_results = build_verify_results(requests, &target_tokens)?; + self.record_verify_dflash_context( + requests, + &request_results, + Some(bufs.captured_hidden()), + )?; + Ok(VerifyResult { + requests: request_results, + }) + })(); + self.verify_bufs = Some(bufs); + result + } + fn execute_decode( &mut self, token_ids: &[u32], @@ -2302,6 +2626,16 @@ enum StepCommand { decode_stream: crate::green_ctx::SendStream, sample_seed: u64, }, + /// Speculative verify: one target forward over each request's `K + 1` draft + /// span (with a speculative KV view), capturing target hidden states for the + /// next draft round. Greedy argmax per position drives [`accept_greedy`]. + SpeculativeVerify { + requests: Vec, + kv_views: Vec, + }, + /// Speculative draft: roll the DFlash draft model forward one block per + /// request. Uses the draft's own KV — no target KV views. + SpeculativeDraft { requests: Vec }, } impl StepCommand { @@ -2311,6 +2645,8 @@ impl StepCommand { Self::Decode { .. } => "decode", Self::Unified { .. } => "unified", Self::SplitConcurrent { .. } => "split_concurrent", + Self::SpeculativeVerify { .. } => "speculative_verify", + Self::SpeculativeDraft { .. } => "speculative_draft", } } } @@ -2340,6 +2676,18 @@ enum WorkerCommand { ResolvePrefill { resp: channel::Sender>, }, + /// Load the DFlash draft model into the primary lane (built on the worker + /// thread because it reads the co-located target model). + LoadDflash { + draft_path: String, + resp: channel::Sender>, + }, + /// Drop a request's DFlash draft state (request retired, or it fell back to + /// a plain decode that advanced the sequence outside the speculative path). + DropDflash { + request_id: RequestId, + resp: channel::Sender>, + }, Shutdown, } @@ -2357,6 +2705,8 @@ enum WorkerStepOutcome { /// query this to check if prefill is done without blocking. prefill_event: SendEvent, }, + SpeculativeVerify(VerifyResult), + SpeculativeDraft(DraftResult), } impl WorkerStepOutcome { @@ -2367,6 +2717,8 @@ impl WorkerStepOutcome { Self::Decode(_) => "decode", Self::Unified(_) => "unified", Self::SplitDecodeReady { .. } => "split_decode_ready", + Self::SpeculativeVerify(_) => "speculative_verify", + Self::SpeculativeDraft(_) => "speculative_draft", } } } @@ -2420,6 +2772,14 @@ impl RankWorker { let result = lane.resolve_inflight_prefill(); let _ = resp.send(result); } + WorkerCommand::LoadDflash { draft_path, resp } => { + let result = lane.load_dflash(&draft_path); + let _ = resp.send(result); + } + WorkerCommand::DropDflash { request_id, resp } => { + lane.drop_dflash_request(request_id); + let _ = resp.send(Ok(())); + } WorkerCommand::Shutdown => break, } } @@ -2464,6 +2824,35 @@ impl RankWorker { Ok(resp_rx) } + /// Load the DFlash draft model into this worker's lane and return its + /// metadata. Blocks until the worker finishes loading. + fn load_dflash(&self, draft_path: String) -> Result { + let (resp_tx, resp_rx) = channel::bounded(1); + self.tx + .send(WorkerCommand::LoadDflash { + draft_path, + resp: resp_tx, + }) + .map_err(|_| anyhow::anyhow!("worker channel closed on load_dflash"))?; + resp_rx + .recv() + .map_err(|_| anyhow::anyhow!("worker dropped load_dflash response"))? + } + + /// Drop a request's DFlash state. Blocks until the worker acknowledges. + fn drop_dflash_request(&self, request_id: RequestId) -> Result<()> { + let (resp_tx, resp_rx) = channel::bounded(1); + self.tx + .send(WorkerCommand::DropDflash { + request_id, + resp: resp_tx, + }) + .map_err(|_| anyhow::anyhow!("worker channel closed on drop_dflash"))?; + resp_rx + .recv() + .map_err(|_| anyhow::anyhow!("worker dropped drop_dflash response"))? + } + fn load_lora_adapter( &self, name: String, diff --git a/openinfer-qwen3-4b/src/executor/dflash_lane.rs b/openinfer-qwen3-4b/src/executor/dflash_lane.rs new file mode 100644 index 00000000..3ff01161 --- /dev/null +++ b/openinfer-qwen3-4b/src/executor/dflash_lane.rs @@ -0,0 +1,319 @@ +//! Worker-side DFlash draft lane: the draft model plus per-request draft state. +//! +//! This lives on the worker thread next to the target model because the draft +//! rollout reads the target's embeddings/head and its captured hidden states. +//! The draft/verify boundary stays a pure token span — the hidden states are +//! private to this lane (`pending_context`), never crossing to the scheduler. + +use std::collections::HashMap; + +use anyhow::Result; +use openinfer_core::sampler::SamplingParams; +use openinfer_core::tensor::HiddenStates; + +use super::dflash_prefill::{dflash_prefill_can_capture, should_capture_dflash_prefill_context}; +use super::{LocalQwen3Lane, PrefillStepItem, RequestId}; +use crate::dflash::{DFlashBatchScratch, DFlashDraftModel, DFlashRequestState}; +use crate::speculative::{ + DraftRequestResult, DraftResult, DraftStepItem, VerifyRequestResult, VerifyStepItem, +}; +use openinfer_core::tensor::DeviceContext; + +pub(super) struct DFlashLaneState { + pub(super) model: DFlashDraftModel, + pub(super) requests: HashMap, + /// Lane-level batched draft scratch, allocated once for the whole decode + /// batch so the dense draft ops run once instead of once per request. + scratch: DFlashBatchScratch, + verified_draft_tokens: usize, + accepted_draft_tokens: usize, +} + +impl DFlashLaneState { + pub(super) fn new( + ctx: &DeviceContext, + model: DFlashDraftModel, + max_decode_batch_size: usize, + ) -> Result { + let scratch = model.new_batch_scratch(ctx, max_decode_batch_size)?; + Ok(Self { + model, + requests: HashMap::new(), + scratch, + verified_draft_tokens: 0, + accepted_draft_tokens: 0, + }) + } +} + +impl LocalQwen3Lane { + /// Target layers whose hidden states the draft model consumes (None when + /// DFlash is not loaded). + pub(super) fn dflash_capture_layer_ids(&self) -> Option> { + self.dflash + .as_ref() + .map(|dflash| dflash.model.target_layer_ids().to_vec()) + } + + pub(super) fn should_capture_dflash_prefill_context( + &self, + requests: &[PrefillStepItem], + ) -> bool { + let Some(dflash) = self.dflash.as_ref() else { + return false; + }; + should_capture_dflash_prefill_context(requests, |request_id| { + dflash.requests.contains_key(&request_id) + }) + } + + /// Fold target hidden states captured during prefill into each eligible + /// request's pending context. Returns the requests that now have context. + pub(super) fn record_prefill_dflash_context( + &mut self, + requests: &[PrefillStepItem], + capture_requested: bool, + captured_hidden: Option<&HiddenStates>, + ) -> Result> { + let Some(captured_hidden) = captured_hidden else { + anyhow::ensure!( + !capture_requested, + "DFlash prefill context capture was requested but no hidden states were returned" + ); + return Ok(Vec::new()); + }; + anyhow::ensure!( + capture_requested, + "DFlash prefill hidden states were returned without a capture request" + ); + let Some(dflash) = self.dflash.as_mut() else { + anyhow::bail!("DFlash prefill context record requested without DFlash"); + }; + let expected_tokens: usize = requests.iter().map(|req| req.chunk_tokens).sum(); + anyhow::ensure!( + captured_hidden.seq_len == expected_tokens, + "DFlash prefill captured {} hidden rows for {} scheduled tokens", + captured_hidden.seq_len, + expected_tokens + ); + let ctx = self.model.device_ctx().clone(); + let mut captured_requests = Vec::new(); + let mut token_offset = 0usize; + for req in requests { + let pending_exists = dflash.requests.contains_key(&req.request_id); + if dflash_prefill_can_capture(req, pending_exists) { + // Admission already caps the request at `draft.max_pos - block_size` + // (see `max_context_tokens`), so this `.min` is a defensive floor: + // it keeps the draft KV alloc within the draft's max positions even + // if a caller bypasses admission. + let max_cache_len = + (req.prompt_tokens.len() + req.max_output_tokens + dflash.model.block_size()) + .min(dflash.model.max_position_embeddings()); + let mut state = match dflash.requests.remove(&req.request_id) { + Some(state) => state, + None => dflash.model.new_request_state(&ctx, max_cache_len)?, + }; + let pending_len = state.pending_context_len().unwrap_or(0); + anyhow::ensure!( + pending_len == req.chunk_start, + "DFlash prefill context for {:?} is discontinuous: pending={}, chunk_start={}", + req.request_id, + pending_len, + req.chunk_start + ); + dflash.model.append_pending_context( + &ctx, + &mut state, + captured_hidden, + token_offset, + req.chunk_tokens, + )?; + dflash.requests.insert(req.request_id, state); + captured_requests.push(req.request_id); + } else { + dflash.requests.remove(&req.request_id); + } + token_offset += req.chunk_tokens; + } + Ok(captured_requests) + } + + /// Seed the next draft round from a verify step: append the target hidden + /// states for the *accepted* span positions to each request's pending + /// context, and log the cumulative acceptance rate. + pub(super) fn record_verify_dflash_context( + &mut self, + requests: &[VerifyStepItem], + results: &[VerifyRequestResult], + captured_hidden: Option<&HiddenStates>, + ) -> Result<()> { + let Some(captured_hidden) = captured_hidden else { + anyhow::bail!("DFlash verify context capture requested but no hidden states returned"); + }; + let Some(dflash) = self.dflash.as_mut() else { + anyhow::bail!("DFlash verify context record requested without DFlash"); + }; + anyhow::ensure!( + requests.len() == results.len(), + "DFlash verify result count {} does not match request count {}", + results.len(), + requests.len() + ); + let expected_tokens: usize = requests.iter().map(|req| req.token_ids.len()).sum(); + anyhow::ensure!( + captured_hidden.seq_len == expected_tokens, + "DFlash verify captured {} hidden rows for {} scheduled tokens", + captured_hidden.seq_len, + expected_tokens + ); + let ctx = self.model.device_ctx().clone(); + let mut token_offset = 0usize; + for (req, result) in requests.iter().zip(results) { + anyhow::ensure!( + req.request_id == result.request_id, + "DFlash verify result {:?} does not match request {:?}", + result.request_id, + req.request_id + ); + let mut state = dflash.requests.remove(&req.request_id).ok_or_else(|| { + anyhow::anyhow!("missing DFlash state after verify for {:?}", req.request_id) + })?; + // Only the accepted prefix's target hidden states are valid context + // for the next draft; rejected drafts had the wrong continuation. + dflash.model.append_pending_context( + &ctx, + &mut state, + captured_hidden, + token_offset, + result.accepted_tokens.len(), + )?; + dflash.requests.insert(req.request_id, state); + dflash.verified_draft_tokens += req.token_ids.len().saturating_sub(1); + dflash.accepted_draft_tokens += result.matched_draft_tokens; + let rate = if dflash.verified_draft_tokens == 0 { + 0.0 + } else { + dflash.accepted_draft_tokens as f64 / dflash.verified_draft_tokens as f64 + }; + log::debug!( + "Qwen3 DFlash request={} accepted_draft={} committed_tokens={} cumulative_accept_rate={:.3}", + req.request_id.get(), + result.matched_draft_tokens, + result.accepted_tokens.len(), + rate, + ); + token_offset += req.token_ids.len(); + } + Ok(()) + } + + /// Roll out one draft span per request: draft forward + greedy argmax over + /// the block. Returns the verify span `[current_token, draft_1, …]`. + pub(super) fn execute_dflash_draft( + &mut self, + requests: &[DraftStepItem], + ) -> Result { + anyhow::ensure!( + !requests.is_empty(), + "DFlash draft requested without active requests" + ); + for req in requests { + anyhow::ensure!( + req.params.is_greedy(), + "DFlash draft currently supports greedy sampling only" + ); + } + + // Take the lane out of `self` so the draft forward (which borrows + // `dflash.model`/`dflash.scratch`) and the argmax (which borrows + // `self.sample_scratch`) don't collide on a `self` borrow. + let Some(mut dflash) = self.dflash.take() else { + anyhow::bail!("DFlash draft requested but DFlash is not loaded"); + }; + let result = (|| -> Result> { + // Pull every active request's state out of the map so the batched + // forward can hold `&mut` to all of them at once. Re-inserted below. + let mut taken: Vec<(RequestId, DFlashRequestState)> = + Vec::with_capacity(requests.len()); + for req in requests { + let state = dflash.requests.remove(&req.request_id).ok_or_else(|| { + anyhow::anyhow!("missing DFlash state for {:?}", req.request_id) + })?; + taken.push((req.request_id, state)); + } + + let block_size = dflash.model.block_size(); + let current_tokens: Vec = requests.iter().map(|req| req.current_token).collect(); + let DFlashLaneState { + model, + scratch, + requests: state_map, + .. + } = &mut dflash; + let mut state_refs: Vec<&mut DFlashRequestState> = + taken.iter_mut().map(|(_, state)| state).collect(); + + let sampled = { + let draft_logits = model.draft_logits_batched( + &self.model, + &mut state_refs, + ¤t_tokens, + scratch, + )?; + let draft_len = draft_logits.seq_len; + anyhow::ensure!( + draft_len == requests.len() * block_size, + "DFlash batched draft produced {} logits rows for {} requests x block {}", + draft_len, + requests.len(), + block_size + ); + let greedy = SamplingParams::default(); + let params: Vec<&SamplingParams> = vec![&greedy; draft_len]; + self.select_step_tokens(draft_logits, ¶ms, 0)? + }; + + // Re-insert every request's state before splitting the result. + for (request_id, state) in taken { + state_map.insert(request_id, state); + } + + anyhow::ensure!( + sampled.len() == requests.len() * block_size, + "DFlash batched draft sampled {} tokens for {} requests x block {}", + sampled.len(), + requests.len(), + block_size + ); + + // Split the batched samples per request: request `i` owns rows + // `[i * block_size, (i + 1) * block_size)`. Verify span = [current + // dangling token, draft_1, …, draft_{K}]. + let mut outputs = Vec::with_capacity(requests.len()); + for (i, req) in requests.iter().enumerate() { + let block = &sampled[i * block_size..(i + 1) * block_size]; + anyhow::ensure!( + block.len() >= 2, + "DFlash draft block {} has fewer than 2 tokens", + i + ); + let mut token_ids = Vec::with_capacity(block.len()); + token_ids.push(req.current_token); + token_ids.extend(block[1..].iter().copied()); + outputs.push(DraftRequestResult { + request_id: req.request_id, + token_ids, + }); + } + Ok(outputs) + })(); + self.dflash = Some(dflash); + Ok(DraftResult { requests: result? }) + } + + pub(super) fn drop_dflash_request(&mut self, request_id: RequestId) { + if let Some(dflash) = self.dflash.as_mut() { + dflash.requests.remove(&request_id); + } + } +} diff --git a/openinfer-qwen3-4b/src/executor/dflash_prefill.rs b/openinfer-qwen3-4b/src/executor/dflash_prefill.rs new file mode 100644 index 00000000..fb3099d4 --- /dev/null +++ b/openinfer-qwen3-4b/src/executor/dflash_prefill.rs @@ -0,0 +1,82 @@ +//! DFlash prefill-capture eligibility predicates. +//! +//! A request can seed the DFlash draft only if its prefill produces clean target +//! hidden states: greedy, no LoRA, no prefix-cache hit, no echo, no logprobs. + +use super::{PrefillStepItem, RequestId}; + +/// Whether a prefill request is eligible to capture DFlash target context. +pub(super) fn dflash_prefill_supported(req: &PrefillStepItem) -> bool { + req.lora_adapter.is_none() + && req.cached_tokens == 0 + && req.logprobs == 0 + && !req.echo + && req.params.is_greedy() +} + +/// Eligible AND continuous: either the first chunk, or a later chunk whose +/// earlier chunks already captured context (no gaps in the pending buffer). +pub(super) fn dflash_prefill_can_capture( + req: &PrefillStepItem, + pending_state_exists: bool, +) -> bool { + dflash_prefill_supported(req) && (req.chunk_start == 0 || pending_state_exists) +} + +/// Capture hidden states during this prefill step iff any request is eligible. +pub(super) fn should_capture_dflash_prefill_context( + requests: &[PrefillStepItem], + pending_state_exists: impl Fn(RequestId) -> bool, +) -> bool { + !requests.is_empty() + && requests + .iter() + .any(|req| dflash_prefill_can_capture(req, pending_state_exists(req.request_id))) +} + +/// What to do with a request's DFlash state after a prefill step. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(super) enum DFlashPrefillAction { + /// Context captured and prefill finished → ready to draft. + MarkReady, + /// Context captured but more chunks remain → keep the pending state. + KeepPending, + /// Ineligible → drop any stale state. + Drop, +} + +pub(super) fn dflash_prefill_action( + captured_context: bool, + completed: bool, +) -> DFlashPrefillAction { + match (captured_context, completed) { + (true, true) => DFlashPrefillAction::MarkReady, + (true, false) => DFlashPrefillAction::KeepPending, + (false, _) => DFlashPrefillAction::Drop, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prefill_action_table() { + assert_eq!( + dflash_prefill_action(true, true), + DFlashPrefillAction::MarkReady + ); + assert_eq!( + dflash_prefill_action(true, false), + DFlashPrefillAction::KeepPending + ); + assert_eq!( + dflash_prefill_action(false, true), + DFlashPrefillAction::Drop + ); + assert_eq!( + dflash_prefill_action(false, false), + DFlashPrefillAction::Drop + ); + } +} diff --git a/openinfer-qwen3-4b/src/executor/spec.rs b/openinfer-qwen3-4b/src/executor/spec.rs new file mode 100644 index 00000000..0ec42881 --- /dev/null +++ b/openinfer-qwen3-4b/src/executor/spec.rs @@ -0,0 +1,188 @@ +//! Executor-side speculative-decode orchestration: the optimistic KV +//! transaction around a verify forward (schedule → forward → accept/commit or +//! roll back) and the thin draft dispatch. +//! +//! The forward itself runs on the worker lane (see [`super::dflash_lane`]); this +//! module owns only the KV bookkeeping the executor thread is responsible for. + +use anyhow::Result; + +use crate::speculative::{DraftPlan, DraftResult, VerifyPlan, VerifyResult}; + +use super::{Qwen3Executor, RequestId, StepCommand, WorkerStepOutcome}; + +impl Qwen3Executor { + pub(super) fn execute_speculative_verify_impl( + &mut self, + plan: VerifyPlan<'_>, + ) -> Result { + anyhow::ensure!( + self.speculative.is_some(), + "speculative verification requested but no draft model is loaded" + ); + for req in plan.requests { + anyhow::ensure!( + !req.as_slice().is_empty(), + "speculative verify request {:?} has an empty verify span", + req.request_id + ); + anyhow::ensure!( + req.params.is_greedy(), + "speculative verification currently supports greedy sampling only" + ); + anyhow::ensure!( + self.dflash_ready_requests.contains(&req.request_id), + "speculative verification requested before DFlash state is ready for {:?}", + req.request_id + ); + anyhow::ensure!( + self.request_kvs.contains_key(&req.request_id), + "missing RequestKv for {:?}", + req.request_id + ); + } + + // Reserve KV slots for each request's full K+1 verify span. Roll back + // every prior reservation if any single one fails — all-or-nothing. + let mut scheduled = Vec::with_capacity(plan.requests.len()); + for req in plan.requests { + let span_len = req.as_slice().len(); + let rkv = self + .request_kvs + .get_mut(&req.request_id) + .expect("RequestKv was validated before speculative scheduling"); + if let Err(e) = rkv.schedule_speculative(span_len, self.kv_mgr.pool()) { + self.revert_speculative_schedules(&scheduled); + return Err(anyhow::anyhow!( + "schedule_speculative failed for {:?}: {e}", + req.request_id + )); + } + scheduled.push(req.request_id); + } + + let kv_views = plan + .requests + .iter() + .map(|req| self.request_kvs[&req.request_id].speculative_view(req.as_slice().len())) + .collect(); + + let step = StepCommand::SpeculativeVerify { + requests: plan.requests.to_vec(), + kv_views, + }; + let outcome = match self.run_step(&step) { + Ok(outcome) => outcome, + Err(e) => { + self.revert_speculative_schedules(&scheduled); + return Err(e); + } + }; + let result = match outcome { + WorkerStepOutcome::SpeculativeVerify(result) => result, + other => { + self.revert_speculative_schedules(&scheduled); + return Err(anyhow::anyhow!( + "speculative verify returned unexpected: {}", + other.kind() + )); + } + }; + if result.requests.len() != plan.requests.len() { + self.revert_speculative_schedules(&scheduled); + return Err(anyhow::anyhow!( + "speculative verify returned {} request results for {} requests", + result.requests.len(), + plan.requests.len() + )); + } + for (req, req_result) in plan.requests.iter().zip(&result.requests) { + if req.request_id != req_result.request_id { + self.revert_speculative_schedules(&scheduled); + return Err(anyhow::anyhow!( + "speculative verify returned request {:?} for {:?}", + req_result.request_id, + req.request_id + )); + } + } + + // Commit the accepted prefix of each request's KV and free the rest. + // On a mid-loop failure, only the not-yet-applied requests roll back + // (applied ones already committed and cannot be reverted). + let mut applied = Vec::with_capacity(result.requests.len()); + for req_result in &result.requests { + let rkv = self + .request_kvs + .get_mut(&req_result.request_id) + .expect("request must exist after speculative verify"); + if let Err(e) = rkv.apply_speculative(&req_result.accepted_tokens, self.kv_mgr.pool()) { + let unapplied = scheduled + .iter() + .copied() + .filter(|request_id| !applied.contains(request_id)) + .collect::>(); + self.revert_speculative_schedules(&unapplied); + return Err(anyhow::anyhow!( + "apply_speculative failed for {:?}: {e}", + req_result.request_id + )); + } + applied.push(req_result.request_id); + } + for req_result in &result.requests { + self.save_sealed_blocks(req_result.request_id); + } + + Ok(result) + } + + pub(super) fn execute_speculative_draft_impl( + &mut self, + plan: DraftPlan<'_>, + ) -> Result { + anyhow::ensure!( + self.speculative.is_some(), + "speculative draft requested but no draft model is loaded" + ); + for req in plan.requests { + anyhow::ensure!( + req.params.is_greedy(), + "speculative draft currently supports greedy sampling only" + ); + anyhow::ensure!( + self.dflash_ready_requests.contains(&req.request_id), + "speculative draft requested before DFlash state is ready for {:?}", + req.request_id + ); + } + let step = StepCommand::SpeculativeDraft { + requests: plan.requests.to_vec(), + }; + match self.run_step(&step)? { + WorkerStepOutcome::SpeculativeDraft(result) => Ok(result), + other => Err(anyhow::anyhow!( + "speculative draft returned unexpected: {}", + other.kind() + )), + } + } + + /// Roll back speculative KV reservations. Each request reverts its own + /// reservation independently (the LIFO block discipline is intra-sequence, + /// via RAII); the reverse order here just mirrors schedule order and is + /// cosmetic. + fn revert_speculative_schedules(&mut self, request_ids: &[RequestId]) { + for request_id in request_ids.iter().rev().copied() { + let Some(rkv) = self.request_kvs.get_mut(&request_id) else { + log::warn!( + "missing RequestKv while reverting speculative schedule for {request_id:?}" + ); + continue; + }; + if let Err(error) = rkv.revert_schedule() { + log::warn!("failed to revert speculative schedule for {request_id:?}: {error}"); + } + } + } +} diff --git a/openinfer-qwen3-4b/src/lib.rs b/openinfer-qwen3-4b/src/lib.rs index e0d82010..a6bb433f 100644 --- a/openinfer-qwen3-4b/src/lib.rs +++ b/openinfer-qwen3-4b/src/lib.rs @@ -5,16 +5,19 @@ mod batch_decode_buffers; mod batch_decode_dag; pub mod batch_decode_trace; mod config; +mod dflash; mod executor; pub(crate) mod green_ctx; pub mod kernel_bench; mod lora; mod prefill; mod scheduler; +mod speculative; mod unified_forward; +mod verify_graph; mod weights; -use std::path::Path; +use std::path::{Path, PathBuf}; use anyhow::Result; use log::{info, warn}; @@ -153,7 +156,7 @@ pub fn probe_model(model_path: &Path) -> Result> { /// Qwen3 startup policy — the TP→device mapping and the LoRA↔CUDA-Graph /// exclusion — and dispatches to the right low-level entry. That policy lives /// with the model instead of leaking into the server. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] pub struct Qwen3LaunchOptions { /// CUDA device for single-GPU loads (ignored when `tp_size > 1`). pub device_ordinal: usize, @@ -170,6 +173,9 @@ pub struct Qwen3LaunchOptions { /// How prefill and decode share the GPU (`--decode-overlap`). pub decode_overlap: DecodeOverlap, pub batch_invariant: bool, + /// `Some` enables DFlash speculative decoding with this drafter model. + /// Single-GPU only and mutually exclusive with LoRA and KV offload. + pub dflash_draft_model_path: Option, } /// Start the Qwen3 engine from server-facing [`Qwen3LaunchOptions`]. @@ -204,6 +210,10 @@ pub fn launch(model_path: &Path, options: Qwen3LaunchOptions) -> Result { info!( @@ -231,6 +241,7 @@ pub fn launch(model_path: &Path, options: Qwen3LaunchOptions) -> Result Result Result, ) -> Result { let EngineLoadOptions { enable_cuda_graph, @@ -281,6 +295,12 @@ pub fn start_engine_with_offload( .ok_or_else(|| anyhow::anyhow!("model path must be valid UTF-8"))?; ensure_batch_invariant_supported(decode_overlap, batch_invariant)?; apply_batch_invariant_policy(batch_invariant); + let dflash_draft_model_path = dflash_draft_model_path + .map(|path| { + path.to_str() + .ok_or_else(|| anyhow::anyhow!("DFlash draft model path must be valid UTF-8")) + }) + .transpose()?; scheduler::start_qwen3( model_path, enable_cuda_graph, @@ -291,6 +311,7 @@ pub fn start_engine_with_offload( max_prefill_tokens, memory_options, decode_overlap, + dflash_draft_model_path, ) } diff --git a/openinfer-qwen3-4b/src/prefill.rs b/openinfer-qwen3-4b/src/prefill.rs index e9d6163f..0d348247 100644 --- a/openinfer-qwen3-4b/src/prefill.rs +++ b/openinfer-qwen3-4b/src/prefill.rs @@ -73,6 +73,23 @@ impl PrefillBuffers { attn_output: HiddenStates::zeros(ctx, q_dim, seq_len)?, }) } + + /// Point every scratch buffer's logical row count at `rows` without + /// reallocating. Used by the fixed-buffer verify path (see + /// [`crate::verify_graph`]); the buffers must have been allocated for at + /// least `rows`. + pub(super) fn set_rows(&mut self, rows: usize) { + self.hidden_out.seq_len = rows; + self.normed.seq_len = rows; + self.q_batch.seq_len = rows; + self.k_batch.seq_len = rows; + self.v_batch.seq_len = rows; + self.o_buf.seq_len = rows; + self.gate_out.seq_len = rows; + self.up_out.seq_len = rows; + self.act_out.seq_len = rows; + self.attn_output.seq_len = rows; + } } impl Qwen3Model { @@ -100,8 +117,25 @@ impl Qwen3Model { Ok(out) } + /// Embed a device-resident token buffer into a pre-allocated output, with + /// no host round-trip or allocation — used by the DFlash draft rollout's + /// graph-stable scratch. + pub(super) fn get_embeddings_batch_into( + &self, + token_ids_gpu: &cudarc::driver::CudaSlice, + out: &mut HiddenStates, + ) -> Result<()> { + anyhow::ensure!( + out.hidden_dim == self.config.hidden_size, + "embedding output hidden_dim {} does not match model hidden_size {}", + out.hidden_dim, + self.config.hidden_size + ); + ops::embedding_batch(&self.ctx, &self.embed_tokens, token_ids_gpu, out) + } + #[allow(clippy::too_many_arguments)] - fn forward_layer_batch_paged( + pub(crate) fn forward_layer_batch_paged( &self, layer_idx: usize, layer: &TransformerBlock, @@ -112,10 +146,24 @@ impl Qwen3Model { lora_groups: &[DeviceLoraTokenGroup<'_>], bufs: &mut PrefillBuffers, ) -> Result<()> { - let num_heads = self.local_num_attention_heads(); - let num_kv_heads = self.local_num_key_value_heads(); - let head_dim = self.config.head_dim; + self.forward_layer_pre_attn(layer_idx, layer, hidden, lora_groups, bufs)?; + self.forward_layer_attn(layer_idx, layer, kv_buffer, layout, plan, bufs)?; + self.forward_layer_post_attn(layer_idx, layer, hidden, lora_groups, bufs)?; + Ok(()) + } + /// Pre-attention dense ops: input RMSNorm + fused QKV projections (+ LoRA). + /// Reads `hidden`; writes `bufs.normed` / `bufs.q_batch` / `bufs.k_batch` / + /// `bufs.v_batch`. Graph-safe — shapes depend only on the fixed row count, not + /// on KV length — so the verify piecewise CUDA Graph captures it. + pub(crate) fn forward_layer_pre_attn( + &self, + layer_idx: usize, + layer: &TransformerBlock, + hidden: &HiddenStates, + lora_groups: &[DeviceLoraTokenGroup<'_>], + bufs: &mut PrefillBuffers, + ) -> Result<()> { // 1. RMSNorm → bufs.normed ops::rms_norm_batch_into( &self.ctx, @@ -176,6 +224,26 @@ impl Qwen3Model { &mut bufs.v_batch, 0, )?; + Ok(()) + } + + /// The attention op: q/k norm + RoPE + paged KV append + paged attention. + /// This is the ONLY part of the layer whose KV iteration count tracks the + /// (growing) context length, so the verify piecewise graph keeps it EAGER — + /// capturing it would freeze the KV length at capture time (`num_iterations` + /// in FlashInfer's prefill kernel is fixed when the graph is recorded). + pub(crate) fn forward_layer_attn( + &self, + layer_idx: usize, + layer: &TransformerBlock, + kv_buffer: &cudarc::driver::CudaSlice, + layout: &openinfer_core::kv_pool::KvLayout, + plan: &PrefillPagedPlan, + bufs: &mut PrefillBuffers, + ) -> Result<()> { + let num_heads = self.local_num_attention_heads(); + let num_kv_heads = self.local_num_key_value_heads(); + let head_dim = self.config.head_dim; // 3. Paged prefill: norm+RoPE → append K/V to paged → batch attention ops::prefill_attention_paged_into( @@ -197,7 +265,21 @@ impl Qwen3Model { head_dim, self.config.rms_norm_eps, )?; + Ok(()) + } + /// Post-attention dense ops: O projection + residual + MLP + final residual + /// add. Reads `bufs.attn_output` and `hidden`; writes the layer output back + /// into `hidden` via the ping-pong buffer swap. Graph-safe (no KV-length + /// dependence) — captured into the verify piecewise CUDA Graph. + pub(crate) fn forward_layer_post_attn( + &self, + layer_idx: usize, + layer: &TransformerBlock, + hidden: &mut HiddenStates, + lora_groups: &[DeviceLoraTokenGroup<'_>], + bufs: &mut PrefillBuffers, + ) -> Result<()> { // 4. O projection → bufs.o_buf (as o_batch) ops::gemm_into( &self.ctx, @@ -338,6 +420,13 @@ impl Qwen3Model { /// /// If `echo` is true, also returns all-position logits as a /// `HiddenStates [vocab_size, total_tokens]` for prompt logprobs. + /// Batch prefill forward. + /// + /// `capture_layer_ids`, when set, copies the residual-stream hidden states + /// after the listed (strictly increasing) transformer layers into an extra + /// `[hidden_size * layers, total_tokens]` buffer returned as the third tuple + /// element. This feeds the DFlash draft model its target context; `None` + /// behaves identically to a plain prefill and returns `None` there. pub(crate) fn batch_prefill( &self, prompts: &[&[u32]], @@ -346,7 +435,8 @@ impl Qwen3Model { kv_buffer: &CudaSlice, layout: &KvLayout, echo: bool, - ) -> Result<(HiddenStates, Option)> { + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option, Option)> { let batch_size = prompts.len(); assert_eq!(batch_size, kv_views.len()); assert_eq!(batch_size, lora_adapters.len()); @@ -385,8 +475,14 @@ impl Qwen3Model { )?; // Forward through all layers - let hidden = - self.process_all_layers_batch_multi(hidden, layout, kv_buffer, &plan, &lora_groups)?; + let (hidden, captured_hidden) = self.process_all_layers_batch_multi( + hidden, + layout, + kv_buffer, + &plan, + &lora_groups, + capture_layer_ids, + )?; // All-position logits for echo (before we extract last-token logits) let all_logits = if echo { @@ -412,7 +508,7 @@ impl Qwen3Model { defer_drop(plan); } - Ok((logits, all_logits)) + Ok((logits, all_logits, captured_hidden)) } fn process_all_layers_batch_multi( @@ -422,12 +518,35 @@ impl Qwen3Model { kv_buffer: &cudarc::driver::CudaSlice, plan: &PrefillPagedPlan, lora_groups: &[DeviceLoraTokenGroup<'_>], - ) -> Result { + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option)> { let total_tokens = hidden.seq_len; let inter_dim = self.local_intermediate_size(); let q_dim = self.local_q_dim(); let kv_dim = self.local_kv_dim(); + let capture_layer_ids = capture_layer_ids.unwrap_or(&[]); + anyhow::ensure!( + capture_layer_ids.windows(2).all(|pair| pair[0] < pair[1]), + "target hidden capture layer ids must be strictly increasing" + ); + anyhow::ensure!( + capture_layer_ids + .iter() + .all(|&layer| layer < self.layers.len()), + "target hidden capture layer id out of range" + ); + let mut captured_hidden = if capture_layer_ids.is_empty() { + None + } else { + Some(HiddenStates::zeros( + &self.ctx, + self.config.hidden_size * capture_layer_ids.len(), + total_tokens, + )?) + }; + let mut next_capture = 0usize; + let mut bufs = PrefillBuffers::new( &self.ctx, self.config.hidden_size, @@ -448,6 +567,18 @@ impl Qwen3Model { lora_groups, &mut bufs, )?; + if capture_layer_ids.get(next_capture) == Some(&layer_idx) { + let out = captured_hidden + .as_mut() + .expect("capture buffer exists when ids are non-empty"); + ops::copy_hidden_rows_into( + &self.ctx, + &hidden, + out, + next_capture * self.config.hidden_size, + )?; + next_capture += 1; + } } // Defer drop of PrefillBuffers in SM-partition mode. @@ -455,6 +586,6 @@ impl Qwen3Model { defer_drop(bufs); } - Ok(hidden) + Ok((hidden, captured_hidden)) } } diff --git a/openinfer-qwen3-4b/src/scheduler.rs b/openinfer-qwen3-4b/src/scheduler.rs index 6476eab8..71c91877 100644 --- a/openinfer-qwen3-4b/src/scheduler.rs +++ b/openinfer-qwen3-4b/src/scheduler.rs @@ -28,7 +28,9 @@ use openinfer_core::engine::{ use openinfer_core::sampler::SamplingParams; use self::effects::apply_effects; -use self::plan::{ExecutionArtifacts, ExecutionPlan, build_next_plan, execute_plan}; +use self::plan::{ + ExecutionArtifacts, ExecutionPlan, build_next_plan, execute_plan, should_speculative_decode, +}; use self::resolve::resolve_step; // ── Internal types ────────────────────────────────────────────────────── @@ -134,6 +136,7 @@ fn take_prefill_chunks( // ── Entry point ───────────────────────────────────────────────────────── +#[allow(clippy::too_many_arguments)] pub(crate) fn start_qwen3( model_path: &str, enable_cuda_graph: bool, @@ -144,6 +147,7 @@ pub(crate) fn start_qwen3( max_prefill_tokens: usize, memory_options: Qwen3MemoryOptions, decode_overlap: crate::DecodeOverlap, + dflash_draft_model_path: Option<&str>, ) -> Result { let mut executor = Qwen3Executor::from_runtime_with_lora_options( model_path, @@ -152,10 +156,18 @@ pub(crate) fn start_qwen3( Qwen3LoraOptions::default(), offload_options, max_prefill_tokens, + dflash_draft_model_path, memory_options, )?; executor.set_no_prefix_cache(no_prefix_cache); executor.enable_decode_overlap(decode_overlap)?; + // Speculative decoding loads its draft model after the target is up (the + // draft is built against the target's embeddings/head) and forces the + // prefix cache off, so it must follow set_no_prefix_cache. Its GPU footprint + // was already reserved during profiling from the draft path passed above. + if let Some(draft_path) = dflash_draft_model_path { + executor.load_dflash_draft_model(draft_path)?; + } Ok(start_with_executor(executor, seed, max_prefill_tokens)) } @@ -179,6 +191,7 @@ pub(crate) fn start_qwen3_with_lora_control( lora_options, offload_options, max_prefill_tokens, + None, memory_options, )?; executor.set_no_prefix_cache(no_prefix_cache); @@ -482,7 +495,7 @@ fn scheduler_loop( take_prefill_chunks(&mut prefilling, max_prefill_tokens) }; - let Some(plan) = build_next_plan(!active.is_empty(), pending) else { + let Some(plan) = runtime_plan(&executor, &active, pending) else { continue; }; @@ -684,7 +697,7 @@ fn scheduler_loop_with_lora_control( continue; } - let Some(plan) = build_next_plan(!active.is_empty(), pending) else { + let Some(plan) = runtime_plan(&executor, &active, pending) else { continue; }; let failure_targets = failure_targets_for(&active, &plan); @@ -1055,6 +1068,26 @@ fn send_unknown_lora_rejection(req: &PendingRequest) { }); } +/// Choose the step plan, preferring a speculative-decode step when the whole +/// active batch is draft-ready. Prefill of new arrivals still takes priority — +/// a speculative step only runs when there is nothing to prefill, so the two +/// never mix in one step. +fn runtime_plan( + executor: &impl ModelExecutor, + active: &[ActiveRequestState], + pending: Vec, +) -> Option { + if should_speculative_decode(executor, active) { + if pending.is_empty() { + Some(ExecutionPlan::SpeculativeDecode) + } else { + Some(ExecutionPlan::Prefill { pending }) + } + } else { + build_next_plan(!active.is_empty(), pending) + } +} + fn failure_targets_for( active: &[ActiveRequestState], plan: &self::plan::ExecutionPlan, @@ -1067,6 +1100,9 @@ fn failure_targets_for( self::plan::ExecutionPlan::Decode => { targets.extend(active.iter().map(active_failure_target)); } + self::plan::ExecutionPlan::SpeculativeDecode => { + targets.extend(active.iter().map(active_failure_target)); + } self::plan::ExecutionPlan::Unified { pending } => { targets.extend(active.iter().map(active_failure_target)); targets.extend(pending.iter().map(pending_failure_target)); diff --git a/openinfer-qwen3-4b/src/scheduler/effects.rs b/openinfer-qwen3-4b/src/scheduler/effects.rs index 5a8eac92..11fd176d 100644 --- a/openinfer-qwen3-4b/src/scheduler/effects.rs +++ b/openinfer-qwen3-4b/src/scheduler/effects.rs @@ -69,6 +69,20 @@ pub(super) enum DecodeEffect { logprob: Option, completion_tokens: usize, }, + /// Commit several accepted speculative tokens and keep the request running. + EmitManyAndContinue { + request_id: RequestId, + tokens: Vec, + completion_tokens: usize, + }, + /// Commit several accepted speculative tokens, then finish — a stop token or + /// the max-output budget was hit partway through the accepted span. + EmitManyAndFinish { + request_id: RequestId, + tokens: Vec, + finish_reason: FinishReason, + completion_tokens: usize, + }, } pub(super) struct StepEffects { @@ -192,6 +206,81 @@ pub(super) fn apply_effects( req.generated_count = completion_tokens; } } + DecodeEffect::EmitManyAndContinue { + request_id, + tokens, + completion_tokens, + } => { + let Some(index) = active.iter().position(|req| req.request_id == request_id) else { + continue; + }; + let req = &mut active[index]; + let mut sent = true; + for &token in &tokens { + if req + .token_tx + .send(TokenEvent::Token { + id: token, + logprob: None, + }) + .is_err() + { + sent = false; + break; + } + } + if sent { + req.last_token = *tokens + .last() + .expect("EmitManyAndContinue must carry at least one token"); + req.generated_count = completion_tokens; + } else { + debug!( + "request dropped: client disconnected: request_id={:?} tokens_generated={}", + request_id, completion_tokens + ); + let _ = executor.drop_request(request_id); + to_retire.push(index); + } + } + DecodeEffect::EmitManyAndFinish { + request_id, + tokens, + finish_reason, + completion_tokens, + } => { + let Some(index) = active.iter().position(|req| req.request_id == request_id) else { + continue; + }; + let req = &active[index]; + debug!( + "request finished: request_id={:?} prompt_tokens={} completion_tokens={} finish_reason={:?}", + request_id, req.prompt_len, completion_tokens, finish_reason + ); + let mut sent = true; + for &token in &tokens { + if req + .token_tx + .send(TokenEvent::Token { + id: token, + logprob: None, + }) + .is_err() + { + sent = false; + break; + } + } + if sent { + let _ = req.token_tx.send(TokenEvent::Finished { + finish_reason, + prompt_tokens: req.prompt_len, + completion_tokens, + }); + } + let _ = executor.drop_request(request_id); + to_retire.push(index); + } } } to_retire.sort_unstable(); diff --git a/openinfer-qwen3-4b/src/scheduler/plan.rs b/openinfer-qwen3-4b/src/scheduler/plan.rs index 7163e982..5f985919 100644 --- a/openinfer-qwen3-4b/src/scheduler/plan.rs +++ b/openinfer-qwen3-4b/src/scheduler/plan.rs @@ -5,13 +5,22 @@ use crate::executor::{ DecodePlan, DecodeResult, DecodeStepItem, ModelExecutor, PrefillPlan, PrefillResult, PrefillStepItem, UnifiedPlan, UnifiedResult, }; +use crate::speculative::{ + DraftPlan, DraftRequestResult, DraftStepItem, VerifyPlan, VerifyResult, VerifyStepItem, +}; use super::{ActiveRequestState, PendingRequest}; pub(super) enum ExecutionPlan { - Prefill { pending: Vec }, + Prefill { + pending: Vec, + }, Decode, - Unified { pending: Vec }, + /// Draft + verify the whole active batch (all requests are draft-ready). + SpeculativeDecode, + Unified { + pending: Vec, + }, } pub(super) enum ExecutionArtifacts { @@ -26,6 +35,9 @@ pub(super) enum ExecutionArtifacts { Decode { result: DecodeResult, }, + SpeculativeDecode { + verify: VerifyResult, + }, Unified { pending: Vec, result: UnifiedResult, @@ -86,6 +98,22 @@ pub(super) fn execute_plan( sort_decode_results(&mut result.requests); Ok(ExecutionArtifacts::Decode { result }) } + ExecutionPlan::SpeculativeDecode => { + // Two executor calls per step: draft proposes K tokens per request, + // then a single target forward verifies the K+1 span. Both index by + // request_id; sorting keeps draft and verify results aligned. + let draft_requests = build_speculative_draft_items(active); + let mut draft = executor.execute_speculative_draft(DraftPlan { + requests: &draft_requests, + })?; + draft.requests.sort_by_key(|result| result.request_id); + let verify_requests = build_speculative_verify_items(active, &draft.requests); + let mut verify = executor.execute_speculative_verify(VerifyPlan { + requests: &verify_requests, + })?; + verify.requests.sort_by_key(|result| result.request_id); + Ok(ExecutionArtifacts::SpeculativeDecode { verify }) + } ExecutionPlan::Unified { pending } => { let scheduled_at_unix_s = openinfer_core::engine::unix_now_s(); let pending_indices: Vec = (0..pending.len()).collect(); @@ -108,6 +136,52 @@ pub(super) fn execute_plan( } } +/// All-or-nothing: speculate the whole active batch only when every request is +/// draft-ready and greedy (no LoRA, no logprobs). A single non-ready request +/// falls the batch back to plain decode rather than running a mixed step. +pub(super) fn should_speculative_decode( + executor: &impl ModelExecutor, + active: &[ActiveRequestState], +) -> bool { + executor.speculative_enabled() + && !active.is_empty() + && active.iter().all(|req| { + executor.speculative_request_ready(req.request_id) + && req.lora_adapter.is_none() + && req.logprobs == 0 + && req.params.is_greedy() + }) +} + +fn build_speculative_draft_items(active: &[ActiveRequestState]) -> Vec { + active + .iter() + .map(|r| DraftStepItem::new(r.request_id, r.last_token, r.params)) + .collect() +} + +fn build_speculative_verify_items( + active: &[ActiveRequestState], + draft_results: &[DraftRequestResult], +) -> Vec { + draft_results + .iter() + .map(|draft| { + let active = active + .iter() + .find(|req| req.request_id == draft.request_id) + .expect("draft request_id must exist in active set"); + // Clamp the verify span to the request's remaining output budget so + // a long accepted run can't overshoot max_tokens. + let remaining = active.max_tokens.saturating_sub(active.generated_count); + assert!(remaining > 0, "active request must have output budget"); + let mut token_ids = draft.token_ids.clone(); + token_ids.truncate(remaining); + VerifyStepItem::new(draft.request_id, token_ids, active.params) + }) + .collect() +} + fn build_prefill_items(pending: &[PendingRequest], indices: &[usize]) -> Vec { indices .iter() @@ -179,6 +253,37 @@ mod tests { } } + fn active(generated_count: usize, max_tokens: usize) -> ActiveRequestState { + let (token_tx, _rx) = openinfer_core::engine::TokenSink::standalone(); + ActiveRequestState { + request_id: RequestId::new(7), + lora_adapter: None, + token_tx, + last_token: 42, + generated_count, + max_tokens, + prompt_len: 10, + params: SamplingParams::default(), + logprobs: 0, + } + } + + #[test] + fn speculative_verify_items_clamp_to_remaining_output_budget() { + let active = [active(24, 32)]; + let draft = DraftRequestResult { + request_id: RequestId::new(7), + token_ids: (0..16).collect(), + }; + + let verify = build_speculative_verify_items(&active, &[draft]); + + assert_eq!(verify.len(), 1); + // 32 - 24 = 8 remaining → the 16-token span truncates to 8. + assert_eq!(verify[0].as_slice().len(), 8); + assert_eq!(verify[0].as_slice(), (0..8).collect::>()); + } + // The plan selector is the whole batch-formation policy: what the scheduler // does each tick is fully determined by (have_active, has_pending). Pin the // 2×2 truth table so a policy regression can't slip through silently. diff --git a/openinfer-qwen3-4b/src/scheduler/resolve.rs b/openinfer-qwen3-4b/src/scheduler/resolve.rs index 1f91dcd5..61058b1a 100644 --- a/openinfer-qwen3-4b/src/scheduler/resolve.rs +++ b/openinfer-qwen3-4b/src/scheduler/resolve.rs @@ -1,4 +1,5 @@ use crate::executor::{DecodeRequestResult, ModelExecutor, PrefillRequestResult}; +use crate::speculative::VerifyRequestResult; use openinfer_core::engine::FinishReason; use super::effects::{DecodeEffect, PendingEffect, PromptEchoEffect, ScheduledEffect, StepEffects}; @@ -22,6 +23,12 @@ pub(super) fn resolve_step( pending: Vec::new(), decode: resolve_decode_outputs(executor, active, &result.requests), }, + ExecutionArtifacts::SpeculativeDecode { verify } => StepEffects { + scheduled: Vec::new(), + prompt_echoes: Vec::new(), + pending: Vec::new(), + decode: resolve_speculative_outputs(executor, active, &verify.requests), + }, ExecutionArtifacts::Unified { pending, result, @@ -39,6 +46,54 @@ pub(super) fn resolve_step( } } +/// Turn each request's accepted speculative span into a decode effect. A span +/// commits 1..=K+1 tokens at once; we walk it in order so a stop token or the +/// max-output budget truncates exactly where it lands (the executor already +/// suppressed nothing — stop handling lives here, mirroring single-token decode). +pub(super) fn resolve_speculative_outputs( + executor: &impl ModelExecutor, + active: &[ActiveRequestState], + request_results: &[VerifyRequestResult], +) -> Vec { + request_results + .iter() + .map(|result| { + let req = active + .iter() + .find(|req| req.request_id == result.request_id) + .expect("speculative request_id must exist in active set"); + let mut emitted = Vec::new(); + let mut completion_tokens = req.generated_count; + for &token in &result.accepted_tokens { + completion_tokens += 1; + let is_eos = !req.params.ignore_eos && executor.is_stop_token(token); + if is_eos { + return DecodeEffect::EmitManyAndFinish { + request_id: result.request_id, + tokens: emitted, + finish_reason: FinishReason::Stop, + completion_tokens, + }; + } + emitted.push(token); + if completion_tokens >= req.max_tokens { + return DecodeEffect::EmitManyAndFinish { + request_id: result.request_id, + tokens: emitted, + finish_reason: FinishReason::Length, + completion_tokens, + }; + } + } + DecodeEffect::EmitManyAndContinue { + request_id: result.request_id, + tokens: emitted, + completion_tokens, + } + }) + .collect() +} + fn resolve_prefill_outputs( executor: &impl ModelExecutor, pending: Vec, diff --git a/openinfer-qwen3-4b/src/scheduler/tests.rs b/openinfer-qwen3-4b/src/scheduler/tests.rs index 72eec8d2..bcd96113 100644 --- a/openinfer-qwen3-4b/src/scheduler/tests.rs +++ b/openinfer-qwen3-4b/src/scheduler/tests.rs @@ -28,6 +28,7 @@ struct FakeExecutor { loaded_lora_adapters: HashSet, dropped: Arc>>, prefetch_offers: Arc>>, + stop_token: Option, } impl FakeExecutor { @@ -44,9 +45,15 @@ impl FakeExecutor { loaded_lora_adapters: HashSet::new(), dropped, prefetch_offers: Arc::new(Mutex::new(Vec::new())), + stop_token: None, } } + fn with_stop_token(mut self, token: u32) -> Self { + self.stop_token = Some(token); + self + } + fn with_decode_failure(mut self) -> Self { self.fail_decode_once = true; self @@ -125,8 +132,8 @@ impl ModelExecutor for FakeExecutor { self.available_blocks } - fn is_stop_token(&self, _token_id: u32) -> bool { - false + fn is_stop_token(&self, token_id: u32) -> bool { + self.stop_token == Some(token_id) } fn drop_request(&mut self, request_id: RequestId) -> Result<()> { @@ -174,6 +181,7 @@ impl ModelExecutor for FakeExecutor { .iter() .map(|req| self.fake_prefill_result(req)) .collect(), + dflash_context_captured_requests: Vec::new(), }) } @@ -1019,3 +1027,195 @@ fn lora_control_waits_until_scheduler_idle() { .expect_err("adapter load should be a stub error"); assert!(matches!(error, EngineControlError::OperationFailed(_))); } + +// ── Speculative span resolution (multi-token emission) ────────────────────── +// resolve_speculative_outputs walks an accepted span [t_0, .., t_m] and decides, +// per request, whether to emit-and-continue or truncate at a stop token / the +// max-output budget. These were GPU-test-only; the truth table below pins the +// branch behaviour with the existing FakeExecutor (only is_stop_token matters). + +use crate::speculative::VerifyRequestResult; +use openinfer_core::engine::FinishReason; + +fn spec_active( + id: u64, + generated_count: usize, + max_tokens: usize, + ignore_eos: bool, +) -> ActiveRequestState { + let (token_tx, _rx) = TokenSink::standalone(); + ActiveRequestState { + request_id: RequestId(id), + lora_adapter: None, + token_tx, + last_token: 1, + generated_count, + max_tokens, + prompt_len: 16, + params: SamplingParams { + ignore_eos, + ..SamplingParams::default() + }, + logprobs: 0, + } +} + +fn spec_result(id: u64, accepted: Vec) -> VerifyRequestResult { + VerifyRequestResult { + request_id: RequestId(id), + matched_draft_tokens: accepted.len().saturating_sub(1), + accepted_tokens: accepted, + } +} + +const SPEC_EOS: u32 = 99; + +#[test] +fn speculative_full_span_accept_continues() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 3, 100, false)]; + let results = [spec_result(1, vec![10, 11, 12, 13])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [ + effects::DecodeEffect::EmitManyAndContinue { + request_id, + tokens, + completion_tokens, + }, + ] => { + assert_eq!(*request_id, RequestId(1)); + assert_eq!(tokens, &vec![10, 11, 12, 13]); + assert_eq!( + *completion_tokens, + 3 + 4, + "completion = prior generated + span len" + ); + } + _ => panic!("expected EmitManyAndContinue"), + } +} + +#[test] +fn speculative_stop_token_midspan_finishes_and_suppresses_eos() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 5, 100, false)]; + // EOS lands at span position 2; tokens before it are emitted, EOS is not. + let results = [spec_result(1, vec![10, 11, SPEC_EOS, 13])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [ + effects::DecodeEffect::EmitManyAndFinish { + tokens, + finish_reason, + completion_tokens, + .. + }, + ] => { + assert_eq!( + tokens, + &vec![10, 11], + "EOS itself is suppressed from emission" + ); + assert!(matches!(finish_reason, FinishReason::Stop)); + assert_eq!( + *completion_tokens, + 5 + 3, + "the EOS still counts toward completion length" + ); + } + _ => panic!("expected EmitManyAndFinish(Stop)"), + } +} + +#[test] +fn speculative_stop_token_at_span_start_emits_nothing() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 5, 100, false)]; + let results = [spec_result(1, vec![SPEC_EOS, 11, 12])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [ + effects::DecodeEffect::EmitManyAndFinish { + tokens, + finish_reason, + completion_tokens, + .. + }, + ] => { + assert!(tokens.is_empty(), "stop at position 0 emits no tokens"); + assert!(matches!(finish_reason, FinishReason::Stop)); + assert_eq!(*completion_tokens, 5 + 1); + } + _ => panic!("expected EmitManyAndFinish(Stop)"), + } +} + +#[test] +fn speculative_max_tokens_truncates_midspan() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + // generated 8, budget 10 -> only 2 more tokens fit; the span offers 4. + let active = [spec_active(1, 8, 10, false)]; + let results = [spec_result(1, vec![10, 11, 12, 13])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [ + effects::DecodeEffect::EmitManyAndFinish { + tokens, + finish_reason, + completion_tokens, + .. + }, + ] => { + assert_eq!( + tokens, + &vec![10, 11], + "the budget-hitting token is emitted, the rest dropped" + ); + assert!(matches!(finish_reason, FinishReason::Length)); + assert_eq!( + *completion_tokens, 10, + "completion stops exactly at max_tokens" + ); + } + _ => panic!("expected EmitManyAndFinish(Length)"), + } +} + +#[test] +fn speculative_ignore_eos_does_not_stop() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 0, 100, true)]; + let results = [spec_result(1, vec![SPEC_EOS, SPEC_EOS])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [effects::DecodeEffect::EmitManyAndContinue { tokens, .. }] => { + assert_eq!( + tokens, + &vec![SPEC_EOS, SPEC_EOS], + "ignore_eos passes stop tokens through" + ); + } + _ => panic!("expected EmitManyAndContinue"), + } +} + +#[test] +fn speculative_resolves_each_request_independently() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 0, 100, false), spec_active(2, 0, 100, false)]; + let results = [ + spec_result(1, vec![10, 11]), // continues + spec_result(2, vec![20, SPEC_EOS]), // finishes on EOS + ]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + assert!(matches!( + &effects[0], + effects::DecodeEffect::EmitManyAndContinue { request_id, .. } if *request_id == RequestId(1) + )); + assert!(matches!( + &effects[1], + effects::DecodeEffect::EmitManyAndFinish { request_id, finish_reason: FinishReason::Stop, .. } + if *request_id == RequestId(2) + )); +} diff --git a/openinfer-qwen3-4b/src/speculative.rs b/openinfer-qwen3-4b/src/speculative.rs new file mode 100644 index 00000000..fe5f8a7d --- /dev/null +++ b/openinfer-qwen3-4b/src/speculative.rs @@ -0,0 +1,264 @@ +//! Method-agnostic core of **greedy** speculative decoding for Qwen3. +//! +//! Speculative decoding is an optimistic-concurrency transaction over the +//! decode loop: *propose* a span of `K` cheap draft tokens, *verify* them with a +//! single target forward over the `K + 1` span positions, *accept* the longest +//! prefix the target agrees with (plus one bonus token), then *commit* the +//! accepted KV and roll back the rejected draft KV. Only the **propose** step +//! varies between methods (n-gram lookup, DFlash draft model, EAGLE, …); verify, +//! accept, and the KV transaction are shared. +//! +//! This module owns the shared half. The draft/verify boundary is a **pure +//! token span** — a model proposer's hidden states never cross it; they stay +//! inside the proposer (see [`crate::dflash`]). DFlash is the only proposer +//! today and is kept concrete; a proposer trait is deferred until a second +//! implementation (n-gram / EAGLE) validates the shape. +//! +//! What is *not* generic yet: [`accept_greedy`] returns argmax-based acceptance, +//! which is the *greedy* rule. Sampling-correct acceptance would need the target +//! and draft distributions, touching both the verify path and the proposer — so +//! it is left until a sampling method actually lands. + +use anyhow::Result; + +use crate::executor::RequestId; +use openinfer_core::sampler::SamplingParams; + +/// One request's verify span: the current dangling token followed by the draft +/// candidates (`token_ids[0]` is the confirmed last token, `token_ids[1..]` are +/// the `K` drafts). Token-only by construction — the proposer that produced the +/// drafts keeps any hidden state to itself. +#[derive(Clone)] +pub(crate) struct VerifyStepItem { + pub(crate) request_id: RequestId, + pub(crate) token_ids: Vec, + pub(crate) params: SamplingParams, +} + +impl VerifyStepItem { + pub(crate) fn new(request_id: RequestId, token_ids: Vec, params: SamplingParams) -> Self { + Self { + request_id, + token_ids, + params, + } + } + + pub(crate) fn as_slice(&self) -> &[u32] { + &self.token_ids + } +} + +pub(crate) struct VerifyPlan<'a> { + pub requests: &'a [VerifyStepItem], +} + +#[derive(Clone, Debug)] +pub(crate) struct VerifyRequestResult { + pub request_id: RequestId, + /// Number of draft candidates accepted before the posterior bonus. + pub matched_draft_tokens: usize, + /// Tokens to commit: the accepted draft prefix followed by the target's + /// posterior token at the first mismatch (or the block-end continuation + /// when every draft is accepted). Always `1..=K + 1` tokens, so a verify + /// step always makes at least one token of progress. The scheduler still + /// owns stop-token suppression before client emission. + pub accepted_tokens: Vec, +} + +pub(crate) struct VerifyResult { + pub requests: Vec, +} + +/// One request's draft request: the proposer continues from `current_token`. +#[derive(Clone)] +pub(crate) struct DraftStepItem { + pub(crate) request_id: RequestId, + pub(crate) current_token: u32, + pub(crate) params: SamplingParams, +} + +impl DraftStepItem { + pub(crate) fn new(request_id: RequestId, current_token: u32, params: SamplingParams) -> Self { + Self { + request_id, + current_token, + params, + } + } +} + +pub(crate) struct DraftPlan<'a> { + pub requests: &'a [DraftStepItem], +} + +#[derive(Clone, Debug)] +pub(crate) struct DraftRequestResult { + pub request_id: RequestId, + /// Verify-span tokens: current dangling token first, then draft candidates. + pub token_ids: Vec, +} + +pub(crate) struct DraftResult { + pub requests: Vec, +} + +/// Greedy speculative acceptance — the shared seam every method reuses. +/// +/// * `proposed` — the `K` candidate tokens from the proposer. +/// * `target_argmax` — the target model's greedy token at each of the `K + 1` +/// verify positions. `target_argmax[i]` is the model's prediction *after* +/// consuming verify input `i`; `target_argmax[0]` follows the last confirmed +/// token and `target_argmax[K]` is the model's own continuation after the +/// whole candidate run. +/// +/// Returns the longest accepted prefix of `proposed` followed by exactly one +/// model token (the correction at the first divergence, or the bonus +/// continuation when every candidate is accepted) — always `1..=K + 1` tokens. +/// +/// # Panics +/// Panics (debug builds) if `target_argmax.len() != proposed.len() + 1`. +#[must_use] +pub(crate) fn accept_greedy(proposed: &[u32], target_argmax: &[u32]) -> Vec { + debug_assert_eq!( + target_argmax.len(), + proposed.len() + 1, + "verify must produce one greedy token per candidate plus a bonus" + ); + let n = num_accepted(proposed, target_argmax); + let mut committed = Vec::with_capacity(n + 1); + committed.extend_from_slice(&proposed[..n]); + // The model's own token at the first divergence (or the bonus continuation + // when the whole run was accepted). `n <= proposed.len() < target_argmax.len()` + // so this index is always valid. + committed.push(target_argmax[n]); + committed +} + +/// Length of the accepted prefix: leading drafts whose token matches the +/// target's argmax. +fn num_accepted(proposed: &[u32], target_argmax: &[u32]) -> usize { + let mut i = 0; + while i < proposed.len() && proposed[i] == target_argmax[i] { + i += 1; + } + i +} + +/// Batched greedy acceptance over a verify forward's flattened per-position +/// argmax. `target_tokens` is the concatenation of each request's `K + 1` +/// posterior columns, in `requests` order. Each request applies the shared +/// [`accept_greedy`] over its own span. +pub(crate) fn build_verify_results( + requests: &[VerifyStepItem], + target_tokens: &[u32], +) -> Result> { + let mut outputs = Vec::with_capacity(requests.len()); + let mut offset = 0usize; + for req in requests { + let span_len = req.token_ids.len(); + anyhow::ensure!( + span_len > 0, + "speculative verify request {:?} has an empty verify span", + req.request_id + ); + let end = offset + span_len; + anyhow::ensure!( + end <= target_tokens.len(), + "speculative target-token result is shorter than the verify span" + ); + let posterior = &target_tokens[offset..end]; + // proposed = the K drafts (span minus the leading confirmed token); + // posterior = the K + 1 argmax columns. accept_greedy ties them together. + let accepted_tokens = accept_greedy(&req.token_ids[1..], posterior); + outputs.push(VerifyRequestResult { + request_id: req.request_id, + matched_draft_tokens: accepted_tokens.len() - 1, + accepted_tokens, + }); + offset = end; + } + anyhow::ensure!( + offset == target_tokens.len(), + "unused speculative target-token result columns: used {offset}, total {}", + target_tokens.len() + ); + Ok(outputs) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn accepts_full_run_plus_bonus() { + let proposed = [10u32, 11, 12]; + let argmax = [10u32, 11, 12, 13]; + assert_eq!(accept_greedy(&proposed, &argmax), vec![10, 11, 12, 13]); + assert_eq!(num_accepted(&proposed, &argmax), 3); + } + + #[test] + fn accepts_prefix_then_correction() { + let proposed = [10u32, 11, 99]; + let argmax = [10u32, 11, 22, 33]; + assert_eq!(accept_greedy(&proposed, &argmax), vec![10, 11, 22]); + assert_eq!(num_accepted(&proposed, &argmax), 2); + } + + #[test] + fn rejects_first_candidate_commits_one() { + let proposed = [10u32, 11, 12]; + let argmax = [7u32, 8, 9, 10]; + assert_eq!(accept_greedy(&proposed, &argmax), vec![7]); + assert_eq!(num_accepted(&proposed, &argmax), 0); + } + + #[test] + fn empty_proposal_commits_model_token() { + let proposed: [u32; 0] = []; + let argmax = [42u32]; + assert_eq!(accept_greedy(&proposed, &argmax), vec![42]); + assert_eq!(num_accepted(&proposed, &argmax), 0); + } + + #[test] + fn always_commits_at_least_one_token() { + let proposed = [1u32, 2]; + let argmax = [9u32, 9, 9]; + assert!(!accept_greedy(&proposed, &argmax).is_empty()); + } + + #[test] + fn batched_accepts_matching_prefix_plus_posterior_bonus() { + let req = VerifyStepItem::new( + RequestId(7), + vec![10, 11, 12, 13], + SamplingParams::default(), + ); + let results = build_verify_results(&[req], &[11, 12, 99, 100]).expect("verify results"); + assert_eq!(results.len(), 1); + assert_eq!(results[0].request_id, RequestId(7)); + assert_eq!(results[0].matched_draft_tokens, 2); + assert_eq!(results[0].accepted_tokens, vec![11, 12, 99]); + } + + #[test] + fn batched_all_match_still_adds_block_end_posterior() { + let req = VerifyStepItem::new(RequestId(8), vec![20, 21, 22], SamplingParams::default()); + let results = build_verify_results(&[req], &[21, 22, 23]).expect("verify results"); + assert_eq!(results[0].matched_draft_tokens, 2); + assert_eq!(results[0].accepted_tokens, vec![21, 22, 23]); + } + + #[test] + fn batched_multi_request_splits_columns_by_span() { + let a = VerifyStepItem::new(RequestId(1), vec![5, 6], SamplingParams::default()); + let b = VerifyStepItem::new(RequestId(2), vec![7, 8, 9], SamplingParams::default()); + // a: posterior [6, 100] -> accept draft 6, bonus 100. b: posterior [8, 77, 0] + // -> accept draft 8, correction 77. + let results = build_verify_results(&[a, b], &[6, 100, 8, 77, 0]).expect("verify results"); + assert_eq!(results[0].accepted_tokens, vec![6, 100]); + assert_eq!(results[1].accepted_tokens, vec![8, 77]); + } +} diff --git a/openinfer-qwen3-4b/src/verify_graph.rs b/openinfer-qwen3-4b/src/verify_graph.rs new file mode 100644 index 00000000..af6d4b39 --- /dev/null +++ b/openinfer-qwen3-4b/src/verify_graph.rs @@ -0,0 +1,422 @@ +//! Fixed, pre-allocated buffers for the DFlash speculative *verify* forward. +//! +//! The verify forward runs a target prefill over each active request's `span = +//! num_speculative_tokens + 1` token block (see [`super::executor`]'s +//! `SpeculativeVerify` handler). The default [`Qwen3Model::batch_prefill`] path +//! allocates fresh GPU scratch every step (`PrefillBuffers::new`, the embedding +//! `HiddenStates`, `all_logits`, and the `PrefillPagedPlan` upload). That churns +//! `cuMemAllocAsync`/`cuMemFreeAsync` and — more importantly — hands CUDA Graph +//! capture moving pointers. +//! +//! [`VerifyGraphBuffers`] pre-allocates all of that once at the worst-case shape +//! (`max_batch * span` rows) and refills it in place each step, then captures the +//! forward into a **piecewise** CUDA Graph: the dense ops (embedding, RMSNorm, +//! every GEMM, SwiGLU, residual adds — ~84% of the per-step kernel-launch gap) +//! are captured per segment and replayed, while the attention op runs EAGER +//! between segments. Attention must stay eager because FlashInfer's paged-prefill +//! kernel fixes its KV-iteration count when the graph is recorded; with the verify +//! context growing every step, a captured attention would under-read KV and +//! corrupt later tokens. The dense segments are shape-stable in the fixed +//! `span`-row layout, so one capture replays losslessly for the request's life. + +use anyhow::Result; +use cudarc::driver::CudaSlice; + +use openinfer_core::cuda_graph::CudaGraphState; +use openinfer_core::kv_pool::KvLayout; +use openinfer_core::ops; +use openinfer_core::ops::PrefillPagedPlan; +use openinfer_core::tensor::HiddenStates; +use openinfer_kv_cache::KvView; + +use crate::batch_decode_buffers::BATCH_BUCKETS; +use crate::config::PREFILL_ATTENTION_CTA_TILE_Q; +use crate::prefill::PrefillBuffers; +use crate::weights::Qwen3Model; + +/// All GPU scratch the verify forward needs, sized once for `max_batch * span` +/// rows and reused (in place) every step. Pointer-stable for CUDA Graph capture. +pub(crate) struct VerifyGraphBuffers { + /// Per-layer projection/attention scratch (reused exactly as the allocating + /// prefill path uses it). + prefill_bufs: PrefillBuffers, + /// Residual-stream hidden states `[hidden_dim, max_total_rows]`. + hidden: HiddenStates, + /// Captured target hidden states for the DFlash layers, + /// `[hidden_size * num_capture_layers, max_total_rows]`. + captured_hidden: HiddenStates, + /// RMS-norm output feeding the lm_head GEMM `[hidden_dim, max_total_rows]`. + all_logits_normed: HiddenStates, + /// All-position logits `[vocab, max_total_rows]` (the verify forward's output). + all_logits: HiddenStates, + /// Device-resident concatenated verify tokens `[max_total_rows]`. + token_ids_d: CudaSlice, + /// Paged-attention plan, refilled in place each step. + plan: PrefillPagedPlan, + /// Piecewise CUDA Graphs: `graphs[bucket_idx][segment]`. Each bucket's verify + /// forward is split into `num_layers + 1` dense segments (attention runs eager + /// between them); a segment is captured once at its exact-bucket batch and + /// replayed thereafter. Empty (`CudaGraphState::new()`) until first captured. + graphs: Vec>, + max_batch: usize, + span: usize, +} + +impl VerifyGraphBuffers { + /// Allocate verify scratch for up to `max_batch` requests, each a fixed + /// `span`-token block. `num_capture_layers` is the DFlash target-layer count + /// (the captured-hidden buffer holds one `hidden_size` slice per layer). + /// `max_total_pages` bounds the paged-attention page list; pass the KV + /// pool's total block count for a guaranteed worst case. + pub(crate) fn new( + model: &Qwen3Model, + max_batch: usize, + span: usize, + num_capture_layers: usize, + max_total_pages: usize, + ) -> Result { + anyhow::ensure!(max_batch > 0, "verify buffers need max_batch >= 1"); + anyhow::ensure!(span > 0, "verify buffers need span >= 1"); + let ctx = model.device_ctx(); + let hidden_dim = model.config().hidden_size; + let q_dim = model.local_q_dim(); + let kv_dim = model.local_kv_dim(); + let inter_dim = model.local_intermediate_size(); + let vocab = model.config().vocab_size; + let max_total_rows = max_batch * span; + + // Each request's `span` query tokens fan out to `span * group_size` + // packed-QO rows; with a CTA tile of at least 1, that bounds tiles per + // request. `max_batch * span * group_size` is the conservative ceiling. + let group_size = model.local_num_attention_heads() / model.local_num_key_value_heads(); + let max_tiles = max_batch * span * group_size.max(1); + + Ok(Self { + prefill_bufs: PrefillBuffers::new( + ctx, + hidden_dim, + q_dim, + kv_dim, + inter_dim, + max_total_rows, + )?, + hidden: HiddenStates::zeros(ctx, hidden_dim, max_total_rows)?, + captured_hidden: HiddenStates::zeros( + ctx, + hidden_dim * num_capture_layers.max(1), + max_total_rows, + )?, + all_logits_normed: HiddenStates::zeros(ctx, hidden_dim, max_total_rows)?, + all_logits: HiddenStates::zeros(ctx, vocab, max_total_rows)?, + token_ids_d: ctx.stream.alloc_zeros(max_total_rows)?, + plan: PrefillPagedPlan::new_preallocated( + ctx, + max_total_rows, + max_total_pages, + max_batch, + max_tiles, + )?, + // num_layers + 1 dense segments per bucket (attention is eager between). + graphs: BATCH_BUCKETS + .iter() + .map(|_| { + (0..model.config().num_hidden_layers + 1) + .map(|_| CudaGraphState::new()) + .collect() + }) + .collect(), + max_batch, + span, + }) + } + + /// Point every buffer's logical extent at `total_rows` (`<= max capacity`). + /// Like [`PrefillBuffers`] / [`super::batch_decode_buffers`], this only moves + /// `seq_len`; it never reallocates. + fn set_rows(&mut self, total_rows: usize) { + let cap = self.max_batch * self.span; + assert!( + total_rows <= cap, + "verify total_rows {total_rows} exceeds capacity {cap}" + ); + self.prefill_bufs.set_rows(total_rows); + self.hidden.seq_len = total_rows; + self.captured_hidden.seq_len = total_rows; + self.all_logits_normed.seq_len = total_rows; + self.all_logits.seq_len = total_rows; + } + + /// All-position logits `[vocab, total_rows]` from the last forward. + pub(crate) fn all_logits(&self) -> &HiddenStates { + &self.all_logits + } + + /// Captured target hidden states `[hidden_size * num_capture_layers, total_rows]`. + pub(crate) fn captured_hidden(&self) -> &HiddenStates { + &self.captured_hidden + } +} + +impl Qwen3Model { + /// Fixed-buffer, piecewise-CUDA-Graph twin of [`Qwen3Model::batch_prefill`] + /// for the DFlash verify forward. Issues the same per-op kernels as the + /// allocating path (split into `forward_layer_pre_attn` / `forward_layer_attn` + /// / `forward_layer_post_attn`), so the all-position logits and captured hidden + /// states match it; only the buffer *source* differs (reused vs. freshly + /// allocated) and the dense ops replay from a graph. Results land in `bufs` + /// (`all_logits()` / `captured_hidden()`). + /// + /// `capture_layer_ids` must be the strictly-increasing DFlash target layers + /// whose count matches the `num_capture_layers` `bufs` was built with. + pub(crate) fn batch_prefill_into( + &self, + prompts: &[&[u32]], + kv_views: &[KvView], + kv_buffer: &CudaSlice, + layout: &KvLayout, + capture_layer_ids: &[usize], + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + let batch_size = prompts.len(); + anyhow::ensure!( + batch_size == kv_views.len(), + "verify prompts ({batch_size}) and kv_views ({}) length mismatch", + kv_views.len() + ); + anyhow::ensure!( + batch_size <= bufs.max_batch, + "verify batch {batch_size} exceeds buffer capacity {}", + bufs.max_batch + ); + anyhow::ensure!( + capture_layer_ids.windows(2).all(|pair| pair[0] < pair[1]), + "verify capture layer ids must be strictly increasing" + ); + anyhow::ensure!( + capture_layer_ids + .iter() + .all(|&layer| layer < self.layers.len()), + "verify capture layer id out of range" + ); + let expected_capture_dim = self.config().hidden_size * capture_layer_ids.len().max(1); + anyhow::ensure!( + bufs.captured_hidden.hidden_dim == expected_capture_dim, + "verify capture buffer dim {} does not match {} capture layers", + bufs.captured_hidden.hidden_dim, + capture_layer_ids.len(), + ); + + let seq_lens: Vec = prompts.iter().map(|p| p.len()).collect(); + let total_tokens: usize = seq_lens.iter().sum(); + anyhow::ensure!(total_tokens > 0, "verify forward has no tokens"); + let start_positions: Vec = kv_views + .iter() + .zip(prompts.iter()) + .map(|(v, p)| v.seq_len() - p.len()) + .collect(); + + bufs.set_rows(total_tokens); + + // --- prep: H2D staging that MUST stay outside the graph capture (CUDA + // Graph forbids host round-trips in a captured segment). The embedding + // kernel itself runs inside graph segment 0 and reads this buffer. --- + let all_tokens: Vec = prompts.iter().flat_map(|p| p.iter().copied()).collect(); + anyhow::ensure!( + all_tokens.len() == total_tokens, + "verify token concat {} != total_tokens {total_tokens}", + all_tokens.len() + ); + let ctx = self.device_ctx(); + // Stage the active tokens into the front of the fixed device buffer; the + // embedding kernel reads exactly `total_tokens` ids from its base pointer, + // so the unused tail is never touched. + ctx.stream.memcpy_htod(&all_tokens, &mut bufs.token_ids_d)?; + + // Refill the paged plan in place (same host math as the allocating path). + let page_indices: Vec> = + kv_views.iter().map(|v| v.page_indices().to_vec()).collect(); + let last_page_lens: Vec = kv_views + .iter() + .map(openinfer_kv_cache::KvView::last_page_len) + .collect(); + bufs.plan.update_batch_with_cta_tile_q( + ctx, + &page_indices, + &last_page_lens, + &start_positions, + &seq_lens, + self.local_num_attention_heads(), + self.local_num_key_value_heads(), + self.config().head_dim, + PREFILL_ATTENTION_CTA_TILE_Q, + )?; + + // --- piecewise CUDA Graph: dense ops captured per segment, attention + // EAGER between segments. FlashInfer's prefill attention freezes its KV + // iteration count at capture time (it tracks the growing context), so + // capturing it corrupts later tokens; every other op is shape-stable in + // the fixed `span`-row layout. Segments: [embed + L0.pre] [L0.attn] + // [L0.post + L1.pre] [L1.attn] ... [L_last.post + lm_head]. --- + let num_layers = self.layers.len(); + match BATCH_BUCKETS.iter().position(|&b| b == batch_size) { + Some(bidx) => { + // Take the bucket's segment graphs out so the capture closures can + // borrow `bufs` mutably; restore them after (even on error). + let mut segs = std::mem::take(&mut bufs.graphs[bidx]); + let result = (|| -> Result<()> { + segs[0].run_or_capture(ctx, || self.verify_seg_embed_pre(bufs))?; + self.verify_attn(0, kv_buffer, layout, bufs)?; + for i in 1..num_layers { + segs[i].run_or_capture(ctx, || { + self.verify_seg_post_pre(i, capture_layer_ids, bufs) + })?; + self.verify_attn(i, kv_buffer, layout, bufs)?; + } + segs[num_layers].run_or_capture(ctx, || { + self.verify_seg_post_logits(capture_layer_ids, bufs) + })?; + Ok(()) + })(); + bufs.graphs[bidx] = segs; + result?; + } + None => { + // Off-bucket batch: run the same segments eager (no capture). + self.verify_seg_embed_pre(bufs)?; + self.verify_attn(0, kv_buffer, layout, bufs)?; + for i in 1..num_layers { + self.verify_seg_post_pre(i, capture_layer_ids, bufs)?; + self.verify_attn(i, kv_buffer, layout, bufs)?; + } + self.verify_seg_post_logits(capture_layer_ids, bufs)?; + } + } + + Ok(()) + } + + /// Graph segment 0: embedding (reads the staged `token_ids_d`) plus layer 0's + /// pre-attention dense ops. Verify never uses LoRA, so the LoRA group is empty. + fn verify_seg_embed_pre(&self, bufs: &mut VerifyGraphBuffers) -> Result<()> { + self.get_embeddings_batch_into(&bufs.token_ids_d, &mut bufs.hidden)?; + self.forward_layer_pre_attn( + 0, + &self.layers[0], + &bufs.hidden, + &[], + &mut bufs.prefill_bufs, + ) + } + + /// Eager attention for layer `i` — kept out of every graph (see + /// [`Self::forward_layer_attn`]). Touches only the fixed `prefill_bufs` and + /// the refilled `plan`. + fn verify_attn( + &self, + i: usize, + kv_buffer: &CudaSlice, + layout: &KvLayout, + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + self.forward_layer_attn( + i, + &self.layers[i], + kv_buffer, + layout, + &bufs.plan, + &mut bufs.prefill_bufs, + ) + } + + /// Middle graph segment `i` (`1..num_layers`): finish layer `i-1` + /// (post-attention dense + DFlash-context capture), then start layer `i` + /// (pre-attention dense). + /// + /// The ping-pong swap inside `post_attn` is graph-safe regardless of layer + /// parity: `run_or_capture` runs the closure (and thus the CPU-side swap) only + /// on the capture step, so the captured graph bakes the exact buffer pointers + /// for every op; replay just relaunches them. Each step's segment-0 embedding + /// overwrites the same baked buffer that this segment's first `post_attn` + /// reads, so no stale residual can leak across steps. The only live (eager) op + /// between segments — attention — touches just `q/k/v_batch` / `attn_output`, + /// which never participate in the swap, so it is independent of `hidden`'s + /// logical pointer. Parity only decides which physical buffer holds the final + /// hidden; that choice is baked into the last segment either way. + fn verify_seg_post_pre( + &self, + i: usize, + capture_layer_ids: &[usize], + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + let prev = i - 1; + self.forward_layer_post_attn( + prev, + &self.layers[prev], + &mut bufs.hidden, + &[], + &mut bufs.prefill_bufs, + )?; + self.verify_capture_if_needed(prev, capture_layer_ids, bufs)?; + self.forward_layer_pre_attn( + i, + &self.layers[i], + &bufs.hidden, + &[], + &mut bufs.prefill_bufs, + ) + } + + /// Final graph segment: finish the last layer (post-attention + capture), then + /// the all-position logits (final RMSNorm + lm_head GEMM) into `all_logits`. + fn verify_seg_post_logits( + &self, + capture_layer_ids: &[usize], + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + let last = self.layers.len() - 1; + self.forward_layer_post_attn( + last, + &self.layers[last], + &mut bufs.hidden, + &[], + &mut bufs.prefill_bufs, + )?; + self.verify_capture_if_needed(last, capture_layer_ids, bufs)?; + let ctx = self.device_ctx(); + ops::rms_norm_batch_into( + ctx, + &bufs.hidden, + &self.norm, + self.config().rms_norm_eps, + &mut bufs.all_logits_normed, + ); + ops::gemm_into( + ctx, + self.output_projection(), + &bufs.all_logits_normed, + &mut bufs.all_logits, + ); + Ok(()) + } + + /// Copy layer `layer_idx`'s residual-stream hidden into the captured-hidden + /// buffer when that layer is a DFlash target. `capture_layer_ids` is strictly + /// increasing, so its position is the capture slot. + fn verify_capture_if_needed( + &self, + layer_idx: usize, + capture_layer_ids: &[usize], + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + if let Some(slot) = capture_layer_ids.iter().position(|&l| l == layer_idx) { + let hidden_size = self.config().hidden_size; + ops::copy_hidden_rows_into( + self.device_ctx(), + &bufs.hidden, + &mut bufs.captured_hidden, + slot * hidden_size, + )?; + } + Ok(()) + } +} diff --git a/openinfer-qwen3-4b/src/weights.rs b/openinfer-qwen3-4b/src/weights.rs index 560f5ac2..bf909e3a 100644 --- a/openinfer-qwen3-4b/src/weights.rs +++ b/openinfer-qwen3-4b/src/weights.rs @@ -886,6 +886,7 @@ impl Qwen3Model { self.kv_budget_from_bytes( geometry, bytes_per_block, + 0, kv_budget_bytes, free_bytes, "heuristic", @@ -896,6 +897,7 @@ impl Qwen3Model { &self, max_prefill_tokens: usize, max_decode_batch_size: usize, + dflash_kv_bytes_per_token: usize, memory_options: Qwen3MemoryOptions, ) -> Result { let memory_options = memory_options.validate()?; @@ -1017,6 +1019,7 @@ impl Qwen3Model { Ok(self.kv_budget_from_bytes( geometry, bytes_per_block, + dflash_kv_bytes_per_token, kv_budget_bytes, initial_free_bytes, "profiled", @@ -1049,11 +1052,18 @@ impl Qwen3Model { &self, mut geometry: KvBudget, bytes_per_block: usize, + dflash_kv_bytes_per_token: usize, kv_budget_bytes: usize, free_bytes: usize, source: &'static str, ) -> KvBudget { - let num_blocks = (kv_budget_bytes / bytes_per_block).max(64); + // DFlash keeps its own per-request KV (plus prompt-scaling scratch) outside + // the paged pool, scaling with the same token count. Charge it as extra + // bytes per pool token so the target block count shrinks to leave room; the + // pool itself is still allocated at the target-only `bytes_per_block`. + let effective_bytes_per_block = + bytes_per_block + dflash_kv_bytes_per_token * geometry.block_size; + let num_blocks = (kv_budget_bytes / effective_bytes_per_block).max(64); let kv_mb = num_blocks * bytes_per_block / (1024 * 1024); log::info!( "KV cache ({source}): {num_blocks} blocks ({kv_mb} MB, {:.0}% of {:.0} MB free)", diff --git a/openinfer-qwen3-4b/tests/batch_invariance_reject.rs b/openinfer-qwen3-4b/tests/batch_invariance_reject.rs index e73785c4..37489503 100644 --- a/openinfer-qwen3-4b/tests/batch_invariance_reject.rs +++ b/openinfer-qwen3-4b/tests/batch_invariance_reject.rs @@ -27,6 +27,7 @@ fn batch_invariant_rejects_decode_overlap() { Qwen3MemoryOptions::default(), DecodeOverlap::SharedSm, true, + None, ) .err() .expect("--batch-invariant + --decode-overlap must be rejected"); diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs new file mode 100644 index 00000000..2f2127b5 --- /dev/null +++ b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs @@ -0,0 +1,442 @@ +//! DFlash speculative-decoding losslessness gate. +//! +//! Greedy speculative decoding must be *lossless*: every draft is verified by a +//! target forward and only the matching-argmax prefix (plus one bonus) is +//! committed, so the accepted tokens are the target model's own greedy +//! continuation. The catch is pure numerics — the verify path runs the +//! *prefill* attention kernel over the K+1 span while a plain decode runs the +//! *decode* kernel, and the two differ by ~1 bf16 ULP. On a near-tie that flips +//! one argmax, and from there two greedy runs fan out completely. +//! +//! So an exact `spec == baseline` token match is the wrong gate: it false-fails +//! on a benign tie flip. We use a *regret* test like `hf_golden_gate`: at the +//! first position the two sequences disagree (where they still share an +//! identical context, so the comparison is valid) we ask how far below the +//! argmax the speculative pick sits — measured *in the prefill kernel's own +//! distribution*, because that is the kernel the verify path runs. A re-prefill +//! of the shared context (`prefill_next`) gives that reference distribution. +//! The verify path's committed KV is built incrementally across batched +//! speculative spans, while a one-shot prefill builds it in a single forward; +//! the two K/V differ by a few bf16 ULP, so on a near-tie the argmax flips. +//! Within `MARGIN_TOL` of the prefill argmax ⇒ a benign numerical tie. Clearly +//! worse (or outside the prefill top-K) ⇒ the verify/accept/capture logic chose +//! a token the forward never favored — a real bug. A systematic bug corrupts +//! the non-tie positions too, so it cannot hide behind the tie band. +//! +//! (Empirically the one prompt that flips — "The capital of France is" — sits on +//! a Germany-vs-Paris near-tie: the prefill kernel scores them -0.71 vs -0.83, +//! a 0.12-nat gap, well inside `MARGIN_TOL`. The other four prompts are bit +//! identical. A real verify bug would not single out the one degenerate prompt.) +//! +//! The baseline runs with logprobs on (plain decode); the speculative engine +//! runs with logprobs off (logprobs force the spec path off by design), so it +//! reports chosen tokens only — exactly what the regret check needs. +//! +//! Runs the two engines sequentially (baseline dropped before the speculative +//! engine loads) so only one Qwen3-4B is resident at a time. +//! +//! Requires a CUDA GPU, Qwen3-4B weights, and the DFlash drafter. Set +//! `OPENINFER_TEST_MODEL_PATH` (target) and `OPENINFER_DFLASH_TEST_MODEL_PATH` +//! (drafter); skips cleanly when either is absent. + +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use openinfer_core::engine::{EngineHandle, GenerateRequest, TokenEvent, TokenSink}; +use openinfer_core::sampler::SamplingParams; +use openinfer_qwen3_4b::{ + DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES, DEFAULT_MAX_PREFILL_TOKENS, DecodeOverlap, + Qwen3LaunchOptions, Qwen3MemoryOptions, Qwen3OffloadOptions, +}; + +mod common; + +const MODEL_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B"); +const DRAFT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B-DFlash-b16"); +const GENERATED_TOKENS: usize = 64; +/// Top-K logprobs requested from the baseline; wide enough that the speculative +/// pick is in the set on any real tie (a pick outside the top-K is itself a red +/// flag the gate should catch). +const LOGPROBS: usize = 20; +/// Max acceptable regret: how far below the baseline's argmax (in the baseline's +/// own logprobs) the speculative pick may sit at the divergence point. ~3 bf16 +/// ULP at typical logit magnitudes — mirrors `hf_golden_gate`'s `MARGIN_TOL`. +const MARGIN_TOL: f32 = 0.20; + +/// Both tests launch a Qwen3-4B engine, and two at once overflow a 16 GB card. +/// Cargo runs tests in one binary concurrently, so serialize the engine-holding +/// bodies — only one engine is ever resident on the GPU. +static GPU: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +fn target_path_or_skip() -> Option { + match std::env::var("OPENINFER_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => { + Some(MODEL_PATH.to_string()) + } + Err(_) => { + eprintln!( + "skipping dflash gate: {MODEL_PATH}/config.json missing; set OPENINFER_TEST_MODEL_PATH" + ); + None + } + } +} + +fn draft_path_or_skip() -> Option { + match std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => { + Some(DRAFT_PATH.to_string()) + } + Err(_) => { + eprintln!( + "skipping dflash gate: {DRAFT_PATH}/config.json missing; set OPENINFER_DFLASH_TEST_MODEL_PATH" + ); + None + } + } +} + +fn launch_options(draft: Option) -> Qwen3LaunchOptions { + Qwen3LaunchOptions { + device_ordinal: 0, + tp_size: 1, + cuda_graph: true, + offload: Qwen3OffloadOptions::disabled(), + // The speculative engine forces the prefix cache off; match it on the + // baseline so both take the same cold prefill path. + no_prefix_cache: true, + max_prefill_tokens: DEFAULT_MAX_PREFILL_TOKENS, + memory: Qwen3MemoryOptions::new(0.85, DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES) + .validate() + .expect("valid memory options"), + lora: None, + decode_overlap: DecodeOverlap::Off, + batch_invariant: false, + dflash_draft_model_path: draft, + } +} + +/// One decoded position: the chosen token and (when requested) the top-K +/// `(token, logprob)` distribution that produced it. +struct Step { + id: u32, + top_logprobs: Vec<(u32, f32)>, +} + +/// Submit one greedy request and collect the decoded steps until `Finished`. +fn generate(handle: &EngineHandle, prompt_tokens: Vec, logprobs: usize) -> Vec { + let (token_tx, mut rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens, + params: SamplingParams::default(), + max_tokens: GENERATED_TOKENS, + lora_adapter: None, + token_tx, + logprobs, + echo: false, + }) + .expect("submit failed"); + + let mut steps = Vec::new(); + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { id, logprob }) => steps.push(Step { + id, + top_logprobs: logprob.map(|lp| lp.top_logprobs).unwrap_or_default(), + }), + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Finished { .. }) => return steps, + Some(TokenEvent::Error { message, .. }) => panic!("generation failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => panic!("generation rejected: {message}"), + None => panic!("scheduler channel closed without Finished"), + } + } +} + +/// Prefill `context` (echo) and return the next-token distribution the *prefill* +/// kernel produces — the kernel the speculative verify path also uses. This is +/// the reference the spec pick should match (vs the plain-decode baseline, whose +/// kernel resolves bifurcation ties to the other side). Returns the first +/// generated token's `(id, top_logprobs)`. +fn prefill_next(handle: &EngineHandle, context: Vec, logprobs: usize) -> Step { + let (token_tx, mut rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens: context, + params: SamplingParams::default(), + max_tokens: 1, + lora_adapter: None, + token_tx, + logprobs, + echo: true, + }) + .expect("submit failed"); + + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { id, logprob }) => { + return Step { + id, + top_logprobs: logprob.map(|lp| lp.top_logprobs).unwrap_or_default(), + }; + } + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Finished { .. }) => panic!("echo prefill finished without a token"), + Some(TokenEvent::Error { message, .. }) => panic!("prefill failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => panic!("prefill rejected: {message}"), + None => panic!("scheduler channel closed without a token"), + } + } +} + +#[test] +fn dflash_speculative_greedy_matches_plain_greedy() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + return; + }; + let _gpu = GPU.lock().unwrap_or_else(|p| p.into_inner()); + + let prompts = [ + "The capital of France is", + "Here is a short story about a dragon. Once upon a time", + "def fibonacci(n):", + "Q: What is 17 multiplied by 23? A: Let's think step by step.", + "The three primary colors are", + ]; + + let tokenizer = common::load_tokenizer(&model_path); + let encoded: Vec> = prompts + .iter() + .map(|p| tokenizer.encode(p, false).expect("encode failed")) + .collect(); + + // 1. Baseline: plain greedy decode (speculative off), with logprobs so the + // regret check has the reference distribution at the divergence point. + let baseline: Vec> = { + let handle = openinfer_qwen3_4b::launch(Path::new(&model_path), launch_options(None)) + .expect("failed to start baseline engine"); + let out = encoded + .iter() + .map(|t| generate(&handle, t.clone(), LOGPROBS)) + .collect(); + drop(handle); + // Let the scheduler thread tear down and free GPU memory before the + // speculative engine loads the same 8 GB target. + std::thread::sleep(Duration::from_secs(2)); + out + }; + + // 2. Speculative: DFlash draft + verify (logprobs off ⇒ spec path active). + // Keep the engine alive through analysis: at a divergence we re-prefill + // the shared context to read the prefill-kernel reference (the kernel the + // verify path uses), which the plain-decode baseline cannot provide. + let handle = openinfer_qwen3_4b::launch( + Path::new(&model_path), + launch_options(Some(PathBuf::from(&draft_path))), + ) + .expect("failed to start speculative engine"); + + let mut failures = Vec::new(); + for (i, prompt) in prompts.iter().enumerate() { + let base = &baseline[i]; + let spec = generate(&handle, encoded[i].clone(), 0); + let matched = base + .iter() + .zip(&spec) + .take_while(|(b, s)| b.id == s.id) + .count(); + + // Identical sequences (or one a prefix of the other): perfectly lossless. + if matched == base.len().min(spec.len()) { + eprintln!( + "prompt {i} ({prompt:?}): {matched}/{} tokens identical (100% lossless)", + base.len() + ); + continue; + } + + let spec_id = spec[matched].id; + let decode_argmax = base[matched].top_logprobs[0].0; + + // Diagnostic: show the exact branch point. + { + let lo = matched.saturating_sub(2); + let hi = (matched + 3).min(base.len()).min(spec.len()); + let base_ids: Vec = base[..hi].iter().map(|s| s.id).collect(); + let spec_ids: Vec = spec[..hi].iter().map(|s| s.id).collect(); + eprintln!(" [diag] prompt {i} matched={matched}"); + eprintln!( + " [diag] context+gen base ids {:?} = {:?}", + &base_ids, + tokenizer.decode(&base_ids, false).unwrap_or_default() + ); + eprintln!( + " [diag] base[{lo}..{hi}] = {:?}", + base[lo..hi] + .iter() + .map(|s| (s.id, tokenizer.decode(&[s.id], false).unwrap_or_default())) + .collect::>() + ); + eprintln!( + " [diag] spec[{lo}..{hi}] = {:?}", + spec[lo..hi] + .iter() + .map(|s| (s.id, tokenizer.decode(&[s.id], false).unwrap_or_default())) + .collect::>() + ); + let _ = spec_ids; + } + + // The verify path runs the prefill kernel, so the right reference for the + // spec pick is a plain *prefill* of the same shared context — not the + // plain-decode baseline, whose kernel resolves a bifurcation tie to the + // other side and amplifies the gap. Build that context from the matched + // tokens and ask what the prefill kernel predicts next. + let mut context = encoded[i].clone(); + context.extend(base[..matched].iter().map(|s| s.id)); + let prefill_ref = prefill_next(&handle, context, LOGPROBS); + + if prefill_ref.id == spec_id { + // Spec faithfully reproduced the prefill-kernel greedy pick; the + // divergence is purely the pre-existing prefill-vs-decode kernel gap + // at a near-tie (here decode→{decode_argmax}, prefill→{spec_id}). + let decode_lp = base[matched] + .top_logprobs + .iter() + .find(|(t, _)| *t == spec_id) + .map(|(_, lp)| base[matched].top_logprobs[0].1 - lp); + eprintln!( + "prompt {i} ({prompt:?}): kernel-gap flip at token {matched} — verify(prefill)→{spec_id}, \ + decode→{decode_argmax}; spec matches prefill greedy (decode-margin {:?}). Not a spec bug.", + decode_lp + ); + continue; + } + + // Spec's greedy pick differs from the prefill-kernel argmax too. The + // verify path builds its committed KV incrementally across batched + // speculative spans, while this reference prefill builds it in one + // shot; the two differ by a few bf16 ULP. On a near-tie that flips the + // argmax — benign. So the deciding question is *how far* below the + // prefill argmax the spec pick sits IN THE PREFILL KERNEL'S OWN + // distribution (the kernel the verify path uses). Within MARGIN_TOL ⇒ + // a numerical tie flip, not a bug. Clearly worse ⇒ the verify/accept + // logic picked a token the forward never favored — a real bug. + let prefill_regret = prefill_ref + .top_logprobs + .iter() + .find(|(t, _)| *t == spec_id) + .map(|(_, lp)| prefill_ref.top_logprobs[0].1 - lp); + + if let Some(regret) = prefill_regret { + if regret <= MARGIN_TOL { + eprintln!( + "prompt {i} ({prompt:?}): tie flip at token {matched} — \ + verify(prefill)→{}, spec→{spec_id}, decode→{decode_argmax}; \ + spec pick is #2 in the prefill distribution (regret {regret:.3} ≤ {MARGIN_TOL}). \ + Not a spec bug.", + prefill_ref.id, + ); + continue; + } + } + + // Either the spec pick is outside the prefill top-K entirely, or it sits + // clearly below the prefill argmax — neither is a benign tie. + let decode_regret = base[matched] + .top_logprobs + .iter() + .find(|(t, _)| *t == spec_id) + .map(|(_, lp)| base[matched].top_logprobs[0].1 - lp); + failures.push(format!( + "prompt {i}: at token {matched} spec chose {spec_id} but prefill greedy says {} and \ + decode greedy says {decode_argmax} (spec regret in prefill dist: {prefill_regret:?} > \ + {MARGIN_TOL}; in decode dist: {decode_regret:?}) — real spec bug", + prefill_ref.id, + )); + } + + drop(handle); + + assert!( + failures.is_empty(), + "speculative greedy decode is not lossless:\n{}", + failures.join("\n") + ); +} + +/// P2 regression: a request that fits the target context window but lands in the +/// draft's `block_size` in-fill headroom (`max_pos - block_size < prompt + +/// max_tokens <= max_pos`) must be rejected cleanly at admission. Before the +/// admission cap, such a request was admitted on the target's limit and then +/// panicked mid-prefill when the draft allocated KV past its own max positions. +#[test] +fn dflash_request_in_draft_headroom_is_rejected_not_panicked() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + return; + }; + let _gpu = GPU.lock().unwrap_or_else(|p| p.into_inner()); + + // Read the real context window so the boundary is exact regardless of the + // checkpoint, then size the request to sit inside the draft's final in-fill + // block — it fits the target window but not the DFlash-effective one. + let config: serde_json::Value = serde_json::from_str( + &std::fs::read_to_string(Path::new(&model_path).join("config.json")).expect("read config"), + ) + .expect("parse config"); + let max_pos = config["max_position_embeddings"] + .as_u64() + .expect("max_position_embeddings") as usize; + const BLOCK_SIZE: usize = 16; // DFlash drafter block size. + let prompt_len = 16usize; + // total in (max_pos - BLOCK_SIZE, max_pos]: clears the target check, trips + // the DFlash admission cap (max_pos - BLOCK_SIZE). + let max_tokens = max_pos - BLOCK_SIZE / 2 - prompt_len; + + let handle = openinfer_qwen3_4b::launch( + Path::new(&model_path), + launch_options(Some(PathBuf::from(&draft_path))), + ) + .expect("failed to start speculative engine"); + + let (token_tx, mut rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens: vec![100u32; prompt_len], + params: SamplingParams::default(), + max_tokens, + lora_adapter: None, + token_tx, + logprobs: 0, + echo: false, + }) + .expect("submit failed"); + + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Rejected { message, .. }) => { + eprintln!("draft-headroom request rejected as expected: {message}"); + break; + } + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Token { .. } | TokenEvent::Finished { .. }) => { + panic!("draft-headroom request was admitted instead of rejected") + } + Some(TokenEvent::Error { message, .. }) => { + panic!( + "draft-headroom request errored mid-flight instead of clean rejection: {message}" + ) + } + None => panic!("scheduler channel closed without a rejection"), + } + } +} diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs new file mode 100644 index 00000000..ff0e017a --- /dev/null +++ b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs @@ -0,0 +1,168 @@ +//! DFlash speculative-decoding single-stream latency A/B. +//! +//! Speculative decoding's direct win is single-stream (batch=1) decode latency: +//! plain decode is memory-bound (one target forward per token), while spec +//! amortizes that forward over the accepted run. This measures end-to-end +//! wall-clock to generate a fixed token budget, speculative OFF vs ON, on the +//! same prompts and hardware, and reports the speedup. +//! +//! This is a measurement harness, not a pass/fail gate — it asserts only that +//! spec is not catastrophically slower (a guard against the draft mispredicting +//! everything). Read the printed numbers for the real signal. `--nocapture`. +//! +//! Requires a CUDA GPU, Qwen3-4B weights, and the DFlash drafter. Set +//! `OPENINFER_TEST_MODEL_PATH` + `OPENINFER_DFLASH_TEST_MODEL_PATH`; skips when +//! either is absent. Single-stream only — throughput under load is a separate +//! `vllm bench serve` A/B. + +use std::path::{Path, PathBuf}; +use std::time::{Duration, Instant}; + +use openinfer_core::engine::{EngineHandle, GenerateRequest, TokenEvent, TokenSink}; +use openinfer_core::sampler::SamplingParams; +use openinfer_qwen3_4b::{ + DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES, DEFAULT_MAX_PREFILL_TOKENS, DecodeOverlap, + Qwen3LaunchOptions, Qwen3MemoryOptions, Qwen3OffloadOptions, +}; + +mod common; + +const MODEL_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B"); +const DRAFT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B-DFlash-b16"); +const GENERATED_TOKENS: usize = 256; + +fn target_path_or_skip() -> Option { + match std::env::var("OPENINFER_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => { + Some(MODEL_PATH.to_string()) + } + Err(_) => None, + } +} + +fn draft_path_or_skip() -> Option { + match std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => { + Some(DRAFT_PATH.to_string()) + } + Err(_) => None, + } +} + +fn launch_options(draft: Option) -> Qwen3LaunchOptions { + Qwen3LaunchOptions { + device_ordinal: 0, + tp_size: 1, + cuda_graph: true, + offload: Qwen3OffloadOptions::disabled(), + no_prefix_cache: true, + max_prefill_tokens: DEFAULT_MAX_PREFILL_TOKENS, + memory: Qwen3MemoryOptions::new(0.85, DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES) + .validate() + .expect("valid memory options"), + lora: None, + decode_overlap: DecodeOverlap::Off, + batch_invariant: false, + dflash_draft_model_path: draft, + } +} + +/// Generate `GENERATED_TOKENS` greedily and return (token_count, elapsed). +fn timed_generate(handle: &EngineHandle, prompt_tokens: Vec) -> (usize, Duration) { + let (token_tx, mut rx) = TokenSink::standalone(); + let start = Instant::now(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens, + params: SamplingParams { + ignore_eos: true, + ..SamplingParams::default() + }, + max_tokens: GENERATED_TOKENS, + lora_adapter: None, + token_tx, + logprobs: 0, + echo: false, + }) + .expect("submit failed"); + + let mut count = 0usize; + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { .. }) => count += 1, + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Finished { .. }) => return (count, start.elapsed()), + Some(TokenEvent::Error { message, .. }) => panic!("generation failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => panic!("generation rejected: {message}"), + None => panic!("scheduler channel closed without Finished"), + } + } +} + +/// Decode tok/s averaged over the prompts (one warm-up run discarded). +fn measure(handle: &EngineHandle, prompts: &[Vec]) -> f64 { + // Warm up CUDA-graph capture / allocator on the first prompt. + let _ = timed_generate(handle, prompts[0].clone()); + let mut tokens = 0usize; + let mut elapsed = Duration::ZERO; + for p in prompts { + let (n, dt) = timed_generate(handle, p.clone()); + tokens += n; + elapsed += dt; + } + tokens as f64 / elapsed.as_secs_f64() +} + +#[test] +fn dflash_speculative_single_stream_speedup() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + eprintln!( + "skipping dflash perf A/B: set OPENINFER_TEST_MODEL_PATH + OPENINFER_DFLASH_TEST_MODEL_PATH" + ); + return; + }; + + let tokenizer = common::load_tokenizer(&model_path); + let prompts: Vec> = [ + "Write a short essay about the history of the Roman Empire.", + "Explain how a transformer neural network works, step by step.", + "List ten facts about the planet Mars and describe each one.", + ] + .iter() + .map(|p| tokenizer.encode(p, false).expect("encode failed")) + .collect(); + + let baseline_tps = { + let handle = openinfer_qwen3_4b::launch(Path::new(&model_path), launch_options(None)) + .expect("baseline engine"); + let tps = measure(&handle, &prompts); + drop(handle); + std::thread::sleep(Duration::from_secs(2)); + tps + }; + + let spec_tps = { + let handle = openinfer_qwen3_4b::launch( + Path::new(&model_path), + launch_options(Some(PathBuf::from(&draft_path))), + ) + .expect("speculative engine"); + measure(&handle, &prompts) + }; + + let speedup = spec_tps / baseline_tps; + eprintln!("───────────── DFlash single-stream decode A/B (bs=1) ─────────────"); + eprintln!(" spec OFF (plain decode): {baseline_tps:7.1} tok/s"); + eprintln!(" spec ON (DFlash): {spec_tps:7.1} tok/s"); + eprintln!(" speedup: {speedup:7.2}×"); + eprintln!("───────────────────────────────────────────────────────────────────────────"); + + assert!( + speedup > 0.8, + "speculative decode is catastrophically slower ({speedup:.2}×) — draft likely mispredicting" + ); +} diff --git a/openinfer-qwen3-4b/tests/kv_offload_cpu_hit.rs b/openinfer-qwen3-4b/tests/kv_offload_cpu_hit.rs index 6ba26a95..f39f7697 100644 --- a/openinfer-qwen3-4b/tests/kv_offload_cpu_hit.rs +++ b/openinfer-qwen3-4b/tests/kv_offload_cpu_hit.rs @@ -138,6 +138,7 @@ fn live_gpu_and_cpu_prefix_hits() { Qwen3LoraOptions::default(), Qwen3OffloadOptions::enabled(HOST_TIER_BYTES), openinfer_qwen3_4b::DEFAULT_MAX_PREFILL_TOKENS, + None, openinfer_qwen3_4b::Qwen3MemoryOptions::default(), ) .expect("build offload executor"); diff --git a/openinfer-qwen3-4b/tests/scheduler_robustness.rs b/openinfer-qwen3-4b/tests/scheduler_robustness.rs index 5688fc55..c9fe7fed 100644 --- a/openinfer-qwen3-4b/tests/scheduler_robustness.rs +++ b/openinfer-qwen3-4b/tests/scheduler_robustness.rs @@ -104,6 +104,7 @@ fn scheduler_survives_consumer_drop() { openinfer_qwen3_4b::Qwen3MemoryOptions::default(), openinfer_qwen3_4b::DecodeOverlap::Off, true, + None, ) .expect("failed to start engine"); assert_eq!( diff --git a/openinfer-server/src/config.rs b/openinfer-server/src/config.rs index 7ee423a6..db65959e 100644 --- a/openinfer-server/src/config.rs +++ b/openinfer-server/src/config.rs @@ -89,6 +89,12 @@ pub(crate) struct Args { #[arg(long, default_value_t = false)] pub no_prefix_cache: bool, + /// Enable Qwen3 DFlash speculative decoding with this drafter model path. + /// Single-GPU greedy only; incompatible with --enable-lora and --kv-offload, + /// and forces the prefix cache off (it needs clean target hidden states). + #[arg(long = "dflash-draft-model-path")] + pub dflash_draft_model_path: Option, + /// Cap on total prompt tokens forwarded in one Qwen3 scheduler step /// (chunked prefill). Prefill activation scratch scales with the step's /// prompt tokens, so this bounds peak VRAM under request bursts; prompts diff --git a/openinfer-server/src/main.rs b/openinfer-server/src/main.rs index 3af27a72..d46df28b 100644 --- a/openinfer-server/src/main.rs +++ b/openinfer-server/src/main.rs @@ -109,6 +109,16 @@ async fn main() -> anyhow::Result<()> { // defaults, capability constraints, cross-arg validation). The server only // picks the crate by detected model type and forwards the relevant CLI knobs. fn load_engine(args: &Args, model_type: ModelType) -> anyhow::Result { + // Only Qwen3 wires the DFlash drafter; fail loud rather than silently + // ignoring the flag for another model line. + #[cfg(feature = "qwen3-4b")] + let is_qwen3 = matches!(model_type, ModelType::Qwen3); + #[cfg(not(feature = "qwen3-4b"))] + let is_qwen3 = false; + anyhow::ensure!( + args.dflash_draft_model_path.is_none() || is_qwen3, + "--dflash-draft-model-path is only supported for Qwen3 (got {model_type:?})" + ); let handle = match model_type { #[cfg(feature = "deepseek-v4")] ModelType::DeepSeekV4 => openinfer_deepseek_v4::launch( @@ -154,6 +164,24 @@ fn load_engine(args: &Args, model_type: ModelType) -> anyhow::Result { + anyhow::ensure!( + !args.enable_lora, + "--dflash-draft-model-path is not supported with --enable-lora" + ); + anyhow::ensure!( + !args.kv_offload, + "--dflash-draft-model-path is not supported with --kv-offload" + ); + anyhow::ensure!( + args.tp_size == 1, + "--dflash-draft-model-path currently requires --tp-size=1" + ); + Some(path) + } + None => None, + }; openinfer_qwen3_4b::launch( &args.model_path, Qwen3LaunchOptions { @@ -171,6 +199,7 @@ fn load_engine(args: &Args, model_type: ModelType) -> anyhow::Result