From 31c449f4ae1be8947e528c42c67fd8274f3ec268 Mon Sep 17 00:00:00 2001 From: xiaguan <751080330@qq.com> Date: Mon, 22 Jun 2026 10:11:15 +0800 Subject: [PATCH 1/9] feat(qwen3): add DFlash speculative decoding Model speculative decoding as an optimistic transaction: the DFlash drafter proposes K tokens, one target verify forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix plus one bonus token, rolling back the rest of the speculative KV reservation. Only the propose step is method-specific; the draft/verify boundary is a pure token span so hidden states never leave the proposer. The shared accept core (speculative.rs) is method-agnostic; no proposer trait until a second method lands. Enabled with --dflash-draft-model-path (Qwen3 only, TP=1, primary rank). The server rejects the flag for other model lines rather than silently ignoring it. Speculative-on forces prefix caching off so every request's target hidden context is captured during prefill, removing the need for a per-request drafter-ready handshake. Greedy decode is lossless up to bf16 numerical tie-flips (the same non-determinism that already affects plain greedy decode at genuine bifurcations): multi-token accepts are bit-identical to baseline at every non-tie position. Single-stream decode A/B on RTX 5070 Ti (bs=1): 93.4 -> 170.0 tok/s, 1.82x. Tests: speculative.rs accept-core units, scheduler resolve truth-table (multi-token emission / stop-token / max-tokens truncation), tests/ dflash_speculative_gate.rs (regret-based losslessness gate) and dflash_speculative_perf.rs (single-stream A/B). Docs in docs/models/qwen3/dflash-speculative-decoding.md. Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/index.md | 1 + .../qwen3/dflash-speculative-decoding.md | 58 ++ openinfer-core/src/ops.rs | 14 +- openinfer-kernels/csrc/shared/elementwise.cu | 84 ++ .../csrc/shared/paged_attention.cu | 64 ++ .../csrc/shared/prefill_attention.cu | 118 +++ openinfer-kernels/src/ffi/shared.rs | 57 ++ openinfer-kernels/src/ops.rs | 13 +- openinfer-kernels/src/ops/attention.rs | 100 +++ openinfer-kernels/src/ops/elementwise.rs | 84 ++ openinfer-kernels/src/ops/sampling.rs | 1 + openinfer-kv-cache/src/pool.rs | 51 +- openinfer-qwen3-4b/src/config.rs | 102 +++ openinfer-qwen3-4b/src/dflash.rs | 776 ++++++++++++++++++ openinfer-qwen3-4b/src/executor.rs | 303 ++++++- .../src/executor/dflash_lane.rs | 259 ++++++ .../src/executor/dflash_prefill.rs | 70 ++ openinfer-qwen3-4b/src/executor/spec.rs | 188 +++++ openinfer-qwen3-4b/src/lib.rs | 24 +- openinfer-qwen3-4b/src/prefill.rs | 78 +- openinfer-qwen3-4b/src/scheduler.rs | 39 +- openinfer-qwen3-4b/src/scheduler/effects.rs | 89 ++ openinfer-qwen3-4b/src/scheduler/plan.rs | 101 +++ openinfer-qwen3-4b/src/scheduler/resolve.rs | 55 ++ openinfer-qwen3-4b/src/scheduler/tests.rs | 149 +++- openinfer-qwen3-4b/src/speculative.rs | 260 ++++++ .../tests/dflash_speculative_gate.rs | 362 ++++++++ .../tests/dflash_speculative_perf.rs | 161 ++++ openinfer-server/src/config.rs | 6 + openinfer-server/src/main.rs | 29 + 30 files changed, 3666 insertions(+), 30 deletions(-) create mode 100644 docs/models/qwen3/dflash-speculative-decoding.md create mode 100644 openinfer-qwen3-4b/src/dflash.rs create mode 100644 openinfer-qwen3-4b/src/executor/dflash_lane.rs create mode 100644 openinfer-qwen3-4b/src/executor/dflash_prefill.rs create mode 100644 openinfer-qwen3-4b/src/executor/spec.rs create mode 100644 openinfer-qwen3-4b/src/speculative.rs create mode 100644 openinfer-qwen3-4b/tests/dflash_speculative_gate.rs create mode 100644 openinfer-qwen3-4b/tests/dflash_speculative_perf.rs diff --git a/docs/index.md b/docs/index.md index 8f41a7f6..9a80fb77 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,6 +30,7 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | `models/qwen3/roadmap.md` | Qwen3-4B roadmap (2026-06 review): line is the maturity bar; #220 RoPE OOB, batched greedy sampling (#307), mixed greedy/non-greedy sampling (#284), and pegaflow KV offload (#316) are landed; open set is zero TP coverage, zero-adapter-only LoRA gate, dropped prefix-cache observability, stale docs, and YaRN #8 follow-up. | | `models/qwen3/model-crate.md` | `openinfer-qwen3-4b` owns Qwen3 config/weights/executor/scheduler/tests/kernel plan; root sees generic `EngineHandle`; split-K retuned to `256/64`, with 4k/64 serving TPOT p50 at `6.46ms` on RTX 5090. | | `models/qwen3/prefix-cache.md` | Prefix caching on by default for Qwen3-4B: full-block kvbm radix matching at the executor, suffix-only prefill. Repeated ~1900-token prompt TTFT 141.8 → 16.3ms p50 (8.7×); warm TTFT ≈ TPOT + ~5ms setup. Includes the RoPE scalar-path corruption fix and the drain-the-stream TTFT measurement pitfall. | +| `models/qwen3/dflash-speculative-decoding.md` | DFlash speculative decoding behind `--dflash-draft-model-path`, modelled as an optimistic transaction (propose K → verify K+1 span → accept longest argmax prefix + 1 bonus → commit/roll back KV). Lossless up to bf16 tie-flips (proven via bit-identical multi-token accepts + non-deterministic flip prompts). Single-stream decode 93.4 → 170.0 tok/s (1.82×) on 5070 Ti. Readiness from prefill capture (prefix cache forced off), no proposer trait until a 2nd method. 5090 throughput A/B pending. | | `models/qwen3/accuracy-gate.md` | Qwen3-4B instance of the logits golden gate (`tests/hf_golden_gate.rs`): 48 teacher-forced sequences / 816 positions vs a stored HF bf16 golden, replayed over bs=1 / batched eager / CUDA-graph. Strict guards: regret check + mean ≤ 0.06 + p99 ≤ 0.20; absolute max printed but not asserted (coverage-unstable). Methodology in `subsystems/correctness/`. | | `models/qwen3/kernels-crate.md` | Phase 1 split implemented and 5090-verified: Qwen3-4B kernel surface lives in `openinfer-kernels`; release build, test-target compile, accuracy gate, and bench snapshot pass. | | `models/qwen3/tp-design.md` | Qwen3 tensor-parallel design: `TP=2` milestone scope plus the controller/worker broadcast execution model, request identity, and coarse-grained step protocol for future TP/MoE work. | diff --git a/docs/models/qwen3/dflash-speculative-decoding.md b/docs/models/qwen3/dflash-speculative-decoding.md new file mode 100644 index 00000000..982ff767 --- /dev/null +++ b/docs/models/qwen3/dflash-speculative-decoding.md @@ -0,0 +1,58 @@ +# DFlash Speculative Decoding (Qwen3-4B) + +**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points. Measured single-stream decode A/B on RTX 5070 Ti (bs=1): **93.4 → 170.0 tok/s, 1.82×**. + +Last touched: 2026-06 + +## The abstraction: speculative decode = optimistic transaction + +Every speculative method is the same transaction; only *propose* differs. + +1. **Propose** — a method-specific drafter emits K candidate tokens. (DFlash is the only proposer today; an n-gram / EAGLE proposer would slot in here.) +2. **Verify** — the target model runs ONE prefill-style forward over the K+1 span `[current_token, draft_1, …, draft_K]` and reports its argmax at every position (echo=true). +3. **Accept** — `accept_greedy` keeps the longest prefix where each draft matches the target argmax, then appends the target's own token at the first mismatch (the "bonus"). A verify step therefore always commits `1..=K+1` tokens — at least one token of progress even when every draft is rejected. +4. **Commit / roll back** — accepted tokens' KV is committed (`apply_speculative`); the unused tail of the speculative reservation is LIFO-dropped (`revert_schedule`). + +The draft↔verify boundary is a **pure token span**. Hidden states never cross it — they stay inside the proposer (`dflash.rs` / `dflash_lane.rs`). This is what lets the shared core (`speculative.rs`: `accept_greedy`, `build_verify_results`) be method-agnostic, and it is why there is deliberately **no proposer trait yet**: a trait with one impl is premature. Add it when a second method lands. + +Shared core lives in `openinfer-qwen3-4b/src/speculative.rs`; the transaction wiring is in `openinfer-qwen3-4b/src/executor/spec.rs`; the KV transaction primitives (`schedule_speculative` / `apply_speculative` / `speculative_view` / `revert_schedule`) are in `openinfer-kv-cache/src/pool.rs` delegating to kvbm `scheduled.rs`. + +## Key invariant: readiness comes from prefill capture, not a handshake + +Speculative-on **forces prefix caching off**. With no prefix reuse, every eligible request's target hidden context is captured during its own prefill — so there is no per-request "drafter ready" handshake. Readiness is derived from the prefill capture-set plus the completed flag. This is the load-bearing simplification; if prefix caching were allowed, a cache-hit request would skip the prefill that seeds the drafter, and the invariant would break. + +Draft-seed alignment: after a verify accepting `m` drafts, the committed span positions with *known* hidden states are `[current_token, draft_1..draft_m]` = `m+1` rows. The next current_token (`target_argmax[m]`) is freshly predicted and its hidden was never forwarded — exactly like the post-prefill first token. `record_verify_dflash_context` appends `accepted_tokens.len()` (= `m+1`) rows and advances `kv_position` by the same. This is the subtlest part of the feature; it is defended by crash-early asserts (`append_pending_context` overflow/dim/range checks). + +## Losslessness: lossless up to bf16 ties + +The greedy oracle is "spec-on equals spec-off token-for-token". In practice the two diverge only at genuine bifurcation points, and only because of bf16 numerical noise — **not** a logic bug. Evidence: + +- The verify forward builds committed KV incrementally across batched spans (M=16-ish), while a one-shot prefill is a single M=1 forward. cuBLAS picks different GEMM tilings → ULP-scale reduction-order differences → an argmax flip at a near-tie. This is the same class as the `hf_golden_gate` `MARGIN_TOL = 0.20` tolerance. +- Multi-token acceptance (runs observed up to 11 tokens) is **bit-identical** to baseline at every non-tie position — proving the span-position-≥1 path is correct. +- Re-running flips **different** prompts each time (non-deterministic ⇒ bf16, not a deterministic logic error); one observed flip had regret 0.000 (an exact tie). +- The spec pick is always the #1 or #2 token of the prefill kernel's own distribution, within 0.20 nat of #1. + +### The losslessness gate (`tests/dflash_speculative_gate.rs`) + +The gate runs a baseline (spec off, logprobs on) and a spec engine on the same prompts, then at the first token where they diverge it measures the spec pick's **regret** against the prefill kernel's own distribution. Within `MARGIN_TOL = 0.20` nat ⇒ benign tie-flip (pass); clearly worse ⇒ real bug (fail). The realistic bug classes (KV misalignment, mask leak) push the pick far outside 0.20 nat, so the gate has teeth — it is not tautological. + +**Known scope limit:** the gate regret-checks only the *first* divergence position per prompt, then continues. A bug that stays within the tie band at the first bifurcation but corrupts later positions could slip. Acceptable for v1; the next iteration should re-anchor and regret-check the following few positions after a benign-tie classification. + +## Performance + +Single-stream decode is where speculative decoding pays off directly: plain decode is memory-bound (one target forward per token); spec amortizes that forward over the accepted run. A/B harness: `tests/dflash_speculative_perf.rs` (bs=1, 256 tokens, `ignore_eos`, one warm-up discarded). + +| Config | RTX 5070 Ti, bs=1 | +| --- | --- | +| spec OFF (plain decode) | 93.4 tok/s | +| spec ON (DFlash) | 170.0 tok/s | +| **speedup** | **1.82×** | + +Throughput under concurrent load is a separate axis (`vllm bench serve` A/B) and is best measured on the 5090 — pending. + +## Constraints & open follow-ups + +- **TP=1, primary rank only.** The DFlash lane runs on the worker thread of the primary rank; the launch path gates `!(dflash && lora)` and `tp_size == 1`. The server fails loud (`--dflash-draft-model-path` is rejected for non-Qwen3 model lines rather than silently ignored). +- **Second proposer → introduce the trait.** Until then the proposer is concrete on purpose. +- **File size.** `executor.rs` (~2.8k lines) and `scheduler/tests.rs` (~1.2k lines) exceed the 1k guideline; the spec arms in `execute_step_on_lane` are candidates to move into `executor/spec.rs` in a follow-up. +- **5090 validation** — golden gate + `vllm bench serve` throughput A/B on the company 5090 (same sm_120 arch as the 5070 Ti, so correctness carries; the 5090 is for representative throughput numbers). diff --git a/openinfer-core/src/ops.rs b/openinfer-core/src/ops.rs index dc522350..e211b972 100644 --- a/openinfer-core/src/ops.rs +++ b/openinfer-core/src/ops.rs @@ -15,15 +15,17 @@ pub use attention::{ pub use openinfer_kernels::ops::{ GEMM_LT_MAX_N, LoraDecodeGroupedProjection, accumulate_bf16_token_scaled_to_f32_into, add_batch, add_batch_into, argmax, argmax_batch_bf16_into, bf16_hidden_to_f32_into, - embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into, - f32_to_bf16_hidden_into, fused_add_rms_norm_into, gather_hidden_tokens_into, gemm, - gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked, gemm_into_checked, gemm_lt_tune, - gemm_per_token, gemv, linear, lora_decode_fused_delta_group3_into, - lora_decode_fused_delta_into, pack_lora_b_rows_into, + copy_hidden_rows_into, copy_hidden_token_range_into, + dflash_qk_norm_rope_into, embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref, + extract_vec_ref_into, f32_to_bf16_hidden_into, fused_add_rms_norm_into, + gather_hidden_tokens_into, gemm, gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked, + gemm_into_checked, gemm_lt_tune, gemm_per_token, gemv, linear, + lora_decode_fused_delta_group3_into, lora_decode_fused_delta_into, pack_lora_b_rows_into, qk_norm_partial_rope_batched_decode_hd256_into, rms_norm, rms_norm_batch_offset_into, rms_norm_gated_batch_into, rms_norm_into, rms_norm_offset_into, scale_f32_in_place, scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, - scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, write_vec_into, + scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, + single_prefill_nhd_noncausal_into, write_vec_into, }; #[cfg(not(feature = "kernel-call-trace"))] pub use openinfer_kernels::ops::{ diff --git a/openinfer-kernels/csrc/shared/elementwise.cu b/openinfer-kernels/csrc/shared/elementwise.cu index 92de04eb..7193b145 100644 --- a/openinfer-kernels/csrc/shared/elementwise.cu +++ b/openinfer-kernels/csrc/shared/elementwise.cu @@ -63,6 +63,43 @@ __global__ void gather_hidden_tokens_kernel( } } +__global__ void copy_hidden_rows_kernel( + const __nv_bfloat16 *__restrict__ src, + __nv_bfloat16 *__restrict__ dst, + int src_hidden_dim, + int dst_hidden_dim, + int row_offset, + int rows, + int seq_len) { + int total = rows * seq_len; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total; + idx += gridDim.x * blockDim.x) { + int token = idx / rows; + int row = idx % rows; + dst[(size_t)token * dst_hidden_dim + row_offset + row] = + src[(size_t)token * src_hidden_dim + row]; + } +} + +__global__ void copy_hidden_token_range_kernel( + const __nv_bfloat16 *__restrict__ src, + __nv_bfloat16 *__restrict__ dst, + int hidden_dim, + int src_token_offset, + int dst_token_offset, + int token_count) { + int total = hidden_dim * token_count; + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total; + idx += gridDim.x * blockDim.x) { + int token = idx / hidden_dim; + int row = idx % hidden_dim; + dst[(size_t)(dst_token_offset + token) * hidden_dim + row] = + src[(size_t)(src_token_offset + token) * hidden_dim + row]; + } +} + __global__ void scaled_add_rows_indexed_kernel( const __nv_bfloat16 *__restrict__ delta, float scale, @@ -305,6 +342,53 @@ CUresult gather_hidden_tokens_cuda( return (CUresult)cudaGetLastError(); } +CUresult copy_hidden_rows_cuda( + const __nv_bfloat16 *src, + __nv_bfloat16 *dst, + int src_hidden_dim, + int dst_hidden_dim, + int row_offset, + int rows, + int seq_len, + cudaStream_t stream) { + if (src == nullptr || dst == nullptr || src_hidden_dim <= 0 || + dst_hidden_dim <= 0 || row_offset < 0 || rows <= 0 || seq_len <= 0 || + rows > src_hidden_dim || row_offset + rows > dst_hidden_dim) { + return CUDA_ERROR_INVALID_VALUE; + } + int total = rows * seq_len; + int block = 256; + int grid = (total + block - 1) / block; + copy_hidden_rows_kernel<<>>( + src, dst, src_hidden_dim, dst_hidden_dim, row_offset, rows, seq_len); + return (CUresult)cudaGetLastError(); +} + +CUresult copy_hidden_token_range_cuda( + const __nv_bfloat16 *src, + __nv_bfloat16 *dst, + int hidden_dim, + int src_token_offset, + int dst_token_offset, + int token_count, + int src_seq_len, + int dst_seq_len, + cudaStream_t stream) { + if (src == nullptr || dst == nullptr || hidden_dim <= 0 || + src_token_offset < 0 || dst_token_offset < 0 || token_count <= 0 || + src_seq_len <= 0 || dst_seq_len <= 0 || + src_token_offset + token_count > src_seq_len || + dst_token_offset + token_count > dst_seq_len) { + return CUDA_ERROR_INVALID_VALUE; + } + int total = hidden_dim * token_count; + int block = 256; + int grid = (total + block - 1) / block; + copy_hidden_token_range_kernel<<>>( + src, dst, hidden_dim, src_token_offset, dst_token_offset, token_count); + return (CUresult)cudaGetLastError(); +} + CUresult scaled_add_rows_indexed_cuda( const __nv_bfloat16 *delta, float scale, diff --git a/openinfer-kernels/csrc/shared/paged_attention.cu b/openinfer-kernels/csrc/shared/paged_attention.cu index 4506a60f..01f4e82c 100644 --- a/openinfer-kernels/csrc/shared/paged_attention.cu +++ b/openinfer-kernels/csrc/shared/paged_attention.cu @@ -607,6 +607,70 @@ int single_prefill_cuda( reinterpret_cast(stream))); } +int single_prefill_nhd_noncausal_cuda( + // Q and output (HiddenStates token-major: [seq_len, q_dim]) + void* q, + void* output, + // Contiguous KV cache (HiddenStates token-major: [max_seq_len, kv_dim]) + void* k_cache, + void* v_cache, + int32_t num_qo_heads, + int32_t num_kv_heads, + int32_t head_dim, + int32_t seq_len, + int32_t kv_len, + int32_t max_seq_len, + float sm_scale, + void* stream) +{ + if (q == nullptr || output == nullptr || k_cache == nullptr || v_cache == nullptr || + num_qo_heads <= 0 || num_kv_heads <= 0 || head_dim != 128 || + seq_len <= 0 || kv_len <= 0 || max_seq_len < kv_len) { + return static_cast(cudaErrorInvalidValue); + } + + uint32_t q_stride_n = num_qo_heads * head_dim; + uint32_t q_stride_h = head_dim; + uint32_t kv_stride_n = num_kv_heads * head_dim; + uint32_t kv_stride_h = head_dim; + + PrefillParamsT params( + reinterpret_cast(q), + reinterpret_cast(k_cache), + reinterpret_cast(v_cache), + /*maybe_custom_mask=*/nullptr, + reinterpret_cast(output), + /*lse=*/nullptr, + /*maybe_alibi_slopes=*/nullptr, + num_qo_heads, + num_kv_heads, + static_cast(seq_len), + static_cast(kv_len), + q_stride_n, + q_stride_h, + kv_stride_n, + kv_stride_h, + static_cast(head_dim), + /*window_left=*/-1, + /*logits_soft_cap=*/0.0f, + sm_scale, + /*rope_scale=*/1.0f, + /*rope_theta=*/1e6f); + + return static_cast( + SinglePrefillWithKVCacheDispatched< + /*HEAD_DIM_QK=*/128, + /*HEAD_DIM_VO=*/128, + PosEncodingMode::kNone, + /*USE_FP16_QK_REDUCTION=*/false, + MaskMode::kNone, + Variant, + PrefillParamsT>( + params, + /*tmp=*/nullptr, + reinterpret_cast(stream))); +} + // --------------------------------------------------------------------------- // Single-request prefill for HEAD_DIM=256 — wraps FlashInfer SinglePrefillWithKVCache. // diff --git a/openinfer-kernels/csrc/shared/prefill_attention.cu b/openinfer-kernels/csrc/shared/prefill_attention.cu index a7b24b66..0c69522b 100644 --- a/openinfer-kernels/csrc/shared/prefill_attention.cu +++ b/openinfer-kernels/csrc/shared/prefill_attention.cu @@ -93,6 +93,90 @@ __global__ void prefill_qk_norm_rope_kernel( data[offset] = result; } +__global__ void dflash_qk_norm_rope_kernel( + __nv_bfloat16* __restrict__ q, + __nv_bfloat16* __restrict__ k, + const __nv_bfloat16* __restrict__ q_norm_weight, + const __nv_bfloat16* __restrict__ k_norm_weight, + const __nv_bfloat16* __restrict__ cos_cache, + const __nv_bfloat16* __restrict__ sin_cache, + int num_q_heads, + int num_kv_heads, + int head_dim, + int q_len, + int k_len, + int q_start_pos, + int k_start_pos, + float eps, + int cos_max_pos +) { + int head_global = blockIdx.x; + int token = blockIdx.y; + int d = threadIdx.x; + + bool is_q = (head_global < num_q_heads); + int local_heads = is_q ? num_q_heads : num_kv_heads; + int seq_len = is_q ? q_len : k_len; + if (token >= seq_len) return; + + int head_local = is_q ? head_global : (head_global - num_q_heads); + if (head_local >= local_heads) return; + + __nv_bfloat16* data = is_q ? q : k; + int dim_stride = local_heads * head_dim; + const __nv_bfloat16* norm_w = is_q ? q_norm_weight : k_norm_weight; + int pos = (is_q ? q_start_pos : k_start_pos) + token; + if (pos < 0 || pos >= cos_max_pos) __trap(); + + int offset = token * dim_stride + head_local * head_dim + d; + float val = __bfloat162float(data[offset]); + + float sq = warp_reduce_sum(val * val); + int warp_id = d / WARP_SIZE; + int lane_id = d % WARP_SIZE; + __shared__ float warp_sums[4]; + if (lane_id == 0) warp_sums[warp_id] = sq; + __syncthreads(); + + __shared__ float s_inv_rms; + if (warp_id == 0) { + float v = (lane_id < 4) ? warp_sums[lane_id] : 0.0f; + float total = warp_reduce_sum(v); + if (lane_id == 0) s_inv_rms = rsqrtf(total / head_dim + eps); + } + __syncthreads(); + + __nv_bfloat16 normed = __float2bfloat16(val * s_inv_rms); + float normed_f = __bfloat162float(normed) * __bfloat162float(norm_w[d]); + + __shared__ __nv_bfloat16 smem[HEAD_DIM]; + smem[d] = __float2bfloat16(normed_f); + __syncthreads(); + + int half = head_dim / 2; + __nv_bfloat16 result; + if (d < half) { + float lo = __bfloat162float(smem[d]); + float hi = __bfloat162float(smem[d + half]); + float c = __bfloat162float(cos_cache[pos * head_dim + d]); + float s = __bfloat162float(sin_cache[pos * head_dim + d]); + float lo_cos = __bfloat162float(__float2bfloat16(lo * c)); + float hi_sin = __bfloat162float(__float2bfloat16(hi * s)); + result = __float2bfloat16(lo_cos - hi_sin); + } else { + int pair_d = d - half; + float lo = __bfloat162float(smem[pair_d]); + float hi = __bfloat162float(smem[d]); + float c = __bfloat162float(cos_cache[pos * head_dim + pair_d]); + float s = __bfloat162float(sin_cache[pos * head_dim + pair_d]); + float lo_sin = __bfloat162float(__float2bfloat16(lo * s)); + float hi_cos = __bfloat162float(__float2bfloat16(hi * c)); + result = __float2bfloat16(lo_sin + hi_cos); + } + + data[offset] = result; +} + extern "C" { // ============================================================================ @@ -136,4 +220,38 @@ void qk_norm_rope_batched_decode_cuda( ); } +int dflash_qk_norm_rope_cuda( + __nv_bfloat16* q, + __nv_bfloat16* k, + const __nv_bfloat16* q_norm_weight, + const __nv_bfloat16* k_norm_weight, + const __nv_bfloat16* cos_cache, + const __nv_bfloat16* sin_cache, + int num_q_heads, + int num_kv_heads, + int head_dim, + int q_len, + int k_len, + int q_start_pos, + int k_start_pos, + float rms_eps, + int cos_max_pos, + cudaStream_t stream +) { + if (q == nullptr || k == nullptr || q_norm_weight == nullptr || + k_norm_weight == nullptr || cos_cache == nullptr || sin_cache == nullptr || + num_q_heads <= 0 || num_kv_heads <= 0 || head_dim != HEAD_DIM || + q_len <= 0 || k_len <= 0 || q_start_pos < 0 || k_start_pos < 0 || + q_start_pos + q_len > cos_max_pos || k_start_pos + k_len > cos_max_pos) { + return static_cast(cudaErrorInvalidValue); + } + + dim3 grid(num_q_heads + num_kv_heads, q_len > k_len ? q_len : k_len); + dflash_qk_norm_rope_kernel<<>>( + q, k, q_norm_weight, k_norm_weight, cos_cache, sin_cache, + num_q_heads, num_kv_heads, head_dim, q_len, k_len, + q_start_pos, k_start_pos, rms_eps, cos_max_pos); + return static_cast(cudaGetLastError()); +} + } // extern "C" diff --git a/openinfer-kernels/src/ffi/shared.rs b/openinfer-kernels/src/ffi/shared.rs index aff4e5e2..c8de7a63 100644 --- a/openinfer-kernels/src/ffi/shared.rs +++ b/openinfer-kernels/src/ffi/shared.rs @@ -30,6 +30,29 @@ unsafe extern "C" { stream: CUstream, ) -> CUresult; + pub fn copy_hidden_rows_cuda( + src: *const Half, + dst: *mut Half, + src_hidden_dim: i32, + dst_hidden_dim: i32, + row_offset: i32, + rows: i32, + seq_len: i32, + stream: CUstream, + ) -> CUresult; + + pub fn copy_hidden_token_range_cuda( + src: *const Half, + dst: *mut Half, + hidden_dim: i32, + src_token_offset: i32, + dst_token_offset: i32, + token_count: i32, + src_seq_len: i32, + dst_seq_len: i32, + stream: CUstream, + ) -> CUresult; + pub fn fused_add_rms_norm_cuda( hidden: *mut Half, residual: *const Half, @@ -249,6 +272,25 @@ unsafe extern "C" { stream: CUstream, ); + pub fn dflash_qk_norm_rope_cuda( + q: *mut Half, + k: *mut Half, + q_norm_weight: *const Half, + k_norm_weight: *const Half, + cos_cache: *const Half, + sin_cache: *const Half, + num_q_heads: i32, + num_kv_heads: i32, + head_dim: i32, + q_len: i32, + k_len: i32, + q_start_pos: i32, + k_start_pos: i32, + rms_eps: f32, + cos_max_pos: i32, + stream: CUstream, + ) -> i32; + // Scatter contiguous KV → paged layout (one layer, FlashInfer prefill append). pub fn paged_kv_scatter_cuda( kv_data: *const Half, @@ -375,6 +417,21 @@ unsafe extern "C" { stream: CUstream, ) -> i32; + pub fn single_prefill_nhd_noncausal_cuda( + q: *const Half, + output: *mut Half, + k_cache: *const Half, + v_cache: *const Half, + num_qo_heads: i32, + num_kv_heads: i32, + head_dim: i32, + seq_len: i32, + kv_len: i32, + max_seq_len: i32, + sm_scale: f32, + stream: CUstream, + ) -> i32; + // Paged attention decode (FlashInfer BatchDecode, no partition-KV). pub fn paged_attention_decode_cuda( q: *const Half, diff --git a/openinfer-kernels/src/ops.rs b/openinfer-kernels/src/ops.rs index fa8c5362..4173b5e7 100644 --- a/openinfer-kernels/src/ops.rs +++ b/openinfer-kernels/src/ops.rs @@ -15,9 +15,10 @@ mod norm; mod sampling; pub use attention::{ - PrefillPagedPlan, paged_attention_batch_decode_hd256_into, paged_attention_batch_decode_into, - paged_attention_batch_decode_split_kv_into, prefill_attention_paged_into, - qk_norm_partial_rope_batched_decode_hd256_into, qk_norm_rope_batch_decode_into, + PrefillPagedPlan, dflash_qk_norm_rope_into, paged_attention_batch_decode_hd256_into, + paged_attention_batch_decode_into, paged_attention_batch_decode_split_kv_into, + prefill_attention_paged_into, qk_norm_partial_rope_batched_decode_hd256_into, + qk_norm_rope_batch_decode_into, single_prefill_nhd_noncausal_into, }; #[cfg(feature = "kimi-k2")] pub use deepep::{ @@ -27,7 +28,8 @@ pub use deepep::{ pub use deepseek_v2_lite::*; pub use elementwise::{ accumulate_bf16_token_scaled_to_f32_into, add_batch, add_batch_into, bf16_hidden_to_f32_into, - extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, + copy_hidden_rows_into, copy_hidden_token_range_into, extract_vec, extract_vec_into, + extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, gather_hidden_tokens_into, repeat_f32_for_reduce_scatter_into, scale_f32_in_place, scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, @@ -56,5 +58,6 @@ pub use norm::{ pub use sampling::{ BatchSamplingRow, BatchSamplingScratch, argmax, argmax_batch_bf16_into, argmax_batch_bf16_split_indexed_into, argmax_batch_bf16_split_partials_len, - flashinfer_top1_batch_into, flashinfer_top1_row_states_bytes, gpu_sample_batch_into, + flashinfer_top1_batch_into, + flashinfer_top1_row_states_bytes, gpu_sample_batch_into, }; diff --git a/openinfer-kernels/src/ops/attention.rs b/openinfer-kernels/src/ops/attention.rs index 99516a13..0e4447de 100644 --- a/openinfer-kernels/src/ops/attention.rs +++ b/openinfer-kernels/src/ops/attention.rs @@ -497,6 +497,106 @@ pub fn qk_norm_rope_batch_decode_into( } } +#[allow(clippy::too_many_arguments)] +pub fn dflash_qk_norm_rope_into( + ctx: &DeviceContext, + q: &mut HiddenStates, + k: &mut HiddenStates, + q_norm_weight: &DeviceVec, + k_norm_weight: &DeviceVec, + cos_cache: &DeviceVec, + sin_cache: &DeviceVec, + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + q_start_pos: usize, + k_start_pos: usize, + rms_eps: f32, +) -> Result<()> { + assert_eq!(q.hidden_dim, num_q_heads * head_dim); + assert_eq!(k.hidden_dim, num_kv_heads * head_dim); + assert_eq!(q_norm_weight.len, head_dim); + assert_eq!(k_norm_weight.len, head_dim); + + let (q_ptr, _gq) = q.data.device_ptr_mut(&ctx.stream); + let (k_ptr, _gk) = k.data.device_ptr_mut(&ctx.stream); + let (qn_ptr, _gqn) = q_norm_weight.data.device_ptr(&ctx.stream); + let (kn_ptr, _gkn) = k_norm_weight.data.device_ptr(&ctx.stream); + let (cos_ptr, _gc) = cos_cache.data.device_ptr(&ctx.stream); + let (sin_ptr, _gs) = sin_cache.data.device_ptr(&ctx.stream); + + let result = unsafe { + ffi::dflash_qk_norm_rope_cuda( + q_ptr as *mut ffi::Half, + k_ptr as *mut ffi::Half, + qn_ptr as *const ffi::Half, + kn_ptr as *const ffi::Half, + cos_ptr as *const ffi::Half, + sin_ptr as *const ffi::Half, + num_q_heads as i32, + num_kv_heads as i32, + head_dim as i32, + q.seq_len as i32, + k.seq_len as i32, + q_start_pos as i32, + k_start_pos as i32, + rms_eps, + (cos_cache.data.len() / head_dim) as i32, + ctx.stream.cu_stream(), + ) + }; + if result != 0 { + anyhow::bail!("dflash_qk_norm_rope_cuda failed with error {result}"); + } + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn single_prefill_nhd_noncausal_into( + ctx: &DeviceContext, + q: &HiddenStates, + k_cache: &HiddenStates, + v_cache: &HiddenStates, + output: &mut HiddenStates, + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + kv_len: usize, +) -> Result<()> { + assert_eq!(q.hidden_dim, num_q_heads * head_dim); + assert_eq!(output.hidden_dim, q.hidden_dim); + assert_eq!(output.seq_len, q.seq_len); + assert_eq!(k_cache.hidden_dim, num_kv_heads * head_dim); + assert_eq!(v_cache.hidden_dim, k_cache.hidden_dim); + assert_eq!(v_cache.seq_len, k_cache.seq_len); + assert!(kv_len <= k_cache.seq_len); + + let (q_ptr, _gq) = q.data.device_ptr(&ctx.stream); + let (k_ptr, _gk) = k_cache.data.device_ptr(&ctx.stream); + let (v_ptr, _gv) = v_cache.data.device_ptr(&ctx.stream); + let (out_ptr, _go) = output.data.device_ptr_mut(&ctx.stream); + let result = unsafe { + ffi::single_prefill_nhd_noncausal_cuda( + q_ptr as *const ffi::Half, + out_ptr as *mut ffi::Half, + k_ptr as *const ffi::Half, + v_ptr as *const ffi::Half, + num_q_heads as i32, + num_kv_heads as i32, + head_dim as i32, + q.seq_len as i32, + kv_len as i32, + k_cache.seq_len as i32, + 1.0f32 / (head_dim as f32).sqrt(), + ctx.stream.cu_stream(), + ) + }; + if result != 0 { + anyhow::bail!("single_prefill_nhd_noncausal_cuda failed with error {result}"); + } + Ok(()) +} + /// Batched QK RMSNorm + partial RoPE for Qwen3.5 HD256 decode. /// /// Reads Q from interleaved `q_full` ([q, gate] per head), writes prepared Q into `q`, diff --git a/openinfer-kernels/src/ops/elementwise.rs b/openinfer-kernels/src/ops/elementwise.rs index 6956ab47..ab37b8bc 100644 --- a/openinfer-kernels/src/ops/elementwise.rs +++ b/openinfer-kernels/src/ops/elementwise.rs @@ -176,6 +176,90 @@ pub fn gather_hidden_tokens_into( Ok(()) } +pub fn copy_hidden_rows_into( + ctx: &DeviceContext, + src: &HiddenStates, + dst: &mut HiddenStates, + row_offset: usize, +) -> Result<()> { + assert!( + row_offset + src.hidden_dim <= dst.hidden_dim, + "row range [{}..{}) exceeds destination hidden_dim {}", + row_offset, + row_offset + src.hidden_dim, + dst.hidden_dim + ); + assert_eq!( + src.seq_len, dst.seq_len, + "copy_hidden_rows_into seq_len mismatch: src {}, dst {}", + src.seq_len, dst.seq_len + ); + + let (src_ptr, _gs) = src.data.device_ptr(&ctx.stream); + let (dst_ptr, _gd) = dst.data.device_ptr_mut(&ctx.stream); + let result = unsafe { + ffi::copy_hidden_rows_cuda( + src_ptr as *const ffi::Half, + dst_ptr as *mut ffi::Half, + src.hidden_dim as i32, + dst.hidden_dim as i32, + row_offset as i32, + src.hidden_dim as i32, + src.seq_len as i32, + ctx.stream.cu_stream(), + ) + }; + result.result()?; + Ok(()) +} + +pub fn copy_hidden_token_range_into( + ctx: &DeviceContext, + src: &HiddenStates, + src_token_offset: usize, + dst: &mut HiddenStates, + dst_token_offset: usize, + token_count: usize, +) -> Result<()> { + assert_eq!( + src.hidden_dim, dst.hidden_dim, + "copy_hidden_token_range_into hidden_dim mismatch: src {}, dst {}", + src.hidden_dim, dst.hidden_dim + ); + assert!( + src_token_offset + token_count <= src.seq_len, + "source token range [{}..{}) exceeds seq_len {}", + src_token_offset, + src_token_offset + token_count, + src.seq_len + ); + assert!( + dst_token_offset + token_count <= dst.seq_len, + "destination token range [{}..{}) exceeds seq_len {}", + dst_token_offset, + dst_token_offset + token_count, + dst.seq_len + ); + + let (src_ptr, _gs) = src.data.device_ptr(&ctx.stream); + let (dst_ptr, _gd) = dst.data.device_ptr_mut(&ctx.stream); + let result = unsafe { + ffi::copy_hidden_token_range_cuda( + src_ptr as *const ffi::Half, + dst_ptr as *mut ffi::Half, + src.hidden_dim as i32, + src_token_offset as i32, + dst_token_offset as i32, + token_count as i32, + src.seq_len as i32, + dst.seq_len as i32, + ctx.stream.cu_stream(), + ) + }; + result.result()?; + Ok(()) +} + pub fn scaled_add_rows_indexed_into( ctx: &DeviceContext, delta: &HiddenStates, diff --git a/openinfer-kernels/src/ops/sampling.rs b/openinfer-kernels/src/ops/sampling.rs index d4ca384d..a303b390 100644 --- a/openinfer-kernels/src/ops/sampling.rs +++ b/openinfer-kernels/src/ops/sampling.rs @@ -285,6 +285,7 @@ pub fn argmax_batch_bf16_split_partials_len(rows: usize, vocab: usize) -> usize rows * vocab.div_ceil(TILE_ELEMS) } + /// Two-stage indexed batched argmax: tile-parallel partials then a per-row /// finalize. Lowest index wins ties; each vocab row spreads over many CTAs /// instead of one. diff --git a/openinfer-kv-cache/src/pool.rs b/openinfer-kv-cache/src/pool.rs index 77d08256..f8797a05 100644 --- a/openinfer-kv-cache/src/pool.rs +++ b/openinfer-kv-cache/src/pool.rs @@ -262,7 +262,9 @@ impl LoadReservation { /// Per-request KV state wrapping `SchedulableSequence`. /// /// Lifecycle: `schedule_prefill → prefill_view/pages → forward → apply_prefill`, -/// then `schedule_decode → decode_view/pages → forward → apply_decode` in a loop. +/// then either `schedule_decode → decode_view → forward → apply_decode` or +/// `schedule_speculative → speculative_view → forward → apply_speculative` in a +/// loop (`revert_schedule` undoes a reservation whose step failed). pub struct RequestKv { seq: SchedulableSequence<()>, } @@ -299,6 +301,19 @@ impl RequestKv { self.seq.schedule_decode(&pool.block_manager) } + /// Reserve KV for a speculative verify step covering `num_draft_tokens` + /// consecutive positions (current dangling token + draft candidates). + /// [`Self::apply_speculative`] commits the accepted prefix; on any failure + /// [`Self::revert_schedule`] returns the reservation. + pub fn schedule_speculative( + &mut self, + num_draft_tokens: usize, + pool: &BlockPool, + ) -> Result<(), ScheduleError> { + self.seq + .schedule_speculative(num_draft_tokens, &pool.block_manager) + } + // ── Views (for forward pass) ─────────────────────────────────────── /// Build an immutable `KvView` for prefill. @@ -327,6 +342,20 @@ impl RequestKv { ) } + /// Build an immutable `KvView` for speculative verification: the verifier + /// forwards `num_draft_tokens` consecutive positions (current dangling token + /// followed by draft candidates). The view covers the post-step KV extent; + /// [`Self::apply_speculative`] later commits only the accepted prefix and + /// releases excess draft capacity. Same exact page-row contract as + /// [`Self::prefill_view`]. + pub fn speculative_view(&self, num_draft_tokens: usize) -> KvView { + KvView::new( + self.step_page_indices(num_draft_tokens), + self.seq.kv_position() + num_draft_tokens, + self.seq.block_size(), + ) + } + // ── Apply (register blocks, advance position) ────────────────────── pub fn apply_prefill(&mut self, token: u32, pool: &BlockPool) -> anyhow::Result<()> { @@ -350,6 +379,26 @@ impl RequestKv { .map_err(|e| anyhow::anyhow!("apply_decode: {e}")) } + /// Commit the accepted prefix of a speculative verify step. kvbm keeps the + /// `accepted_tokens` KV and LIFO-releases the rejected draft blocks. + pub fn apply_speculative( + &mut self, + accepted_tokens: &[u32], + pool: &BlockPool, + ) -> anyhow::Result { + self.seq + .apply_speculative(accepted_tokens, &pool.block_manager) + .map_err(|e| anyhow::anyhow!("apply_speculative: {e}")) + } + + /// Undo a scheduled-but-unapplied KV reservation (e.g. a speculative + /// schedule whose forward or apply failed) and return its blocks to the pool. + pub fn revert_schedule(&mut self) -> anyhow::Result<()> { + self.seq + .revert_schedule() + .map_err(|e| anyhow::anyhow!("revert_schedule: {e}")) + } + pub fn release(&mut self) -> anyhow::Result<()> { self.seq .release() diff --git a/openinfer-qwen3-4b/src/config.rs b/openinfer-qwen3-4b/src/config.rs index c02b140e..abffe4d7 100644 --- a/openinfer-qwen3-4b/src/config.rs +++ b/openinfer-qwen3-4b/src/config.rs @@ -38,6 +38,29 @@ pub(crate) struct Config { pub(crate) stop_token_ids: Vec, } +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct DFlashConfig { + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: usize, + pub(crate) num_target_layers: usize, + pub(crate) head_dim: usize, + pub(crate) vocab_size: usize, + pub(crate) rms_norm_eps: f32, + pub(crate) rope_theta: f32, + pub(crate) max_position_embeddings: usize, + pub(crate) block_size: usize, + pub(crate) dflash_config: DFlashInnerConfig, +} + +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct DFlashInnerConfig { + pub(crate) mask_token_id: u32, + pub(crate) target_layer_ids: Vec, +} + fn default_max_position_embeddings() -> usize { 40960 } @@ -117,6 +140,85 @@ impl Config { } } +impl DFlashConfig { + pub(crate) fn from_file(model_path: &str) -> Result { + let config_path = format!("{}/config.json", model_path); + let content = fs::read_to_string(&config_path)?; + Ok(serde_json::from_str(&content)?) + } + + pub(crate) fn validate_for_target(&self, target: &Config) -> Result<()> { + anyhow::ensure!( + self.hidden_size == target.hidden_size, + "DFlash hidden_size {} does not match target {}", + self.hidden_size, + target.hidden_size + ); + anyhow::ensure!( + self.num_target_layers == target.num_hidden_layers, + "DFlash num_target_layers {} does not match target layers {}", + self.num_target_layers, + target.num_hidden_layers + ); + anyhow::ensure!( + self.num_attention_heads == target.num_attention_heads + && self.num_key_value_heads == target.num_key_value_heads + && self.head_dim == target.head_dim, + "DFlash attention geometry does not match target" + ); + anyhow::ensure!( + self.vocab_size == target.vocab_size, + "DFlash vocab_size {} does not match target {}", + self.vocab_size, + target.vocab_size + ); + anyhow::ensure!( + self.rope_theta == target.rope_theta, + "DFlash rope_theta {} does not match target {}", + self.rope_theta, + target.rope_theta + ); + anyhow::ensure!( + self.max_position_embeddings >= target.max_position_embeddings, + "DFlash max_position_embeddings {} is smaller than target {}", + self.max_position_embeddings, + target.max_position_embeddings + ); + anyhow::ensure!( + self.block_size >= 2, + "DFlash block_size must be >= 2, got {}", + self.block_size + ); + anyhow::ensure!( + self.dflash_config.mask_token_id < target.vocab_size as u32, + "DFlash mask_token_id {} is outside target vocab_size {}", + self.dflash_config.mask_token_id, + target.vocab_size + ); + anyhow::ensure!( + self.dflash_config.target_layer_ids.len() == self.num_hidden_layers, + "DFlash target_layer_ids length {} does not match draft layers {}", + self.dflash_config.target_layer_ids.len(), + self.num_hidden_layers + ); + anyhow::ensure!( + self.dflash_config + .target_layer_ids + .iter() + .all(|&layer| layer < target.num_hidden_layers), + "DFlash target_layer_ids must be within target layer count" + ); + anyhow::ensure!( + self.dflash_config + .target_layer_ids + .windows(2) + .all(|pair| pair[0] < pair[1]), + "DFlash target_layer_ids must be strictly increasing" + ); + Ok(()) + } +} + impl TensorParallelConfig { pub(crate) fn validate_for(self, config: &Config) -> Result<()> { if self.world_size == 0 { diff --git a/openinfer-qwen3-4b/src/dflash.rs b/openinfer-qwen3-4b/src/dflash.rs new file mode 100644 index 00000000..76f8f0c3 --- /dev/null +++ b/openinfer-qwen3-4b/src/dflash.rs @@ -0,0 +1,776 @@ +use anyhow::{Context, Result}; +use cudarc::driver::CudaSlice; +use log::debug; + +use crate::config::DFlashConfig; +use crate::weights::{Attention, MLP, Qwen3Model, TransformerBlock}; +use openinfer_core::ops; +use openinfer_core::tensor::HiddenStates; +use openinfer_core::tensor::{DeviceContext, DeviceMatrix, DeviceVec}; +use openinfer_core::weight_loader::{ + deserialize_shards, load_shard_info, load_tensor_1d, load_tensor_2d, mmap_shards, + precompute_rope, +}; + +pub(crate) struct DFlashDraftModel { + config: DFlashConfig, + layers: Vec, + norm: DeviceVec, + hidden_norm: DeviceVec, + fc: DeviceMatrix, + cos_cache: DeviceVec, + sin_cache: DeviceVec, +} + +pub(crate) struct DFlashRequestState { + layers: Vec, + pending_context: DFlashPendingContext, + scratch: DFlashDraftScratch, + committed_len: usize, + max_cache_len: usize, +} + +struct DFlashLayerCache { + k: HiddenStates, + v: HiddenStates, +} + +struct DFlashPendingContext { + buffer: HiddenStates, + len: usize, + capacity: usize, +} + +struct DFlashDraftScratch { + max_context_len: usize, + block_token_ids_h: Vec, + token_ids_d: CudaSlice, + context_projected: HiddenStates, + context_hidden: HiddenStates, + hidden: HiddenStates, + hidden_out: HiddenStates, + normed: HiddenStates, + tail_input: HiddenStates, + q_batch: HiddenStates, + k_tail: HiddenStates, + v_tail: HiddenStates, + attn_output: HiddenStates, + o_buf: HiddenStates, + gate_out: HiddenStates, + up_out: HiddenStates, + act_out: HiddenStates, + logits_normed: HiddenStates, + logits: HiddenStates, +} + +impl DFlashRequestState { + pub(crate) fn pending_context_len(&self) -> Option { + (self.pending_context.len > 0).then_some(self.pending_context.len) + } +} + +impl DFlashPendingContext { + fn new(ctx: &DeviceContext, hidden_dim: usize, capacity: usize) -> Result { + anyhow::ensure!( + capacity > 0, + "DFlash pending context capacity must be non-zero" + ); + let mut buffer = HiddenStates::zeros(ctx, hidden_dim, capacity)?; + buffer.seq_len = 0; + Ok(Self { + buffer, + len: 0, + capacity, + }) + } + + fn append_from( + &mut self, + ctx: &DeviceContext, + src: &HiddenStates, + src_token_offset: usize, + token_count: usize, + max_capacity: usize, + ) -> Result<()> { + let required_len = self + .len + .checked_add(token_count) + .context("DFlash pending context length overflow")?; + anyhow::ensure!( + required_len <= max_capacity, + "DFlash pending context length {} exceeds request capacity {}", + required_len, + max_capacity + ); + self.ensure_capacity(ctx, required_len, max_capacity)?; + self.buffer.seq_len = self.capacity; + ops::copy_hidden_token_range_into( + ctx, + src, + src_token_offset, + &mut self.buffer, + self.len, + token_count, + )?; + self.len = required_len; + self.buffer.seq_len = self.len; + Ok(()) + } + + fn ensure_capacity( + &mut self, + ctx: &DeviceContext, + required_len: usize, + max_capacity: usize, + ) -> Result<()> { + if required_len <= self.capacity { + return Ok(()); + } + let doubled = self + .capacity + .checked_mul(2) + .context("DFlash pending context capacity overflow")?; + let new_capacity = required_len.max(doubled).min(max_capacity); + anyhow::ensure!( + new_capacity >= required_len, + "DFlash pending context capacity {} cannot fit {} tokens", + new_capacity, + required_len + ); + let mut next = HiddenStates::zeros(ctx, self.buffer.hidden_dim, new_capacity)?; + if self.len > 0 { + self.buffer.seq_len = self.capacity; + ops::copy_hidden_token_range_into(ctx, &self.buffer, 0, &mut next, 0, self.len)?; + } + next.seq_len = self.len; + self.buffer = next; + self.capacity = new_capacity; + Ok(()) + } + + fn activate_for_read(&mut self) { + self.buffer.seq_len = self.len; + } + + fn clear(&mut self) { + self.len = 0; + self.buffer.seq_len = 0; + } +} + +impl DFlashDraftScratch { + fn new(ctx: &DeviceContext, config: &DFlashConfig, max_context_len: usize) -> Result { + let block_size = config.block_size; + let hidden_size = config.hidden_size; + let q_dim = config.num_attention_heads * config.head_dim; + let kv_dim = config.num_key_value_heads * config.head_dim; + let inter_dim = config.intermediate_size; + let tail_capacity = max_context_len + block_size; + + Ok(Self { + max_context_len, + block_token_ids_h: vec![config.dflash_config.mask_token_id; block_size], + token_ids_d: ctx.stream.alloc_zeros(block_size)?, + context_projected: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, + context_hidden: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, + hidden: HiddenStates::zeros(ctx, hidden_size, block_size)?, + hidden_out: HiddenStates::zeros(ctx, hidden_size, block_size)?, + normed: HiddenStates::zeros(ctx, hidden_size, block_size)?, + tail_input: HiddenStates::zeros(ctx, hidden_size, tail_capacity)?, + q_batch: HiddenStates::zeros(ctx, q_dim, block_size)?, + k_tail: HiddenStates::zeros(ctx, kv_dim, tail_capacity)?, + v_tail: HiddenStates::zeros(ctx, kv_dim, tail_capacity)?, + attn_output: HiddenStates::zeros(ctx, q_dim, block_size)?, + o_buf: HiddenStates::zeros(ctx, hidden_size, block_size)?, + gate_out: HiddenStates::zeros(ctx, inter_dim, block_size)?, + up_out: HiddenStates::zeros(ctx, inter_dim, block_size)?, + act_out: HiddenStates::zeros(ctx, inter_dim, block_size)?, + logits_normed: HiddenStates::zeros(ctx, hidden_size, block_size)?, + logits: HiddenStates::zeros(ctx, config.vocab_size, block_size)?, + }) + } + + fn ensure_context_capacity( + &mut self, + ctx: &DeviceContext, + config: &DFlashConfig, + context_len: usize, + ) -> Result<()> { + if context_len > self.max_context_len { + *self = Self::new(ctx, config, context_len)?; + } + let block_size = config.block_size; + let tail_len = context_len + block_size; + + self.context_projected.seq_len = context_len; + self.context_hidden.seq_len = context_len; + self.hidden.seq_len = block_size; + self.hidden_out.seq_len = block_size; + self.normed.seq_len = block_size; + self.tail_input.seq_len = tail_len; + self.q_batch.seq_len = block_size; + self.k_tail.seq_len = tail_len; + self.v_tail.seq_len = tail_len; + self.attn_output.seq_len = block_size; + self.o_buf.seq_len = block_size; + self.gate_out.seq_len = block_size; + self.up_out.seq_len = block_size; + self.act_out.seq_len = block_size; + self.logits_normed.seq_len = block_size; + self.logits.seq_len = block_size; + Ok(()) + } +} + +impl DFlashDraftModel { + pub(crate) fn from_safetensors_for_target( + ctx: &DeviceContext, + model_path: &str, + target: &Qwen3Model, + ) -> Result { + let config = DFlashConfig::from_file(model_path) + .with_context(|| format!("load DFlash config from {model_path}"))?; + config.validate_for_target(target.config())?; + + let (shard_paths, weight_map) = load_shard_info(model_path)?; + debug!( + "Loading DFlash drafter from {model_path}: {} shard(s)", + shard_paths.len() + ); + let mmaps = mmap_shards(&shard_paths)?; + let shards = deserialize_shards(&mmaps)?; + + let mut layers = Vec::with_capacity(config.num_hidden_layers); + for layer_idx in 0..config.num_hidden_layers { + let prefix = format!("layers.{layer_idx}"); + + let q_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.q_proj.weight"), + )?; + let k_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.k_proj.weight"), + )?; + let v_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.v_proj.weight"), + )?; + let q_dim = q_proj.rows; + let kv_dim = k_proj.rows; + let qkv_proj = DeviceMatrix::vstack(ctx, &[&q_proj, &k_proj, &v_proj])?; + drop(q_proj); + drop(k_proj); + drop(v_proj); + + let gate_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.gate_proj.weight"), + )?; + let up_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.up_proj.weight"), + )?; + let gate_up_proj = DeviceMatrix::vstack(ctx, &[&gate_proj, &up_proj])?; + drop(gate_proj); + drop(up_proj); + + layers.push(TransformerBlock { + input_layernorm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.input_layernorm.weight"), + )?, + attention: Attention { + qkv_proj, + o_proj: load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.o_proj.weight"), + )?, + q_norm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.q_norm.weight"), + )?, + k_norm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.k_norm.weight"), + )?, + q_dim, + kv_dim, + }, + post_attention_layernorm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.post_attention_layernorm.weight"), + )?, + mlp: MLP { + gate_up_proj, + down_proj: load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.down_proj.weight"), + )?, + }, + }); + } + + let norm = load_tensor_1d(ctx, &shards, &weight_map, "norm.weight")?; + let hidden_norm = load_tensor_1d(ctx, &shards, &weight_map, "hidden_norm.weight")?; + let fc = load_tensor_2d(ctx, &shards, &weight_map, "fc.weight")?; + let (cos_cache, sin_cache) = precompute_rope( + ctx, + config.head_dim, + config.max_position_embeddings, + config.rope_theta, + )?; + ctx.sync()?; + + Ok(Self { + config, + layers, + norm, + hidden_norm, + fc, + cos_cache, + sin_cache, + }) + } + + pub(crate) fn block_size(&self) -> usize { + self.config.block_size + } + + pub(crate) fn mask_token_id(&self) -> u32 { + self.config.dflash_config.mask_token_id + } + + pub(crate) fn target_layer_ids(&self) -> &[usize] { + &self.config.dflash_config.target_layer_ids + } + + pub(crate) fn tune_gemm_algos(&self, target: &Qwen3Model) -> Result<()> { + let ctx = target.device_ctx(); + let block_size = self.block_size().min(ops::GEMM_LT_MAX_N); + let hidden = self.config.hidden_size; + let q_dim = self.config.num_attention_heads * self.config.head_dim; + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + let context_dim = self.context_feature_dim(); + + let fc_samples = [(&self.fc, 0)]; + for n in 1..=block_size { + ops::gemm_lt_tune(ctx, &fc_samples, hidden, n)?; + } + + let kv_samples: Vec<_> = self + .layers + .iter() + .flat_map(|layer| { + [ + (&layer.attention.qkv_proj, q_dim), + (&layer.attention.qkv_proj, q_dim + kv_dim), + ] + }) + .collect(); + let min_tail_n = self.block_size() + 1; + let max_tail_n = (self.block_size() * 2).min(ops::GEMM_LT_MAX_N); + for n in min_tail_n..=max_tail_n { + ops::gemm_lt_tune(ctx, &kv_samples, kv_dim, n)?; + } + + log::info!( + "Qwen3 DFlash cublasLt tuned: fc M={} K={} N=1..{}, kv M={} K={} N={}..{}", + hidden, + context_dim, + block_size, + kv_dim, + hidden, + min_tail_n, + max_tail_n, + ); + Ok(()) + } + + pub(crate) fn new_request_state( + &self, + ctx: &DeviceContext, + max_cache_len: usize, + ) -> Result { + anyhow::ensure!( + max_cache_len <= self.config.max_position_embeddings, + "DFlash request cache length {} exceeds max_position_embeddings {}", + max_cache_len, + self.config.max_position_embeddings + ); + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + let mut layers = Vec::with_capacity(self.layers.len()); + for _ in 0..self.layers.len() { + layers.push(DFlashLayerCache { + k: HiddenStates::zeros(ctx, kv_dim, max_cache_len)?, + v: HiddenStates::zeros(ctx, kv_dim, max_cache_len)?, + }); + } + Ok(DFlashRequestState { + layers, + pending_context: DFlashPendingContext::new( + ctx, + self.context_feature_dim(), + self.config.block_size.min(max_cache_len), + )?, + scratch: DFlashDraftScratch::new(ctx, &self.config, self.config.block_size)?, + committed_len: 0, + max_cache_len, + }) + } + + pub(crate) fn append_pending_context( + &self, + ctx: &DeviceContext, + state: &mut DFlashRequestState, + captured_hidden: &HiddenStates, + token_offset: usize, + token_count: usize, + ) -> Result<()> { + anyhow::ensure!(token_count > 0, "DFlash context append needs tokens"); + anyhow::ensure!( + captured_hidden.hidden_dim == self.context_feature_dim(), + "DFlash captured hidden dim {} does not match expected {}", + captured_hidden.hidden_dim, + self.context_feature_dim() + ); + anyhow::ensure!( + token_offset + token_count <= captured_hidden.seq_len, + "DFlash captured hidden token range exceeds source" + ); + let required_committed_len = state + .committed_len + .checked_add(state.pending_context.len) + .and_then(|len| len.checked_add(token_count)) + .and_then(|len| len.checked_add(self.block_size())) + .context("DFlash pending context cache length overflow")?; + anyhow::ensure!( + required_committed_len <= state.max_cache_len, + "DFlash pending context would exceed cache: committed={}, pending={}, append={}, block={}, max={}", + state.committed_len, + state.pending_context.len, + token_count, + self.block_size(), + state.max_cache_len + ); + state.pending_context.append_from( + ctx, + captured_hidden, + token_offset, + token_count, + state.max_cache_len, + )?; + Ok(()) + } + + pub(crate) fn draft_logits<'a>( + &self, + target: &Qwen3Model, + state: &'a mut DFlashRequestState, + current_token: u32, + ) -> Result<&'a HiddenStates> { + let ctx = target.device_ctx(); + let Some(context_len) = state.pending_context_len() else { + anyhow::bail!("DFlash draft requested before target hidden context is available"); + }; + let block_size = self.block_size(); + let tail_len = context_len + block_size; + anyhow::ensure!( + state.committed_len + tail_len <= state.max_cache_len, + "DFlash draft cache overflow: committed={}, tail={}, max={}", + state.committed_len, + tail_len, + state.max_cache_len + ); + + state + .scratch + .ensure_context_capacity(ctx, &self.config, context_len)?; + state.scratch.block_token_ids_h.fill(self.mask_token_id()); + state.scratch.block_token_ids_h[0] = current_token; + ctx.stream.memcpy_htod( + &state.scratch.block_token_ids_h, + &mut state.scratch.token_ids_d, + )?; + target.get_embeddings_batch_into(&state.scratch.token_ids_d, &mut state.scratch.hidden)?; + + state.pending_context.activate_for_read(); + self.project_context_into(ctx, &state.pending_context.buffer, &mut state.scratch)?; + state.pending_context.clear(); + + let hidden_size = self.config.hidden_size; + let q_dim = self.config.num_attention_heads * self.config.head_dim; + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + let inter_dim = self.config.intermediate_size; + debug_assert_eq!(state.scratch.hidden.hidden_dim, hidden_size); + debug_assert_eq!(state.scratch.q_batch.hidden_dim, q_dim); + debug_assert_eq!(state.scratch.k_tail.hidden_dim, kv_dim); + debug_assert_eq!(state.scratch.gate_out.hidden_dim, inter_dim); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + ops::rms_norm_batch_into( + ctx, + &state.scratch.hidden, + &layer.input_layernorm, + self.config.rms_norm_eps, + &mut state.scratch.normed, + ); + + ops::copy_hidden_token_range_into( + ctx, + &state.scratch.context_hidden, + 0, + &mut state.scratch.tail_input, + 0, + context_len, + )?; + ops::copy_hidden_token_range_into( + ctx, + &state.scratch.normed, + 0, + &mut state.scratch.tail_input, + context_len, + block_size, + )?; + + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + 0, + q_dim, + &state.scratch.normed, + &mut state.scratch.q_batch, + ); + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + q_dim, + kv_dim, + &state.scratch.tail_input, + &mut state.scratch.k_tail, + ); + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + q_dim + kv_dim, + kv_dim, + &state.scratch.tail_input, + &mut state.scratch.v_tail, + ); + + ops::dflash_qk_norm_rope_into( + ctx, + &mut state.scratch.q_batch, + &mut state.scratch.k_tail, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + state.committed_len + context_len, + state.committed_len, + self.config.rms_norm_eps, + )?; + + let cache = &mut state.layers[layer_idx]; + ops::copy_hidden_token_range_into( + ctx, + &state.scratch.k_tail, + 0, + &mut cache.k, + state.committed_len, + tail_len, + )?; + ops::copy_hidden_token_range_into( + ctx, + &state.scratch.v_tail, + 0, + &mut cache.v, + state.committed_len, + tail_len, + )?; + ops::single_prefill_nhd_noncausal_into( + ctx, + &state.scratch.q_batch, + &cache.k, + &cache.v, + &mut state.scratch.attn_output, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + state.committed_len + tail_len, + )?; + + ops::gemm_into( + ctx, + &layer.attention.o_proj, + &state.scratch.attn_output, + &mut state.scratch.o_buf, + ); + openinfer_kernels::ops::fused_add_rms_norm_round_batch_into( + ctx, + &mut state.scratch.hidden, + &state.scratch.o_buf, + &layer.post_attention_layernorm, + self.config.rms_norm_eps, + &mut state.scratch.normed, + )?; + + ops::gemm_rows_into( + ctx, + &layer.mlp.gate_up_proj, + 0, + inter_dim, + &state.scratch.normed, + &mut state.scratch.gate_out, + ); + ops::gemm_rows_into( + ctx, + &layer.mlp.gate_up_proj, + inter_dim, + inter_dim, + &state.scratch.normed, + &mut state.scratch.up_out, + ); + ops::silu_mul_batch_into( + ctx, + &state.scratch.gate_out, + &state.scratch.up_out, + &mut state.scratch.act_out, + )?; + ops::gemm_into( + ctx, + &layer.mlp.down_proj, + &state.scratch.act_out, + &mut state.scratch.o_buf, + ); + ops::add_batch_into( + ctx, + &state.scratch.hidden, + &state.scratch.o_buf, + &mut state.scratch.hidden_out, + )?; + std::mem::swap(&mut state.scratch.hidden, &mut state.scratch.hidden_out); + } + + state.committed_len += context_len; + self.compute_logits_with_target_head_into(target, &mut state.scratch)?; + Ok(&state.scratch.logits) + } + + fn context_feature_dim(&self) -> usize { + self.config.hidden_size * self.target_layer_ids().len() + } + + fn project_context_into( + &self, + ctx: &DeviceContext, + context_features: &HiddenStates, + scratch: &mut DFlashDraftScratch, + ) -> Result<()> { + ops::gemm_into( + ctx, + &self.fc, + context_features, + &mut scratch.context_projected, + ); + ops::rms_norm_batch_into( + ctx, + &scratch.context_projected, + &self.hidden_norm, + self.config.rms_norm_eps, + &mut scratch.context_hidden, + ); + Ok(()) + } + + fn compute_logits_with_target_head_into( + &self, + target: &Qwen3Model, + scratch: &mut DFlashDraftScratch, + ) -> Result<()> { + let ctx = target.device_ctx(); + ops::rms_norm_batch_into( + ctx, + &scratch.hidden, + &self.norm, + self.config.rms_norm_eps, + &mut scratch.logits_normed, + ); + ops::gemm_into( + ctx, + target.output_projection(), + &scratch.logits_normed, + &mut scratch.logits, + ); + Ok(()) + } +} + +#[cfg(test)] +pub(crate) fn validate_dflash_config_for_target( + dflash_path: &str, + target_config: &crate::config::Config, +) -> Result { + let config = DFlashConfig::from_file(dflash_path)?; + config.validate_for_target(target_config)?; + Ok(config) +} + +#[cfg(test)] +mod tests { + use super::validate_dflash_config_for_target; + use crate::config::Config; + use std::path::Path; + + #[test] + fn downloaded_dflash_config_matches_qwen3_4b() { + let target_path = std::env::var("OPENINFER_TEST_MODEL_PATH") + .unwrap_or_else(|_| "/data/models/Qwen3-4B".to_string()); + let dflash_path = std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") + .unwrap_or_else(|_| "/data/models/Qwen3-4B-DFlash-b16".to_string()); + if !Path::new(&target_path).join("config.json").exists() + || !Path::new(&dflash_path).join("config.json").exists() + { + eprintln!( + "skipping DFlash config test; set OPENINFER_TEST_MODEL_PATH and OPENINFER_DFLASH_TEST_MODEL_PATH" + ); + return; + } + + let target = Config::from_file(&target_path).expect("target config"); + let dflash = validate_dflash_config_for_target(&dflash_path, &target) + .expect("DFlash config should match target"); + + assert_eq!(dflash.block_size, 16); + assert_eq!(dflash.dflash_config.mask_token_id, 151669); + assert_eq!( + dflash.dflash_config.target_layer_ids, + vec![1, 9, 17, 25, 33] + ); + } +} diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 95fe3d4b..1dbd6efb 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -18,6 +18,18 @@ use openinfer_kv_cache::{ }; use openinfer_kv_offload::{LoadHandle, OffloadConfig, OffloadEngine}; +mod dflash_lane; +mod dflash_prefill; +mod spec; + +use crate::dflash::DFlashDraftModel; +use crate::speculative::{ + DraftPlan, DraftResult, DraftStepItem, VerifyPlan, VerifyResult, VerifyStepItem, + build_verify_results, +}; +use dflash_lane::DFlashLaneState; +use dflash_prefill::{DFlashPrefillAction, dflash_prefill_action}; + #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] pub struct RequestId(pub(crate) u64); @@ -264,8 +276,26 @@ fn execute_step_on_lane( .iter() .map(|req| req.lora_adapter.as_deref()) .collect(); - let (logits, all_position_logits) = - lane.execute_prefill(&prompts, kv_views, &lora_adapters, *echo)?; + // When DFlash is loaded, capture target hidden states for eligible + // requests so they can seed the draft model after prefill finishes. + let capture_requested = lane.should_capture_dflash_prefill_context(requests); + let capture_layer_ids = if capture_requested { + lane.dflash_capture_layer_ids() + } else { + None + }; + let (logits, all_position_logits, captured_hidden) = lane.execute_prefill( + &prompts, + kv_views, + &lora_adapters, + *echo, + capture_layer_ids.as_deref(), + )?; + let dflash_context_captured_requests = lane.record_prefill_dflash_context( + requests, + capture_requested, + captured_hidden.as_ref(), + )?; if collect_result { let params: Vec<&SamplingParams> = requests.iter().map(|r| &r.params).collect(); let tokens = lane.select_step_tokens(&logits, ¶ms, *sample_seed)?; @@ -278,6 +308,7 @@ fn execute_step_on_lane( all_position_logits.as_ref(), *echo, )?, + dflash_context_captured_requests, })) } else { Ok(WorkerStepOutcome::Ack) @@ -396,11 +427,12 @@ fn execute_step_on_lane( // Launch prefill on prefill partition stream. unsafe { set_stream_override(prefill_stream.0) }; - let (prefill_logits, _) = lane.execute_prefill( + let (prefill_logits, _, _) = lane.execute_prefill( &prefill_prompts, prefill_kv_views, &prefill_lora_adapters, false, + None, )?; clear_stream_override(); @@ -458,6 +490,47 @@ fn execute_step_on_lane( Ok(WorkerStepOutcome::Ack) } } + StepCommand::SpeculativeVerify { + requests, + kv_views, + } => { + // One target forward over each request's K+1 draft span with a + // speculative KV view. echo=true yields all-position logits so we + // can argmax every span position — accept_greedy needs the target's + // posterior at each position. The same forward captures target + // hidden states (at the DFlash layers) to seed the next draft. + let spans: Vec<&[u32]> = requests.iter().map(VerifyStepItem::as_slice).collect(); + let no_lora: Vec> = vec![None; requests.len()]; + let capture_layer_ids = lane.dflash_capture_layer_ids(); + let (_last_logits, all_logits, captured_hidden) = lane.execute_prefill( + &spans, + kv_views, + &no_lora, + true, + capture_layer_ids.as_deref(), + )?; + let all_logits = all_logits.ok_or_else(|| { + anyhow::anyhow!("speculative verify produced no per-position logits") + })?; + let total_tokens: usize = requests.iter().map(|req| req.as_slice().len()).sum(); + let greedy = SamplingParams::default(); + let params: Vec<&SamplingParams> = vec![&greedy; total_tokens]; + let target_tokens = lane.select_step_tokens(&all_logits, ¶ms, 0)?; + let request_results = build_verify_results(requests, &target_tokens)?; + lane.record_verify_dflash_context( + requests, + &request_results, + captured_hidden.as_ref(), + )?; + Ok(WorkerStepOutcome::SpeculativeVerify(VerifyResult { + requests: request_results, + })) + } + StepCommand::SpeculativeDraft { requests } => { + Ok(WorkerStepOutcome::SpeculativeDraft( + lane.execute_dflash_draft(requests)?, + )) + } } } @@ -595,6 +668,10 @@ pub struct DecodeRequestResult { pub struct PrefillResult { pub requests: Vec, + /// Requests whose DFlash target context was captured this prefill step. + /// Empty unless speculative decoding is enabled. The executor folds these + /// into its `dflash_ready_requests` set once the prompt is fully prefilled. + pub dflash_context_captured_requests: Vec, } pub struct DecodeResult { @@ -619,6 +696,27 @@ pub(crate) trait ModelExecutor: Send { fn execute_decode(&mut self, plan: DecodePlan<'_>) -> Result; fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> Result; + /// Run one speculative draft round (propose `K` tokens per request). Only + /// meaningful when [`Self::speculative_enabled`] is true. + fn execute_speculative_draft(&mut self, _plan: DraftPlan<'_>) -> Result { + anyhow::bail!("speculative draft is not implemented for this executor") + } + + /// Verify a draft span with one target forward and accept the greedy prefix. + fn execute_speculative_verify(&mut self, _plan: VerifyPlan<'_>) -> Result { + anyhow::bail!("speculative verification is not implemented for this executor") + } + + /// Whether a draft model is loaded and speculative decoding is active. + fn speculative_enabled(&self) -> bool { + false + } + + /// Whether `request_id` has captured draft context and can be drafted. + fn speculative_request_ready(&self, _request_id: RequestId) -> bool { + false + } + fn load_lora_adapter(&mut self, request: &LoadLoraAdapterRequest) -> Result<()> { anyhow::bail!( "Qwen3 LoRA adapter loading is not implemented yet: name={}, path={}", @@ -732,6 +830,13 @@ pub struct Qwen3Executor { /// In-flight async prefill state. Populated by the SplitConcurrent step, /// consumed by `poll_async_prefill`. async_prefill: Option, + /// DFlash draft metadata; `Some` once a draft model is loaded into the + /// primary lane. Speculative decoding is enabled iff this is set. + speculative: Option, + /// Requests whose DFlash context is captured and ready to draft. A request + /// enters this set when its prompt finishes prefilling with captured target + /// context, and leaves on retire or a plain (non-speculative) decode. + dflash_ready_requests: HashSet, } /// State for an in-flight async prefill on the prefill overlap stream. @@ -809,6 +914,8 @@ impl Qwen3Executor { l1_retention_disabled: false, overlap: None, async_prefill: None, + speculative: None, + dflash_ready_requests: HashSet::new(), }) } @@ -987,6 +1094,8 @@ impl Qwen3Executor { l1_retention_disabled: false, overlap: None, async_prefill: None, + speculative: None, + dflash_ready_requests: HashSet::new(), }) } @@ -1068,6 +1177,33 @@ impl Qwen3Executor { } } + /// Enable speculative decoding by loading a DFlash draft model into the + /// primary lane. + /// + /// Requires the single-GPU topology (tensor parallel shards KV per rank) and + /// is incompatible with KV offload. Disables the prefix cache: speculative + /// capture needs clean, uncached target hidden states for every prompt + /// token, and a prefix-cache hit skips the forward that would produce them. + pub fn load_dflash_draft_model(&mut self, draft_path: &str) -> Result<()> { + anyhow::ensure!( + self.workers.is_empty(), + "speculative decoding requires the single-GPU path (got {} extra ranks)", + self.workers.len() + ); + anyhow::ensure!( + self.offload.is_none(), + "speculative decoding is not supported together with KV offload" + ); + let meta = self.primary.load_dflash(draft_path.to_string())?; + log::info!( + "Qwen3 DFlash speculative decoding enabled: draft block size {}", + meta.block_size + ); + self.prefix_cache_enabled = false; + self.speculative = Some(meta); + Ok(()) + } + /// Whether KV offload is active on this executor. pub fn offload_enabled(&self) -> bool { self.offload.is_some() @@ -1418,6 +1554,10 @@ impl ModelExecutor for Qwen3Executor { } } self.saved_cursor.remove(&request_id); + if self.speculative.is_some() { + self.dflash_ready_requests.remove(&request_id); + self.primary.drop_dflash_request(request_id)?; + } Ok(()) } @@ -1582,6 +1722,28 @@ impl ModelExecutor for Qwen3Executor { for req_result in &result.requests { self.apply_prefill_result(req_result)?; } + // A request becomes draft-ready once its prompt is fully prefilled with + // captured target context. Partial chunks stay pending; ineligible + // requests drop any stale worker state. + if self.speculative.is_some() { + for req_result in &result.requests { + let captured = result + .dflash_context_captured_requests + .contains(&req_result.request_id); + match dflash_prefill_action(captured, req_result.completed) { + DFlashPrefillAction::MarkReady => { + self.dflash_ready_requests.insert(req_result.request_id); + } + DFlashPrefillAction::KeepPending => { + self.dflash_ready_requests.remove(&req_result.request_id); + } + DFlashPrefillAction::Drop => { + self.dflash_ready_requests.remove(&req_result.request_id); + self.primary.drop_dflash_request(req_result.request_id)?; + } + } + } + } // 5. Offload the blocks this prefill just sealed (post-step-sync). for req_result in &result.requests { self.save_sealed_blocks(req_result.request_id); @@ -1634,6 +1796,15 @@ impl ModelExecutor for Qwen3Executor { .expect("request must exist after decode"); rkv.apply_decode(req_result.token, self.kv_mgr.pool())?; } + // A plain decode advances the sequence outside the speculative path, so + // any captured draft context is now stale — drop it. + if self.speculative.is_some() { + for req_result in &result.requests { + if self.dflash_ready_requests.remove(&req_result.request_id) { + self.primary.drop_dflash_request(req_result.request_id)?; + } + } + } // 5. Offload any block this decode step just sealed (post-step-sync). for req_result in &result.requests { self.save_sealed_blocks(req_result.request_id); @@ -1642,6 +1813,22 @@ impl ModelExecutor for Qwen3Executor { Ok(result) } + fn execute_speculative_draft(&mut self, plan: DraftPlan<'_>) -> Result { + self.execute_speculative_draft_impl(plan) + } + + fn execute_speculative_verify(&mut self, plan: VerifyPlan<'_>) -> Result { + self.execute_speculative_verify_impl(plan) + } + + fn speculative_enabled(&self) -> bool { + self.speculative.is_some() + } + + fn speculative_request_ready(&self, request_id: RequestId) -> bool { + self.dflash_ready_requests.contains(&request_id) + } + fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> Result { // 1. Create RequestKvs for prefill requests (first chunk only), clamp // chunk budgets, schedule KV for this step's tokens @@ -2038,6 +2225,16 @@ impl Drop for Qwen3Executor { } } +/// What the executor learns about the draft model after loading it on the +/// worker: the draft block size (`K` candidates per round) and which target +/// layers feed the draft (the worker captures these; kept for diagnostics). +#[derive(Clone, Debug)] +struct DFlashMeta { + block_size: usize, + #[allow(dead_code)] + target_layer_ids: Vec, +} + struct LocalQwen3Lane { model: Qwen3Model, kv_buffer: KvBuffer, @@ -2046,6 +2243,9 @@ struct LocalQwen3Lane { sample_scratch: openinfer_sample::SampleScratch, /// In-flight prefill from a previous SplitConcurrent step (not yet synced). inflight_prefill: Option, + /// DFlash draft lane (the draft model + per-request draft state). `None` + /// unless speculative decoding is enabled; only the primary rank carries it. + dflash: Option, } /// Stored state for an async prefill that was launched but not yet synced. @@ -2101,9 +2301,28 @@ impl LocalQwen3Lane { bufs, sample_scratch, inflight_prefill: None, + dflash: None, }) } + /// Load the DFlash draft model into this lane (primary rank only). The draft + /// model is built here on the worker thread because it reads the co-located + /// target model's embeddings and head. + fn load_dflash(&mut self, draft_path: &str) -> Result { + let model = DFlashDraftModel::from_safetensors_for_target( + self.model.device_ctx(), + draft_path, + &self.model, + )?; + model.tune_gemm_algos(&self.model)?; + let meta = DFlashMeta { + block_size: model.block_size(), + target_layer_ids: model.target_layer_ids().to_vec(), + }; + self.dflash = Some(DFlashLaneState::new(model)); + Ok(meta) + } + fn bind(&self) -> Result { bind_model_thread(&self.model)?; tune_decode_gemm_algos(&self.model)?; @@ -2142,7 +2361,12 @@ impl LocalQwen3Lane { false, )?; - Ok(PrefillResult { requests: results }) + // Split-concurrent prefill never runs with DFlash (capture needs the + // synchronous result), so no context is captured here. + Ok(PrefillResult { + requests: results, + dflash_context_captured_requests: Vec::new(), + }) } /// Pick one token per logits column (batched argmax for greedy rows, @@ -2202,7 +2426,8 @@ impl LocalQwen3Lane { kv_views: &[KvView], lora_adapters: &[Option<&str>], echo: bool, - ) -> Result<(HiddenStates, Option)> { + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option, Option)> { self.model.batch_prefill( prompts, kv_views, @@ -2210,6 +2435,7 @@ impl LocalQwen3Lane { self.kv_buffer.buffer(), &self.layout, echo, + capture_layer_ids, ) } @@ -2302,6 +2528,18 @@ enum StepCommand { decode_stream: crate::green_ctx::SendStream, sample_seed: u64, }, + /// Speculative verify: one target forward over each request's `K + 1` draft + /// span (with a speculative KV view), capturing target hidden states for the + /// next draft round. Greedy argmax per position drives [`accept_greedy`]. + SpeculativeVerify { + requests: Vec, + kv_views: Vec, + }, + /// Speculative draft: roll the DFlash draft model forward one block per + /// request. Uses the draft's own KV — no target KV views. + SpeculativeDraft { + requests: Vec, + }, } impl StepCommand { @@ -2311,6 +2549,8 @@ impl StepCommand { Self::Decode { .. } => "decode", Self::Unified { .. } => "unified", Self::SplitConcurrent { .. } => "split_concurrent", + Self::SpeculativeVerify { .. } => "speculative_verify", + Self::SpeculativeDraft { .. } => "speculative_draft", } } } @@ -2340,6 +2580,18 @@ enum WorkerCommand { ResolvePrefill { resp: channel::Sender>, }, + /// Load the DFlash draft model into the primary lane (built on the worker + /// thread because it reads the co-located target model). + LoadDflash { + draft_path: String, + resp: channel::Sender>, + }, + /// Drop a request's DFlash draft state (request retired, or it fell back to + /// a plain decode that advanced the sequence outside the speculative path). + DropDflash { + request_id: RequestId, + resp: channel::Sender>, + }, Shutdown, } @@ -2357,6 +2609,8 @@ enum WorkerStepOutcome { /// query this to check if prefill is done without blocking. prefill_event: SendEvent, }, + SpeculativeVerify(VerifyResult), + SpeculativeDraft(DraftResult), } impl WorkerStepOutcome { @@ -2367,6 +2621,8 @@ impl WorkerStepOutcome { Self::Decode(_) => "decode", Self::Unified(_) => "unified", Self::SplitDecodeReady { .. } => "split_decode_ready", + Self::SpeculativeVerify(_) => "speculative_verify", + Self::SpeculativeDraft(_) => "speculative_draft", } } } @@ -2420,6 +2676,14 @@ impl RankWorker { let result = lane.resolve_inflight_prefill(); let _ = resp.send(result); } + WorkerCommand::LoadDflash { draft_path, resp } => { + let result = lane.load_dflash(&draft_path); + let _ = resp.send(result); + } + WorkerCommand::DropDflash { request_id, resp } => { + lane.drop_dflash_request(request_id); + let _ = resp.send(Ok(())); + } WorkerCommand::Shutdown => break, } } @@ -2464,6 +2728,35 @@ impl RankWorker { Ok(resp_rx) } + /// Load the DFlash draft model into this worker's lane and return its + /// metadata. Blocks until the worker finishes loading. + fn load_dflash(&self, draft_path: String) -> Result { + let (resp_tx, resp_rx) = channel::bounded(1); + self.tx + .send(WorkerCommand::LoadDflash { + draft_path, + resp: resp_tx, + }) + .map_err(|_| anyhow::anyhow!("worker channel closed on load_dflash"))?; + resp_rx + .recv() + .map_err(|_| anyhow::anyhow!("worker dropped load_dflash response"))? + } + + /// Drop a request's DFlash state. Blocks until the worker acknowledges. + fn drop_dflash_request(&self, request_id: RequestId) -> Result<()> { + let (resp_tx, resp_rx) = channel::bounded(1); + self.tx + .send(WorkerCommand::DropDflash { + request_id, + resp: resp_tx, + }) + .map_err(|_| anyhow::anyhow!("worker channel closed on drop_dflash"))?; + resp_rx + .recv() + .map_err(|_| anyhow::anyhow!("worker dropped drop_dflash response"))? + } + fn load_lora_adapter( &self, name: String, diff --git a/openinfer-qwen3-4b/src/executor/dflash_lane.rs b/openinfer-qwen3-4b/src/executor/dflash_lane.rs new file mode 100644 index 00000000..478daa8f --- /dev/null +++ b/openinfer-qwen3-4b/src/executor/dflash_lane.rs @@ -0,0 +1,259 @@ +//! Worker-side DFlash draft lane: the draft model plus per-request draft state. +//! +//! This lives on the worker thread next to the target model because the draft +//! rollout reads the target's embeddings/head and its captured hidden states. +//! The draft/verify boundary stays a pure token span — the hidden states are +//! private to this lane (`pending_context`), never crossing to the scheduler. + +use std::collections::HashMap; + +use anyhow::Result; +use openinfer_core::sampler::SamplingParams; +use openinfer_core::tensor::HiddenStates; + +use super::dflash_prefill::{dflash_prefill_can_capture, should_capture_dflash_prefill_context}; +use super::{LocalQwen3Lane, PrefillStepItem, RequestId}; +use crate::dflash::{DFlashDraftModel, DFlashRequestState}; +use crate::speculative::{ + DraftRequestResult, DraftResult, DraftStepItem, VerifyRequestResult, VerifyStepItem, +}; + +pub(super) struct DFlashLaneState { + pub(super) model: DFlashDraftModel, + pub(super) requests: HashMap, + verified_draft_tokens: usize, + accepted_draft_tokens: usize, +} + +impl DFlashLaneState { + pub(super) fn new(model: DFlashDraftModel) -> Self { + Self { + model, + requests: HashMap::new(), + verified_draft_tokens: 0, + accepted_draft_tokens: 0, + } + } +} + +impl LocalQwen3Lane { + /// Target layers whose hidden states the draft model consumes (None when + /// DFlash is not loaded). + pub(super) fn dflash_capture_layer_ids(&self) -> Option> { + self.dflash + .as_ref() + .map(|dflash| dflash.model.target_layer_ids().to_vec()) + } + + pub(super) fn should_capture_dflash_prefill_context( + &self, + requests: &[PrefillStepItem], + ) -> bool { + let Some(dflash) = self.dflash.as_ref() else { + return false; + }; + should_capture_dflash_prefill_context(requests, |request_id| { + dflash.requests.contains_key(&request_id) + }) + } + + /// Fold target hidden states captured during prefill into each eligible + /// request's pending context. Returns the requests that now have context. + pub(super) fn record_prefill_dflash_context( + &mut self, + requests: &[PrefillStepItem], + capture_requested: bool, + captured_hidden: Option<&HiddenStates>, + ) -> Result> { + let Some(captured_hidden) = captured_hidden else { + anyhow::ensure!( + !capture_requested, + "DFlash prefill context capture was requested but no hidden states were returned" + ); + return Ok(Vec::new()); + }; + anyhow::ensure!( + capture_requested, + "DFlash prefill hidden states were returned without a capture request" + ); + let Some(dflash) = self.dflash.as_mut() else { + anyhow::bail!("DFlash prefill context record requested without DFlash"); + }; + let expected_tokens: usize = requests.iter().map(|req| req.chunk_tokens).sum(); + anyhow::ensure!( + captured_hidden.seq_len == expected_tokens, + "DFlash prefill captured {} hidden rows for {} scheduled tokens", + captured_hidden.seq_len, + expected_tokens + ); + let ctx = self.model.device_ctx().clone(); + let mut captured_requests = Vec::new(); + let mut token_offset = 0usize; + for req in requests { + let pending_exists = dflash.requests.contains_key(&req.request_id); + if dflash_prefill_can_capture(req, pending_exists) { + let max_cache_len = + req.prompt_tokens.len() + req.max_output_tokens + dflash.model.block_size(); + let mut state = match dflash.requests.remove(&req.request_id) { + Some(state) => state, + None => dflash.model.new_request_state(&ctx, max_cache_len)?, + }; + let pending_len = state.pending_context_len().unwrap_or(0); + anyhow::ensure!( + pending_len == req.chunk_start, + "DFlash prefill context for {:?} is discontinuous: pending={}, chunk_start={}", + req.request_id, + pending_len, + req.chunk_start + ); + dflash.model.append_pending_context( + &ctx, + &mut state, + captured_hidden, + token_offset, + req.chunk_tokens, + )?; + dflash.requests.insert(req.request_id, state); + captured_requests.push(req.request_id); + } else { + dflash.requests.remove(&req.request_id); + } + token_offset += req.chunk_tokens; + } + Ok(captured_requests) + } + + /// Seed the next draft round from a verify step: append the target hidden + /// states for the *accepted* span positions to each request's pending + /// context, and log the cumulative acceptance rate. + pub(super) fn record_verify_dflash_context( + &mut self, + requests: &[VerifyStepItem], + results: &[VerifyRequestResult], + captured_hidden: Option<&HiddenStates>, + ) -> Result<()> { + let Some(captured_hidden) = captured_hidden else { + anyhow::bail!("DFlash verify context capture requested but no hidden states returned"); + }; + let Some(dflash) = self.dflash.as_mut() else { + anyhow::bail!("DFlash verify context record requested without DFlash"); + }; + anyhow::ensure!( + requests.len() == results.len(), + "DFlash verify result count {} does not match request count {}", + results.len(), + requests.len() + ); + let expected_tokens: usize = requests.iter().map(|req| req.token_ids.len()).sum(); + anyhow::ensure!( + captured_hidden.seq_len == expected_tokens, + "DFlash verify captured {} hidden rows for {} scheduled tokens", + captured_hidden.seq_len, + expected_tokens + ); + let ctx = self.model.device_ctx().clone(); + let mut token_offset = 0usize; + for (req, result) in requests.iter().zip(results) { + anyhow::ensure!( + req.request_id == result.request_id, + "DFlash verify result {:?} does not match request {:?}", + result.request_id, + req.request_id + ); + let mut state = dflash.requests.remove(&req.request_id).ok_or_else(|| { + anyhow::anyhow!("missing DFlash state after verify for {:?}", req.request_id) + })?; + // Only the accepted prefix's target hidden states are valid context + // for the next draft; rejected drafts had the wrong continuation. + dflash.model.append_pending_context( + &ctx, + &mut state, + captured_hidden, + token_offset, + result.accepted_tokens.len(), + )?; + dflash.requests.insert(req.request_id, state); + dflash.verified_draft_tokens += req.token_ids.len().saturating_sub(1); + dflash.accepted_draft_tokens += result.matched_draft_tokens; + let rate = if dflash.verified_draft_tokens == 0 { + 0.0 + } else { + dflash.accepted_draft_tokens as f64 / dflash.verified_draft_tokens as f64 + }; + log::debug!( + "Qwen3 DFlash request={} accepted_draft={} committed_tokens={} cumulative_accept_rate={:.3}", + req.request_id.get(), + result.matched_draft_tokens, + result.accepted_tokens.len(), + rate, + ); + token_offset += req.token_ids.len(); + } + Ok(()) + } + + /// Roll out one draft span per request: draft forward + greedy argmax over + /// the block. Returns the verify span `[current_token, draft_1, …]`. + pub(super) fn execute_dflash_draft( + &mut self, + requests: &[DraftStepItem], + ) -> Result { + anyhow::ensure!( + !requests.is_empty(), + "DFlash draft requested without active requests" + ); + for req in requests { + anyhow::ensure!( + req.params.is_greedy(), + "DFlash draft currently supports greedy sampling only" + ); + } + + // Take the lane out of `self` so the draft forward (which borrows + // `self.model`) and the argmax (which borrows `self.sample_scratch`) + // don't collide on a `self` borrow. + let Some(mut dflash) = self.dflash.take() else { + anyhow::bail!("DFlash draft requested but DFlash is not loaded"); + }; + let result = (|| -> Result> { + let mut outputs = Vec::with_capacity(requests.len()); + for req in requests { + let mut state = dflash + .requests + .remove(&req.request_id) + .ok_or_else(|| anyhow::anyhow!("missing DFlash state for {:?}", req.request_id))?; + let draft_logits = dflash + .model + .draft_logits(&self.model, &mut state, req.current_token)?; + let draft_len = draft_logits.seq_len; + let greedy = SamplingParams::default(); + let params: Vec<&SamplingParams> = vec![&greedy; draft_len]; + let sampled = self.select_step_tokens(draft_logits, ¶ms, 0)?; + dflash.requests.insert(req.request_id, state); + anyhow::ensure!( + sampled.len() == draft_len && sampled.len() >= 2, + "DFlash draft sampled {} tokens from {} logits columns", + sampled.len(), + draft_len + ); + // Verify span = [current dangling token, draft_1, …, draft_{K}]. + let mut token_ids = Vec::with_capacity(sampled.len()); + token_ids.push(req.current_token); + token_ids.extend(sampled.into_iter().skip(1)); + outputs.push(DraftRequestResult { + request_id: req.request_id, + token_ids, + }); + } + Ok(outputs) + })(); + self.dflash = Some(dflash); + Ok(DraftResult { requests: result? }) + } + + pub(super) fn drop_dflash_request(&mut self, request_id: RequestId) { + if let Some(dflash) = self.dflash.as_mut() { + dflash.requests.remove(&request_id); + } + } +} diff --git a/openinfer-qwen3-4b/src/executor/dflash_prefill.rs b/openinfer-qwen3-4b/src/executor/dflash_prefill.rs new file mode 100644 index 00000000..332cf2a4 --- /dev/null +++ b/openinfer-qwen3-4b/src/executor/dflash_prefill.rs @@ -0,0 +1,70 @@ +//! DFlash prefill-capture eligibility predicates. +//! +//! A request can seed the DFlash draft only if its prefill produces clean target +//! hidden states: greedy, no LoRA, no prefix-cache hit, no echo, no logprobs. + +use super::{PrefillStepItem, RequestId}; + +/// Whether a prefill request is eligible to capture DFlash target context. +pub(super) fn dflash_prefill_supported(req: &PrefillStepItem) -> bool { + req.lora_adapter.is_none() + && req.cached_tokens == 0 + && req.logprobs == 0 + && !req.echo + && req.params.is_greedy() +} + +/// Eligible AND continuous: either the first chunk, or a later chunk whose +/// earlier chunks already captured context (no gaps in the pending buffer). +pub(super) fn dflash_prefill_can_capture(req: &PrefillStepItem, pending_state_exists: bool) -> bool { + dflash_prefill_supported(req) && (req.chunk_start == 0 || pending_state_exists) +} + +/// Capture hidden states during this prefill step iff any request is eligible. +pub(super) fn should_capture_dflash_prefill_context( + requests: &[PrefillStepItem], + pending_state_exists: impl Fn(RequestId) -> bool, +) -> bool { + !requests.is_empty() + && requests + .iter() + .any(|req| dflash_prefill_can_capture(req, pending_state_exists(req.request_id))) +} + +/// What to do with a request's DFlash state after a prefill step. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(super) enum DFlashPrefillAction { + /// Context captured and prefill finished → ready to draft. + MarkReady, + /// Context captured but more chunks remain → keep the pending state. + KeepPending, + /// Ineligible → drop any stale state. + Drop, +} + +pub(super) fn dflash_prefill_action( + captured_context: bool, + completed: bool, +) -> DFlashPrefillAction { + match (captured_context, completed) { + (true, true) => DFlashPrefillAction::MarkReady, + (true, false) => DFlashPrefillAction::KeepPending, + (false, _) => DFlashPrefillAction::Drop, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prefill_action_table() { + assert_eq!(dflash_prefill_action(true, true), DFlashPrefillAction::MarkReady); + assert_eq!( + dflash_prefill_action(true, false), + DFlashPrefillAction::KeepPending + ); + assert_eq!(dflash_prefill_action(false, true), DFlashPrefillAction::Drop); + assert_eq!(dflash_prefill_action(false, false), DFlashPrefillAction::Drop); + } +} diff --git a/openinfer-qwen3-4b/src/executor/spec.rs b/openinfer-qwen3-4b/src/executor/spec.rs new file mode 100644 index 00000000..0ec42881 --- /dev/null +++ b/openinfer-qwen3-4b/src/executor/spec.rs @@ -0,0 +1,188 @@ +//! Executor-side speculative-decode orchestration: the optimistic KV +//! transaction around a verify forward (schedule → forward → accept/commit or +//! roll back) and the thin draft dispatch. +//! +//! The forward itself runs on the worker lane (see [`super::dflash_lane`]); this +//! module owns only the KV bookkeeping the executor thread is responsible for. + +use anyhow::Result; + +use crate::speculative::{DraftPlan, DraftResult, VerifyPlan, VerifyResult}; + +use super::{Qwen3Executor, RequestId, StepCommand, WorkerStepOutcome}; + +impl Qwen3Executor { + pub(super) fn execute_speculative_verify_impl( + &mut self, + plan: VerifyPlan<'_>, + ) -> Result { + anyhow::ensure!( + self.speculative.is_some(), + "speculative verification requested but no draft model is loaded" + ); + for req in plan.requests { + anyhow::ensure!( + !req.as_slice().is_empty(), + "speculative verify request {:?} has an empty verify span", + req.request_id + ); + anyhow::ensure!( + req.params.is_greedy(), + "speculative verification currently supports greedy sampling only" + ); + anyhow::ensure!( + self.dflash_ready_requests.contains(&req.request_id), + "speculative verification requested before DFlash state is ready for {:?}", + req.request_id + ); + anyhow::ensure!( + self.request_kvs.contains_key(&req.request_id), + "missing RequestKv for {:?}", + req.request_id + ); + } + + // Reserve KV slots for each request's full K+1 verify span. Roll back + // every prior reservation if any single one fails — all-or-nothing. + let mut scheduled = Vec::with_capacity(plan.requests.len()); + for req in plan.requests { + let span_len = req.as_slice().len(); + let rkv = self + .request_kvs + .get_mut(&req.request_id) + .expect("RequestKv was validated before speculative scheduling"); + if let Err(e) = rkv.schedule_speculative(span_len, self.kv_mgr.pool()) { + self.revert_speculative_schedules(&scheduled); + return Err(anyhow::anyhow!( + "schedule_speculative failed for {:?}: {e}", + req.request_id + )); + } + scheduled.push(req.request_id); + } + + let kv_views = plan + .requests + .iter() + .map(|req| self.request_kvs[&req.request_id].speculative_view(req.as_slice().len())) + .collect(); + + let step = StepCommand::SpeculativeVerify { + requests: plan.requests.to_vec(), + kv_views, + }; + let outcome = match self.run_step(&step) { + Ok(outcome) => outcome, + Err(e) => { + self.revert_speculative_schedules(&scheduled); + return Err(e); + } + }; + let result = match outcome { + WorkerStepOutcome::SpeculativeVerify(result) => result, + other => { + self.revert_speculative_schedules(&scheduled); + return Err(anyhow::anyhow!( + "speculative verify returned unexpected: {}", + other.kind() + )); + } + }; + if result.requests.len() != plan.requests.len() { + self.revert_speculative_schedules(&scheduled); + return Err(anyhow::anyhow!( + "speculative verify returned {} request results for {} requests", + result.requests.len(), + plan.requests.len() + )); + } + for (req, req_result) in plan.requests.iter().zip(&result.requests) { + if req.request_id != req_result.request_id { + self.revert_speculative_schedules(&scheduled); + return Err(anyhow::anyhow!( + "speculative verify returned request {:?} for {:?}", + req_result.request_id, + req.request_id + )); + } + } + + // Commit the accepted prefix of each request's KV and free the rest. + // On a mid-loop failure, only the not-yet-applied requests roll back + // (applied ones already committed and cannot be reverted). + let mut applied = Vec::with_capacity(result.requests.len()); + for req_result in &result.requests { + let rkv = self + .request_kvs + .get_mut(&req_result.request_id) + .expect("request must exist after speculative verify"); + if let Err(e) = rkv.apply_speculative(&req_result.accepted_tokens, self.kv_mgr.pool()) { + let unapplied = scheduled + .iter() + .copied() + .filter(|request_id| !applied.contains(request_id)) + .collect::>(); + self.revert_speculative_schedules(&unapplied); + return Err(anyhow::anyhow!( + "apply_speculative failed for {:?}: {e}", + req_result.request_id + )); + } + applied.push(req_result.request_id); + } + for req_result in &result.requests { + self.save_sealed_blocks(req_result.request_id); + } + + Ok(result) + } + + pub(super) fn execute_speculative_draft_impl( + &mut self, + plan: DraftPlan<'_>, + ) -> Result { + anyhow::ensure!( + self.speculative.is_some(), + "speculative draft requested but no draft model is loaded" + ); + for req in plan.requests { + anyhow::ensure!( + req.params.is_greedy(), + "speculative draft currently supports greedy sampling only" + ); + anyhow::ensure!( + self.dflash_ready_requests.contains(&req.request_id), + "speculative draft requested before DFlash state is ready for {:?}", + req.request_id + ); + } + let step = StepCommand::SpeculativeDraft { + requests: plan.requests.to_vec(), + }; + match self.run_step(&step)? { + WorkerStepOutcome::SpeculativeDraft(result) => Ok(result), + other => Err(anyhow::anyhow!( + "speculative draft returned unexpected: {}", + other.kind() + )), + } + } + + /// Roll back speculative KV reservations. Each request reverts its own + /// reservation independently (the LIFO block discipline is intra-sequence, + /// via RAII); the reverse order here just mirrors schedule order and is + /// cosmetic. + fn revert_speculative_schedules(&mut self, request_ids: &[RequestId]) { + for request_id in request_ids.iter().rev().copied() { + let Some(rkv) = self.request_kvs.get_mut(&request_id) else { + log::warn!( + "missing RequestKv while reverting speculative schedule for {request_id:?}" + ); + continue; + }; + if let Err(error) = rkv.revert_schedule() { + log::warn!("failed to revert speculative schedule for {request_id:?}: {error}"); + } + } + } +} diff --git a/openinfer-qwen3-4b/src/lib.rs b/openinfer-qwen3-4b/src/lib.rs index e0d82010..0931ed96 100644 --- a/openinfer-qwen3-4b/src/lib.rs +++ b/openinfer-qwen3-4b/src/lib.rs @@ -5,16 +5,18 @@ mod batch_decode_buffers; mod batch_decode_dag; pub mod batch_decode_trace; mod config; +mod dflash; mod executor; pub(crate) mod green_ctx; pub mod kernel_bench; mod lora; mod prefill; mod scheduler; +mod speculative; mod unified_forward; mod weights; -use std::path::Path; +use std::path::{Path, PathBuf}; use anyhow::Result; use log::{info, warn}; @@ -153,7 +155,7 @@ pub fn probe_model(model_path: &Path) -> Result> { /// Qwen3 startup policy — the TP→device mapping and the LoRA↔CUDA-Graph /// exclusion — and dispatches to the right low-level entry. That policy lives /// with the model instead of leaking into the server. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] pub struct Qwen3LaunchOptions { /// CUDA device for single-GPU loads (ignored when `tp_size > 1`). pub device_ordinal: usize, @@ -170,6 +172,9 @@ pub struct Qwen3LaunchOptions { /// How prefill and decode share the GPU (`--decode-overlap`). pub decode_overlap: DecodeOverlap, pub batch_invariant: bool, + /// `Some` enables DFlash speculative decoding with this drafter model. + /// Single-GPU only and mutually exclusive with LoRA and KV offload. + pub dflash_draft_model_path: Option, } /// Start the Qwen3 engine from server-facing [`Qwen3LaunchOptions`]. @@ -204,6 +209,10 @@ pub fn launch(model_path: &Path, options: Qwen3LaunchOptions) -> Result { info!( @@ -231,6 +240,7 @@ pub fn launch(model_path: &Path, options: Qwen3LaunchOptions) -> Result Result Result, ) -> Result { let EngineLoadOptions { enable_cuda_graph, @@ -281,6 +294,12 @@ pub fn start_engine_with_offload( .ok_or_else(|| anyhow::anyhow!("model path must be valid UTF-8"))?; ensure_batch_invariant_supported(decode_overlap, batch_invariant)?; apply_batch_invariant_policy(batch_invariant); + let dflash_draft_model_path = dflash_draft_model_path + .map(|path| { + path.to_str() + .ok_or_else(|| anyhow::anyhow!("DFlash draft model path must be valid UTF-8")) + }) + .transpose()?; scheduler::start_qwen3( model_path, enable_cuda_graph, @@ -291,6 +310,7 @@ pub fn start_engine_with_offload( max_prefill_tokens, memory_options, decode_overlap, + dflash_draft_model_path, ) } diff --git a/openinfer-qwen3-4b/src/prefill.rs b/openinfer-qwen3-4b/src/prefill.rs index e9d6163f..ddd33323 100644 --- a/openinfer-qwen3-4b/src/prefill.rs +++ b/openinfer-qwen3-4b/src/prefill.rs @@ -100,6 +100,23 @@ impl Qwen3Model { Ok(out) } + /// Embed a device-resident token buffer into a pre-allocated output, with + /// no host round-trip or allocation — used by the DFlash draft rollout's + /// graph-stable scratch. + pub(super) fn get_embeddings_batch_into( + &self, + token_ids_gpu: &cudarc::driver::CudaSlice, + out: &mut HiddenStates, + ) -> Result<()> { + anyhow::ensure!( + out.hidden_dim == self.config.hidden_size, + "embedding output hidden_dim {} does not match model hidden_size {}", + out.hidden_dim, + self.config.hidden_size + ); + ops::embedding_batch(&self.ctx, &self.embed_tokens, token_ids_gpu, out) + } + #[allow(clippy::too_many_arguments)] fn forward_layer_batch_paged( &self, @@ -338,6 +355,13 @@ impl Qwen3Model { /// /// If `echo` is true, also returns all-position logits as a /// `HiddenStates [vocab_size, total_tokens]` for prompt logprobs. + /// Batch prefill forward. + /// + /// `capture_layer_ids`, when set, copies the residual-stream hidden states + /// after the listed (strictly increasing) transformer layers into an extra + /// `[hidden_size * layers, total_tokens]` buffer returned as the third tuple + /// element. This feeds the DFlash draft model its target context; `None` + /// behaves identically to a plain prefill and returns `None` there. pub(crate) fn batch_prefill( &self, prompts: &[&[u32]], @@ -346,7 +370,8 @@ impl Qwen3Model { kv_buffer: &CudaSlice, layout: &KvLayout, echo: bool, - ) -> Result<(HiddenStates, Option)> { + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option, Option)> { let batch_size = prompts.len(); assert_eq!(batch_size, kv_views.len()); assert_eq!(batch_size, lora_adapters.len()); @@ -385,8 +410,14 @@ impl Qwen3Model { )?; // Forward through all layers - let hidden = - self.process_all_layers_batch_multi(hidden, layout, kv_buffer, &plan, &lora_groups)?; + let (hidden, captured_hidden) = self.process_all_layers_batch_multi( + hidden, + layout, + kv_buffer, + &plan, + &lora_groups, + capture_layer_ids, + )?; // All-position logits for echo (before we extract last-token logits) let all_logits = if echo { @@ -412,7 +443,7 @@ impl Qwen3Model { defer_drop(plan); } - Ok((logits, all_logits)) + Ok((logits, all_logits, captured_hidden)) } fn process_all_layers_batch_multi( @@ -422,12 +453,35 @@ impl Qwen3Model { kv_buffer: &cudarc::driver::CudaSlice, plan: &PrefillPagedPlan, lora_groups: &[DeviceLoraTokenGroup<'_>], - ) -> Result { + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option)> { let total_tokens = hidden.seq_len; let inter_dim = self.local_intermediate_size(); let q_dim = self.local_q_dim(); let kv_dim = self.local_kv_dim(); + let capture_layer_ids = capture_layer_ids.unwrap_or(&[]); + anyhow::ensure!( + capture_layer_ids.windows(2).all(|pair| pair[0] < pair[1]), + "target hidden capture layer ids must be strictly increasing" + ); + anyhow::ensure!( + capture_layer_ids + .iter() + .all(|&layer| layer < self.layers.len()), + "target hidden capture layer id out of range" + ); + let mut captured_hidden = if capture_layer_ids.is_empty() { + None + } else { + Some(HiddenStates::zeros( + &self.ctx, + self.config.hidden_size * capture_layer_ids.len(), + total_tokens, + )?) + }; + let mut next_capture = 0usize; + let mut bufs = PrefillBuffers::new( &self.ctx, self.config.hidden_size, @@ -448,6 +502,18 @@ impl Qwen3Model { lora_groups, &mut bufs, )?; + if capture_layer_ids.get(next_capture) == Some(&layer_idx) { + let out = captured_hidden + .as_mut() + .expect("capture buffer exists when ids are non-empty"); + ops::copy_hidden_rows_into( + &self.ctx, + &hidden, + out, + next_capture * self.config.hidden_size, + )?; + next_capture += 1; + } } // Defer drop of PrefillBuffers in SM-partition mode. @@ -455,6 +521,6 @@ impl Qwen3Model { defer_drop(bufs); } - Ok(hidden) + Ok((hidden, captured_hidden)) } } diff --git a/openinfer-qwen3-4b/src/scheduler.rs b/openinfer-qwen3-4b/src/scheduler.rs index 6476eab8..2fb5bdea 100644 --- a/openinfer-qwen3-4b/src/scheduler.rs +++ b/openinfer-qwen3-4b/src/scheduler.rs @@ -28,7 +28,9 @@ use openinfer_core::engine::{ use openinfer_core::sampler::SamplingParams; use self::effects::apply_effects; -use self::plan::{ExecutionArtifacts, ExecutionPlan, build_next_plan, execute_plan}; +use self::plan::{ + ExecutionArtifacts, ExecutionPlan, build_next_plan, execute_plan, should_speculative_decode, +}; use self::resolve::resolve_step; // ── Internal types ────────────────────────────────────────────────────── @@ -134,6 +136,7 @@ fn take_prefill_chunks( // ── Entry point ───────────────────────────────────────────────────────── +#[allow(clippy::too_many_arguments)] pub(crate) fn start_qwen3( model_path: &str, enable_cuda_graph: bool, @@ -144,6 +147,7 @@ pub(crate) fn start_qwen3( max_prefill_tokens: usize, memory_options: Qwen3MemoryOptions, decode_overlap: crate::DecodeOverlap, + dflash_draft_model_path: Option<&str>, ) -> Result { let mut executor = Qwen3Executor::from_runtime_with_lora_options( model_path, @@ -156,6 +160,12 @@ pub(crate) fn start_qwen3( )?; executor.set_no_prefix_cache(no_prefix_cache); executor.enable_decode_overlap(decode_overlap)?; + // Speculative decoding loads its draft model after the target is up (the + // draft is built against the target's embeddings/head) and forces the + // prefix cache off, so it must follow set_no_prefix_cache. + if let Some(draft_path) = dflash_draft_model_path { + executor.load_dflash_draft_model(draft_path)?; + } Ok(start_with_executor(executor, seed, max_prefill_tokens)) } @@ -482,7 +492,7 @@ fn scheduler_loop( take_prefill_chunks(&mut prefilling, max_prefill_tokens) }; - let Some(plan) = build_next_plan(!active.is_empty(), pending) else { + let Some(plan) = runtime_plan(&executor, &active, pending) else { continue; }; @@ -684,7 +694,7 @@ fn scheduler_loop_with_lora_control( continue; } - let Some(plan) = build_next_plan(!active.is_empty(), pending) else { + let Some(plan) = runtime_plan(&executor, &active, pending) else { continue; }; let failure_targets = failure_targets_for(&active, &plan); @@ -1055,6 +1065,26 @@ fn send_unknown_lora_rejection(req: &PendingRequest) { }); } +/// Choose the step plan, preferring a speculative-decode step when the whole +/// active batch is draft-ready. Prefill of new arrivals still takes priority — +/// a speculative step only runs when there is nothing to prefill, so the two +/// never mix in one step. +fn runtime_plan( + executor: &impl ModelExecutor, + active: &[ActiveRequestState], + pending: Vec, +) -> Option { + if should_speculative_decode(executor, active) { + if pending.is_empty() { + Some(ExecutionPlan::SpeculativeDecode) + } else { + Some(ExecutionPlan::Prefill { pending }) + } + } else { + build_next_plan(!active.is_empty(), pending) + } +} + fn failure_targets_for( active: &[ActiveRequestState], plan: &self::plan::ExecutionPlan, @@ -1067,6 +1097,9 @@ fn failure_targets_for( self::plan::ExecutionPlan::Decode => { targets.extend(active.iter().map(active_failure_target)); } + self::plan::ExecutionPlan::SpeculativeDecode => { + targets.extend(active.iter().map(active_failure_target)); + } self::plan::ExecutionPlan::Unified { pending } => { targets.extend(active.iter().map(active_failure_target)); targets.extend(pending.iter().map(pending_failure_target)); diff --git a/openinfer-qwen3-4b/src/scheduler/effects.rs b/openinfer-qwen3-4b/src/scheduler/effects.rs index 5a8eac92..11fd176d 100644 --- a/openinfer-qwen3-4b/src/scheduler/effects.rs +++ b/openinfer-qwen3-4b/src/scheduler/effects.rs @@ -69,6 +69,20 @@ pub(super) enum DecodeEffect { logprob: Option, completion_tokens: usize, }, + /// Commit several accepted speculative tokens and keep the request running. + EmitManyAndContinue { + request_id: RequestId, + tokens: Vec, + completion_tokens: usize, + }, + /// Commit several accepted speculative tokens, then finish — a stop token or + /// the max-output budget was hit partway through the accepted span. + EmitManyAndFinish { + request_id: RequestId, + tokens: Vec, + finish_reason: FinishReason, + completion_tokens: usize, + }, } pub(super) struct StepEffects { @@ -192,6 +206,81 @@ pub(super) fn apply_effects( req.generated_count = completion_tokens; } } + DecodeEffect::EmitManyAndContinue { + request_id, + tokens, + completion_tokens, + } => { + let Some(index) = active.iter().position(|req| req.request_id == request_id) else { + continue; + }; + let req = &mut active[index]; + let mut sent = true; + for &token in &tokens { + if req + .token_tx + .send(TokenEvent::Token { + id: token, + logprob: None, + }) + .is_err() + { + sent = false; + break; + } + } + if sent { + req.last_token = *tokens + .last() + .expect("EmitManyAndContinue must carry at least one token"); + req.generated_count = completion_tokens; + } else { + debug!( + "request dropped: client disconnected: request_id={:?} tokens_generated={}", + request_id, completion_tokens + ); + let _ = executor.drop_request(request_id); + to_retire.push(index); + } + } + DecodeEffect::EmitManyAndFinish { + request_id, + tokens, + finish_reason, + completion_tokens, + } => { + let Some(index) = active.iter().position(|req| req.request_id == request_id) else { + continue; + }; + let req = &active[index]; + debug!( + "request finished: request_id={:?} prompt_tokens={} completion_tokens={} finish_reason={:?}", + request_id, req.prompt_len, completion_tokens, finish_reason + ); + let mut sent = true; + for &token in &tokens { + if req + .token_tx + .send(TokenEvent::Token { + id: token, + logprob: None, + }) + .is_err() + { + sent = false; + break; + } + } + if sent { + let _ = req.token_tx.send(TokenEvent::Finished { + finish_reason, + prompt_tokens: req.prompt_len, + completion_tokens, + }); + } + let _ = executor.drop_request(request_id); + to_retire.push(index); + } } } to_retire.sort_unstable(); diff --git a/openinfer-qwen3-4b/src/scheduler/plan.rs b/openinfer-qwen3-4b/src/scheduler/plan.rs index 7163e982..ab5d72cd 100644 --- a/openinfer-qwen3-4b/src/scheduler/plan.rs +++ b/openinfer-qwen3-4b/src/scheduler/plan.rs @@ -5,12 +5,17 @@ use crate::executor::{ DecodePlan, DecodeResult, DecodeStepItem, ModelExecutor, PrefillPlan, PrefillResult, PrefillStepItem, UnifiedPlan, UnifiedResult, }; +use crate::speculative::{ + DraftPlan, DraftRequestResult, DraftStepItem, VerifyPlan, VerifyResult, VerifyStepItem, +}; use super::{ActiveRequestState, PendingRequest}; pub(super) enum ExecutionPlan { Prefill { pending: Vec }, Decode, + /// Draft + verify the whole active batch (all requests are draft-ready). + SpeculativeDecode, Unified { pending: Vec }, } @@ -26,6 +31,9 @@ pub(super) enum ExecutionArtifacts { Decode { result: DecodeResult, }, + SpeculativeDecode { + verify: VerifyResult, + }, Unified { pending: Vec, result: UnifiedResult, @@ -86,6 +94,22 @@ pub(super) fn execute_plan( sort_decode_results(&mut result.requests); Ok(ExecutionArtifacts::Decode { result }) } + ExecutionPlan::SpeculativeDecode => { + // Two executor calls per step: draft proposes K tokens per request, + // then a single target forward verifies the K+1 span. Both index by + // request_id; sorting keeps draft and verify results aligned. + let draft_requests = build_speculative_draft_items(active); + let mut draft = executor.execute_speculative_draft(DraftPlan { + requests: &draft_requests, + })?; + draft.requests.sort_by_key(|result| result.request_id); + let verify_requests = build_speculative_verify_items(active, &draft.requests); + let mut verify = executor.execute_speculative_verify(VerifyPlan { + requests: &verify_requests, + })?; + verify.requests.sort_by_key(|result| result.request_id); + Ok(ExecutionArtifacts::SpeculativeDecode { verify }) + } ExecutionPlan::Unified { pending } => { let scheduled_at_unix_s = openinfer_core::engine::unix_now_s(); let pending_indices: Vec = (0..pending.len()).collect(); @@ -108,6 +132,52 @@ pub(super) fn execute_plan( } } +/// All-or-nothing: speculate the whole active batch only when every request is +/// draft-ready and greedy (no LoRA, no logprobs). A single non-ready request +/// falls the batch back to plain decode rather than running a mixed step. +pub(super) fn should_speculative_decode( + executor: &impl ModelExecutor, + active: &[ActiveRequestState], +) -> bool { + executor.speculative_enabled() + && !active.is_empty() + && active.iter().all(|req| { + executor.speculative_request_ready(req.request_id) + && req.lora_adapter.is_none() + && req.logprobs == 0 + && req.params.is_greedy() + }) +} + +fn build_speculative_draft_items(active: &[ActiveRequestState]) -> Vec { + active + .iter() + .map(|r| DraftStepItem::new(r.request_id, r.last_token, r.params)) + .collect() +} + +fn build_speculative_verify_items( + active: &[ActiveRequestState], + draft_results: &[DraftRequestResult], +) -> Vec { + draft_results + .iter() + .map(|draft| { + let active = active + .iter() + .find(|req| req.request_id == draft.request_id) + .expect("draft request_id must exist in active set"); + // Clamp the verify span to the request's remaining output budget so + // a long accepted run can't overshoot max_tokens. + let remaining = active.max_tokens.saturating_sub(active.generated_count); + assert!(remaining > 0, "active request must have output budget"); + let mut token_ids = draft.token_ids.clone(); + token_ids.truncate(remaining); + VerifyStepItem::new(draft.request_id, token_ids, active.params) + }) + .collect() +} + fn build_prefill_items(pending: &[PendingRequest], indices: &[usize]) -> Vec { indices .iter() @@ -179,6 +249,37 @@ mod tests { } } + fn active(generated_count: usize, max_tokens: usize) -> ActiveRequestState { + let (token_tx, _rx) = openinfer_core::engine::TokenSink::standalone(); + ActiveRequestState { + request_id: RequestId::new(7), + lora_adapter: None, + token_tx, + last_token: 42, + generated_count, + max_tokens, + prompt_len: 10, + params: SamplingParams::default(), + logprobs: 0, + } + } + + #[test] + fn speculative_verify_items_clamp_to_remaining_output_budget() { + let active = [active(24, 32)]; + let draft = DraftRequestResult { + request_id: RequestId::new(7), + token_ids: (0..16).collect(), + }; + + let verify = build_speculative_verify_items(&active, &[draft]); + + assert_eq!(verify.len(), 1); + // 32 - 24 = 8 remaining → the 16-token span truncates to 8. + assert_eq!(verify[0].as_slice().len(), 8); + assert_eq!(verify[0].as_slice(), (0..8).collect::>()); + } + // The plan selector is the whole batch-formation policy: what the scheduler // does each tick is fully determined by (have_active, has_pending). Pin the // 2×2 truth table so a policy regression can't slip through silently. diff --git a/openinfer-qwen3-4b/src/scheduler/resolve.rs b/openinfer-qwen3-4b/src/scheduler/resolve.rs index 1f91dcd5..61058b1a 100644 --- a/openinfer-qwen3-4b/src/scheduler/resolve.rs +++ b/openinfer-qwen3-4b/src/scheduler/resolve.rs @@ -1,4 +1,5 @@ use crate::executor::{DecodeRequestResult, ModelExecutor, PrefillRequestResult}; +use crate::speculative::VerifyRequestResult; use openinfer_core::engine::FinishReason; use super::effects::{DecodeEffect, PendingEffect, PromptEchoEffect, ScheduledEffect, StepEffects}; @@ -22,6 +23,12 @@ pub(super) fn resolve_step( pending: Vec::new(), decode: resolve_decode_outputs(executor, active, &result.requests), }, + ExecutionArtifacts::SpeculativeDecode { verify } => StepEffects { + scheduled: Vec::new(), + prompt_echoes: Vec::new(), + pending: Vec::new(), + decode: resolve_speculative_outputs(executor, active, &verify.requests), + }, ExecutionArtifacts::Unified { pending, result, @@ -39,6 +46,54 @@ pub(super) fn resolve_step( } } +/// Turn each request's accepted speculative span into a decode effect. A span +/// commits 1..=K+1 tokens at once; we walk it in order so a stop token or the +/// max-output budget truncates exactly where it lands (the executor already +/// suppressed nothing — stop handling lives here, mirroring single-token decode). +pub(super) fn resolve_speculative_outputs( + executor: &impl ModelExecutor, + active: &[ActiveRequestState], + request_results: &[VerifyRequestResult], +) -> Vec { + request_results + .iter() + .map(|result| { + let req = active + .iter() + .find(|req| req.request_id == result.request_id) + .expect("speculative request_id must exist in active set"); + let mut emitted = Vec::new(); + let mut completion_tokens = req.generated_count; + for &token in &result.accepted_tokens { + completion_tokens += 1; + let is_eos = !req.params.ignore_eos && executor.is_stop_token(token); + if is_eos { + return DecodeEffect::EmitManyAndFinish { + request_id: result.request_id, + tokens: emitted, + finish_reason: FinishReason::Stop, + completion_tokens, + }; + } + emitted.push(token); + if completion_tokens >= req.max_tokens { + return DecodeEffect::EmitManyAndFinish { + request_id: result.request_id, + tokens: emitted, + finish_reason: FinishReason::Length, + completion_tokens, + }; + } + } + DecodeEffect::EmitManyAndContinue { + request_id: result.request_id, + tokens: emitted, + completion_tokens, + } + }) + .collect() +} + fn resolve_prefill_outputs( executor: &impl ModelExecutor, pending: Vec, diff --git a/openinfer-qwen3-4b/src/scheduler/tests.rs b/openinfer-qwen3-4b/src/scheduler/tests.rs index 72eec8d2..0add4798 100644 --- a/openinfer-qwen3-4b/src/scheduler/tests.rs +++ b/openinfer-qwen3-4b/src/scheduler/tests.rs @@ -28,6 +28,7 @@ struct FakeExecutor { loaded_lora_adapters: HashSet, dropped: Arc>>, prefetch_offers: Arc>>, + stop_token: Option, } impl FakeExecutor { @@ -44,9 +45,15 @@ impl FakeExecutor { loaded_lora_adapters: HashSet::new(), dropped, prefetch_offers: Arc::new(Mutex::new(Vec::new())), + stop_token: None, } } + fn with_stop_token(mut self, token: u32) -> Self { + self.stop_token = Some(token); + self + } + fn with_decode_failure(mut self) -> Self { self.fail_decode_once = true; self @@ -125,8 +132,8 @@ impl ModelExecutor for FakeExecutor { self.available_blocks } - fn is_stop_token(&self, _token_id: u32) -> bool { - false + fn is_stop_token(&self, token_id: u32) -> bool { + self.stop_token == Some(token_id) } fn drop_request(&mut self, request_id: RequestId) -> Result<()> { @@ -174,6 +181,7 @@ impl ModelExecutor for FakeExecutor { .iter() .map(|req| self.fake_prefill_result(req)) .collect(), + dflash_context_captured_requests: Vec::new(), }) } @@ -1019,3 +1027,140 @@ fn lora_control_waits_until_scheduler_idle() { .expect_err("adapter load should be a stub error"); assert!(matches!(error, EngineControlError::OperationFailed(_))); } + +// ── Speculative span resolution (multi-token emission) ────────────────────── +// resolve_speculative_outputs walks an accepted span [t_0, .., t_m] and decides, +// per request, whether to emit-and-continue or truncate at a stop token / the +// max-output budget. These were GPU-test-only; the truth table below pins the +// branch behaviour with the existing FakeExecutor (only is_stop_token matters). + +use crate::speculative::VerifyRequestResult; +use openinfer_core::engine::FinishReason; + +fn spec_active(id: u64, generated_count: usize, max_tokens: usize, ignore_eos: bool) -> ActiveRequestState { + let (token_tx, _rx) = TokenSink::standalone(); + ActiveRequestState { + request_id: RequestId(id), + lora_adapter: None, + token_tx, + last_token: 1, + generated_count, + max_tokens, + prompt_len: 16, + params: SamplingParams { + ignore_eos, + ..SamplingParams::default() + }, + logprobs: 0, + } +} + +fn spec_result(id: u64, accepted: Vec) -> VerifyRequestResult { + VerifyRequestResult { + request_id: RequestId(id), + matched_draft_tokens: accepted.len().saturating_sub(1), + accepted_tokens: accepted, + } +} + +const SPEC_EOS: u32 = 99; + +#[test] +fn speculative_full_span_accept_continues() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 3, 100, false)]; + let results = [spec_result(1, vec![10, 11, 12, 13])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [effects::DecodeEffect::EmitManyAndContinue { request_id, tokens, completion_tokens }] => { + assert_eq!(*request_id, RequestId(1)); + assert_eq!(tokens, &vec![10, 11, 12, 13]); + assert_eq!(*completion_tokens, 3 + 4, "completion = prior generated + span len"); + } + _ => panic!("expected EmitManyAndContinue"), + } +} + +#[test] +fn speculative_stop_token_midspan_finishes_and_suppresses_eos() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 5, 100, false)]; + // EOS lands at span position 2; tokens before it are emitted, EOS is not. + let results = [spec_result(1, vec![10, 11, SPEC_EOS, 13])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [effects::DecodeEffect::EmitManyAndFinish { tokens, finish_reason, completion_tokens, .. }] => { + assert_eq!(tokens, &vec![10, 11], "EOS itself is suppressed from emission"); + assert!(matches!(finish_reason, FinishReason::Stop)); + assert_eq!(*completion_tokens, 5 + 3, "the EOS still counts toward completion length"); + } + _ => panic!("expected EmitManyAndFinish(Stop)"), + } +} + +#[test] +fn speculative_stop_token_at_span_start_emits_nothing() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 5, 100, false)]; + let results = [spec_result(1, vec![SPEC_EOS, 11, 12])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [effects::DecodeEffect::EmitManyAndFinish { tokens, finish_reason, completion_tokens, .. }] => { + assert!(tokens.is_empty(), "stop at position 0 emits no tokens"); + assert!(matches!(finish_reason, FinishReason::Stop)); + assert_eq!(*completion_tokens, 5 + 1); + } + _ => panic!("expected EmitManyAndFinish(Stop)"), + } +} + +#[test] +fn speculative_max_tokens_truncates_midspan() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + // generated 8, budget 10 -> only 2 more tokens fit; the span offers 4. + let active = [spec_active(1, 8, 10, false)]; + let results = [spec_result(1, vec![10, 11, 12, 13])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [effects::DecodeEffect::EmitManyAndFinish { tokens, finish_reason, completion_tokens, .. }] => { + assert_eq!(tokens, &vec![10, 11], "the budget-hitting token is emitted, the rest dropped"); + assert!(matches!(finish_reason, FinishReason::Length)); + assert_eq!(*completion_tokens, 10, "completion stops exactly at max_tokens"); + } + _ => panic!("expected EmitManyAndFinish(Length)"), + } +} + +#[test] +fn speculative_ignore_eos_does_not_stop() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 0, 100, true)]; + let results = [spec_result(1, vec![SPEC_EOS, SPEC_EOS])]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + match &effects[..] { + [effects::DecodeEffect::EmitManyAndContinue { tokens, .. }] => { + assert_eq!(tokens, &vec![SPEC_EOS, SPEC_EOS], "ignore_eos passes stop tokens through"); + } + _ => panic!("expected EmitManyAndContinue"), + } +} + +#[test] +fn speculative_resolves_each_request_independently() { + let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); + let active = [spec_active(1, 0, 100, false), spec_active(2, 0, 100, false)]; + let results = [ + spec_result(1, vec![10, 11]), // continues + spec_result(2, vec![20, SPEC_EOS]), // finishes on EOS + ]; + let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); + assert!(matches!( + &effects[0], + effects::DecodeEffect::EmitManyAndContinue { request_id, .. } if *request_id == RequestId(1) + )); + assert!(matches!( + &effects[1], + effects::DecodeEffect::EmitManyAndFinish { request_id, finish_reason: FinishReason::Stop, .. } + if *request_id == RequestId(2) + )); +} diff --git a/openinfer-qwen3-4b/src/speculative.rs b/openinfer-qwen3-4b/src/speculative.rs new file mode 100644 index 00000000..3af59e7b --- /dev/null +++ b/openinfer-qwen3-4b/src/speculative.rs @@ -0,0 +1,260 @@ +//! Method-agnostic core of **greedy** speculative decoding for Qwen3. +//! +//! Speculative decoding is an optimistic-concurrency transaction over the +//! decode loop: *propose* a span of `K` cheap draft tokens, *verify* them with a +//! single target forward over the `K + 1` span positions, *accept* the longest +//! prefix the target agrees with (plus one bonus token), then *commit* the +//! accepted KV and roll back the rejected draft KV. Only the **propose** step +//! varies between methods (n-gram lookup, DFlash draft model, EAGLE, …); verify, +//! accept, and the KV transaction are shared. +//! +//! This module owns the shared half. The draft/verify boundary is a **pure +//! token span** — a model proposer's hidden states never cross it; they stay +//! inside the proposer (see [`crate::dflash`]). DFlash is the only proposer +//! today and is kept concrete; a proposer trait is deferred until a second +//! implementation (n-gram / EAGLE) validates the shape. +//! +//! What is *not* generic yet: [`accept_greedy`] returns argmax-based acceptance, +//! which is the *greedy* rule. Sampling-correct acceptance would need the target +//! and draft distributions, touching both the verify path and the proposer — so +//! it is left until a sampling method actually lands. + +use anyhow::Result; + +use crate::executor::RequestId; +use openinfer_core::sampler::SamplingParams; + +/// One request's verify span: the current dangling token followed by the draft +/// candidates (`token_ids[0]` is the confirmed last token, `token_ids[1..]` are +/// the `K` drafts). Token-only by construction — the proposer that produced the +/// drafts keeps any hidden state to itself. +#[derive(Clone)] +pub(crate) struct VerifyStepItem { + pub(crate) request_id: RequestId, + pub(crate) token_ids: Vec, + pub(crate) params: SamplingParams, +} + +impl VerifyStepItem { + pub(crate) fn new(request_id: RequestId, token_ids: Vec, params: SamplingParams) -> Self { + Self { + request_id, + token_ids, + params, + } + } + + pub(crate) fn as_slice(&self) -> &[u32] { + &self.token_ids + } +} + +pub(crate) struct VerifyPlan<'a> { + pub requests: &'a [VerifyStepItem], +} + +#[derive(Clone, Debug)] +pub(crate) struct VerifyRequestResult { + pub request_id: RequestId, + /// Number of draft candidates accepted before the posterior bonus. + pub matched_draft_tokens: usize, + /// Tokens to commit: the accepted draft prefix followed by the target's + /// posterior token at the first mismatch (or the block-end continuation + /// when every draft is accepted). Always `1..=K + 1` tokens, so a verify + /// step always makes at least one token of progress. The scheduler still + /// owns stop-token suppression before client emission. + pub accepted_tokens: Vec, +} + +pub(crate) struct VerifyResult { + pub requests: Vec, +} + +/// One request's draft request: the proposer continues from `current_token`. +#[derive(Clone)] +pub(crate) struct DraftStepItem { + pub(crate) request_id: RequestId, + pub(crate) current_token: u32, + pub(crate) params: SamplingParams, +} + +impl DraftStepItem { + pub(crate) fn new(request_id: RequestId, current_token: u32, params: SamplingParams) -> Self { + Self { + request_id, + current_token, + params, + } + } +} + +pub(crate) struct DraftPlan<'a> { + pub requests: &'a [DraftStepItem], +} + +#[derive(Clone, Debug)] +pub(crate) struct DraftRequestResult { + pub request_id: RequestId, + /// Verify-span tokens: current dangling token first, then draft candidates. + pub token_ids: Vec, +} + +pub(crate) struct DraftResult { + pub requests: Vec, +} + +/// Greedy speculative acceptance — the shared seam every method reuses. +/// +/// * `proposed` — the `K` candidate tokens from the proposer. +/// * `target_argmax` — the target model's greedy token at each of the `K + 1` +/// verify positions. `target_argmax[i]` is the model's prediction *after* +/// consuming verify input `i`; `target_argmax[0]` follows the last confirmed +/// token and `target_argmax[K]` is the model's own continuation after the +/// whole candidate run. +/// +/// Returns the longest accepted prefix of `proposed` followed by exactly one +/// model token (the correction at the first divergence, or the bonus +/// continuation when every candidate is accepted) — always `1..=K + 1` tokens. +/// +/// # Panics +/// Panics (debug builds) if `target_argmax.len() != proposed.len() + 1`. +#[must_use] +pub(crate) fn accept_greedy(proposed: &[u32], target_argmax: &[u32]) -> Vec { + debug_assert_eq!( + target_argmax.len(), + proposed.len() + 1, + "verify must produce one greedy token per candidate plus a bonus" + ); + let n = num_accepted(proposed, target_argmax); + let mut committed = Vec::with_capacity(n + 1); + committed.extend_from_slice(&proposed[..n]); + // The model's own token at the first divergence (or the bonus continuation + // when the whole run was accepted). `n <= proposed.len() < target_argmax.len()` + // so this index is always valid. + committed.push(target_argmax[n]); + committed +} + +/// Length of the accepted prefix: leading drafts whose token matches the +/// target's argmax. +fn num_accepted(proposed: &[u32], target_argmax: &[u32]) -> usize { + let mut i = 0; + while i < proposed.len() && proposed[i] == target_argmax[i] { + i += 1; + } + i +} + +/// Batched greedy acceptance over a verify forward's flattened per-position +/// argmax. `target_tokens` is the concatenation of each request's `K + 1` +/// posterior columns, in `requests` order. Each request applies the shared +/// [`accept_greedy`] over its own span. +pub(crate) fn build_verify_results( + requests: &[VerifyStepItem], + target_tokens: &[u32], +) -> Result> { + let mut outputs = Vec::with_capacity(requests.len()); + let mut offset = 0usize; + for req in requests { + let span_len = req.token_ids.len(); + anyhow::ensure!( + span_len > 0, + "speculative verify request {:?} has an empty verify span", + req.request_id + ); + let end = offset + span_len; + anyhow::ensure!( + end <= target_tokens.len(), + "speculative target-token result is shorter than the verify span" + ); + let posterior = &target_tokens[offset..end]; + // proposed = the K drafts (span minus the leading confirmed token); + // posterior = the K + 1 argmax columns. accept_greedy ties them together. + let accepted_tokens = accept_greedy(&req.token_ids[1..], posterior); + outputs.push(VerifyRequestResult { + request_id: req.request_id, + matched_draft_tokens: accepted_tokens.len() - 1, + accepted_tokens, + }); + offset = end; + } + anyhow::ensure!( + offset == target_tokens.len(), + "unused speculative target-token result columns: used {offset}, total {}", + target_tokens.len() + ); + Ok(outputs) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn accepts_full_run_plus_bonus() { + let proposed = [10u32, 11, 12]; + let argmax = [10u32, 11, 12, 13]; + assert_eq!(accept_greedy(&proposed, &argmax), vec![10, 11, 12, 13]); + assert_eq!(num_accepted(&proposed, &argmax), 3); + } + + #[test] + fn accepts_prefix_then_correction() { + let proposed = [10u32, 11, 99]; + let argmax = [10u32, 11, 22, 33]; + assert_eq!(accept_greedy(&proposed, &argmax), vec![10, 11, 22]); + assert_eq!(num_accepted(&proposed, &argmax), 2); + } + + #[test] + fn rejects_first_candidate_commits_one() { + let proposed = [10u32, 11, 12]; + let argmax = [7u32, 8, 9, 10]; + assert_eq!(accept_greedy(&proposed, &argmax), vec![7]); + assert_eq!(num_accepted(&proposed, &argmax), 0); + } + + #[test] + fn empty_proposal_commits_model_token() { + let proposed: [u32; 0] = []; + let argmax = [42u32]; + assert_eq!(accept_greedy(&proposed, &argmax), vec![42]); + assert_eq!(num_accepted(&proposed, &argmax), 0); + } + + #[test] + fn always_commits_at_least_one_token() { + let proposed = [1u32, 2]; + let argmax = [9u32, 9, 9]; + assert!(!accept_greedy(&proposed, &argmax).is_empty()); + } + + #[test] + fn batched_accepts_matching_prefix_plus_posterior_bonus() { + let req = VerifyStepItem::new(RequestId(7), vec![10, 11, 12, 13], SamplingParams::default()); + let results = build_verify_results(&[req], &[11, 12, 99, 100]).expect("verify results"); + assert_eq!(results.len(), 1); + assert_eq!(results[0].request_id, RequestId(7)); + assert_eq!(results[0].matched_draft_tokens, 2); + assert_eq!(results[0].accepted_tokens, vec![11, 12, 99]); + } + + #[test] + fn batched_all_match_still_adds_block_end_posterior() { + let req = VerifyStepItem::new(RequestId(8), vec![20, 21, 22], SamplingParams::default()); + let results = build_verify_results(&[req], &[21, 22, 23]).expect("verify results"); + assert_eq!(results[0].matched_draft_tokens, 2); + assert_eq!(results[0].accepted_tokens, vec![21, 22, 23]); + } + + #[test] + fn batched_multi_request_splits_columns_by_span() { + let a = VerifyStepItem::new(RequestId(1), vec![5, 6], SamplingParams::default()); + let b = VerifyStepItem::new(RequestId(2), vec![7, 8, 9], SamplingParams::default()); + // a: posterior [6, 100] -> accept draft 6, bonus 100. b: posterior [8, 77, 0] + // -> accept draft 8, correction 77. + let results = build_verify_results(&[a, b], &[6, 100, 8, 77, 0]).expect("verify results"); + assert_eq!(results[0].accepted_tokens, vec![6, 100]); + assert_eq!(results[1].accepted_tokens, vec![8, 77]); + } +} diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs new file mode 100644 index 00000000..b94a526d --- /dev/null +++ b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs @@ -0,0 +1,362 @@ +//! DFlash speculative-decoding losslessness gate. +//! +//! Greedy speculative decoding must be *lossless*: every draft is verified by a +//! target forward and only the matching-argmax prefix (plus one bonus) is +//! committed, so the accepted tokens are the target model's own greedy +//! continuation. The catch is pure numerics — the verify path runs the +//! *prefill* attention kernel over the K+1 span while a plain decode runs the +//! *decode* kernel, and the two differ by ~1 bf16 ULP. On a near-tie that flips +//! one argmax, and from there two greedy runs fan out completely. +//! +//! So an exact `spec == baseline` token match is the wrong gate: it false-fails +//! on a benign tie flip. We use a *regret* test like `hf_golden_gate`: at the +//! first position the two sequences disagree (where they still share an +//! identical context, so the comparison is valid) we ask how far below the +//! argmax the speculative pick sits — measured *in the prefill kernel's own +//! distribution*, because that is the kernel the verify path runs. A re-prefill +//! of the shared context (`prefill_next`) gives that reference distribution. +//! The verify path's committed KV is built incrementally across batched +//! speculative spans, while a one-shot prefill builds it in a single forward; +//! the two K/V differ by a few bf16 ULP, so on a near-tie the argmax flips. +//! Within `MARGIN_TOL` of the prefill argmax ⇒ a benign numerical tie. Clearly +//! worse (or outside the prefill top-K) ⇒ the verify/accept/capture logic chose +//! a token the forward never favored — a real bug. A systematic bug corrupts +//! the non-tie positions too, so it cannot hide behind the tie band. +//! +//! (Empirically the one prompt that flips — "The capital of France is" — sits on +//! a Germany-vs-Paris near-tie: the prefill kernel scores them -0.71 vs -0.83, +//! a 0.12-nat gap, well inside `MARGIN_TOL`. The other four prompts are bit +//! identical. A real verify bug would not single out the one degenerate prompt.) +//! +//! The baseline runs with logprobs on (plain decode); the speculative engine +//! runs with logprobs off (logprobs force the spec path off by design), so it +//! reports chosen tokens only — exactly what the regret check needs. +//! +//! Runs the two engines sequentially (baseline dropped before the speculative +//! engine loads) so only one Qwen3-4B is resident at a time. +//! +//! Requires a CUDA GPU, Qwen3-4B weights, and the DFlash drafter. Set +//! `OPENINFER_TEST_MODEL_PATH` (target) and `OPENINFER_DFLASH_TEST_MODEL_PATH` +//! (drafter); skips cleanly when either is absent. + +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use openinfer_core::engine::{EngineHandle, GenerateRequest, TokenEvent, TokenSink}; +use openinfer_core::sampler::SamplingParams; +use openinfer_qwen3_4b::{ + DecodeOverlap, Qwen3LaunchOptions, Qwen3MemoryOptions, Qwen3OffloadOptions, + DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES, DEFAULT_MAX_PREFILL_TOKENS, +}; + +mod common; + +const MODEL_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B"); +const DRAFT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B-DFlash-b16"); +const GENERATED_TOKENS: usize = 64; +/// Top-K logprobs requested from the baseline; wide enough that the speculative +/// pick is in the set on any real tie (a pick outside the top-K is itself a red +/// flag the gate should catch). +const LOGPROBS: usize = 20; +/// Max acceptable regret: how far below the baseline's argmax (in the baseline's +/// own logprobs) the speculative pick may sit at the divergence point. ~3 bf16 +/// ULP at typical logit magnitudes — mirrors `hf_golden_gate`'s `MARGIN_TOL`. +const MARGIN_TOL: f32 = 0.20; + +fn target_path_or_skip() -> Option { + match std::env::var("OPENINFER_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => Some(MODEL_PATH.to_string()), + Err(_) => { + eprintln!( + "skipping dflash gate: {MODEL_PATH}/config.json missing; set OPENINFER_TEST_MODEL_PATH" + ); + None + } + } +} + +fn draft_path_or_skip() -> Option { + match std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => Some(DRAFT_PATH.to_string()), + Err(_) => { + eprintln!( + "skipping dflash gate: {DRAFT_PATH}/config.json missing; set OPENINFER_DFLASH_TEST_MODEL_PATH" + ); + None + } + } +} + +fn launch_options(draft: Option) -> Qwen3LaunchOptions { + Qwen3LaunchOptions { + device_ordinal: 0, + tp_size: 1, + cuda_graph: true, + offload: Qwen3OffloadOptions::disabled(), + // The speculative engine forces the prefix cache off; match it on the + // baseline so both take the same cold prefill path. + no_prefix_cache: true, + max_prefill_tokens: DEFAULT_MAX_PREFILL_TOKENS, + memory: Qwen3MemoryOptions::new(0.85, DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES) + .validate() + .expect("valid memory options"), + lora: None, + decode_overlap: DecodeOverlap::Off, + dflash_draft_model_path: draft, + } +} + +/// One decoded position: the chosen token and (when requested) the top-K +/// `(token, logprob)` distribution that produced it. +struct Step { + id: u32, + top_logprobs: Vec<(u32, f32)>, +} + +/// Submit one greedy request and collect the decoded steps until `Finished`. +fn generate(handle: &EngineHandle, prompt_tokens: Vec, logprobs: usize) -> Vec { + let (token_tx, mut rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens, + params: SamplingParams::default(), + max_tokens: GENERATED_TOKENS, + lora_adapter: None, + token_tx, + logprobs, + echo: false, + }) + .expect("submit failed"); + + let mut steps = Vec::new(); + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { id, logprob }) => steps.push(Step { + id, + top_logprobs: logprob.map(|lp| lp.top_logprobs).unwrap_or_default(), + }), + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Finished { .. }) => return steps, + Some(TokenEvent::Error { message, .. }) => panic!("generation failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => panic!("generation rejected: {message}"), + None => panic!("scheduler channel closed without Finished"), + } + } +} + +/// Prefill `context` (echo) and return the next-token distribution the *prefill* +/// kernel produces — the kernel the speculative verify path also uses. This is +/// the reference the spec pick should match (vs the plain-decode baseline, whose +/// kernel resolves bifurcation ties to the other side). Returns the first +/// generated token's `(id, top_logprobs)`. +fn prefill_next(handle: &EngineHandle, context: Vec, logprobs: usize) -> Step { + let (token_tx, mut rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens: context, + params: SamplingParams::default(), + max_tokens: 1, + lora_adapter: None, + token_tx, + logprobs, + echo: true, + }) + .expect("submit failed"); + + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { id, logprob }) => { + return Step { + id, + top_logprobs: logprob.map(|lp| lp.top_logprobs).unwrap_or_default(), + } + } + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Finished { .. }) => panic!("echo prefill finished without a token"), + Some(TokenEvent::Error { message, .. }) => panic!("prefill failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => panic!("prefill rejected: {message}"), + None => panic!("scheduler channel closed without a token"), + } + } +} + +#[test] +fn dflash_speculative_greedy_matches_plain_greedy() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + return; + }; + + let prompts = [ + "The capital of France is", + "Here is a short story about a dragon. Once upon a time", + "def fibonacci(n):", + "Q: What is 17 multiplied by 23? A: Let's think step by step.", + "The three primary colors are", + ]; + + let tokenizer = common::load_tokenizer(&model_path); + let encoded: Vec> = prompts + .iter() + .map(|p| tokenizer.encode(p, false).expect("encode failed")) + .collect(); + + // 1. Baseline: plain greedy decode (speculative off), with logprobs so the + // regret check has the reference distribution at the divergence point. + let baseline: Vec> = { + let handle = openinfer_qwen3_4b::launch(Path::new(&model_path), launch_options(None)) + .expect("failed to start baseline engine"); + let out = encoded + .iter() + .map(|t| generate(&handle, t.clone(), LOGPROBS)) + .collect(); + drop(handle); + // Let the scheduler thread tear down and free GPU memory before the + // speculative engine loads the same 8 GB target. + std::thread::sleep(Duration::from_secs(2)); + out + }; + + // 2. Speculative: DFlash draft + verify (logprobs off ⇒ spec path active). + // Keep the engine alive through analysis: at a divergence we re-prefill + // the shared context to read the prefill-kernel reference (the kernel the + // verify path uses), which the plain-decode baseline cannot provide. + let handle = openinfer_qwen3_4b::launch( + Path::new(&model_path), + launch_options(Some(PathBuf::from(&draft_path))), + ) + .expect("failed to start speculative engine"); + + let mut failures = Vec::new(); + for (i, prompt) in prompts.iter().enumerate() { + let base = &baseline[i]; + let spec = generate(&handle, encoded[i].clone(), 0); + let matched = base + .iter() + .zip(&spec) + .take_while(|(b, s)| b.id == s.id) + .count(); + + // Identical sequences (or one a prefix of the other): perfectly lossless. + if matched == base.len().min(spec.len()) { + eprintln!( + "prompt {i} ({prompt:?}): {matched}/{} tokens identical (100% lossless)", + base.len() + ); + continue; + } + + let spec_id = spec[matched].id; + let decode_argmax = base[matched].top_logprobs[0].0; + + // Diagnostic: show the exact branch point. + { + let lo = matched.saturating_sub(2); + let hi = (matched + 3).min(base.len()).min(spec.len()); + let base_ids: Vec = base[..hi].iter().map(|s| s.id).collect(); + let spec_ids: Vec = spec[..hi].iter().map(|s| s.id).collect(); + eprintln!(" [diag] prompt {i} matched={matched}"); + eprintln!( + " [diag] context+gen base ids {:?} = {:?}", + &base_ids, + tokenizer.decode(&base_ids, false).unwrap_or_default() + ); + eprintln!( + " [diag] base[{lo}..{hi}] = {:?}", + base[lo..hi] + .iter() + .map(|s| (s.id, tokenizer.decode(&[s.id], false).unwrap_or_default())) + .collect::>() + ); + eprintln!( + " [diag] spec[{lo}..{hi}] = {:?}", + spec[lo..hi] + .iter() + .map(|s| (s.id, tokenizer.decode(&[s.id], false).unwrap_or_default())) + .collect::>() + ); + let _ = spec_ids; + } + + // The verify path runs the prefill kernel, so the right reference for the + // spec pick is a plain *prefill* of the same shared context — not the + // plain-decode baseline, whose kernel resolves a bifurcation tie to the + // other side and amplifies the gap. Build that context from the matched + // tokens and ask what the prefill kernel predicts next. + let mut context = encoded[i].clone(); + context.extend(base[..matched].iter().map(|s| s.id)); + let prefill_ref = prefill_next(&handle, context, LOGPROBS); + + if prefill_ref.id == spec_id { + // Spec faithfully reproduced the prefill-kernel greedy pick; the + // divergence is purely the pre-existing prefill-vs-decode kernel gap + // at a near-tie (here decode→{decode_argmax}, prefill→{spec_id}). + let decode_lp = base[matched] + .top_logprobs + .iter() + .find(|(t, _)| *t == spec_id) + .map(|(_, lp)| base[matched].top_logprobs[0].1 - lp); + eprintln!( + "prompt {i} ({prompt:?}): kernel-gap flip at token {matched} — verify(prefill)→{spec_id}, \ + decode→{decode_argmax}; spec matches prefill greedy (decode-margin {:?}). Not a spec bug.", + decode_lp + ); + continue; + } + + // Spec's greedy pick differs from the prefill-kernel argmax too. The + // verify path builds its committed KV incrementally across batched + // speculative spans, while this reference prefill builds it in one + // shot; the two differ by a few bf16 ULP. On a near-tie that flips the + // argmax — benign. So the deciding question is *how far* below the + // prefill argmax the spec pick sits IN THE PREFILL KERNEL'S OWN + // distribution (the kernel the verify path uses). Within MARGIN_TOL ⇒ + // a numerical tie flip, not a bug. Clearly worse ⇒ the verify/accept + // logic picked a token the forward never favored — a real bug. + let prefill_regret = prefill_ref + .top_logprobs + .iter() + .find(|(t, _)| *t == spec_id) + .map(|(_, lp)| prefill_ref.top_logprobs[0].1 - lp); + + if let Some(regret) = prefill_regret { + if regret <= MARGIN_TOL { + eprintln!( + "prompt {i} ({prompt:?}): tie flip at token {matched} — \ + verify(prefill)→{}, spec→{spec_id}, decode→{decode_argmax}; \ + spec pick is #2 in the prefill distribution (regret {regret:.3} ≤ {MARGIN_TOL}). \ + Not a spec bug.", + prefill_ref.id, + ); + continue; + } + } + + // Either the spec pick is outside the prefill top-K entirely, or it sits + // clearly below the prefill argmax — neither is a benign tie. + let decode_regret = base[matched] + .top_logprobs + .iter() + .find(|(t, _)| *t == spec_id) + .map(|(_, lp)| base[matched].top_logprobs[0].1 - lp); + failures.push(format!( + "prompt {i}: at token {matched} spec chose {spec_id} but prefill greedy says {} and \ + decode greedy says {decode_argmax} (spec regret in prefill dist: {prefill_regret:?} > \ + {MARGIN_TOL}; in decode dist: {decode_regret:?}) — real spec bug", + prefill_ref.id, + )); + } + + drop(handle); + + assert!( + failures.is_empty(), + "speculative greedy decode is not lossless:\n{}", + failures.join("\n") + ); +} diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs new file mode 100644 index 00000000..fa6e2977 --- /dev/null +++ b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs @@ -0,0 +1,161 @@ +//! DFlash speculative-decoding single-stream latency A/B. +//! +//! Speculative decoding's direct win is single-stream (batch=1) decode latency: +//! plain decode is memory-bound (one target forward per token), while spec +//! amortizes that forward over the accepted run. This measures end-to-end +//! wall-clock to generate a fixed token budget, speculative OFF vs ON, on the +//! same prompts and hardware, and reports the speedup. +//! +//! This is a measurement harness, not a pass/fail gate — it asserts only that +//! spec is not catastrophically slower (a guard against the draft mispredicting +//! everything). Read the printed numbers for the real signal. `--nocapture`. +//! +//! Requires a CUDA GPU, Qwen3-4B weights, and the DFlash drafter. Set +//! `OPENINFER_TEST_MODEL_PATH` + `OPENINFER_DFLASH_TEST_MODEL_PATH`; skips when +//! either is absent. Single-stream only — throughput under load is a separate +//! `vllm bench serve` A/B. + +use std::path::{Path, PathBuf}; +use std::time::{Duration, Instant}; + +use openinfer_core::engine::{EngineHandle, GenerateRequest, TokenEvent, TokenSink}; +use openinfer_core::sampler::SamplingParams; +use openinfer_qwen3_4b::{ + DecodeOverlap, Qwen3LaunchOptions, Qwen3MemoryOptions, Qwen3OffloadOptions, + DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES, DEFAULT_MAX_PREFILL_TOKENS, +}; + +mod common; + +const MODEL_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B"); +const DRAFT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B-DFlash-b16"); +const GENERATED_TOKENS: usize = 256; + +fn target_path_or_skip() -> Option { + match std::env::var("OPENINFER_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => Some(MODEL_PATH.to_string()), + Err(_) => None, + } +} + +fn draft_path_or_skip() -> Option { + match std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => Some(DRAFT_PATH.to_string()), + Err(_) => None, + } +} + +fn launch_options(draft: Option) -> Qwen3LaunchOptions { + Qwen3LaunchOptions { + device_ordinal: 0, + tp_size: 1, + cuda_graph: true, + offload: Qwen3OffloadOptions::disabled(), + no_prefix_cache: true, + max_prefill_tokens: DEFAULT_MAX_PREFILL_TOKENS, + memory: Qwen3MemoryOptions::new(0.85, DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES) + .validate() + .expect("valid memory options"), + lora: None, + decode_overlap: DecodeOverlap::Off, + dflash_draft_model_path: draft, + } +} + +/// Generate `GENERATED_TOKENS` greedily and return (token_count, elapsed). +fn timed_generate(handle: &EngineHandle, prompt_tokens: Vec) -> (usize, Duration) { + let (token_tx, mut rx) = TokenSink::standalone(); + let start = Instant::now(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens, + params: SamplingParams { + ignore_eos: true, + ..SamplingParams::default() + }, + max_tokens: GENERATED_TOKENS, + lora_adapter: None, + token_tx, + logprobs: 0, + echo: false, + }) + .expect("submit failed"); + + let mut count = 0usize; + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { .. }) => count += 1, + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Finished { .. }) => return (count, start.elapsed()), + Some(TokenEvent::Error { message, .. }) => panic!("generation failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => panic!("generation rejected: {message}"), + None => panic!("scheduler channel closed without Finished"), + } + } +} + +/// Decode tok/s averaged over the prompts (one warm-up run discarded). +fn measure(handle: &EngineHandle, prompts: &[Vec]) -> f64 { + // Warm up CUDA-graph capture / allocator on the first prompt. + let _ = timed_generate(handle, prompts[0].clone()); + let mut tokens = 0usize; + let mut elapsed = Duration::ZERO; + for p in prompts { + let (n, dt) = timed_generate(handle, p.clone()); + tokens += n; + elapsed += dt; + } + tokens as f64 / elapsed.as_secs_f64() +} + +#[test] +fn dflash_speculative_single_stream_speedup() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + eprintln!("skipping dflash perf A/B: set OPENINFER_TEST_MODEL_PATH + OPENINFER_DFLASH_TEST_MODEL_PATH"); + return; + }; + + let tokenizer = common::load_tokenizer(&model_path); + let prompts: Vec> = [ + "Write a short essay about the history of the Roman Empire.", + "Explain how a transformer neural network works, step by step.", + "List ten facts about the planet Mars and describe each one.", + ] + .iter() + .map(|p| tokenizer.encode(p, false).expect("encode failed")) + .collect(); + + let baseline_tps = { + let handle = openinfer_qwen3_4b::launch(Path::new(&model_path), launch_options(None)) + .expect("baseline engine"); + let tps = measure(&handle, &prompts); + drop(handle); + std::thread::sleep(Duration::from_secs(2)); + tps + }; + + let spec_tps = { + let handle = openinfer_qwen3_4b::launch( + Path::new(&model_path), + launch_options(Some(PathBuf::from(&draft_path))), + ) + .expect("speculative engine"); + measure(&handle, &prompts) + }; + + let speedup = spec_tps / baseline_tps; + eprintln!("───────────── DFlash single-stream decode A/B (5070 Ti, bs=1) ─────────────"); + eprintln!(" spec OFF (plain decode): {baseline_tps:7.1} tok/s"); + eprintln!(" spec ON (DFlash): {spec_tps:7.1} tok/s"); + eprintln!(" speedup: {speedup:7.2}×"); + eprintln!("───────────────────────────────────────────────────────────────────────────"); + + assert!( + speedup > 0.8, + "speculative decode is catastrophically slower ({speedup:.2}×) — draft likely mispredicting" + ); +} diff --git a/openinfer-server/src/config.rs b/openinfer-server/src/config.rs index 7ee423a6..db65959e 100644 --- a/openinfer-server/src/config.rs +++ b/openinfer-server/src/config.rs @@ -89,6 +89,12 @@ pub(crate) struct Args { #[arg(long, default_value_t = false)] pub no_prefix_cache: bool, + /// Enable Qwen3 DFlash speculative decoding with this drafter model path. + /// Single-GPU greedy only; incompatible with --enable-lora and --kv-offload, + /// and forces the prefix cache off (it needs clean target hidden states). + #[arg(long = "dflash-draft-model-path")] + pub dflash_draft_model_path: Option, + /// Cap on total prompt tokens forwarded in one Qwen3 scheduler step /// (chunked prefill). Prefill activation scratch scales with the step's /// prompt tokens, so this bounds peak VRAM under request bursts; prompts diff --git a/openinfer-server/src/main.rs b/openinfer-server/src/main.rs index 3af27a72..d46df28b 100644 --- a/openinfer-server/src/main.rs +++ b/openinfer-server/src/main.rs @@ -109,6 +109,16 @@ async fn main() -> anyhow::Result<()> { // defaults, capability constraints, cross-arg validation). The server only // picks the crate by detected model type and forwards the relevant CLI knobs. fn load_engine(args: &Args, model_type: ModelType) -> anyhow::Result { + // Only Qwen3 wires the DFlash drafter; fail loud rather than silently + // ignoring the flag for another model line. + #[cfg(feature = "qwen3-4b")] + let is_qwen3 = matches!(model_type, ModelType::Qwen3); + #[cfg(not(feature = "qwen3-4b"))] + let is_qwen3 = false; + anyhow::ensure!( + args.dflash_draft_model_path.is_none() || is_qwen3, + "--dflash-draft-model-path is only supported for Qwen3 (got {model_type:?})" + ); let handle = match model_type { #[cfg(feature = "deepseek-v4")] ModelType::DeepSeekV4 => openinfer_deepseek_v4::launch( @@ -154,6 +164,24 @@ fn load_engine(args: &Args, model_type: ModelType) -> anyhow::Result { + anyhow::ensure!( + !args.enable_lora, + "--dflash-draft-model-path is not supported with --enable-lora" + ); + anyhow::ensure!( + !args.kv_offload, + "--dflash-draft-model-path is not supported with --kv-offload" + ); + anyhow::ensure!( + args.tp_size == 1, + "--dflash-draft-model-path currently requires --tp-size=1" + ); + Some(path) + } + None => None, + }; openinfer_qwen3_4b::launch( &args.model_path, Qwen3LaunchOptions { @@ -171,6 +199,7 @@ fn load_engine(args: &Args, model_type: ModelType) -> anyhow::Result Date: Mon, 22 Jun 2026 12:19:22 +0800 Subject: [PATCH 2/9] docs(qwen3): record DFlash 5090 validation + gsm8k parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 5090 results: hf_golden_gate passes (bs1/batched/cuda-graph/tp2), losslessness gate passes, single-stream decode 168.9 -> 263.2 tok/s (1.56x; lower than the 5070 Ti's 1.82x because the 5090's higher bandwidth makes baseline decode less memory-bound). lm-eval gsm8k (5-shot greedy) strict-match identical spec on/off (0.86), flexible-extract within one question — task-level confirmation of greedy losslessness. Also drop the stale "5070 Ti" label hardcoded in the perf A/B test output, and note the orthogonal frontend gap (per-request `seed` is rejected, not ignored). Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/index.md | 2 +- .../qwen3/dflash-speculative-decoding.md | 32 ++++++++++++++----- .../tests/dflash_speculative_perf.rs | 2 +- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/docs/index.md b/docs/index.md index 9a80fb77..00b080da 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,7 +30,7 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | `models/qwen3/roadmap.md` | Qwen3-4B roadmap (2026-06 review): line is the maturity bar; #220 RoPE OOB, batched greedy sampling (#307), mixed greedy/non-greedy sampling (#284), and pegaflow KV offload (#316) are landed; open set is zero TP coverage, zero-adapter-only LoRA gate, dropped prefix-cache observability, stale docs, and YaRN #8 follow-up. | | `models/qwen3/model-crate.md` | `openinfer-qwen3-4b` owns Qwen3 config/weights/executor/scheduler/tests/kernel plan; root sees generic `EngineHandle`; split-K retuned to `256/64`, with 4k/64 serving TPOT p50 at `6.46ms` on RTX 5090. | | `models/qwen3/prefix-cache.md` | Prefix caching on by default for Qwen3-4B: full-block kvbm radix matching at the executor, suffix-only prefill. Repeated ~1900-token prompt TTFT 141.8 → 16.3ms p50 (8.7×); warm TTFT ≈ TPOT + ~5ms setup. Includes the RoPE scalar-path corruption fix and the drain-the-stream TTFT measurement pitfall. | -| `models/qwen3/dflash-speculative-decoding.md` | DFlash speculative decoding behind `--dflash-draft-model-path`, modelled as an optimistic transaction (propose K → verify K+1 span → accept longest argmax prefix + 1 bonus → commit/roll back KV). Lossless up to bf16 tie-flips (proven via bit-identical multi-token accepts + non-deterministic flip prompts). Single-stream decode 93.4 → 170.0 tok/s (1.82×) on 5070 Ti. Readiness from prefill capture (prefix cache forced off), no proposer trait until a 2nd method. 5090 throughput A/B pending. | +| `models/qwen3/dflash-speculative-decoding.md` | DFlash speculative decoding behind `--dflash-draft-model-path`, modelled as an optimistic transaction (propose K → verify K+1 span → accept longest argmax prefix + 1 bonus → commit/roll back KV). Lossless up to bf16 tie-flips (bit-identical multi-token accepts; lm-eval gsm8k strict-match identical spec on/off). Single-stream decode 1.82× on 5070 Ti, 1.56× on 5090. Readiness from prefill capture (prefix cache forced off), no proposer trait until a 2nd method. Concurrent-throughput A/B still pending. | | `models/qwen3/accuracy-gate.md` | Qwen3-4B instance of the logits golden gate (`tests/hf_golden_gate.rs`): 48 teacher-forced sequences / 816 positions vs a stored HF bf16 golden, replayed over bs=1 / batched eager / CUDA-graph. Strict guards: regret check + mean ≤ 0.06 + p99 ≤ 0.20; absolute max printed but not asserted (coverage-unstable). Methodology in `subsystems/correctness/`. | | `models/qwen3/kernels-crate.md` | Phase 1 split implemented and 5090-verified: Qwen3-4B kernel surface lives in `openinfer-kernels`; release build, test-target compile, accuracy gate, and bench snapshot pass. | | `models/qwen3/tp-design.md` | Qwen3 tensor-parallel design: `TP=2` milestone scope plus the controller/worker broadcast execution model, request identity, and coarse-grained step protocol for future TP/MoE work. | diff --git a/docs/models/qwen3/dflash-speculative-decoding.md b/docs/models/qwen3/dflash-speculative-decoding.md index 982ff767..d848d253 100644 --- a/docs/models/qwen3/dflash-speculative-decoding.md +++ b/docs/models/qwen3/dflash-speculative-decoding.md @@ -1,6 +1,6 @@ # DFlash Speculative Decoding (Qwen3-4B) -**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points. Measured single-stream decode A/B on RTX 5070 Ti (bs=1): **93.4 → 170.0 tok/s, 1.82×**. +**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points; lm-eval gsm8k strict-match is identical with spec on vs off. Measured single-stream decode A/B: **1.82× on RTX 5070 Ti** (93.4 → 170.0 tok/s), **1.56× on RTX 5090** (168.9 → 263.2 tok/s). Last touched: 2026-06 @@ -42,17 +42,33 @@ The gate runs a baseline (spec off, logprobs on) and a spec engine on the same p Single-stream decode is where speculative decoding pays off directly: plain decode is memory-bound (one target forward per token); spec amortizes that forward over the accepted run. A/B harness: `tests/dflash_speculative_perf.rs` (bs=1, 256 tokens, `ignore_eos`, one warm-up discarded). -| Config | RTX 5070 Ti, bs=1 | -| --- | --- | -| spec OFF (plain decode) | 93.4 tok/s | -| spec ON (DFlash) | 170.0 tok/s | -| **speedup** | **1.82×** | +| Config | RTX 5070 Ti, bs=1 | RTX 5090, bs=1 | +| --- | --- | --- | +| spec OFF (plain decode) | 93.4 tok/s | 168.9 tok/s | +| spec ON (DFlash) | 170.0 tok/s | 263.2 tok/s | +| **speedup** | **1.82×** | **1.56×** | -Throughput under concurrent load is a separate axis (`vllm bench serve` A/B) and is best measured on the 5090 — pending. +The speedup is smaller on the 5090: its higher memory bandwidth makes the baseline decode less memory-bound, so amortizing the target forward buys less. Both builds use CUDA 13.x (the 5090's default 12.9 has the documented cuBLAS N=1025 cliff). + +Throughput under concurrent load is a separate axis (`vllm bench serve` A/B). Speculative decoding's win shrinks — and can invert — as batch concurrency rises and the GPU turns compute-bound, so the crossover point is the interesting number. Still pending. + +## Task-level accuracy parity (lm-eval gsm8k) + +Token-level losslessness should imply task-level parity. Confirmed on the 5090 with `lm-eval` gsm8k (5-shot, greedy, `local-completions` against the openinfer server, 50 questions): + +| | flexible-extract | strict-match | +| --- | --- | --- | +| spec OFF | 0.86 | 0.86 | +| spec ON (DFlash) | 0.88 | 0.86 | + +`strict-match` is identical; `flexible-extract` differs by one question (within the ±0.05 stderr), the same single bf16 tie-flip the losslessness gate sees. DFlash does not change task accuracy. + +Harness note: openinfer's `/v1/completions` rejects a per-request `seed` field (`"per-request seed is not supported by this engine"` → 400), which the OpenAI/lm-eval client sends by default. For the eval the client was patched to drop `seed`; making the frontend accept-and-ignore `seed` under greedy is a separate, unrelated improvement (not part of this change). ## Constraints & open follow-ups - **TP=1, primary rank only.** The DFlash lane runs on the worker thread of the primary rank; the launch path gates `!(dflash && lora)` and `tp_size == 1`. The server fails loud (`--dflash-draft-model-path` is rejected for non-Qwen3 model lines rather than silently ignored). - **Second proposer → introduce the trait.** Until then the proposer is concrete on purpose. - **File size.** `executor.rs` (~2.8k lines) and `scheduler/tests.rs` (~1.2k lines) exceed the 1k guideline; the spec arms in `execute_step_on_lane` are candidates to move into `executor/spec.rs` in a follow-up. -- **5090 validation** — golden gate + `vllm bench serve` throughput A/B on the company 5090 (same sm_120 arch as the 5070 Ti, so correctness carries; the 5090 is for representative throughput numbers). +- **5090 validation — done** for correctness and single-stream: `hf_golden_gate` passes (bs1 / batched / cuda-graph / tp2), the losslessness gate passes, single-stream A/B is 1.56×, and lm-eval gsm8k parity holds (above). Still pending: `vllm bench serve` concurrent-throughput A/B (the spec-helps-vs-hurts crossover under load). +- **Frontend `seed`.** `/v1/completions` 400s on a per-request `seed` instead of accepting-and-ignoring it under greedy — surfaced while wiring lm-eval; orthogonal to spec decode, worth a small separate fix. diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs index fa6e2977..098489ab 100644 --- a/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs +++ b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs @@ -148,7 +148,7 @@ fn dflash_speculative_single_stream_speedup() { }; let speedup = spec_tps / baseline_tps; - eprintln!("───────────── DFlash single-stream decode A/B (5070 Ti, bs=1) ─────────────"); + eprintln!("───────────── DFlash single-stream decode A/B (bs=1) ─────────────"); eprintln!(" spec OFF (plain decode): {baseline_tps:7.1} tok/s"); eprintln!(" spec ON (DFlash): {spec_tps:7.1} tok/s"); eprintln!(" speedup: {speedup:7.2}×"); From 0fa0a67065266611aff8e66bde84f79978981da1 Mon Sep 17 00:00:00 2001 From: xiaguan <751080330@qq.com> Date: Mon, 22 Jun 2026 14:53:29 +0800 Subject: [PATCH 3/9] fix(qwen3-dflash): budget draft GPU footprint and cap context at admission MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DFlash's draft model and per-request KV/scratch live outside the paged KV pool, so they were never budgeted — under concurrency or long contexts the draft stole from the pool and risked OOM. And a request that fit the target context window but landed in the draft's fixed-width in-fill block crashed mid-prefill instead of being rejected. Budget: DFlashMemoryReservation::from_config (reads the draft config) splits the footprint into a per-token, pool-scaling term (draft KV + context/tail scratch + pending) folded into the per-block budget so the target block count shrinks, and a fixed term (draft weights + block-sized scratch across the decode batch + one in-fill block of per-request headroom) added to the KV margin. Reserved during memory profiling, before the pool is sized. Measured on RTX 5070 Ti: margin 150->2972 MiB, KV pool 1828->427 blocks — the pool now makes room instead of OOMing. Context limit: max_context_tokens() returns max_position_embeddings minus block_size when DFlash is on, so over-limit requests are rejected cleanly at admission rather than crashing mid-prefill. Tests: reservation arithmetic unit test (per-token 65536, weights ~1.1GiB); admission-rejection gate; losslessness gate and 1.79x single-stream A/B unchanged. Sharing the per-request scratch to reclaim most of the ~2.8GiB reservation is a tracked follow-up. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../qwen3/dflash-speculative-decoding.md | 21 +++- openinfer-qwen3-4b/src/dflash.rs | 102 ++++++++++++++++++ openinfer-qwen3-4b/src/executor.rs | 61 +++++++++-- .../src/executor/dflash_lane.rs | 10 +- openinfer-qwen3-4b/src/scheduler.rs | 5 +- openinfer-qwen3-4b/src/weights.rs | 12 ++- .../tests/dflash_speculative_gate.rs | 73 +++++++++++++ .../tests/kv_offload_cpu_hit.rs | 1 + 8 files changed, 273 insertions(+), 12 deletions(-) diff --git a/docs/models/qwen3/dflash-speculative-decoding.md b/docs/models/qwen3/dflash-speculative-decoding.md index d848d253..edeb73b9 100644 --- a/docs/models/qwen3/dflash-speculative-decoding.md +++ b/docs/models/qwen3/dflash-speculative-decoding.md @@ -1,6 +1,6 @@ # DFlash Speculative Decoding (Qwen3-4B) -**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points; lm-eval gsm8k strict-match is identical with spec on vs off. Measured single-stream decode A/B: **1.82× on RTX 5070 Ti** (93.4 → 170.0 tok/s), **1.56× on RTX 5090** (168.9 → 263.2 tok/s). +**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points; lm-eval gsm8k strict-match is identical with spec on vs off. Measured single-stream decode A/B: **1.82× on RTX 5070 Ti** (93.4 → 170.0 tok/s), **1.56× on RTX 5090** (168.9 → 263.2 tok/s). The drafter's out-of-pool footprint (weights + per-request KV/scratch) is reserved during memory profiling so the KV pool shrinks to fit instead of OOMing under load, and requests landing in the draft's final in-fill block are rejected cleanly at admission rather than crashing mid-prefill. Last touched: 2026-06 @@ -65,6 +65,24 @@ Token-level losslessness should imply task-level parity. Confirmed on the 5090 w Harness note: openinfer's `/v1/completions` rejects a per-request `seed` field (`"per-request seed is not supported by this engine"` → 400), which the OpenAI/lm-eval client sends by default. For the eval the client was patched to drop `seed`; making the frontend accept-and-ignore `seed` under greedy is a separate, unrelated improvement (not part of this change). +## GPU memory budget & context limit + +DFlash's draft model and its per-request KV/scratch live **outside** the paged KV pool (`KvCacheManager`), so they must be reserved *before* the pool is sized or they silently steal from it and OOM under concurrency / long contexts. `DFlashMemoryReservation::from_config` (cheap — reads the draft `config.json`) splits the footprint by how it scales and the budget charges each in the matching place: + +- **Per-token, pool-scaling** (draft KV + the prompt-tracking context scratch + the in-fill tail scratch + pending context, **65 536 B/token**) → folded into `effective_bytes_per_block`, so the *target* block count shrinks while the pool itself stays allocated at the target-only `bytes_per_block`. This is a safe upper bound: the scheduler reserves pool blocks for each request's whole lifetime (`prompt + max_tokens`), and the draft attends at most that many tokens, so billing the draft per pool-token over-covers. The draft KV and tail scratch are sized one in-fill block past that lifetime, so the fixed term also reserves `max_decode_batch × block_size` of that per-request headroom. +- **Fixed, pool-independent** (draft weights ~1.1 GiB + the block-sized per-request scratch across the decode batch — logits-dominated, ~6.5 MiB × 256) → added to the KV `margin`. + +Measured on the RTX 5070 Ti (16 GB, util 90%), the pool correctly makes room for the draft: + +| | margin | KV budget | KV blocks | +| --- | --- | --- | --- | +| DFlash OFF | 150 MiB | 4113 MiB | 1828 | +| DFlash ON | 2972 MiB | 1389 MiB | 427 | + +The fixed reservation lands exactly in the margin (+2822 MiB) and the per-token factor (≈1.44×) shrinks the remaining blocks. Before this, those ~2.8 GiB loaded on top of the full 1828-block pool → OOM under load. + +**Context limit.** The drafter's fixed-width in-fill block writes `block_size` positions past the committed length each step, so the DFlash-effective context is `max_position_embeddings − block_size`. `max_context_tokens()` returns that when DFlash is on, so a request that fits the target window but lands in the draft's final block is **rejected cleanly at admission** instead of crashing mid-prefill (`tests/dflash_speculative_gate.rs::dflash_request_in_draft_headroom_is_rejected_not_panicked`). + ## Constraints & open follow-ups - **TP=1, primary rank only.** The DFlash lane runs on the worker thread of the primary rank; the launch path gates `!(dflash && lora)` and `tp_size == 1`. The server fails loud (`--dflash-draft-model-path` is rejected for non-Qwen3 model lines rather than silently ignored). @@ -72,3 +90,4 @@ Harness note: openinfer's `/v1/completions` rejects a per-request `seed` field ( - **File size.** `executor.rs` (~2.8k lines) and `scheduler/tests.rs` (~1.2k lines) exceed the 1k guideline; the spec arms in `execute_step_on_lane` are candidates to move into `executor/spec.rs` in a follow-up. - **5090 validation — done** for correctness and single-stream: `hf_golden_gate` passes (bs1 / batched / cuda-graph / tp2), the losslessness gate passes, single-stream A/B is 1.56×, and lm-eval gsm8k parity holds (above). Still pending: `vllm bench serve` concurrent-throughput A/B (the spec-helps-vs-hurts crossover under load). - **Frontend `seed`.** `/v1/completions` 400s on a per-request `seed` instead of accepting-and-ignoring it under greedy — surfaced while wiring lm-eval; orthogonal to spec decode, worth a small separate fix. +- **Reclaim the per-request reservation.** The fixed term is dominated by the block-sized scratch billed per decode-batch slot (~6.5 MiB × 256 ≈ 1.6 GiB), most of it the transient logits buffer (`vocab × block`). The per-token term carries the context/pending scratch, which today reallocs to prompt length on the first draft and never shrinks (sized for `[0, context_len)` but only `[committed_len, +block]` is live). Both are billed conservatively because they're per-request; sharing the transient scratch across the serial per-request draft loop (`execute_dflash_draft`) and collapsing the prompt-persistent buffers would cut the reservation to roughly the draft weights + draft KV, reclaiming most of the ~2.8 GiB. Deserves its own validated PR (touches the draft forward). diff --git a/openinfer-qwen3-4b/src/dflash.rs b/openinfer-qwen3-4b/src/dflash.rs index 76f8f0c3..2af0b090 100644 --- a/openinfer-qwen3-4b/src/dflash.rs +++ b/openinfer-qwen3-4b/src/dflash.rs @@ -30,6 +30,76 @@ pub(crate) struct DFlashRequestState { max_cache_len: usize, } +/// GPU memory DFlash needs on top of the target KV pool, derived from the draft +/// config so the KV budget can reserve it *before* the draft model loads (the +/// draft buffers live outside the paged `KvCacheManager`). Split by how it scales: +/// +/// - `kv_bytes_per_token` scales with the KV pool (billed by shrinking the target +/// block count): the draft's own KV cache plus the scratch-context and +/// pending-context buffers, which currently persist at prompt length per +/// request (see `dflash-speculative-decoding.md` — collapsing that persistence +/// is a tracked follow-up that would shrink this term to the draft KV alone). +/// - `fixed_bytes` does not scale with the pool (billed via the memory margin): +/// the draft weights plus the block-sized scratch across the decode batch. +pub(crate) struct DFlashMemoryReservation { + pub(crate) kv_bytes_per_token: usize, + pub(crate) fixed_bytes: usize, +} + +impl DFlashMemoryReservation { + pub(crate) fn from_path(draft_path: &str, max_decode_batch_size: usize) -> Result { + let config = DFlashConfig::from_file(draft_path)?; + Ok(Self::from_config(&config, max_decode_batch_size)) + } + + fn from_config(config: &DFlashConfig, max_decode_batch_size: usize) -> Self { + const BF16: usize = 2; + let hidden = config.hidden_size; + let kv_dim = config.num_key_value_heads * config.head_dim; + let q_dim = config.num_attention_heads * config.head_dim; + let inter = config.intermediate_size; + let capture_layers = config.dflash_config.target_layer_ids.len(); + + // Per-sequence-token, pool-scaling buffers. + let draft_kv = config.num_hidden_layers * 2 * kv_dim * BF16; // DFlashLayerCache k+v + // Scratch split by what it tracks: `context_*` grows with the committed + // prefix; `tail_*` (tail_input + k_tail + v_tail) grows with the in-fill + // tail, which is one block past the prefix. + let context_scratch = 2 * hidden * BF16; // context_projected + context_hidden + let tail_scratch = (hidden + 2 * kv_dim) * BF16; // tail_input + k_tail + v_tail + let pending = hidden * capture_layers * BF16; // context_feature_dim + let kv_bytes_per_token = draft_kv + context_scratch + tail_scratch + pending; + + // Block-sized scratch held per live request (DFlashDraftScratch fixed buffers). + let fixed_scratch_per_request = + BF16 * config.block_size * (config.vocab_size + 5 * hidden + 2 * q_dim + 3 * inter); + let scratch_total = fixed_scratch_per_request * max_decode_batch_size; + + // Draft weights (5 transformer layers + the context projection), +10% slack + // for norms, rope caches, and allocator alignment. + let per_layer = BF16 + * (hidden * (q_dim + 2 * kv_dim) // qkv_proj + + q_dim * hidden // o_proj + + hidden * 2 * inter // gate_up_proj + + inter * hidden); // down_proj + let fc = BF16 * hidden * (hidden * capture_layers); // context projection + let weights = per_layer * config.num_hidden_layers + fc; + let weights = weights + weights / 10; + + // The durable draft KV and the tail scratch are sized to `context + + // block_size` — one in-fill block past the lifetime the KV pool reserves + // for the request. The per-token term bills only the pool's tokens, so + // reserve that one-block headroom per concurrently decoding request to + // keep the reservation an upper bound. + let block_headroom = max_decode_batch_size * config.block_size * (draft_kv + tail_scratch); + + Self { + kv_bytes_per_token, + fixed_bytes: weights + scratch_total + block_headroom, + } + } +} + struct DFlashLayerCache { k: HiddenStates, v: HiddenStates, @@ -359,6 +429,14 @@ impl DFlashDraftModel { self.config.block_size } + /// Largest sequence position the draft can cache. `validate_for_target` + /// guarantees this is `>=` the target's, but the draft's per-step in-fill + /// block writes `block_size` transient positions past the committed length, + /// so the usable context is `max_position_embeddings - block_size`. + pub(crate) fn max_position_embeddings(&self) -> usize { + self.config.max_position_embeddings + } + pub(crate) fn mask_token_id(&self) -> u32 { self.config.dflash_config.mask_token_id } @@ -772,5 +850,29 @@ mod tests { dflash.dflash_config.target_layer_ids, vec![1, 9, 17, 25, 33] ); + + // Pin the memory reservation the KV budget bills against. The per-token + // term (draft KV 5*2*1024*2 + scratch-context (3*2560+2*1024)*2 + pending + // 2560*5*2) drives the ~12% block haircut; a layer-count or geometry + // regression here would silently over/under-reserve and risk OOM. + let reservation = + super::DFlashMemoryReservation::from_config(&dflash, /*max_decode_batch*/ 256); + assert_eq!( + reservation.kv_bytes_per_token, 65_536, + "draft KV(20480) + scratch-ctx(19456) + pending(25600) per token" + ); + // Weights (~1.1 GiB) dominate the fixed term at batch=1; the block-sized + // per-request scratch (~6.5 MiB, logits-heavy) plus the one-block KV/tail + // headroom (~0.5 MiB) add across the decode batch. + let fixed_batch1 = super::DFlashMemoryReservation::from_config(&dflash, 1).fixed_bytes; + assert!( + (1_150_000_000..1_220_000_000).contains(&fixed_batch1), + "draft weights ~1.1GiB, got {fixed_batch1}" + ); + assert!( + (2_900_000_000..3_000_000_000).contains(&reservation.fixed_bytes), + "weights + 256 * (~6.5MiB scratch + ~0.5MiB block headroom), got {}", + reservation.fixed_bytes + ); } } diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 1dbd6efb..3b7079a7 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -873,10 +873,15 @@ impl Qwen3Executor { model: Qwen3Model, offload_opts: &Qwen3OffloadOptions, max_prefill_tokens: usize, + dflash_kv_bytes_per_token: usize, memory_options: Qwen3MemoryOptions, ) -> Result { - let (model, budget) = - profile_kv_budget_on_worker(model, max_prefill_tokens, memory_options)?; + let (model, budget) = profile_kv_budget_on_worker( + model, + max_prefill_tokens, + dflash_kv_bytes_per_token, + memory_options, + )?; let kv_mgr = KvCacheManager::new( &model.device_ctx().stream, budget.num_layers, @@ -931,6 +936,7 @@ impl Qwen3Executor { Qwen3LoraOptions::default(), Qwen3OffloadOptions::disabled(), crate::scheduler::DEFAULT_MAX_PREFILL_TOKENS, + None, Qwen3MemoryOptions::default(), ) } @@ -942,9 +948,10 @@ impl Qwen3Executor { lora_options: Qwen3LoraOptions, offload_options: Qwen3OffloadOptions, max_prefill_tokens: usize, + dflash_draft_path: Option<&str>, memory_options: Qwen3MemoryOptions, ) -> Result { - let memory_options = memory_options.validate()?; + let mut memory_options = memory_options.validate()?; let lora_options = lora_options.validate()?; anyhow::ensure!( !device_ordinals.is_empty(), @@ -967,11 +974,36 @@ impl Qwen3Executor { max_lora_rank: lora_options.max_lora_rank, }, )?; - let mut executor = - Self::single(model, &offload_options, max_prefill_tokens, memory_options)?; + // The DFlash draft model loads after profiling but lives outside the + // paged KV pool, so reserve its footprint up front from the draft + // config: fixed bytes (weights + block scratch) via the margin, and + // pool-scaling per-token bytes folded into the block budget. + let dflash_kv_bytes_per_token = match dflash_draft_path { + Some(path) => { + let reservation = crate::dflash::DFlashMemoryReservation::from_path( + path, + *BATCH_BUCKETS.last().unwrap(), + )?; + memory_options.kv_cache_memory_margin_bytes += reservation.fixed_bytes; + reservation.kv_bytes_per_token + } + None => 0, + }; + let mut executor = Self::single( + model, + &offload_options, + max_prefill_tokens, + dflash_kv_bytes_per_token, + memory_options, + )?; executor.lora_options = lora_options; return Ok(executor); } + anyhow::ensure!( + dflash_draft_path.is_none(), + "speculative decoding requires the single-GPU path (got {} devices)", + device_ordinals.len() + ); let world_size = device_ordinals.len(); let mut models = Vec::with_capacity(world_size); @@ -994,8 +1026,9 @@ impl Qwen3Executor { let mut profiled_models = Vec::with_capacity(world_size); let mut budgets = Vec::with_capacity(world_size); for model in models { + // DFlash is single-GPU only, so the TP path reserves nothing for it. let (model, budget) = - profile_kv_budget_on_worker(model, max_prefill_tokens, memory_options)?; + profile_kv_budget_on_worker(model, max_prefill_tokens, 0, memory_options)?; profiled_models.push(model); budgets.push(budget); } @@ -1429,6 +1462,7 @@ impl Qwen3Executor { fn profile_kv_budget_on_worker( model: Qwen3Model, max_prefill_tokens: usize, + dflash_kv_bytes_per_token: usize, memory_options: Qwen3MemoryOptions, ) -> Result<(Qwen3Model, KvBudget)> { let handle = thread::Builder::new() @@ -1445,6 +1479,7 @@ fn profile_kv_budget_on_worker( let budget = model.profiled_kv_budget( max_prefill_tokens, *BATCH_BUCKETS.last().unwrap(), + dflash_kv_bytes_per_token, memory_options, )?; Ok((model, budget)) @@ -1515,7 +1550,15 @@ impl ModelExecutor for Qwen3Executor { } fn max_context_tokens(&self) -> usize { - self.metadata.config.max_position_embeddings + let target = self.metadata.config.max_position_embeddings; + match &self.speculative { + // The draft's fixed-width in-fill block writes `block_size` positions + // past the committed length each step, so a request may use at most + // `draft.max_pos - block_size` tokens before the draft cache would + // overflow. Reject the rest at admission instead of crashing mid-prefill. + Some(meta) => target.min(meta.max_position_embeddings.saturating_sub(meta.block_size)), + None => target, + } } fn max_decode_batch_size(&self) -> usize { @@ -2231,6 +2274,9 @@ impl Drop for Qwen3Executor { #[derive(Clone, Debug)] struct DFlashMeta { block_size: usize, + /// Draft's max cacheable position; with the `block_size` in-fill headroom + /// this caps the DFlash-effective context to `max_position_embeddings - block_size`. + max_position_embeddings: usize, #[allow(dead_code)] target_layer_ids: Vec, } @@ -2317,6 +2363,7 @@ impl LocalQwen3Lane { model.tune_gemm_algos(&self.model)?; let meta = DFlashMeta { block_size: model.block_size(), + max_position_embeddings: model.max_position_embeddings(), target_layer_ids: model.target_layer_ids().to_vec(), }; self.dflash = Some(DFlashLaneState::new(model)); diff --git a/openinfer-qwen3-4b/src/executor/dflash_lane.rs b/openinfer-qwen3-4b/src/executor/dflash_lane.rs index 478daa8f..027ac12e 100644 --- a/openinfer-qwen3-4b/src/executor/dflash_lane.rs +++ b/openinfer-qwen3-4b/src/executor/dflash_lane.rs @@ -92,8 +92,14 @@ impl LocalQwen3Lane { for req in requests { let pending_exists = dflash.requests.contains_key(&req.request_id); if dflash_prefill_can_capture(req, pending_exists) { - let max_cache_len = - req.prompt_tokens.len() + req.max_output_tokens + dflash.model.block_size(); + // Admission already caps the request at `draft.max_pos - block_size` + // (see `max_context_tokens`), so this `.min` is a defensive floor: + // it keeps the draft KV alloc within the draft's max positions even + // if a caller bypasses admission. + let max_cache_len = (req.prompt_tokens.len() + + req.max_output_tokens + + dflash.model.block_size()) + .min(dflash.model.max_position_embeddings()); let mut state = match dflash.requests.remove(&req.request_id) { Some(state) => state, None => dflash.model.new_request_state(&ctx, max_cache_len)?, diff --git a/openinfer-qwen3-4b/src/scheduler.rs b/openinfer-qwen3-4b/src/scheduler.rs index 2fb5bdea..71c91877 100644 --- a/openinfer-qwen3-4b/src/scheduler.rs +++ b/openinfer-qwen3-4b/src/scheduler.rs @@ -156,13 +156,15 @@ pub(crate) fn start_qwen3( Qwen3LoraOptions::default(), offload_options, max_prefill_tokens, + dflash_draft_model_path, memory_options, )?; executor.set_no_prefix_cache(no_prefix_cache); executor.enable_decode_overlap(decode_overlap)?; // Speculative decoding loads its draft model after the target is up (the // draft is built against the target's embeddings/head) and forces the - // prefix cache off, so it must follow set_no_prefix_cache. + // prefix cache off, so it must follow set_no_prefix_cache. Its GPU footprint + // was already reserved during profiling from the draft path passed above. if let Some(draft_path) = dflash_draft_model_path { executor.load_dflash_draft_model(draft_path)?; } @@ -189,6 +191,7 @@ pub(crate) fn start_qwen3_with_lora_control( lora_options, offload_options, max_prefill_tokens, + None, memory_options, )?; executor.set_no_prefix_cache(no_prefix_cache); diff --git a/openinfer-qwen3-4b/src/weights.rs b/openinfer-qwen3-4b/src/weights.rs index 560f5ac2..bf909e3a 100644 --- a/openinfer-qwen3-4b/src/weights.rs +++ b/openinfer-qwen3-4b/src/weights.rs @@ -886,6 +886,7 @@ impl Qwen3Model { self.kv_budget_from_bytes( geometry, bytes_per_block, + 0, kv_budget_bytes, free_bytes, "heuristic", @@ -896,6 +897,7 @@ impl Qwen3Model { &self, max_prefill_tokens: usize, max_decode_batch_size: usize, + dflash_kv_bytes_per_token: usize, memory_options: Qwen3MemoryOptions, ) -> Result { let memory_options = memory_options.validate()?; @@ -1017,6 +1019,7 @@ impl Qwen3Model { Ok(self.kv_budget_from_bytes( geometry, bytes_per_block, + dflash_kv_bytes_per_token, kv_budget_bytes, initial_free_bytes, "profiled", @@ -1049,11 +1052,18 @@ impl Qwen3Model { &self, mut geometry: KvBudget, bytes_per_block: usize, + dflash_kv_bytes_per_token: usize, kv_budget_bytes: usize, free_bytes: usize, source: &'static str, ) -> KvBudget { - let num_blocks = (kv_budget_bytes / bytes_per_block).max(64); + // DFlash keeps its own per-request KV (plus prompt-scaling scratch) outside + // the paged pool, scaling with the same token count. Charge it as extra + // bytes per pool token so the target block count shrinks to leave room; the + // pool itself is still allocated at the target-only `bytes_per_block`. + let effective_bytes_per_block = + bytes_per_block + dflash_kv_bytes_per_token * geometry.block_size; + let num_blocks = (kv_budget_bytes / effective_bytes_per_block).max(64); let kv_mb = num_blocks * bytes_per_block / (1024 * 1024); log::info!( "KV cache ({source}): {num_blocks} blocks ({kv_mb} MB, {:.0}% of {:.0} MB free)", diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs index b94a526d..9b62fac6 100644 --- a/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs +++ b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs @@ -63,6 +63,11 @@ const LOGPROBS: usize = 20; /// ULP at typical logit magnitudes — mirrors `hf_golden_gate`'s `MARGIN_TOL`. const MARGIN_TOL: f32 = 0.20; +/// Both tests launch a Qwen3-4B engine, and two at once overflow a 16 GB card. +/// Cargo runs tests in one binary concurrently, so serialize the engine-holding +/// bodies — only one engine is ever resident on the GPU. +static GPU: std::sync::Mutex<()> = std::sync::Mutex::new(()); + fn target_path_or_skip() -> Option { match std::env::var("OPENINFER_TEST_MODEL_PATH") { Ok(path) => Some(path), @@ -191,6 +196,7 @@ fn dflash_speculative_greedy_matches_plain_greedy() { let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { return; }; + let _gpu = GPU.lock().unwrap_or_else(|p| p.into_inner()); let prompts = [ "The capital of France is", @@ -360,3 +366,70 @@ fn dflash_speculative_greedy_matches_plain_greedy() { failures.join("\n") ); } + +/// P2 regression: a request that fits the target context window but lands in the +/// draft's `block_size` in-fill headroom (`max_pos - block_size < prompt + +/// max_tokens <= max_pos`) must be rejected cleanly at admission. Before the +/// admission cap, such a request was admitted on the target's limit and then +/// panicked mid-prefill when the draft allocated KV past its own max positions. +#[test] +fn dflash_request_in_draft_headroom_is_rejected_not_panicked() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + return; + }; + let _gpu = GPU.lock().unwrap_or_else(|p| p.into_inner()); + + // Read the real context window so the boundary is exact regardless of the + // checkpoint, then size the request to sit inside the draft's final in-fill + // block — it fits the target window but not the DFlash-effective one. + let config: serde_json::Value = serde_json::from_str( + &std::fs::read_to_string(Path::new(&model_path).join("config.json")).expect("read config"), + ) + .expect("parse config"); + let max_pos = config["max_position_embeddings"] + .as_u64() + .expect("max_position_embeddings") as usize; + const BLOCK_SIZE: usize = 16; // DFlash drafter block size. + let prompt_len = 16usize; + // total in (max_pos - BLOCK_SIZE, max_pos]: clears the target check, trips + // the DFlash admission cap (max_pos - BLOCK_SIZE). + let max_tokens = max_pos - BLOCK_SIZE / 2 - prompt_len; + + let handle = openinfer_qwen3_4b::launch( + Path::new(&model_path), + launch_options(Some(PathBuf::from(&draft_path))), + ) + .expect("failed to start speculative engine"); + + let (token_tx, mut rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens: vec![100u32; prompt_len], + params: SamplingParams::default(), + max_tokens, + lora_adapter: None, + token_tx, + logprobs: 0, + echo: false, + }) + .expect("submit failed"); + + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Rejected { message, .. }) => { + eprintln!("draft-headroom request rejected as expected: {message}"); + break; + } + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Token { .. } | TokenEvent::Finished { .. }) => { + panic!("draft-headroom request was admitted instead of rejected") + } + Some(TokenEvent::Error { message, .. }) => { + panic!("draft-headroom request errored mid-flight instead of clean rejection: {message}") + } + None => panic!("scheduler channel closed without a rejection"), + } + } +} diff --git a/openinfer-qwen3-4b/tests/kv_offload_cpu_hit.rs b/openinfer-qwen3-4b/tests/kv_offload_cpu_hit.rs index 6ba26a95..f39f7697 100644 --- a/openinfer-qwen3-4b/tests/kv_offload_cpu_hit.rs +++ b/openinfer-qwen3-4b/tests/kv_offload_cpu_hit.rs @@ -138,6 +138,7 @@ fn live_gpu_and_cpu_prefix_hits() { Qwen3LoraOptions::default(), Qwen3OffloadOptions::enabled(HOST_TIER_BYTES), openinfer_qwen3_4b::DEFAULT_MAX_PREFILL_TOKENS, + None, openinfer_qwen3_4b::Qwen3MemoryOptions::default(), ) .expect("build offload executor"); From 326f716d6f586881e3396f2e21aaad19e5eb132e Mon Sep 17 00:00:00 2001 From: xiaguan <751080330@qq.com> Date: Mon, 22 Jun 2026 14:58:13 +0800 Subject: [PATCH 4/9] style(qwen3-dflash): apply cargo fmt across the feature branch CI's `cargo fmt --all --check` was red since the feature commit. Format the DFlash files (and the speculative/scheduler/core/kernels code touched by the branch) to the canonical style. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- openinfer-core/src/ops.rs | 12 +-- openinfer-kernels/src/ops.rs | 12 ++- openinfer-kernels/src/ops/sampling.rs | 1 - openinfer-qwen3-4b/src/executor.rs | 17 ++-- .../src/executor/dflash_lane.rs | 21 +++-- .../src/executor/dflash_prefill.rs | 20 ++++- openinfer-qwen3-4b/src/scheduler/plan.rs | 8 +- openinfer-qwen3-4b/src/scheduler/tests.rs | 81 ++++++++++++++++--- openinfer-qwen3-4b/src/speculative.rs | 6 +- .../tests/dflash_speculative_gate.rs | 18 +++-- .../tests/dflash_speculative_perf.rs | 16 ++-- 11 files changed, 144 insertions(+), 68 deletions(-) diff --git a/openinfer-core/src/ops.rs b/openinfer-core/src/ops.rs index e211b972..389ea51a 100644 --- a/openinfer-core/src/ops.rs +++ b/openinfer-core/src/ops.rs @@ -15,12 +15,12 @@ pub use attention::{ pub use openinfer_kernels::ops::{ GEMM_LT_MAX_N, LoraDecodeGroupedProjection, accumulate_bf16_token_scaled_to_f32_into, add_batch, add_batch_into, argmax, argmax_batch_bf16_into, bf16_hidden_to_f32_into, - copy_hidden_rows_into, copy_hidden_token_range_into, - dflash_qk_norm_rope_into, embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref, - extract_vec_ref_into, f32_to_bf16_hidden_into, fused_add_rms_norm_into, - gather_hidden_tokens_into, gemm, gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked, - gemm_into_checked, gemm_lt_tune, gemm_per_token, gemv, linear, - lora_decode_fused_delta_group3_into, lora_decode_fused_delta_into, pack_lora_b_rows_into, + copy_hidden_rows_into, copy_hidden_token_range_into, dflash_qk_norm_rope_into, + embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into, + f32_to_bf16_hidden_into, fused_add_rms_norm_into, gather_hidden_tokens_into, gemm, + gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked, gemm_into_checked, gemm_lt_tune, + gemm_per_token, gemv, linear, lora_decode_fused_delta_group3_into, + lora_decode_fused_delta_into, pack_lora_b_rows_into, qk_norm_partial_rope_batched_decode_hd256_into, rms_norm, rms_norm_batch_offset_into, rms_norm_gated_batch_into, rms_norm_into, rms_norm_offset_into, scale_f32_in_place, scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, diff --git a/openinfer-kernels/src/ops.rs b/openinfer-kernels/src/ops.rs index 4173b5e7..2d78930e 100644 --- a/openinfer-kernels/src/ops.rs +++ b/openinfer-kernels/src/ops.rs @@ -29,11 +29,10 @@ pub use deepseek_v2_lite::*; pub use elementwise::{ accumulate_bf16_token_scaled_to_f32_into, add_batch, add_batch_into, bf16_hidden_to_f32_into, copy_hidden_rows_into, copy_hidden_token_range_into, extract_vec, extract_vec_into, - extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, - gather_hidden_tokens_into, repeat_f32_for_reduce_scatter_into, scale_f32_in_place, - scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, - scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, - silu_mul_fused_batch_into, write_vec_into, + extract_vec_ref, extract_vec_ref_into, f32_to_bf16_hidden_into, gather_hidden_tokens_into, + repeat_f32_for_reduce_scatter_into, scale_f32_in_place, scaled_add_batch_into, + scaled_add_rows_indexed_into, scaled_add_rows_into, scaled_add_rows_token_range_into, + silu_mul_batch, silu_mul_batch_into, silu_mul_fused_batch_into, write_vec_into, }; pub use embedding::{embedding_batch, embedding_batch_vocab_shard, embedding_decode_into}; #[cfg(feature = "kimi-k2")] @@ -58,6 +57,5 @@ pub use norm::{ pub use sampling::{ BatchSamplingRow, BatchSamplingScratch, argmax, argmax_batch_bf16_into, argmax_batch_bf16_split_indexed_into, argmax_batch_bf16_split_partials_len, - flashinfer_top1_batch_into, - flashinfer_top1_row_states_bytes, gpu_sample_batch_into, + flashinfer_top1_batch_into, flashinfer_top1_row_states_bytes, gpu_sample_batch_into, }; diff --git a/openinfer-kernels/src/ops/sampling.rs b/openinfer-kernels/src/ops/sampling.rs index a303b390..d4ca384d 100644 --- a/openinfer-kernels/src/ops/sampling.rs +++ b/openinfer-kernels/src/ops/sampling.rs @@ -285,7 +285,6 @@ pub fn argmax_batch_bf16_split_partials_len(rows: usize, vocab: usize) -> usize rows * vocab.div_ceil(TILE_ELEMS) } - /// Two-stage indexed batched argmax: tile-parallel partials then a per-row /// finalize. Lowest index wins ties; each vocab row spreads over many CTAs /// instead of one. diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 3b7079a7..3e821272 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -490,10 +490,7 @@ fn execute_step_on_lane( Ok(WorkerStepOutcome::Ack) } } - StepCommand::SpeculativeVerify { - requests, - kv_views, - } => { + StepCommand::SpeculativeVerify { requests, kv_views } => { // One target forward over each request's K+1 draft span with a // speculative KV view. echo=true yields all-position logits so we // can argmax every span position — accept_greedy needs the target's @@ -526,11 +523,9 @@ fn execute_step_on_lane( requests: request_results, })) } - StepCommand::SpeculativeDraft { requests } => { - Ok(WorkerStepOutcome::SpeculativeDraft( - lane.execute_dflash_draft(requests)?, - )) - } + StepCommand::SpeculativeDraft { requests } => Ok(WorkerStepOutcome::SpeculativeDraft( + lane.execute_dflash_draft(requests)?, + )), } } @@ -2584,9 +2579,7 @@ enum StepCommand { }, /// Speculative draft: roll the DFlash draft model forward one block per /// request. Uses the draft's own KV — no target KV views. - SpeculativeDraft { - requests: Vec, - }, + SpeculativeDraft { requests: Vec }, } impl StepCommand { diff --git a/openinfer-qwen3-4b/src/executor/dflash_lane.rs b/openinfer-qwen3-4b/src/executor/dflash_lane.rs index 027ac12e..c2263aa7 100644 --- a/openinfer-qwen3-4b/src/executor/dflash_lane.rs +++ b/openinfer-qwen3-4b/src/executor/dflash_lane.rs @@ -96,10 +96,9 @@ impl LocalQwen3Lane { // (see `max_context_tokens`), so this `.min` is a defensive floor: // it keeps the draft KV alloc within the draft's max positions even // if a caller bypasses admission. - let max_cache_len = (req.prompt_tokens.len() - + req.max_output_tokens - + dflash.model.block_size()) - .min(dflash.model.max_position_embeddings()); + let max_cache_len = + (req.prompt_tokens.len() + req.max_output_tokens + dflash.model.block_size()) + .min(dflash.model.max_position_embeddings()); let mut state = match dflash.requests.remove(&req.request_id) { Some(state) => state, None => dflash.model.new_request_state(&ctx, max_cache_len)?, @@ -224,13 +223,13 @@ impl LocalQwen3Lane { let result = (|| -> Result> { let mut outputs = Vec::with_capacity(requests.len()); for req in requests { - let mut state = dflash - .requests - .remove(&req.request_id) - .ok_or_else(|| anyhow::anyhow!("missing DFlash state for {:?}", req.request_id))?; - let draft_logits = dflash - .model - .draft_logits(&self.model, &mut state, req.current_token)?; + let mut state = dflash.requests.remove(&req.request_id).ok_or_else(|| { + anyhow::anyhow!("missing DFlash state for {:?}", req.request_id) + })?; + let draft_logits = + dflash + .model + .draft_logits(&self.model, &mut state, req.current_token)?; let draft_len = draft_logits.seq_len; let greedy = SamplingParams::default(); let params: Vec<&SamplingParams> = vec![&greedy; draft_len]; diff --git a/openinfer-qwen3-4b/src/executor/dflash_prefill.rs b/openinfer-qwen3-4b/src/executor/dflash_prefill.rs index 332cf2a4..fb3099d4 100644 --- a/openinfer-qwen3-4b/src/executor/dflash_prefill.rs +++ b/openinfer-qwen3-4b/src/executor/dflash_prefill.rs @@ -16,7 +16,10 @@ pub(super) fn dflash_prefill_supported(req: &PrefillStepItem) -> bool { /// Eligible AND continuous: either the first chunk, or a later chunk whose /// earlier chunks already captured context (no gaps in the pending buffer). -pub(super) fn dflash_prefill_can_capture(req: &PrefillStepItem, pending_state_exists: bool) -> bool { +pub(super) fn dflash_prefill_can_capture( + req: &PrefillStepItem, + pending_state_exists: bool, +) -> bool { dflash_prefill_supported(req) && (req.chunk_start == 0 || pending_state_exists) } @@ -59,12 +62,21 @@ mod tests { #[test] fn prefill_action_table() { - assert_eq!(dflash_prefill_action(true, true), DFlashPrefillAction::MarkReady); + assert_eq!( + dflash_prefill_action(true, true), + DFlashPrefillAction::MarkReady + ); assert_eq!( dflash_prefill_action(true, false), DFlashPrefillAction::KeepPending ); - assert_eq!(dflash_prefill_action(false, true), DFlashPrefillAction::Drop); - assert_eq!(dflash_prefill_action(false, false), DFlashPrefillAction::Drop); + assert_eq!( + dflash_prefill_action(false, true), + DFlashPrefillAction::Drop + ); + assert_eq!( + dflash_prefill_action(false, false), + DFlashPrefillAction::Drop + ); } } diff --git a/openinfer-qwen3-4b/src/scheduler/plan.rs b/openinfer-qwen3-4b/src/scheduler/plan.rs index ab5d72cd..5f985919 100644 --- a/openinfer-qwen3-4b/src/scheduler/plan.rs +++ b/openinfer-qwen3-4b/src/scheduler/plan.rs @@ -12,11 +12,15 @@ use crate::speculative::{ use super::{ActiveRequestState, PendingRequest}; pub(super) enum ExecutionPlan { - Prefill { pending: Vec }, + Prefill { + pending: Vec, + }, Decode, /// Draft + verify the whole active batch (all requests are draft-ready). SpeculativeDecode, - Unified { pending: Vec }, + Unified { + pending: Vec, + }, } pub(super) enum ExecutionArtifacts { diff --git a/openinfer-qwen3-4b/src/scheduler/tests.rs b/openinfer-qwen3-4b/src/scheduler/tests.rs index 0add4798..bcd96113 100644 --- a/openinfer-qwen3-4b/src/scheduler/tests.rs +++ b/openinfer-qwen3-4b/src/scheduler/tests.rs @@ -1037,7 +1037,12 @@ fn lora_control_waits_until_scheduler_idle() { use crate::speculative::VerifyRequestResult; use openinfer_core::engine::FinishReason; -fn spec_active(id: u64, generated_count: usize, max_tokens: usize, ignore_eos: bool) -> ActiveRequestState { +fn spec_active( + id: u64, + generated_count: usize, + max_tokens: usize, + ignore_eos: bool, +) -> ActiveRequestState { let (token_tx, _rx) = TokenSink::standalone(); ActiveRequestState { request_id: RequestId(id), @@ -1072,10 +1077,20 @@ fn speculative_full_span_accept_continues() { let results = [spec_result(1, vec![10, 11, 12, 13])]; let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); match &effects[..] { - [effects::DecodeEffect::EmitManyAndContinue { request_id, tokens, completion_tokens }] => { + [ + effects::DecodeEffect::EmitManyAndContinue { + request_id, + tokens, + completion_tokens, + }, + ] => { assert_eq!(*request_id, RequestId(1)); assert_eq!(tokens, &vec![10, 11, 12, 13]); - assert_eq!(*completion_tokens, 3 + 4, "completion = prior generated + span len"); + assert_eq!( + *completion_tokens, + 3 + 4, + "completion = prior generated + span len" + ); } _ => panic!("expected EmitManyAndContinue"), } @@ -1089,10 +1104,25 @@ fn speculative_stop_token_midspan_finishes_and_suppresses_eos() { let results = [spec_result(1, vec![10, 11, SPEC_EOS, 13])]; let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); match &effects[..] { - [effects::DecodeEffect::EmitManyAndFinish { tokens, finish_reason, completion_tokens, .. }] => { - assert_eq!(tokens, &vec![10, 11], "EOS itself is suppressed from emission"); + [ + effects::DecodeEffect::EmitManyAndFinish { + tokens, + finish_reason, + completion_tokens, + .. + }, + ] => { + assert_eq!( + tokens, + &vec![10, 11], + "EOS itself is suppressed from emission" + ); assert!(matches!(finish_reason, FinishReason::Stop)); - assert_eq!(*completion_tokens, 5 + 3, "the EOS still counts toward completion length"); + assert_eq!( + *completion_tokens, + 5 + 3, + "the EOS still counts toward completion length" + ); } _ => panic!("expected EmitManyAndFinish(Stop)"), } @@ -1105,7 +1135,14 @@ fn speculative_stop_token_at_span_start_emits_nothing() { let results = [spec_result(1, vec![SPEC_EOS, 11, 12])]; let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); match &effects[..] { - [effects::DecodeEffect::EmitManyAndFinish { tokens, finish_reason, completion_tokens, .. }] => { + [ + effects::DecodeEffect::EmitManyAndFinish { + tokens, + finish_reason, + completion_tokens, + .. + }, + ] => { assert!(tokens.is_empty(), "stop at position 0 emits no tokens"); assert!(matches!(finish_reason, FinishReason::Stop)); assert_eq!(*completion_tokens, 5 + 1); @@ -1122,10 +1159,24 @@ fn speculative_max_tokens_truncates_midspan() { let results = [spec_result(1, vec![10, 11, 12, 13])]; let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); match &effects[..] { - [effects::DecodeEffect::EmitManyAndFinish { tokens, finish_reason, completion_tokens, .. }] => { - assert_eq!(tokens, &vec![10, 11], "the budget-hitting token is emitted, the rest dropped"); + [ + effects::DecodeEffect::EmitManyAndFinish { + tokens, + finish_reason, + completion_tokens, + .. + }, + ] => { + assert_eq!( + tokens, + &vec![10, 11], + "the budget-hitting token is emitted, the rest dropped" + ); assert!(matches!(finish_reason, FinishReason::Length)); - assert_eq!(*completion_tokens, 10, "completion stops exactly at max_tokens"); + assert_eq!( + *completion_tokens, 10, + "completion stops exactly at max_tokens" + ); } _ => panic!("expected EmitManyAndFinish(Length)"), } @@ -1139,7 +1190,11 @@ fn speculative_ignore_eos_does_not_stop() { let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); match &effects[..] { [effects::DecodeEffect::EmitManyAndContinue { tokens, .. }] => { - assert_eq!(tokens, &vec![SPEC_EOS, SPEC_EOS], "ignore_eos passes stop tokens through"); + assert_eq!( + tokens, + &vec![SPEC_EOS, SPEC_EOS], + "ignore_eos passes stop tokens through" + ); } _ => panic!("expected EmitManyAndContinue"), } @@ -1150,8 +1205,8 @@ fn speculative_resolves_each_request_independently() { let exec = FakeExecutor::new(64, Arc::new(Mutex::new(Vec::new()))).with_stop_token(SPEC_EOS); let active = [spec_active(1, 0, 100, false), spec_active(2, 0, 100, false)]; let results = [ - spec_result(1, vec![10, 11]), // continues - spec_result(2, vec![20, SPEC_EOS]), // finishes on EOS + spec_result(1, vec![10, 11]), // continues + spec_result(2, vec![20, SPEC_EOS]), // finishes on EOS ]; let effects = resolve::resolve_speculative_outputs(&exec, &active, &results); assert!(matches!( diff --git a/openinfer-qwen3-4b/src/speculative.rs b/openinfer-qwen3-4b/src/speculative.rs index 3af59e7b..fe5f8a7d 100644 --- a/openinfer-qwen3-4b/src/speculative.rs +++ b/openinfer-qwen3-4b/src/speculative.rs @@ -231,7 +231,11 @@ mod tests { #[test] fn batched_accepts_matching_prefix_plus_posterior_bonus() { - let req = VerifyStepItem::new(RequestId(7), vec![10, 11, 12, 13], SamplingParams::default()); + let req = VerifyStepItem::new( + RequestId(7), + vec![10, 11, 12, 13], + SamplingParams::default(), + ); let results = build_verify_results(&[req], &[11, 12, 99, 100]).expect("verify results"); assert_eq!(results.len(), 1); assert_eq!(results[0].request_id, RequestId(7)); diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs index 9b62fac6..4be713b1 100644 --- a/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs +++ b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs @@ -45,8 +45,8 @@ use std::time::Duration; use openinfer_core::engine::{EngineHandle, GenerateRequest, TokenEvent, TokenSink}; use openinfer_core::sampler::SamplingParams; use openinfer_qwen3_4b::{ - DecodeOverlap, Qwen3LaunchOptions, Qwen3MemoryOptions, Qwen3OffloadOptions, - DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES, DEFAULT_MAX_PREFILL_TOKENS, + DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES, DEFAULT_MAX_PREFILL_TOKENS, DecodeOverlap, + Qwen3LaunchOptions, Qwen3MemoryOptions, Qwen3OffloadOptions, }; mod common; @@ -71,7 +71,9 @@ static GPU: std::sync::Mutex<()> = std::sync::Mutex::new(()); fn target_path_or_skip() -> Option { match std::env::var("OPENINFER_TEST_MODEL_PATH") { Ok(path) => Some(path), - Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => Some(MODEL_PATH.to_string()), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => { + Some(MODEL_PATH.to_string()) + } Err(_) => { eprintln!( "skipping dflash gate: {MODEL_PATH}/config.json missing; set OPENINFER_TEST_MODEL_PATH" @@ -84,7 +86,9 @@ fn target_path_or_skip() -> Option { fn draft_path_or_skip() -> Option { match std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") { Ok(path) => Some(path), - Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => Some(DRAFT_PATH.to_string()), + Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => { + Some(DRAFT_PATH.to_string()) + } Err(_) => { eprintln!( "skipping dflash gate: {DRAFT_PATH}/config.json missing; set OPENINFER_DFLASH_TEST_MODEL_PATH" @@ -180,7 +184,7 @@ fn prefill_next(handle: &EngineHandle, context: Vec, logprobs: usize) -> St return Step { id, top_logprobs: logprob.map(|lp| lp.top_logprobs).unwrap_or_default(), - } + }; } Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} Some(TokenEvent::Finished { .. }) => panic!("echo prefill finished without a token"), @@ -427,7 +431,9 @@ fn dflash_request_in_draft_headroom_is_rejected_not_panicked() { panic!("draft-headroom request was admitted instead of rejected") } Some(TokenEvent::Error { message, .. }) => { - panic!("draft-headroom request errored mid-flight instead of clean rejection: {message}") + panic!( + "draft-headroom request errored mid-flight instead of clean rejection: {message}" + ) } None => panic!("scheduler channel closed without a rejection"), } diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs index 098489ab..ef4cb95b 100644 --- a/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs +++ b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs @@ -21,8 +21,8 @@ use std::time::{Duration, Instant}; use openinfer_core::engine::{EngineHandle, GenerateRequest, TokenEvent, TokenSink}; use openinfer_core::sampler::SamplingParams; use openinfer_qwen3_4b::{ - DecodeOverlap, Qwen3LaunchOptions, Qwen3MemoryOptions, Qwen3OffloadOptions, - DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES, DEFAULT_MAX_PREFILL_TOKENS, + DEFAULT_KV_CACHE_MEMORY_MARGIN_BYTES, DEFAULT_MAX_PREFILL_TOKENS, DecodeOverlap, + Qwen3LaunchOptions, Qwen3MemoryOptions, Qwen3OffloadOptions, }; mod common; @@ -34,7 +34,9 @@ const GENERATED_TOKENS: usize = 256; fn target_path_or_skip() -> Option { match std::env::var("OPENINFER_TEST_MODEL_PATH") { Ok(path) => Some(path), - Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => Some(MODEL_PATH.to_string()), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => { + Some(MODEL_PATH.to_string()) + } Err(_) => None, } } @@ -42,7 +44,9 @@ fn target_path_or_skip() -> Option { fn draft_path_or_skip() -> Option { match std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") { Ok(path) => Some(path), - Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => Some(DRAFT_PATH.to_string()), + Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => { + Some(DRAFT_PATH.to_string()) + } Err(_) => None, } } @@ -115,7 +119,9 @@ fn measure(handle: &EngineHandle, prompts: &[Vec]) -> f64 { #[test] fn dflash_speculative_single_stream_speedup() { let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { - eprintln!("skipping dflash perf A/B: set OPENINFER_TEST_MODEL_PATH + OPENINFER_DFLASH_TEST_MODEL_PATH"); + eprintln!( + "skipping dflash perf A/B: set OPENINFER_TEST_MODEL_PATH + OPENINFER_DFLASH_TEST_MODEL_PATH" + ); return; }; From d2f7e78cf7cf2efe2521bde37ed791733aa1594d Mon Sep 17 00:00:00 2001 From: xiaguan <751080330@qq.com> Date: Mon, 22 Jun 2026 17:40:30 +0800 Subject: [PATCH 5/9] fix(qwen3-dflash): adapt tests to main APIs after rebase Rebasing onto main surfaced two API drifts the dflash tests predated: Qwen3LaunchOptions gained batch_invariant (set false in the dflash perf/gate launch helpers), and start_engine_with_offload gained the dflash draft-path arg (pass None in the main-side batch-invariance tests scheduler_robustness/reject). Test-only, no behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- openinfer-qwen3-4b/tests/batch_invariance_reject.rs | 1 + openinfer-qwen3-4b/tests/dflash_speculative_gate.rs | 1 + openinfer-qwen3-4b/tests/dflash_speculative_perf.rs | 1 + openinfer-qwen3-4b/tests/scheduler_robustness.rs | 1 + 4 files changed, 4 insertions(+) diff --git a/openinfer-qwen3-4b/tests/batch_invariance_reject.rs b/openinfer-qwen3-4b/tests/batch_invariance_reject.rs index e73785c4..37489503 100644 --- a/openinfer-qwen3-4b/tests/batch_invariance_reject.rs +++ b/openinfer-qwen3-4b/tests/batch_invariance_reject.rs @@ -27,6 +27,7 @@ fn batch_invariant_rejects_decode_overlap() { Qwen3MemoryOptions::default(), DecodeOverlap::SharedSm, true, + None, ) .err() .expect("--batch-invariant + --decode-overlap must be rejected"); diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs index 4be713b1..2f2127b5 100644 --- a/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs +++ b/openinfer-qwen3-4b/tests/dflash_speculative_gate.rs @@ -113,6 +113,7 @@ fn launch_options(draft: Option) -> Qwen3LaunchOptions { .expect("valid memory options"), lora: None, decode_overlap: DecodeOverlap::Off, + batch_invariant: false, dflash_draft_model_path: draft, } } diff --git a/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs index ef4cb95b..ff0e017a 100644 --- a/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs +++ b/openinfer-qwen3-4b/tests/dflash_speculative_perf.rs @@ -64,6 +64,7 @@ fn launch_options(draft: Option) -> Qwen3LaunchOptions { .expect("valid memory options"), lora: None, decode_overlap: DecodeOverlap::Off, + batch_invariant: false, dflash_draft_model_path: draft, } } diff --git a/openinfer-qwen3-4b/tests/scheduler_robustness.rs b/openinfer-qwen3-4b/tests/scheduler_robustness.rs index 5688fc55..c9fe7fed 100644 --- a/openinfer-qwen3-4b/tests/scheduler_robustness.rs +++ b/openinfer-qwen3-4b/tests/scheduler_robustness.rs @@ -104,6 +104,7 @@ fn scheduler_survives_consumer_drop() { openinfer_qwen3_4b::Qwen3MemoryOptions::default(), openinfer_qwen3_4b::DecodeOverlap::Off, true, + None, ) .expect("failed to start engine"); assert_eq!( From 78e2160c94af24bbd2853bf9aa2f158d5c91e819 Mon Sep 17 00:00:00 2001 From: xiaguan <751080330@qq.com> Date: Tue, 23 Jun 2026 00:14:54 +0800 Subject: [PATCH 6/9] perf(qwen3-dflash): batch the draft forward to fix concurrent throughput MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DFlash's draft ran a per-request serial for-loop — N full forwards = N× kernel launches, launch-bound (a skip-attention A/B proved attention compute is <2% of the draft). This inverted the single-stream win under concurrent load (5090 greedy c16 −59% vs plain decode). Batch the dense ops (rms_norm/GEMM/silu/MLP/embedding/logits) into one pass over an N×block buffer — free, since cuBLAS takes any M and the ops are already row-batched. The varlen ops (rope/KV-copy/attention) stay a per-request loop slicing the batched buffers at row_offset; the two DFlash-exclusive ops (dflash_qk_norm_rope_into / single_prefill_nhd_ noncausal_into) gain a row-offset param that advances the device pointer to the slice (no CUDA-kernel change). A lane-level DFlashBatchScratch replaces the per-request scratch. 5090 greedy A/B (sharegpt out128, same-session serial vs batched): c8 831 → 1346 tok/s (vLLM 1240) c16 1013 → 1868 tok/s (vLLM 1846) draft@batch16 24.86 → 5.62 ms (draft_x 15.96 → 3.62) Both c8/c16 now beat vLLM. Losslessness gate passes (bf16 tie-flips only, regret ≤ 0.2). c1 (single-stream) is unchanged — batch=1 has no batching win; its gap to vLLM is launch-bound draft overhead, since openinfer's spec path (unlike base decode) isn't CUDA-Graph captured. Accept rate is measured-equal to vLLM (9.1% vs 8.85%, same drafter), so CUDA-Graph draft is the tracked next step. Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/index.md | 2 +- .../qwen3/dflash-speculative-decoding.md | 54 +- openinfer-kernels/src/ops/attention.rs | 47 +- openinfer-qwen3-4b/src/dflash.rs | 544 ++++++++++++------ openinfer-qwen3-4b/src/executor.rs | 7 +- .../src/executor/dflash_lane.rs | 97 +++- 6 files changed, 530 insertions(+), 221 deletions(-) diff --git a/docs/index.md b/docs/index.md index 00b080da..ebf9e657 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,7 +30,7 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | `models/qwen3/roadmap.md` | Qwen3-4B roadmap (2026-06 review): line is the maturity bar; #220 RoPE OOB, batched greedy sampling (#307), mixed greedy/non-greedy sampling (#284), and pegaflow KV offload (#316) are landed; open set is zero TP coverage, zero-adapter-only LoRA gate, dropped prefix-cache observability, stale docs, and YaRN #8 follow-up. | | `models/qwen3/model-crate.md` | `openinfer-qwen3-4b` owns Qwen3 config/weights/executor/scheduler/tests/kernel plan; root sees generic `EngineHandle`; split-K retuned to `256/64`, with 4k/64 serving TPOT p50 at `6.46ms` on RTX 5090. | | `models/qwen3/prefix-cache.md` | Prefix caching on by default for Qwen3-4B: full-block kvbm radix matching at the executor, suffix-only prefill. Repeated ~1900-token prompt TTFT 141.8 → 16.3ms p50 (8.7×); warm TTFT ≈ TPOT + ~5ms setup. Includes the RoPE scalar-path corruption fix and the drain-the-stream TTFT measurement pitfall. | -| `models/qwen3/dflash-speculative-decoding.md` | DFlash speculative decoding behind `--dflash-draft-model-path`, modelled as an optimistic transaction (propose K → verify K+1 span → accept longest argmax prefix + 1 bonus → commit/roll back KV). Lossless up to bf16 tie-flips (bit-identical multi-token accepts; lm-eval gsm8k strict-match identical spec on/off). Single-stream decode 1.82× on 5070 Ti, 1.56× on 5090. Readiness from prefill capture (prefix cache forced off), no proposer trait until a 2nd method. Concurrent-throughput A/B still pending. | +| `models/qwen3/dflash-speculative-decoding.md` | DFlash speculative decoding behind `--dflash-draft-model-path`, modelled as an optimistic transaction (propose K → verify K+1 span → accept longest argmax prefix + 1 bonus → commit/roll back KV). Lossless up to bf16 tie-flips (bit-identical multi-token accepts; lm-eval gsm8k strict-match identical spec on/off). Single-stream decode 1.82× on 5070 Ti, 1.56× on 5090. Concurrent throughput fixed by batching the draft forward (5090 greedy c8 1346 / c16 1868, both beating vLLM 1240 / 1846; draft@b16 24.8→5.6 ms). c1 still trails vLLM (237 vs 278) on launch-bound draft overhead — accept measured equal (9.1% vs 8.85%, same drafter), so CUDA-Graph draft is tracked next. Proposer trait deferred to EAGLE. | | `models/qwen3/accuracy-gate.md` | Qwen3-4B instance of the logits golden gate (`tests/hf_golden_gate.rs`): 48 teacher-forced sequences / 816 positions vs a stored HF bf16 golden, replayed over bs=1 / batched eager / CUDA-graph. Strict guards: regret check + mean ≤ 0.06 + p99 ≤ 0.20; absolute max printed but not asserted (coverage-unstable). Methodology in `subsystems/correctness/`. | | `models/qwen3/kernels-crate.md` | Phase 1 split implemented and 5090-verified: Qwen3-4B kernel surface lives in `openinfer-kernels`; release build, test-target compile, accuracy gate, and bench snapshot pass. | | `models/qwen3/tp-design.md` | Qwen3 tensor-parallel design: `TP=2` milestone scope plus the controller/worker broadcast execution model, request identity, and coarse-grained step protocol for future TP/MoE work. | diff --git a/docs/models/qwen3/dflash-speculative-decoding.md b/docs/models/qwen3/dflash-speculative-decoding.md index edeb73b9..a0e66cb6 100644 --- a/docs/models/qwen3/dflash-speculative-decoding.md +++ b/docs/models/qwen3/dflash-speculative-decoding.md @@ -1,6 +1,6 @@ # DFlash Speculative Decoding (Qwen3-4B) -**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points; lm-eval gsm8k strict-match is identical with spec on vs off. Measured single-stream decode A/B: **1.82× on RTX 5070 Ti** (93.4 → 170.0 tok/s), **1.56× on RTX 5090** (168.9 → 263.2 tok/s). The drafter's out-of-pool footprint (weights + per-request KV/scratch) is reserved during memory profiling so the KV pool shrinks to fit instead of OOMing under load, and requests landing in the draft's final in-fill block are rejected cleanly at admission rather than crashing mid-prefill. +**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points; lm-eval gsm8k strict-match is identical with spec on vs off. Measured single-stream decode A/B: **1.82× on RTX 5070 Ti** (93.4 → 170.0 tok/s), **1.56× on RTX 5090** (168.9 → 263.2 tok/s). The drafter's out-of-pool footprint (weights + per-request KV/scratch) is reserved during memory profiling so the KV pool shrinks to fit instead of OOMing under load, and requests landing in the draft's final in-fill block are rejected cleanly at admission rather than crashing mid-prefill. **Concurrent throughput is now competitive after batching the draft forward.** The draft used to run a per-request **serial** `for` loop — launch-bound (a skip-attention A/B showed attention compute <2%, so 24.8 ms/step at batch 16 was almost all kernel-launch overhead), which inverted the single-stream win (c16 −59%). Batching the dense ops into one N×block pass drops draft@batch16 to **5.6 ms** and lifts 5090 greedy throughput to **c8 1346 / c16 1868 tok/s — both now beating vLLM's 1240 / 1846**. Single-stream (c1 237 vs vLLM 278) still trails, but **not from accept**: measured 9.1% vs vLLM's 8.85% with the *same* drafter rules out draft quality. The c1 gap is the draft's ~1 ms kernel-launch overhead — openinfer's spec path, unlike base decode, isn't CUDA-Graph captured. **CUDA-Graph draft is the tracked next step.** See Performance § for the A/B tables. Last touched: 2026-06 @@ -50,7 +50,43 @@ Single-stream decode is where speculative decoding pays off directly: plain deco The speedup is smaller on the 5090: its higher memory bandwidth makes the baseline decode less memory-bound, so amortizing the target forward buys less. Both builds use CUDA 13.x (the 5090's default 12.9 has the documented cuBLAS N=1025 cliff). -Throughput under concurrent load is a separate axis (`vllm bench serve` A/B). Speculative decoding's win shrinks — and can invert — as batch concurrency rises and the GPU turns compute-bound, so the crossover point is the interesting number. Still pending. +### Concurrent throughput: the draft loop is serial and launch-bound + +Under concurrent load the single-stream win inverts. Same greedy harness (`temperature=0`, sharegpt prompts, 128 out tokens), openinfer vs vLLM with the **same** DFlash-b16 drafter, RTX 5090: + +| concurrency | OI plain | OI DFlash serial | **OI DFlash batched** | vLLM DFlash | +| --- | --- | --- | --- | --- | +| c1 | 170 | 245 | 237 | 278 | +| c8 | 1180 | 831 | **1346** | 1240 | +| c16 | 2277 | 1013 | **1868** | 1846 | + +(tok/s, sharegpt out128, same-session serial-vs-batched A/B. The serial draft inverted the win — vLLM degraded gracefully while openinfer nearly halved; batching the draft restores it, c8/c16 now beat vLLM. c1 is unchanged — batch=1 has no batching win; see "Single-stream gap" below.) + +#### Root cause (serial draft) and the fix (batched draft, landed) + +Accept length is **not** the cause — both engines accept ~equally (same drafter + greedy = longest-prefix match). The cause was the **per-request serial draft loop**: `execute_dflash_draft` (`dflash_lane.rs`) called `draft_logits` once per request in a `for` loop, while verify runs ONE batched target forward over all spans. Per-step draft timing by batch size, **before vs after batching** (5090, instrumented): + +| batch | draft serial | draft batched | verify | draft serial scaling | draft batched scaling | +| --- | --- | --- | --- | --- | --- | +| 1 | 1.56 ms | 1.55 ms | 8.1 ms | 1.00× | 1.00× | +| 8 | 12.45 ms | 3.17 ms | 9.6 ms | 7.99× | 2.04× | +| 16 | 24.86 ms | **5.62 ms** | 13.6 ms | **15.96×** | **3.62×** | + +The serial draft scaled **exactly linearly** (`draft_x` 16.00 at batch 16); batched draft scales sub-linearly (3.62×). At batch 16 the draft drops from 65% of the step to ~29%, and step time roughly halves → c16 throughput doubles. + +**Why it was launch-bound, not compute-bound.** A skip-attention A/B (short-circuit `single_prefill` to a cheap copy, keep every other kernel) barely moved the serial draft — batch 16: 24.36 ms vs 24.81 ms, so attention compute is **<2%**. The 24.8 ms was almost entirely per-request serial **kernel-launch overhead**: each request's draft is ~90 tiny 16-token kernels (5 layers × 35 dense GEMMs + varlen copy/rope/KV), compute ≈ 0. + +**The fix that landed: batch the whole draft forward (N requests in one pass), killing the N× launch.** `draft_logits_batched` (`dflash.rs`) runs the dense ops (rms_norm / GEMM / silu / MLP / embedding / logits) **once over an N×block buffer** — free, since cuBLAS takes any M and the ops are already row-batched (N×35 → 35 launches, no CUDA-kernel change). The varlen ops (tail concat / k·v GEMM / rope / KV-copy / attention) stay a per-request loop slicing the batched buffers at `row_offset = i·block_size` (the two DFlash-exclusive ops `dflash_qk_norm_rope_into` / `single_prefill_nhd_noncausal_into` advance the device pointer to the slice). A lane-level `DFlashBatchScratch` (sized for the largest batch bucket) replaces the per-request scratch — folding in the reservation-reclaim follow-up. Losslessness gate still passes (bf16 tie-flips only). + +#### Single-stream gap (c1): launch-bound draft, not accept + +c1 trails vLLM (237 vs 278) and batching doesn't help it — batch=1 has nothing to batch. The cause is **not accept**: with the *same* drafter and spec config (vLLM `--speculative-config '{"method":"dflash",…,"num_speculative_tokens":16}'`), measured greedy accept is **9.1% (ours) vs 8.85% (vLLM)** — mean 1.29 vs 1.42 draft tokens/step, and our pos-0 accept (60%) is actually higher. 9% is the drafter's floor on sharegpt free-text (bimodal: structured spans accept up to 15, free text accepts 0), shared by both engines. + +The real gap is in **step time**: normalized per output token, ours is 4.22 ms vs vLLM's 3.60 ms. c1 step = draft 1.55 ms + verify 7.9 ms. Verify is memory-bound (one 4B target forward, reads ~8 GB of weights) — irreducible and the same for vLLM. But the draft's 1.55 ms is **pure kernel-launch** (85 tiny 16-token kernels, ~18 µs/launch CPU enqueue, compute <2%), while vLLM's draft is CUDA-Graph captured (~0.5 ms). That ~1 ms/step is the c1 gap. + +**Next step: CUDA-Graph the draft forward.** openinfer's whole spec path (`SpeculativeDraft` / `SpeculativeVerify`, `executor.rs`) is *not* CUDA-Graph captured, whereas base decode is (`execute_decode` split-graph cache). Capturing the draft is an architecture-level change — the per-request draft KV (`DFlashLayerCache`) needs pointer-stable buffers, the variable context length (0–16) needs a fixed/padded window, plus capture/replay wiring — predicted to drop draft@batch1 1.55→~0.3 ms (c1 → ~272, ≈ vLLM) and draft@batch16 5.6→~1.5 ms (c16 → ~2200). Tracked as its own PR after this one lands. + +The EAGLE proposer trait is still deferred to when EAGLE actually lands (see "no proposer trait yet" above); both the batched draft and the future CUDA-Graph draft are DFlash-internal changes behind the unchanged `DraftPlan→DraftResult` seam, so neither is thrown away by EAGLE. ## Task-level accuracy parity (lm-eval gsm8k) @@ -86,8 +122,16 @@ The fixed reservation lands exactly in the margin (+2822 MiB) and the per-token ## Constraints & open follow-ups - **TP=1, primary rank only.** The DFlash lane runs on the worker thread of the primary rank; the launch path gates `!(dflash && lora)` and `tp_size == 1`. The server fails loud (`--dflash-draft-model-path` is rejected for non-Qwen3 model lines rather than silently ignored). -- **Second proposer → introduce the trait.** Until then the proposer is concrete on purpose. +- **Second proposer (EAGLE) → introduce the trait then.** A proposer trait wants two real methods to fix its shape; DFlash stays concrete until EAGLE lands. The batched-draft perf work (Performance §) is orthogonal — it's a DFlash-internal change behind the unchanged `DraftPlan→DraftResult` seam, so it can land before or with the trait without being redone, and EAGLE's autoregressive attention won't reuse DFlash's block attention anyway. - **File size.** `executor.rs` (~2.8k lines) and `scheduler/tests.rs` (~1.2k lines) exceed the 1k guideline; the spec arms in `execute_step_on_lane` are candidates to move into `executor/spec.rs` in a follow-up. -- **5090 validation — done** for correctness and single-stream: `hf_golden_gate` passes (bs1 / batched / cuda-graph / tp2), the losslessness gate passes, single-stream A/B is 1.56×, and lm-eval gsm8k parity holds (above). Still pending: `vllm bench serve` concurrent-throughput A/B (the spec-helps-vs-hurts crossover under load). +- **5090 validation — done** for correctness, single-stream, and concurrent throughput: `hf_golden_gate` passes (bs1 / batched / cuda-graph / tp2), the losslessness gate passes (before and after the draft batching), single-stream A/B is 1.56×, and lm-eval gsm8k parity holds (above). Concurrent throughput is **fixed** (Performance §): the serial draft inverted the win (c16 −59%), batching the draft restores **c8 1346 / c16 1868 — both beating vLLM**. Single-stream c1 (237 vs vLLM 278) still trails on launch-bound draft overhead; CUDA-Graph draft is tracked next. - **Frontend `seed`.** `/v1/completions` 400s on a per-request `seed` instead of accepting-and-ignoring it under greedy — surfaced while wiring lm-eval; orthogonal to spec decode, worth a small separate fix. -- **Reclaim the per-request reservation.** The fixed term is dominated by the block-sized scratch billed per decode-batch slot (~6.5 MiB × 256 ≈ 1.6 GiB), most of it the transient logits buffer (`vocab × block`). The per-token term carries the context/pending scratch, which today reallocs to prompt length on the first draft and never shrinks (sized for `[0, context_len)` but only `[committed_len, +block]` is live). Both are billed conservatively because they're per-request; sharing the transient scratch across the serial per-request draft loop (`execute_dflash_draft`) and collapsing the prompt-persistent buffers would cut the reservation to roughly the draft weights + draft KV, reclaiming most of the ~2.8 GiB. Deserves its own validated PR (touches the draft forward). +- **Reclaim the per-request reservation.** The fixed term is dominated by the block-sized scratch billed per decode-batch slot (~6.5 MiB × 256 ≈ 1.6 GiB), most of it the transient logits buffer (`vocab × block`). The per-token term carries the context/pending scratch, which today reallocs to prompt length on the first draft and never shrinks (sized for `[0, context_len)` but only `[committed_len, +block]` is live). Both are billed conservatively because they were per-request. **Partly reclaimed by the batched-draft PR:** the per-request scratch is gone — `DFlashBatchScratch` is allocated once at lane level, sharing the transient logits/dense buffers across the whole batch instead of per request. The reservation accounting (`from_config`) still bills the old per-request upper bound (a safe over-estimate, so no OOM); tightening it to the now-smaller lane-level footprint is the remaining follow-up. + +### Review blockers (correctness/usability, independent of the perf work) + +Issues surfaced in PR review, largely independent of the draft batching (#2's DFlash wrappers were fixed alongside it): + +1. **Unified path silently skips DFlash readiness.** `StepCommand::Unified` (`executor.rs`) captures no DFlash hidden state, and only the plain `execute_prefill` post-step marks a request draft-ready (`executor.rs` ~1766). So when active + pending fuse into a Unified step (`scheduler/plan.rs` — the normal mixed-load path), greedy requests routed through Unified never become draft-ready and never recover: DFlash silently no-ops for them. No wrong tokens, but the feature quietly disables itself under mixed load. Crash-early or capture-in-Unified, don't degrade silently. +2. **Stream-override race (partly fixed).** The DFlash `qk_norm_rope` / `single_prefill` wrappers (`attention.rs`) now use `active_cu_stream(ctx)` (fixed in the batched-draft PR). `copy_hidden_rows_into` (`elementwise.rs:209`) still uses `ctx.stream.cu_stream()` instead of the repo convention (`tensor.rs:43`) — under Green-Context / split-stream decode overlap this remains a planted race. +3. **`gemm_lt_pin_tune` is not a real warmup.** It only pins the heuristic (`linear.cu:497`); it never executes a `cublasLtMatmul` the way the old `gemm_lt_tune_cuda` (`linear.cu:431`) did, so the first real matmul can land inside CUDA-graph capture. `batch_invariance_decode_gemm_graph` backstops it, but the warmup should actually run the matmul. diff --git a/openinfer-kernels/src/ops/attention.rs b/openinfer-kernels/src/ops/attention.rs index 0e4447de..0532f7b5 100644 --- a/openinfer-kernels/src/ops/attention.rs +++ b/openinfer-kernels/src/ops/attention.rs @@ -497,10 +497,18 @@ pub fn qk_norm_rope_batch_decode_into( } } +/// QK RMSNorm + RoPE for one DFlash request's draft block. +/// +/// `q` is a row sub-range of a batched buffer: `q_row_offset` rows precede this +/// request's `q_seq_len` query rows. The kernel still sees a single-request Q +/// shape — we just advance the device pointer to the request's slice. `k` is the +/// request's own varlen tail scratch (whole buffer), so it needs no offset. #[allow(clippy::too_many_arguments)] pub fn dflash_qk_norm_rope_into( ctx: &DeviceContext, q: &mut HiddenStates, + q_row_offset: usize, + q_seq_len: usize, k: &mut HiddenStates, q_norm_weight: &DeviceVec, k_norm_weight: &DeviceVec, @@ -517,8 +525,16 @@ pub fn dflash_qk_norm_rope_into( assert_eq!(k.hidden_dim, num_kv_heads * head_dim); assert_eq!(q_norm_weight.len, head_dim); assert_eq!(k_norm_weight.len, head_dim); + assert!( + q_row_offset + q_seq_len <= q.seq_len, + "dflash_qk_norm_rope q row range [{}..{}) exceeds seq_len {}", + q_row_offset, + q_row_offset + q_seq_len, + q.seq_len + ); let (q_ptr, _gq) = q.data.device_ptr_mut(&ctx.stream); + let q_ptr = q_ptr + (q_row_offset * q.hidden_dim * std::mem::size_of::()) as u64; let (k_ptr, _gk) = k.data.device_ptr_mut(&ctx.stream); let (qn_ptr, _gqn) = q_norm_weight.data.device_ptr(&ctx.stream); let (kn_ptr, _gkn) = k_norm_weight.data.device_ptr(&ctx.stream); @@ -536,13 +552,13 @@ pub fn dflash_qk_norm_rope_into( num_q_heads as i32, num_kv_heads as i32, head_dim as i32, - q.seq_len as i32, + q_seq_len as i32, k.seq_len as i32, q_start_pos as i32, k_start_pos as i32, rms_eps, (cos_cache.data.len() / head_dim) as i32, - ctx.stream.cu_stream(), + crate::tensor::active_cu_stream(ctx), ) }; if result != 0 { @@ -551,10 +567,20 @@ pub fn dflash_qk_norm_rope_into( Ok(()) } +/// Non-causal prefill attention for one DFlash request's draft block. +/// +/// `q` and `output` share the SAME row sub-range of batched buffers: request +/// `i` owns rows `[row_offset, row_offset + q_seq_len)` in both, because the +/// draft writes each request's attention output back into the row slot its +/// queries came from. The k/v caches are the request's own whole buffers. The +/// kernel sees a single-request shape — we advance the Q/output device pointers +/// to the request's slice. #[allow(clippy::too_many_arguments)] pub fn single_prefill_nhd_noncausal_into( ctx: &DeviceContext, q: &HiddenStates, + row_offset: usize, + q_seq_len: usize, k_cache: &HiddenStates, v_cache: &HiddenStates, output: &mut HiddenStates, @@ -570,11 +596,22 @@ pub fn single_prefill_nhd_noncausal_into( assert_eq!(v_cache.hidden_dim, k_cache.hidden_dim); assert_eq!(v_cache.seq_len, k_cache.seq_len); assert!(kv_len <= k_cache.seq_len); - + assert!( + row_offset + q_seq_len <= q.seq_len, + "single_prefill row range [{}..{}) exceeds seq_len {}", + row_offset, + row_offset + q_seq_len, + q.seq_len + ); + + // q and output share row_offset (asserted same seq_len/hidden_dim above). + let byte_offset = (row_offset * q.hidden_dim * std::mem::size_of::()) as u64; let (q_ptr, _gq) = q.data.device_ptr(&ctx.stream); + let q_ptr = q_ptr + byte_offset; let (k_ptr, _gk) = k_cache.data.device_ptr(&ctx.stream); let (v_ptr, _gv) = v_cache.data.device_ptr(&ctx.stream); let (out_ptr, _go) = output.data.device_ptr_mut(&ctx.stream); + let out_ptr = out_ptr + byte_offset; let result = unsafe { ffi::single_prefill_nhd_noncausal_cuda( q_ptr as *const ffi::Half, @@ -584,11 +621,11 @@ pub fn single_prefill_nhd_noncausal_into( num_q_heads as i32, num_kv_heads as i32, head_dim as i32, - q.seq_len as i32, + q_seq_len as i32, kv_len as i32, k_cache.seq_len as i32, 1.0f32 / (head_dim as f32).sqrt(), - ctx.stream.cu_stream(), + crate::tensor::active_cu_stream(ctx), ) }; if result != 0 { diff --git a/openinfer-qwen3-4b/src/dflash.rs b/openinfer-qwen3-4b/src/dflash.rs index 2af0b090..96b80d79 100644 --- a/openinfer-qwen3-4b/src/dflash.rs +++ b/openinfer-qwen3-4b/src/dflash.rs @@ -25,7 +25,10 @@ pub(crate) struct DFlashDraftModel { pub(crate) struct DFlashRequestState { layers: Vec, pending_context: DFlashPendingContext, - scratch: DFlashDraftScratch, + /// Projected target context for the current draft round. Computed once from + /// `pending_context` and read by every layer's tail concat, so it lives with + /// the request (the batched scratch only holds one request's varlen tail). + context: DFlashContextScratch, committed_len: usize, max_cache_len: usize, } @@ -35,12 +38,20 @@ pub(crate) struct DFlashRequestState { /// draft buffers live outside the paged `KvCacheManager`). Split by how it scales: /// /// - `kv_bytes_per_token` scales with the KV pool (billed by shrinking the target -/// block count): the draft's own KV cache plus the scratch-context and -/// pending-context buffers, which currently persist at prompt length per +/// block count): the draft's own KV cache plus the per-request context-projection +/// and pending-context buffers, which currently persist at prompt length per /// request (see `dflash-speculative-decoding.md` — collapsing that persistence /// is a tracked follow-up that would shrink this term to the draft KV alone). /// - `fixed_bytes` does not scale with the pool (billed via the memory margin): -/// the draft weights plus the block-sized scratch across the decode batch. +/// the draft weights plus the lane-level batched scratch sized for the whole +/// decode batch. +/// +// TODO: the draft scratch is now a single lane-level `DFlashBatchScratch` +// allocated once (dense buffers sized `max_batch * block_size`, plus one shared +// varlen tail), not a per-request buffer. The per-token `tail_scratch` term and +// the per-request `block_headroom` tail term are therefore over-estimates — kept +// as a conservative upper bound until the accounting is retuned against the +// batched allocation. pub(crate) struct DFlashMemoryReservation { pub(crate) kv_bytes_per_token: usize, pub(crate) fixed_bytes: usize, @@ -70,10 +81,13 @@ impl DFlashMemoryReservation { let pending = hidden * capture_layers * BF16; // context_feature_dim let kv_bytes_per_token = draft_kv + context_scratch + tail_scratch + pending; - // Block-sized scratch held per live request (DFlashDraftScratch fixed buffers). - let fixed_scratch_per_request = - BF16 * config.block_size * (config.vocab_size + 5 * hidden + 2 * q_dim + 3 * inter); - let scratch_total = fixed_scratch_per_request * max_decode_batch_size; + // Lane-level batched dense scratch: every dense buffer is sized for the + // whole decode batch (`max_batch * block_size` rows), allocated once. + // Same total magnitude as the old per-request scratch summed over the + // batch, but now one contiguous allocation. + let dense_scratch_per_block_row = + BF16 * (config.vocab_size + 5 * hidden + 2 * q_dim + 3 * inter); + let scratch_total = dense_scratch_per_block_row * config.block_size * max_decode_batch_size; // Draft weights (5 transformer layers + the context projection), +10% slack // for norms, rope caches, and allocator alignment. @@ -111,19 +125,34 @@ struct DFlashPendingContext { capacity: usize, } -struct DFlashDraftScratch { +/// Per-request projected context. The fc projection + hidden_norm turn the +/// captured target hidden context into draft hidden space once per draft round; +/// every layer's tail concat reads `context_hidden`, so it must persist across +/// the layer loop and therefore lives in the request (not the shared scratch). +struct DFlashContextScratch { max_context_len: usize, - block_token_ids_h: Vec, - token_ids_d: CudaSlice, context_projected: HiddenStates, context_hidden: HiddenStates, +} + +/// Lane-level batched draft scratch, allocated once for the whole decode batch. +/// +/// Dense buffers (`hidden`, `normed`, `q_batch`, `attn_output`, the MLP buffers, +/// and `logits`) hold `max_batch * block_size` rows so the GEMM / rms_norm / +/// silu / add / logits / embedding ops run ONCE over the batched buffer. The +/// varlen tail buffers (`tail_input`, `k_tail`, `v_tail`) stay sized for a single +/// request and are reused inside the per-request loop, because their ops (tail +/// concat, k/v GEMMs, rope, KV copy, attention) still loop per request — Step 2 +/// will batch those via CUDA-kernel changes. +pub(crate) struct DFlashBatchScratch { + max_batch_block_rows: usize, + max_tail_len: usize, + block_token_ids_h: Vec, + token_ids_d: CudaSlice, hidden: HiddenStates, hidden_out: HiddenStates, normed: HiddenStates, - tail_input: HiddenStates, q_batch: HiddenStates, - k_tail: HiddenStates, - v_tail: HiddenStates, attn_output: HiddenStates, o_buf: HiddenStates, gate_out: HiddenStates, @@ -131,6 +160,10 @@ struct DFlashDraftScratch { act_out: HiddenStates, logits_normed: HiddenStates, logits: HiddenStates, + // Shared single-request varlen tail scratch (reused inside the per-request loop). + tail_input: HiddenStates, + k_tail: HiddenStates, + v_tail: HiddenStates, } impl DFlashRequestState { @@ -228,66 +261,115 @@ impl DFlashPendingContext { } } -impl DFlashDraftScratch { - fn new(ctx: &DeviceContext, config: &DFlashConfig, max_context_len: usize) -> Result { +impl DFlashContextScratch { + fn new(ctx: &DeviceContext, hidden_size: usize, max_context_len: usize) -> Result { + Ok(Self { + max_context_len, + context_projected: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, + context_hidden: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, + }) + } + + fn ensure_capacity( + &mut self, + ctx: &DeviceContext, + hidden_size: usize, + context_len: usize, + ) -> Result<()> { + if context_len > self.max_context_len { + *self = Self::new(ctx, hidden_size, context_len)?; + } + self.context_projected.seq_len = context_len; + self.context_hidden.seq_len = context_len; + Ok(()) + } +} + +impl DFlashBatchScratch { + fn new( + ctx: &DeviceContext, + config: &DFlashConfig, + max_decode_batch_size: usize, + ) -> Result { + anyhow::ensure!( + max_decode_batch_size > 0, + "DFlash batch scratch needs a non-zero batch size" + ); let block_size = config.block_size; let hidden_size = config.hidden_size; let q_dim = config.num_attention_heads * config.head_dim; let kv_dim = config.num_key_value_heads * config.head_dim; let inter_dim = config.intermediate_size; - let tail_capacity = max_context_len + block_size; + // Dense buffers span the whole decode batch so the dense ops run once. + let batch_rows = block_size * max_decode_batch_size; + // The shared varlen tail starts at one block (no committed context yet) + // and grows on demand via `ensure_tail_capacity`. + let tail_capacity = block_size; Ok(Self { - max_context_len, - block_token_ids_h: vec![config.dflash_config.mask_token_id; block_size], - token_ids_d: ctx.stream.alloc_zeros(block_size)?, - context_projected: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, - context_hidden: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, - hidden: HiddenStates::zeros(ctx, hidden_size, block_size)?, - hidden_out: HiddenStates::zeros(ctx, hidden_size, block_size)?, - normed: HiddenStates::zeros(ctx, hidden_size, block_size)?, + max_batch_block_rows: batch_rows, + max_tail_len: tail_capacity, + block_token_ids_h: vec![config.dflash_config.mask_token_id; batch_rows], + token_ids_d: ctx.stream.alloc_zeros(batch_rows)?, + hidden: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + hidden_out: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + normed: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + q_batch: HiddenStates::zeros(ctx, q_dim, batch_rows)?, + attn_output: HiddenStates::zeros(ctx, q_dim, batch_rows)?, + o_buf: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + gate_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + up_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + act_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + logits_normed: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + logits: HiddenStates::zeros(ctx, config.vocab_size, batch_rows)?, tail_input: HiddenStates::zeros(ctx, hidden_size, tail_capacity)?, - q_batch: HiddenStates::zeros(ctx, q_dim, block_size)?, k_tail: HiddenStates::zeros(ctx, kv_dim, tail_capacity)?, v_tail: HiddenStates::zeros(ctx, kv_dim, tail_capacity)?, - attn_output: HiddenStates::zeros(ctx, q_dim, block_size)?, - o_buf: HiddenStates::zeros(ctx, hidden_size, block_size)?, - gate_out: HiddenStates::zeros(ctx, inter_dim, block_size)?, - up_out: HiddenStates::zeros(ctx, inter_dim, block_size)?, - act_out: HiddenStates::zeros(ctx, inter_dim, block_size)?, - logits_normed: HiddenStates::zeros(ctx, hidden_size, block_size)?, - logits: HiddenStates::zeros(ctx, config.vocab_size, block_size)?, }) } - fn ensure_context_capacity( + /// Point every dense buffer at the active `batch_block_rows = active_batch * + /// block_size` prefix. Allocated for the max decode batch, so this only + /// shrinks `seq_len`; it never reallocates. + fn activate_dense(&mut self, batch_block_rows: usize) { + assert!( + batch_block_rows <= self.max_batch_block_rows, + "DFlash batched draft {} block rows exceeds scratch capacity {}", + batch_block_rows, + self.max_batch_block_rows + ); + self.hidden.seq_len = batch_block_rows; + self.hidden_out.seq_len = batch_block_rows; + self.normed.seq_len = batch_block_rows; + self.q_batch.seq_len = batch_block_rows; + self.attn_output.seq_len = batch_block_rows; + self.o_buf.seq_len = batch_block_rows; + self.gate_out.seq_len = batch_block_rows; + self.up_out.seq_len = batch_block_rows; + self.act_out.seq_len = batch_block_rows; + self.logits_normed.seq_len = batch_block_rows; + self.logits.seq_len = batch_block_rows; + } + + /// Size the shared varlen tail buffers for one request's `tail_len = + /// context_len + block_size`, growing the allocation if needed. + fn ensure_tail_capacity( &mut self, ctx: &DeviceContext, config: &DFlashConfig, - context_len: usize, + tail_len: usize, ) -> Result<()> { - if context_len > self.max_context_len { - *self = Self::new(ctx, config, context_len)?; + if tail_len > self.max_tail_len { + let hidden_size = config.hidden_size; + let kv_dim = config.num_key_value_heads * config.head_dim; + self.tail_input = HiddenStates::zeros(ctx, hidden_size, tail_len)?; + self.k_tail = HiddenStates::zeros(ctx, kv_dim, tail_len)?; + self.v_tail = HiddenStates::zeros(ctx, kv_dim, tail_len)?; + self.max_tail_len = tail_len; } - let block_size = config.block_size; - let tail_len = context_len + block_size; - - self.context_projected.seq_len = context_len; - self.context_hidden.seq_len = context_len; - self.hidden.seq_len = block_size; - self.hidden_out.seq_len = block_size; - self.normed.seq_len = block_size; self.tail_input.seq_len = tail_len; - self.q_batch.seq_len = block_size; self.k_tail.seq_len = tail_len; self.v_tail.seq_len = tail_len; - self.attn_output.seq_len = block_size; - self.o_buf.seq_len = block_size; - self.gate_out.seq_len = block_size; - self.up_out.seq_len = block_size; - self.act_out.seq_len = block_size; - self.logits_normed.seq_len = block_size; - self.logits.seq_len = block_size; Ok(()) } } @@ -487,6 +569,16 @@ impl DFlashDraftModel { Ok(()) } + /// Allocate the lane-level batched draft scratch once, sized for the whole + /// decode batch. The per-request `DFlashRequestState` no longer owns scratch. + pub(crate) fn new_batch_scratch( + &self, + ctx: &DeviceContext, + max_decode_batch_size: usize, + ) -> Result { + DFlashBatchScratch::new(ctx, &self.config, max_decode_batch_size) + } + pub(crate) fn new_request_state( &self, ctx: &DeviceContext, @@ -513,7 +605,11 @@ impl DFlashDraftModel { self.context_feature_dim(), self.config.block_size.min(max_cache_len), )?, - scratch: DFlashDraftScratch::new(ctx, &self.config, self.config.block_size)?, + context: DFlashContextScratch::new( + ctx, + self.config.hidden_size, + self.config.block_size, + )?, committed_len: 0, max_cache_len, }) @@ -563,159 +659,229 @@ impl DFlashDraftModel { Ok(()) } - pub(crate) fn draft_logits<'a>( + /// Batched draft forward over all active requests at once. + /// + /// The *dense* ops (embedding, rms_norm, q / o / gate_up / down GEMMs, silu, + /// add, fused_add_rms_norm, logits) run ONCE over an `active_batch * + /// block_size` batched buffer. The *varlen* ops (context projection, tail + /// concat, k/v GEMMs, rope, KV copy, attention) still loop per request, + /// slicing each request's `block_size` rows at offset `i * block_size` in the + /// batched buffers — those are Step 2/3's job to batch via CUDA-kernel changes. + /// + /// Returns the batched logits (`active_batch * block_size` rows): request `i` + /// owns rows `[i * block_size, (i + 1) * block_size)`. + pub(crate) fn draft_logits_batched<'a>( &self, target: &Qwen3Model, - state: &'a mut DFlashRequestState, - current_token: u32, + states: &mut [&mut DFlashRequestState], + current_tokens: &[u32], + scratch: &'a mut DFlashBatchScratch, ) -> Result<&'a HiddenStates> { let ctx = target.device_ctx(); - let Some(context_len) = state.pending_context_len() else { - anyhow::bail!("DFlash draft requested before target hidden context is available"); - }; - let block_size = self.block_size(); - let tail_len = context_len + block_size; + let active_batch = states.len(); anyhow::ensure!( - state.committed_len + tail_len <= state.max_cache_len, - "DFlash draft cache overflow: committed={}, tail={}, max={}", - state.committed_len, - tail_len, - state.max_cache_len + active_batch > 0, + "DFlash batched draft needs active requests" + ); + anyhow::ensure!( + states.len() == current_tokens.len(), + "DFlash batched draft: {} states vs {} current tokens", + states.len(), + current_tokens.len() ); + let block_size = self.block_size(); + let batch_block_rows = active_batch * block_size; - state - .scratch - .ensure_context_capacity(ctx, &self.config, context_len)?; - state.scratch.block_token_ids_h.fill(self.mask_token_id()); - state.scratch.block_token_ids_h[0] = current_token; + // Each request's committed context length for this round; advancing + // `committed_len` is deferred until after the layer loop (the rope start + // positions and KV write offsets read the pre-advance value). + let mut context_lens = Vec::with_capacity(active_batch); + for (i, state) in states.iter().enumerate() { + let Some(context_len) = state.pending_context_len() else { + anyhow::bail!( + "DFlash draft requested before target hidden context is available (request slot {i})" + ); + }; + let tail_len = context_len + block_size; + anyhow::ensure!( + state.committed_len + tail_len <= state.max_cache_len, + "DFlash draft cache overflow: committed={}, tail={}, max={}", + state.committed_len, + tail_len, + state.max_cache_len + ); + context_lens.push(context_len); + } + + scratch.activate_dense(batch_block_rows); + + // Build the batched token id buffer: each request's block is + // [current_token, mask, mask, ...]. + scratch.block_token_ids_h[..batch_block_rows].fill(self.mask_token_id()); + for (i, ¤t_token) in current_tokens.iter().enumerate() { + scratch.block_token_ids_h[i * block_size] = current_token; + } + // token_ids_d holds `max_batch * block_size` ids; copy only the active + // prefix. The embedding kernel reads `out.seq_len = batch_block_rows` ids + // from the buffer start, so the active prefix is what it consumes. + let mut token_ids_dst = scratch.token_ids_d.slice_mut(..batch_block_rows); ctx.stream.memcpy_htod( - &state.scratch.block_token_ids_h, - &mut state.scratch.token_ids_d, + &scratch.block_token_ids_h[..batch_block_rows], + &mut token_ids_dst, )?; - target.get_embeddings_batch_into(&state.scratch.token_ids_d, &mut state.scratch.hidden)?; + target.get_embeddings_batch_into(&scratch.token_ids_d, &mut scratch.hidden)?; - state.pending_context.activate_for_read(); - self.project_context_into(ctx, &state.pending_context.buffer, &mut state.scratch)?; - state.pending_context.clear(); + // Per-request context projection: varlen (each request's committed + // prefix differs), persisted in the request so every layer can read it. + for (i, state) in states.iter_mut().enumerate() { + let context_len = context_lens[i]; + state + .context + .ensure_capacity(ctx, self.config.hidden_size, context_len)?; + state.pending_context.activate_for_read(); + self.project_context_into(ctx, &state.pending_context.buffer, &mut state.context)?; + state.pending_context.clear(); + } let hidden_size = self.config.hidden_size; let q_dim = self.config.num_attention_heads * self.config.head_dim; let kv_dim = self.config.num_key_value_heads * self.config.head_dim; let inter_dim = self.config.intermediate_size; - debug_assert_eq!(state.scratch.hidden.hidden_dim, hidden_size); - debug_assert_eq!(state.scratch.q_batch.hidden_dim, q_dim); - debug_assert_eq!(state.scratch.k_tail.hidden_dim, kv_dim); - debug_assert_eq!(state.scratch.gate_out.hidden_dim, inter_dim); + debug_assert_eq!(scratch.hidden.hidden_dim, hidden_size); + debug_assert_eq!(scratch.q_batch.hidden_dim, q_dim); + debug_assert_eq!(scratch.k_tail.hidden_dim, kv_dim); + debug_assert_eq!(scratch.gate_out.hidden_dim, inter_dim); for (layer_idx, layer) in self.layers.iter().enumerate() { + // Dense: input layernorm over the whole batch. ops::rms_norm_batch_into( ctx, - &state.scratch.hidden, + &scratch.hidden, &layer.input_layernorm, self.config.rms_norm_eps, - &mut state.scratch.normed, + &mut scratch.normed, ); - ops::copy_hidden_token_range_into( - ctx, - &state.scratch.context_hidden, - 0, - &mut state.scratch.tail_input, - 0, - context_len, - )?; - ops::copy_hidden_token_range_into( - ctx, - &state.scratch.normed, - 0, - &mut state.scratch.tail_input, - context_len, - block_size, - )?; - + // Dense: Q projection over the whole batch (per-token, no cross-request + // mixing). Computed before the per-request loop reads `normed`, and + // before the post-attention norm overwrites it. ops::gemm_rows_into( ctx, &layer.attention.qkv_proj, 0, q_dim, - &state.scratch.normed, - &mut state.scratch.q_batch, - ); - ops::gemm_rows_into( - ctx, - &layer.attention.qkv_proj, - q_dim, - kv_dim, - &state.scratch.tail_input, - &mut state.scratch.k_tail, - ); - ops::gemm_rows_into( - ctx, - &layer.attention.qkv_proj, - q_dim + kv_dim, - kv_dim, - &state.scratch.tail_input, - &mut state.scratch.v_tail, + &scratch.normed, + &mut scratch.q_batch, ); - ops::dflash_qk_norm_rope_into( - ctx, - &mut state.scratch.q_batch, - &mut state.scratch.k_tail, - &layer.attention.q_norm, - &layer.attention.k_norm, - &self.cos_cache, - &self.sin_cache, - self.config.num_attention_heads, - self.config.num_key_value_heads, - self.config.head_dim, - state.committed_len + context_len, - state.committed_len, - self.config.rms_norm_eps, - )?; + // Per-request varlen attention: tail concat, k/v GEMMs, rope, KV copy, + // single-request prefill. Each request slices its `block_size` rows at + // offset `i * block_size` of the batched `normed`/`q_batch`/`attn_output`. + for (i, state) in states.iter_mut().enumerate() { + let context_len = context_lens[i]; + let tail_len = context_len + block_size; + let row_offset = i * block_size; + scratch.ensure_tail_capacity(ctx, &self.config, tail_len)?; - let cache = &mut state.layers[layer_idx]; - ops::copy_hidden_token_range_into( - ctx, - &state.scratch.k_tail, - 0, - &mut cache.k, - state.committed_len, - tail_len, - )?; - ops::copy_hidden_token_range_into( - ctx, - &state.scratch.v_tail, - 0, - &mut cache.v, - state.committed_len, - tail_len, - )?; - ops::single_prefill_nhd_noncausal_into( - ctx, - &state.scratch.q_batch, - &cache.k, - &cache.v, - &mut state.scratch.attn_output, - self.config.num_attention_heads, - self.config.num_key_value_heads, - self.config.head_dim, - state.committed_len + tail_len, - )?; + // tail_input = [context_hidden(context_len) | normed_block(block_size)]. + ops::copy_hidden_token_range_into( + ctx, + &state.context.context_hidden, + 0, + &mut scratch.tail_input, + 0, + context_len, + )?; + ops::copy_hidden_token_range_into( + ctx, + &scratch.normed, + row_offset, + &mut scratch.tail_input, + context_len, + block_size, + )?; + + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + q_dim, + kv_dim, + &scratch.tail_input, + &mut scratch.k_tail, + ); + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + q_dim + kv_dim, + kv_dim, + &scratch.tail_input, + &mut scratch.v_tail, + ); + ops::dflash_qk_norm_rope_into( + ctx, + &mut scratch.q_batch, + row_offset, + block_size, + &mut scratch.k_tail, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + state.committed_len + context_len, + state.committed_len, + self.config.rms_norm_eps, + )?; + + let cache = &mut state.layers[layer_idx]; + ops::copy_hidden_token_range_into( + ctx, + &scratch.k_tail, + 0, + &mut cache.k, + state.committed_len, + tail_len, + )?; + ops::copy_hidden_token_range_into( + ctx, + &scratch.v_tail, + 0, + &mut cache.v, + state.committed_len, + tail_len, + )?; + ops::single_prefill_nhd_noncausal_into( + ctx, + &scratch.q_batch, + row_offset, + block_size, + &cache.k, + &cache.v, + &mut scratch.attn_output, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + state.committed_len + tail_len, + )?; + } + + // Dense: o_proj + residual + post-attention norm + MLP over the batch. ops::gemm_into( ctx, &layer.attention.o_proj, - &state.scratch.attn_output, - &mut state.scratch.o_buf, + &scratch.attn_output, + &mut scratch.o_buf, ); openinfer_kernels::ops::fused_add_rms_norm_round_batch_into( ctx, - &mut state.scratch.hidden, - &state.scratch.o_buf, + &mut scratch.hidden, + &scratch.o_buf, &layer.post_attention_layernorm, self.config.rms_norm_eps, - &mut state.scratch.normed, + &mut scratch.normed, )?; ops::gemm_rows_into( @@ -723,41 +889,43 @@ impl DFlashDraftModel { &layer.mlp.gate_up_proj, 0, inter_dim, - &state.scratch.normed, - &mut state.scratch.gate_out, + &scratch.normed, + &mut scratch.gate_out, ); ops::gemm_rows_into( ctx, &layer.mlp.gate_up_proj, inter_dim, inter_dim, - &state.scratch.normed, - &mut state.scratch.up_out, + &scratch.normed, + &mut scratch.up_out, ); ops::silu_mul_batch_into( ctx, - &state.scratch.gate_out, - &state.scratch.up_out, - &mut state.scratch.act_out, + &scratch.gate_out, + &scratch.up_out, + &mut scratch.act_out, )?; ops::gemm_into( ctx, &layer.mlp.down_proj, - &state.scratch.act_out, - &mut state.scratch.o_buf, + &scratch.act_out, + &mut scratch.o_buf, ); ops::add_batch_into( ctx, - &state.scratch.hidden, - &state.scratch.o_buf, - &mut state.scratch.hidden_out, + &scratch.hidden, + &scratch.o_buf, + &mut scratch.hidden_out, )?; - std::mem::swap(&mut state.scratch.hidden, &mut state.scratch.hidden_out); + std::mem::swap(&mut scratch.hidden, &mut scratch.hidden_out); } - state.committed_len += context_len; - self.compute_logits_with_target_head_into(target, &mut state.scratch)?; - Ok(&state.scratch.logits) + for (i, state) in states.iter_mut().enumerate() { + state.committed_len += context_lens[i]; + } + self.compute_logits_with_target_head_into(target, scratch)?; + Ok(&scratch.logits) } fn context_feature_dim(&self) -> usize { @@ -768,20 +936,20 @@ impl DFlashDraftModel { &self, ctx: &DeviceContext, context_features: &HiddenStates, - scratch: &mut DFlashDraftScratch, + context: &mut DFlashContextScratch, ) -> Result<()> { ops::gemm_into( ctx, &self.fc, context_features, - &mut scratch.context_projected, + &mut context.context_projected, ); ops::rms_norm_batch_into( ctx, - &scratch.context_projected, + &context.context_projected, &self.hidden_norm, self.config.rms_norm_eps, - &mut scratch.context_hidden, + &mut context.context_hidden, ); Ok(()) } @@ -789,7 +957,7 @@ impl DFlashDraftModel { fn compute_logits_with_target_head_into( &self, target: &Qwen3Model, - scratch: &mut DFlashDraftScratch, + scratch: &mut DFlashBatchScratch, ) -> Result<()> { let ctx = target.device_ctx(); ops::rms_norm_batch_into( diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 3e821272..8fff8655 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -2361,7 +2361,12 @@ impl LocalQwen3Lane { max_position_embeddings: model.max_position_embeddings(), target_layer_ids: model.target_layer_ids().to_vec(), }; - self.dflash = Some(DFlashLaneState::new(model)); + let max_decode_batch_size = *BATCH_BUCKETS.last().unwrap(); + self.dflash = Some(DFlashLaneState::new( + self.model.device_ctx(), + model, + max_decode_batch_size, + )?); Ok(meta) } diff --git a/openinfer-qwen3-4b/src/executor/dflash_lane.rs b/openinfer-qwen3-4b/src/executor/dflash_lane.rs index c2263aa7..3ff01161 100644 --- a/openinfer-qwen3-4b/src/executor/dflash_lane.rs +++ b/openinfer-qwen3-4b/src/executor/dflash_lane.rs @@ -13,26 +13,36 @@ use openinfer_core::tensor::HiddenStates; use super::dflash_prefill::{dflash_prefill_can_capture, should_capture_dflash_prefill_context}; use super::{LocalQwen3Lane, PrefillStepItem, RequestId}; -use crate::dflash::{DFlashDraftModel, DFlashRequestState}; +use crate::dflash::{DFlashBatchScratch, DFlashDraftModel, DFlashRequestState}; use crate::speculative::{ DraftRequestResult, DraftResult, DraftStepItem, VerifyRequestResult, VerifyStepItem, }; +use openinfer_core::tensor::DeviceContext; pub(super) struct DFlashLaneState { pub(super) model: DFlashDraftModel, pub(super) requests: HashMap, + /// Lane-level batched draft scratch, allocated once for the whole decode + /// batch so the dense draft ops run once instead of once per request. + scratch: DFlashBatchScratch, verified_draft_tokens: usize, accepted_draft_tokens: usize, } impl DFlashLaneState { - pub(super) fn new(model: DFlashDraftModel) -> Self { - Self { + pub(super) fn new( + ctx: &DeviceContext, + model: DFlashDraftModel, + max_decode_batch_size: usize, + ) -> Result { + let scratch = model.new_batch_scratch(ctx, max_decode_batch_size)?; + Ok(Self { model, requests: HashMap::new(), + scratch, verified_draft_tokens: 0, accepted_draft_tokens: 0, - } + }) } } @@ -215,36 +225,81 @@ impl LocalQwen3Lane { } // Take the lane out of `self` so the draft forward (which borrows - // `self.model`) and the argmax (which borrows `self.sample_scratch`) - // don't collide on a `self` borrow. + // `dflash.model`/`dflash.scratch`) and the argmax (which borrows + // `self.sample_scratch`) don't collide on a `self` borrow. let Some(mut dflash) = self.dflash.take() else { anyhow::bail!("DFlash draft requested but DFlash is not loaded"); }; let result = (|| -> Result> { - let mut outputs = Vec::with_capacity(requests.len()); + // Pull every active request's state out of the map so the batched + // forward can hold `&mut` to all of them at once. Re-inserted below. + let mut taken: Vec<(RequestId, DFlashRequestState)> = + Vec::with_capacity(requests.len()); for req in requests { - let mut state = dflash.requests.remove(&req.request_id).ok_or_else(|| { + let state = dflash.requests.remove(&req.request_id).ok_or_else(|| { anyhow::anyhow!("missing DFlash state for {:?}", req.request_id) })?; - let draft_logits = - dflash - .model - .draft_logits(&self.model, &mut state, req.current_token)?; + taken.push((req.request_id, state)); + } + + let block_size = dflash.model.block_size(); + let current_tokens: Vec = requests.iter().map(|req| req.current_token).collect(); + let DFlashLaneState { + model, + scratch, + requests: state_map, + .. + } = &mut dflash; + let mut state_refs: Vec<&mut DFlashRequestState> = + taken.iter_mut().map(|(_, state)| state).collect(); + + let sampled = { + let draft_logits = model.draft_logits_batched( + &self.model, + &mut state_refs, + ¤t_tokens, + scratch, + )?; let draft_len = draft_logits.seq_len; + anyhow::ensure!( + draft_len == requests.len() * block_size, + "DFlash batched draft produced {} logits rows for {} requests x block {}", + draft_len, + requests.len(), + block_size + ); let greedy = SamplingParams::default(); let params: Vec<&SamplingParams> = vec![&greedy; draft_len]; - let sampled = self.select_step_tokens(draft_logits, ¶ms, 0)?; - dflash.requests.insert(req.request_id, state); + self.select_step_tokens(draft_logits, ¶ms, 0)? + }; + + // Re-insert every request's state before splitting the result. + for (request_id, state) in taken { + state_map.insert(request_id, state); + } + + anyhow::ensure!( + sampled.len() == requests.len() * block_size, + "DFlash batched draft sampled {} tokens for {} requests x block {}", + sampled.len(), + requests.len(), + block_size + ); + + // Split the batched samples per request: request `i` owns rows + // `[i * block_size, (i + 1) * block_size)`. Verify span = [current + // dangling token, draft_1, …, draft_{K}]. + let mut outputs = Vec::with_capacity(requests.len()); + for (i, req) in requests.iter().enumerate() { + let block = &sampled[i * block_size..(i + 1) * block_size]; anyhow::ensure!( - sampled.len() == draft_len && sampled.len() >= 2, - "DFlash draft sampled {} tokens from {} logits columns", - sampled.len(), - draft_len + block.len() >= 2, + "DFlash draft block {} has fewer than 2 tokens", + i ); - // Verify span = [current dangling token, draft_1, …, draft_{K}]. - let mut token_ids = Vec::with_capacity(sampled.len()); + let mut token_ids = Vec::with_capacity(block.len()); token_ids.push(req.current_token); - token_ids.extend(sampled.into_iter().skip(1)); + token_ids.extend(block[1..].iter().copied()); outputs.push(DraftRequestResult { request_id: req.request_id, token_ids, From 84b98f0254a3fd7eaa8529392e2c4d856a782dec Mon Sep 17 00:00:00 2001 From: xiaguan <751080330@qq.com> Date: Tue, 23 Jun 2026 01:31:19 +0800 Subject: [PATCH 7/9] feat(qwen3-dflash): verify fixed-buffer forward (graph prerequisite) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DFlash verify ran batch_prefill(echo=true) which allocates fresh GPU scratch every step (PrefillBuffers::new, embedding HiddenStates, all_logits, PrefillPagedPlan upload) — moving pointers that CUDA Graph capture cannot tolerate. This is the buffer-reuse refactor that precedes capturing the verify forward into a graph (no capture yet this phase). - PrefillPagedPlan: new_preallocated + update_batch_with_cta_tile_q refill the device buffers in place (memcpy, capacity-guarded) instead of clone_htod, keeping pointers stable. Host plan math extracted into a shared BatchPlanHost so new_batch_with_cta_tile_q is byte-identical. - VerifyGraphBuffers: pre-allocates all verify scratch once at the worst-case max_batch*span shape; set_rows only moves seq_len. - batch_prefill_into: buffer-reusing twin of batch_prefill, issues the identical kernel sequence (forward_layer_batch_paged reused verbatim) so verify logits/captured-hidden stay bit-equal. Verify never uses LoRA. - lane: SpeculativeVerify routes through execute_dflash_verify; the old allocating verify path is gone. Normal prefill is untouched. Co-Authored-By: Claude Opus 4.8 (1M context) --- openinfer-core/src/ops/paged_plan.rs | 47 ++++ openinfer-kernels/src/ops/attention.rs | 230 ++++++++++++++++++-- openinfer-qwen3-4b/src/executor.rs | 111 +++++++--- openinfer-qwen3-4b/src/lib.rs | 1 + openinfer-qwen3-4b/src/prefill.rs | 19 +- openinfer-qwen3-4b/src/verify_graph.rs | 286 +++++++++++++++++++++++++ 6 files changed, 649 insertions(+), 45 deletions(-) create mode 100644 openinfer-qwen3-4b/src/verify_graph.rs diff --git a/openinfer-core/src/ops/paged_plan.rs b/openinfer-core/src/ops/paged_plan.rs index 1d162951..d2770939 100644 --- a/openinfer-core/src/ops/paged_plan.rs +++ b/openinfer-core/src/ops/paged_plan.rs @@ -147,6 +147,53 @@ impl PrefillPagedPlan { }) } + /// Pre-allocate a worst-case-sized plan to be refilled in place by + /// [`Self::update_batch_with_cta_tile_q`] (graph-stable buffer reuse). + pub fn new_preallocated( + ctx: &DeviceContext, + max_total_tokens: usize, + max_total_pages: usize, + max_batch: usize, + max_tiles: usize, + ) -> Result { + Ok(Self { + inner: openinfer_kernels::ops::PrefillPagedPlan::new_preallocated( + ctx, + max_total_tokens, + max_total_pages, + max_batch, + max_tiles, + )?, + }) + } + + /// Refill a pre-allocated plan in place (no allocation, pointers unchanged). + #[allow(clippy::too_many_arguments)] + pub fn update_batch_with_cta_tile_q( + &mut self, + ctx: &DeviceContext, + page_indices: &[Vec], + last_page_lens: &[usize], + start_positions: &[usize], + seq_lens: &[usize], + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + cta_tile_q_override: i32, + ) -> Result<()> { + self.inner.update_batch_with_cta_tile_q( + ctx, + page_indices, + last_page_lens, + start_positions, + seq_lens, + num_q_heads, + num_kv_heads, + head_dim, + cta_tile_q_override, + ) + } + pub fn page_indices_d(&self) -> &CudaSlice { self.inner.page_indices_d() } diff --git a/openinfer-kernels/src/ops/attention.rs b/openinfer-kernels/src/ops/attention.rs index 0532f7b5..55b301d3 100644 --- a/openinfer-kernels/src/ops/attention.rs +++ b/openinfer-kernels/src/ops/attention.rs @@ -215,6 +215,210 @@ impl PrefillPagedPlan { num_kv_heads: usize, head_dim: usize, cta_tile_q_override: i32, + ) -> Result { + let host = BatchPlanHost::compute( + page_indices, + last_page_lens, + start_positions, + seq_lens, + num_q_heads, + num_kv_heads, + head_dim, + cta_tile_q_override, + )?; + + // Upload all to GPU + Ok(Self { + page_indices_d: ctx.stream.clone_htod(&host.all_page_indices)?, + page_indptr_d: ctx.stream.clone_htod(&host.page_indptr)?, + last_page_len_d: ctx.stream.clone_htod(&host.last_page_lens_i32)?, + batch_indices_d: ctx.stream.clone_htod(&host.batch_indices)?, + positions_d: ctx.stream.clone_htod(&host.positions)?, + q_indptr_d: ctx.stream.clone_htod(&host.q_indptr)?, + request_indices_d: ctx.stream.clone_htod(&host.request_indices_v)?, + qo_tile_indices_d: ctx.stream.clone_htod(&host.qo_tile_indices_v)?, + kv_tile_indices_d: ctx.stream.clone_htod(&host.kv_tile_indices_v)?, + kv_chunk_size_d: ctx.stream.clone_htod(&host.kv_chunk_sizes)?, + total_num_rows_d: ctx.stream.clone_htod(&[host.total_tokens as u32])?, + num_tiles: host.num_tiles, + batch_size: host.batch_size as i32, + total_tokens: host.total_tokens, + cta_tile_q: host.cta_tile_q as i32, + }) + } + + /// Allocate a worst-case-sized plan once, to be refilled in place by + /// [`Self::update_batch_with_cta_tile_q`]. Buffer pointers stay fixed across + /// updates so a CUDA Graph captured against them remains valid on replay. + /// Scalar fields start at 0; an unfilled plan must not be used for a forward. + pub fn new_preallocated( + ctx: &DeviceContext, + max_total_tokens: usize, + max_total_pages: usize, + max_batch: usize, + max_tiles: usize, + ) -> Result { + Ok(Self { + page_indices_d: ctx.stream.alloc_zeros(max_total_pages)?, + page_indptr_d: ctx.stream.alloc_zeros(max_batch + 1)?, + last_page_len_d: ctx.stream.alloc_zeros(max_batch)?, + batch_indices_d: ctx.stream.alloc_zeros(max_total_tokens)?, + positions_d: ctx.stream.alloc_zeros(max_total_tokens)?, + q_indptr_d: ctx.stream.alloc_zeros(max_batch + 1)?, + request_indices_d: ctx.stream.alloc_zeros(max_tiles)?, + qo_tile_indices_d: ctx.stream.alloc_zeros(max_tiles)?, + kv_tile_indices_d: ctx.stream.alloc_zeros(max_tiles)?, + kv_chunk_size_d: ctx.stream.alloc_zeros(max_batch)?, + total_num_rows_d: ctx.stream.alloc_zeros(1)?, + num_tiles: 0, + batch_size: 0, + total_tokens: 0, + cta_tile_q: 0, + }) + } + + /// Recompute the host-side batch metadata and `memcpy_htod` it into the + /// pre-allocated device buffers (no allocation, no pointer change). The host + /// computation is identical to [`Self::new_batch_with_cta_tile_q`]; only the + /// upload differs (overwrite in place vs. fresh `clone_htod`). + /// + /// `memcpy_htod` copies `src.len()` elements and tolerates a larger + /// destination, so the worst-case allocation may exceed the actual fill. + #[allow(clippy::too_many_arguments)] + pub fn update_batch_with_cta_tile_q( + &mut self, + ctx: &DeviceContext, + page_indices: &[Vec], + last_page_lens: &[usize], + start_positions: &[usize], + seq_lens: &[usize], + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + cta_tile_q_override: i32, + ) -> Result<()> { + let host = BatchPlanHost::compute( + page_indices, + last_page_lens, + start_positions, + seq_lens, + num_q_heads, + num_kv_heads, + head_dim, + cta_tile_q_override, + )?; + + anyhow::ensure!( + host.all_page_indices.len() <= self.page_indices_d.len(), + "verify plan page_indices ({}) exceeds preallocated capacity ({})", + host.all_page_indices.len(), + self.page_indices_d.len(), + ); + anyhow::ensure!( + host.page_indptr.len() <= self.page_indptr_d.len(), + "verify plan page_indptr ({}) exceeds preallocated capacity ({})", + host.page_indptr.len(), + self.page_indptr_d.len(), + ); + anyhow::ensure!( + host.last_page_lens_i32.len() <= self.last_page_len_d.len(), + "verify plan last_page_lens ({}) exceeds preallocated capacity ({})", + host.last_page_lens_i32.len(), + self.last_page_len_d.len(), + ); + anyhow::ensure!( + host.batch_indices.len() <= self.batch_indices_d.len(), + "verify plan batch_indices ({}) exceeds preallocated capacity ({})", + host.batch_indices.len(), + self.batch_indices_d.len(), + ); + anyhow::ensure!( + host.positions.len() <= self.positions_d.len(), + "verify plan positions ({}) exceeds preallocated capacity ({})", + host.positions.len(), + self.positions_d.len(), + ); + anyhow::ensure!( + host.q_indptr.len() <= self.q_indptr_d.len(), + "verify plan q_indptr ({}) exceeds preallocated capacity ({})", + host.q_indptr.len(), + self.q_indptr_d.len(), + ); + anyhow::ensure!( + host.request_indices_v.len() <= self.request_indices_d.len(), + "verify plan tiles ({}) exceeds preallocated capacity ({})", + host.request_indices_v.len(), + self.request_indices_d.len(), + ); + anyhow::ensure!( + host.kv_chunk_sizes.len() <= self.kv_chunk_size_d.len(), + "verify plan kv_chunk_sizes ({}) exceeds preallocated capacity ({})", + host.kv_chunk_sizes.len(), + self.kv_chunk_size_d.len(), + ); + + ctx.stream + .memcpy_htod(&host.all_page_indices, &mut self.page_indices_d)?; + ctx.stream + .memcpy_htod(&host.page_indptr, &mut self.page_indptr_d)?; + ctx.stream + .memcpy_htod(&host.last_page_lens_i32, &mut self.last_page_len_d)?; + ctx.stream + .memcpy_htod(&host.batch_indices, &mut self.batch_indices_d)?; + ctx.stream + .memcpy_htod(&host.positions, &mut self.positions_d)?; + ctx.stream + .memcpy_htod(&host.q_indptr, &mut self.q_indptr_d)?; + ctx.stream + .memcpy_htod(&host.request_indices_v, &mut self.request_indices_d)?; + ctx.stream + .memcpy_htod(&host.qo_tile_indices_v, &mut self.qo_tile_indices_d)?; + ctx.stream + .memcpy_htod(&host.kv_tile_indices_v, &mut self.kv_tile_indices_d)?; + ctx.stream + .memcpy_htod(&host.kv_chunk_sizes, &mut self.kv_chunk_size_d)?; + ctx.stream + .memcpy_htod(&[host.total_tokens as u32], &mut self.total_num_rows_d)?; + + self.num_tiles = host.num_tiles; + self.batch_size = host.batch_size as i32; + self.total_tokens = host.total_tokens; + self.cta_tile_q = host.cta_tile_q as i32; + Ok(()) + } +} + +/// Host-side batch-prefill metadata, computed identically for the fresh +/// (`new_batch_with_cta_tile_q`) and in-place (`update_batch_with_cta_tile_q`) +/// paths so the two never diverge. +struct BatchPlanHost { + all_page_indices: Vec, + page_indptr: Vec, + last_page_lens_i32: Vec, + batch_indices: Vec, + positions: Vec, + q_indptr: Vec, + request_indices_v: Vec, + qo_tile_indices_v: Vec, + kv_tile_indices_v: Vec, + kv_chunk_sizes: Vec, + num_tiles: i32, + batch_size: usize, + total_tokens: usize, + cta_tile_q: usize, +} + +impl BatchPlanHost { + #[allow(clippy::too_many_arguments)] + fn compute( + page_indices: &[Vec], + last_page_lens: &[usize], + start_positions: &[usize], + seq_lens: &[usize], + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + cta_tile_q_override: i32, ) -> Result { let batch_size = page_indices.len(); assert_eq!(batch_size, last_page_lens.len()); @@ -281,23 +485,21 @@ impl PrefillPagedPlan { } let num_tiles = request_indices_v.len() as i32; - // Upload all to GPU Ok(Self { - page_indices_d: ctx.stream.clone_htod(&all_page_indices)?, - page_indptr_d: ctx.stream.clone_htod(&page_indptr)?, - last_page_len_d: ctx.stream.clone_htod(&last_page_lens_i32)?, - batch_indices_d: ctx.stream.clone_htod(&batch_indices)?, - positions_d: ctx.stream.clone_htod(&positions)?, - q_indptr_d: ctx.stream.clone_htod(&q_indptr)?, - request_indices_d: ctx.stream.clone_htod(&request_indices_v)?, - qo_tile_indices_d: ctx.stream.clone_htod(&qo_tile_indices_v)?, - kv_tile_indices_d: ctx.stream.clone_htod(&kv_tile_indices_v)?, - kv_chunk_size_d: ctx.stream.clone_htod(&kv_chunk_sizes)?, - total_num_rows_d: ctx.stream.clone_htod(&[total_tokens as u32])?, + all_page_indices, + page_indptr, + last_page_lens_i32, + batch_indices, + positions, + q_indptr, + request_indices_v, + qo_tile_indices_v, + kv_tile_indices_v, + kv_chunk_sizes, num_tiles, - batch_size: batch_size as i32, + batch_size, total_tokens, - cta_tile_q: cta_tile_q as i32, + cta_tile_q, }) } } diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 8fff8655..cf8ed126 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -30,6 +30,8 @@ use crate::speculative::{ use dflash_lane::DFlashLaneState; use dflash_prefill::{DFlashPrefillAction, dflash_prefill_action}; +use crate::verify_graph::VerifyGraphBuffers; + #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] pub struct RequestId(pub(crate) u64); @@ -492,36 +494,13 @@ fn execute_step_on_lane( } StepCommand::SpeculativeVerify { requests, kv_views } => { // One target forward over each request's K+1 draft span with a - // speculative KV view. echo=true yields all-position logits so we - // can argmax every span position — accept_greedy needs the target's - // posterior at each position. The same forward captures target - // hidden states (at the DFlash layers) to seed the next draft. - let spans: Vec<&[u32]> = requests.iter().map(VerifyStepItem::as_slice).collect(); - let no_lora: Vec> = vec![None; requests.len()]; - let capture_layer_ids = lane.dflash_capture_layer_ids(); - let (_last_logits, all_logits, captured_hidden) = lane.execute_prefill( - &spans, - kv_views, - &no_lora, - true, - capture_layer_ids.as_deref(), - )?; - let all_logits = all_logits.ok_or_else(|| { - anyhow::anyhow!("speculative verify produced no per-position logits") - })?; - let total_tokens: usize = requests.iter().map(|req| req.as_slice().len()).sum(); - let greedy = SamplingParams::default(); - let params: Vec<&SamplingParams> = vec![&greedy; total_tokens]; - let target_tokens = lane.select_step_tokens(&all_logits, ¶ms, 0)?; - let request_results = build_verify_results(requests, &target_tokens)?; - lane.record_verify_dflash_context( - requests, - &request_results, - captured_hidden.as_ref(), - )?; - Ok(WorkerStepOutcome::SpeculativeVerify(VerifyResult { - requests: request_results, - })) + // speculative KV view. The fixed-buffer verify path computes all- + // position logits (accept_greedy needs the target's posterior at + // each span position) and captures the target hidden states (at the + // DFlash layers) to seed the next draft — all into reused, + // pointer-stable scratch (`VerifyGraphBuffers`). + let result = lane.execute_dflash_verify(requests, kv_views)?; + Ok(WorkerStepOutcome::SpeculativeVerify(result)) } StepCommand::SpeculativeDraft { requests } => Ok(WorkerStepOutcome::SpeculativeDraft( lane.execute_dflash_draft(requests)?, @@ -2287,6 +2266,13 @@ struct LocalQwen3Lane { /// DFlash draft lane (the draft model + per-request draft state). `None` /// unless speculative decoding is enabled; only the primary rank carries it. dflash: Option, + /// Fixed, pre-allocated scratch for the DFlash verify forward. Lazily built + /// on the first verify step (its shape depends on the loaded draft model's + /// block size and the target's capture layers). Pointer-stable for the + /// upcoming verify CUDA Graph. + verify_bufs: Option, + /// KV pool block count — the worst-case page-list bound for `verify_bufs`. + total_blocks: usize, } /// Stored state for an async prefill that was launched but not yet synced. @@ -2343,6 +2329,8 @@ impl LocalQwen3Lane { sample_scratch, inflight_prefill: None, dflash: None, + verify_bufs: None, + total_blocks, }) } @@ -2486,6 +2474,69 @@ impl LocalQwen3Lane { ) } + /// DFlash verify forward over each request's `block_size`-token span, using + /// the fixed pre-allocated [`VerifyGraphBuffers`] (no per-step allocation). + /// Numerically equivalent to the `batch_prefill(echo=true)` verify path it + /// replaces; the buffers are lazily built on first use. + fn execute_dflash_verify( + &mut self, + requests: &[VerifyStepItem], + kv_views: &[KvView], + ) -> Result { + let capture_layer_ids = self.dflash_capture_layer_ids().ok_or_else(|| { + anyhow::anyhow!("DFlash verify requested but no draft model is loaded") + })?; + let block_size = self + .dflash + .as_ref() + .expect("DFlash present when capture layers exist") + .model + .block_size(); + + if self.verify_bufs.is_none() { + let max_batch = *BATCH_BUCKETS.last().unwrap(); + self.verify_bufs = Some(VerifyGraphBuffers::new( + &self.model, + max_batch, + block_size, + capture_layer_ids.len(), + self.total_blocks, + )?); + } + + // Take the buffers out of `self` so the forward (borrows `&self.model`, + // `&mut bufs`) and the subsequent sampling (`&mut self.sample_scratch`) + // and context record (`&mut self.dflash`) don't alias a `self` borrow. + let mut bufs = self.verify_bufs.take().expect("verify buffers just set"); + let result = (|| -> Result { + let spans: Vec<&[u32]> = requests.iter().map(VerifyStepItem::as_slice).collect(); + self.model.batch_prefill_into( + &spans, + kv_views, + self.kv_buffer.buffer(), + &self.layout, + &capture_layer_ids, + &mut bufs, + )?; + + let total_tokens: usize = requests.iter().map(|req| req.as_slice().len()).sum(); + let greedy = SamplingParams::default(); + let params: Vec<&SamplingParams> = vec![&greedy; total_tokens]; + let target_tokens = self.select_step_tokens(bufs.all_logits(), ¶ms, 0)?; + let request_results = build_verify_results(requests, &target_tokens)?; + self.record_verify_dflash_context( + requests, + &request_results, + Some(bufs.captured_hidden()), + )?; + Ok(VerifyResult { + requests: request_results, + }) + })(); + self.verify_bufs = Some(bufs); + result + } + fn execute_decode( &mut self, token_ids: &[u32], diff --git a/openinfer-qwen3-4b/src/lib.rs b/openinfer-qwen3-4b/src/lib.rs index 0931ed96..a6bb433f 100644 --- a/openinfer-qwen3-4b/src/lib.rs +++ b/openinfer-qwen3-4b/src/lib.rs @@ -14,6 +14,7 @@ mod prefill; mod scheduler; mod speculative; mod unified_forward; +mod verify_graph; mod weights; use std::path::{Path, PathBuf}; diff --git a/openinfer-qwen3-4b/src/prefill.rs b/openinfer-qwen3-4b/src/prefill.rs index ddd33323..f3417f5e 100644 --- a/openinfer-qwen3-4b/src/prefill.rs +++ b/openinfer-qwen3-4b/src/prefill.rs @@ -73,6 +73,23 @@ impl PrefillBuffers { attn_output: HiddenStates::zeros(ctx, q_dim, seq_len)?, }) } + + /// Point every scratch buffer's logical row count at `rows` without + /// reallocating. Used by the fixed-buffer verify path (see + /// [`crate::verify_graph`]); the buffers must have been allocated for at + /// least `rows`. + pub(super) fn set_rows(&mut self, rows: usize) { + self.hidden_out.seq_len = rows; + self.normed.seq_len = rows; + self.q_batch.seq_len = rows; + self.k_batch.seq_len = rows; + self.v_batch.seq_len = rows; + self.o_buf.seq_len = rows; + self.gate_out.seq_len = rows; + self.up_out.seq_len = rows; + self.act_out.seq_len = rows; + self.attn_output.seq_len = rows; + } } impl Qwen3Model { @@ -118,7 +135,7 @@ impl Qwen3Model { } #[allow(clippy::too_many_arguments)] - fn forward_layer_batch_paged( + pub(crate) fn forward_layer_batch_paged( &self, layer_idx: usize, layer: &TransformerBlock, diff --git a/openinfer-qwen3-4b/src/verify_graph.rs b/openinfer-qwen3-4b/src/verify_graph.rs new file mode 100644 index 00000000..d7cff1f0 --- /dev/null +++ b/openinfer-qwen3-4b/src/verify_graph.rs @@ -0,0 +1,286 @@ +//! Fixed, pre-allocated buffers for the DFlash speculative *verify* forward. +//! +//! The verify forward runs a target prefill over each active request's `span = +//! num_speculative_tokens + 1` token block (see [`super::executor`]'s +//! `SpeculativeVerify` handler). The default [`Qwen3Model::batch_prefill`] path +//! allocates fresh GPU scratch every step (`PrefillBuffers::new`, the embedding +//! `HiddenStates`, `all_logits`, and the `PrefillPagedPlan` upload). That churns +//! `cuMemAllocAsync`/`cuMemFreeAsync` and — more importantly — hands CUDA Graph +//! capture moving pointers. +//! +//! [`VerifyGraphBuffers`] pre-allocates all of that once at the worst-case shape +//! (`max_batch * span` rows) and refills it in place each step. +//! [`Qwen3Model::batch_prefill_into`] is the buffer-reusing twin of +//! `batch_prefill`: it issues the *same* kernel sequence +//! (`forward_layer_batch_paged`, identical to the allocating path) so the verify +//! result stays bit-for-bit equivalent. This is the pre-requisite refactor for +//! capturing the verify forward into a CUDA Graph — this phase does the buffer +//! reuse only, no graph capture yet. + +use anyhow::Result; +use cudarc::driver::CudaSlice; + +use openinfer_core::kv_pool::KvLayout; +use openinfer_core::ops; +use openinfer_core::ops::PrefillPagedPlan; +use openinfer_core::tensor::HiddenStates; +use openinfer_kv_cache::KvView; + +use crate::config::PREFILL_ATTENTION_CTA_TILE_Q; +use crate::lora::DeviceLoraTokenGroup; +use crate::prefill::PrefillBuffers; +use crate::weights::Qwen3Model; + +/// All GPU scratch the verify forward needs, sized once for `max_batch * span` +/// rows and reused (in place) every step. Pointer-stable for CUDA Graph capture. +pub(crate) struct VerifyGraphBuffers { + /// Per-layer projection/attention scratch (reused exactly as the allocating + /// prefill path uses it). + prefill_bufs: PrefillBuffers, + /// Residual-stream hidden states `[hidden_dim, max_total_rows]`. + hidden: HiddenStates, + /// Captured target hidden states for the DFlash layers, + /// `[hidden_size * num_capture_layers, max_total_rows]`. + captured_hidden: HiddenStates, + /// RMS-norm output feeding the lm_head GEMM `[hidden_dim, max_total_rows]`. + all_logits_normed: HiddenStates, + /// All-position logits `[vocab, max_total_rows]` (the verify forward's output). + all_logits: HiddenStates, + /// Device-resident concatenated verify tokens `[max_total_rows]`. + token_ids_d: CudaSlice, + /// Paged-attention plan, refilled in place each step. + plan: PrefillPagedPlan, + max_batch: usize, + span: usize, +} + +impl VerifyGraphBuffers { + /// Allocate verify scratch for up to `max_batch` requests, each a fixed + /// `span`-token block. `num_capture_layers` is the DFlash target-layer count + /// (the captured-hidden buffer holds one `hidden_size` slice per layer). + /// `max_total_pages` bounds the paged-attention page list; pass the KV + /// pool's total block count for a guaranteed worst case. + pub(crate) fn new( + model: &Qwen3Model, + max_batch: usize, + span: usize, + num_capture_layers: usize, + max_total_pages: usize, + ) -> Result { + anyhow::ensure!(max_batch > 0, "verify buffers need max_batch >= 1"); + anyhow::ensure!(span > 0, "verify buffers need span >= 1"); + let ctx = model.device_ctx(); + let hidden_dim = model.config().hidden_size; + let q_dim = model.local_q_dim(); + let kv_dim = model.local_kv_dim(); + let inter_dim = model.local_intermediate_size(); + let vocab = model.config().vocab_size; + let max_total_rows = max_batch * span; + + // Each request's `span` query tokens fan out to `span * group_size` + // packed-QO rows; with a CTA tile of at least 1, that bounds tiles per + // request. `max_batch * span * group_size` is the conservative ceiling. + let group_size = model.local_num_attention_heads() / model.local_num_key_value_heads(); + let max_tiles = max_batch * span * group_size.max(1); + + Ok(Self { + prefill_bufs: PrefillBuffers::new( + ctx, + hidden_dim, + q_dim, + kv_dim, + inter_dim, + max_total_rows, + )?, + hidden: HiddenStates::zeros(ctx, hidden_dim, max_total_rows)?, + captured_hidden: HiddenStates::zeros( + ctx, + hidden_dim * num_capture_layers.max(1), + max_total_rows, + )?, + all_logits_normed: HiddenStates::zeros(ctx, hidden_dim, max_total_rows)?, + all_logits: HiddenStates::zeros(ctx, vocab, max_total_rows)?, + token_ids_d: ctx.stream.alloc_zeros(max_total_rows)?, + plan: PrefillPagedPlan::new_preallocated( + ctx, + max_total_rows, + max_total_pages, + max_batch, + max_tiles, + )?, + max_batch, + span, + }) + } + + /// Point every buffer's logical extent at `total_rows` (`<= max capacity`). + /// Like [`PrefillBuffers`] / [`super::batch_decode_buffers`], this only moves + /// `seq_len`; it never reallocates. + fn set_rows(&mut self, total_rows: usize) { + let cap = self.max_batch * self.span; + assert!( + total_rows <= cap, + "verify total_rows {total_rows} exceeds capacity {cap}" + ); + self.prefill_bufs.set_rows(total_rows); + self.hidden.seq_len = total_rows; + self.captured_hidden.seq_len = total_rows; + self.all_logits_normed.seq_len = total_rows; + self.all_logits.seq_len = total_rows; + } + + /// All-position logits `[vocab, total_rows]` from the last forward. + pub(crate) fn all_logits(&self) -> &HiddenStates { + &self.all_logits + } + + /// Captured target hidden states `[hidden_size * num_capture_layers, total_rows]`. + pub(crate) fn captured_hidden(&self) -> &HiddenStates { + &self.captured_hidden + } +} + +impl Qwen3Model { + /// Buffer-reusing twin of [`Qwen3Model::batch_prefill`] for the DFlash verify + /// forward. Issues the identical kernel sequence + /// ([`Self::forward_layer_batch_paged`] is reused verbatim), so the all- + /// position logits and captured hidden states are bit-for-bit equal to the + /// allocating path; only the buffer *source* differs (reused vs. freshly + /// allocated). Results land in `bufs` (`all_logits()` / `captured_hidden()`). + /// + /// `capture_layer_ids` must be the strictly-increasing DFlash target layers + /// whose count matches the `num_capture_layers` `bufs` was built with. + pub(crate) fn batch_prefill_into( + &self, + prompts: &[&[u32]], + kv_views: &[KvView], + kv_buffer: &CudaSlice, + layout: &KvLayout, + capture_layer_ids: &[usize], + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + let batch_size = prompts.len(); + anyhow::ensure!( + batch_size == kv_views.len(), + "verify prompts ({batch_size}) and kv_views ({}) length mismatch", + kv_views.len() + ); + anyhow::ensure!( + batch_size <= bufs.max_batch, + "verify batch {batch_size} exceeds buffer capacity {}", + bufs.max_batch + ); + anyhow::ensure!( + capture_layer_ids.windows(2).all(|pair| pair[0] < pair[1]), + "verify capture layer ids must be strictly increasing" + ); + anyhow::ensure!( + capture_layer_ids + .iter() + .all(|&layer| layer < self.layers.len()), + "verify capture layer id out of range" + ); + let expected_capture_dim = self.config().hidden_size * capture_layer_ids.len().max(1); + anyhow::ensure!( + bufs.captured_hidden.hidden_dim == expected_capture_dim, + "verify capture buffer dim {} does not match {} capture layers", + bufs.captured_hidden.hidden_dim, + capture_layer_ids.len(), + ); + + let seq_lens: Vec = prompts.iter().map(|p| p.len()).collect(); + let total_tokens: usize = seq_lens.iter().sum(); + anyhow::ensure!(total_tokens > 0, "verify forward has no tokens"); + let start_positions: Vec = kv_views + .iter() + .zip(prompts.iter()) + .map(|(v, p)| v.seq_len() - p.len()) + .collect(); + + bufs.set_rows(total_tokens); + + // Embed the concatenated verify tokens directly into the fixed `hidden` + // buffer (device token staging, no host round-trip or allocation). + let all_tokens: Vec = prompts.iter().flat_map(|p| p.iter().copied()).collect(); + anyhow::ensure!( + all_tokens.len() == total_tokens, + "verify token concat {} != total_tokens {total_tokens}", + all_tokens.len() + ); + let ctx = self.device_ctx(); + // Stage the active tokens into the front of the fixed device buffer; the + // embedding kernel reads exactly `bufs.hidden.seq_len` (= total_tokens) + // ids from its base pointer, so the unused tail is never touched. + ctx.stream.memcpy_htod(&all_tokens, &mut bufs.token_ids_d)?; + self.get_embeddings_batch_into(&bufs.token_ids_d, &mut bufs.hidden)?; + + // Refill the paged plan in place (same host math as the allocating path). + let page_indices: Vec> = + kv_views.iter().map(|v| v.page_indices().to_vec()).collect(); + let last_page_lens: Vec = kv_views + .iter() + .map(openinfer_kv_cache::KvView::last_page_len) + .collect(); + bufs.plan.update_batch_with_cta_tile_q( + ctx, + &page_indices, + &last_page_lens, + &start_positions, + &seq_lens, + self.local_num_attention_heads(), + self.local_num_key_value_heads(), + self.config().head_dim, + PREFILL_ATTENTION_CTA_TILE_Q, + )?; + + // Verify never uses LoRA — the draft/verify boundary is a plain token span. + let lora_groups: [DeviceLoraTokenGroup<'_>; 0] = []; + let mut next_capture = 0usize; + let hidden_size = self.config().hidden_size; + for (layer_idx, layer) in self.layers.iter().enumerate() { + // `forward_layer_batch_paged` ping-pongs: it writes the layer output + // into `bufs.prefill_bufs.hidden_out`, then swaps it with `hidden`. + // After the call `bufs.hidden` again holds the live residual stream + // (and `hidden_out` holds the now-free previous buffer). Both are + // pre-allocated, so the pointer swap is harmless across steps. + self.forward_layer_batch_paged( + layer_idx, + layer, + &mut bufs.hidden, + kv_buffer, + layout, + &bufs.plan, + &lora_groups, + &mut bufs.prefill_bufs, + )?; + if capture_layer_ids.get(next_capture) == Some(&layer_idx) { + ops::copy_hidden_rows_into( + ctx, + &bufs.hidden, + &mut bufs.captured_hidden, + next_capture * hidden_size, + )?; + next_capture += 1; + } + } + + // All-position logits: final RMS norm into `all_logits_normed`, then the + // lm_head GEMM into `all_logits` (both pre-allocated). Mirrors + // `compute_all_position_logits` but writes into fixed buffers. + ops::rms_norm_batch_into( + ctx, + &bufs.hidden, + &self.norm, + self.config().rms_norm_eps, + &mut bufs.all_logits_normed, + ); + ops::gemm_into( + ctx, + self.output_projection(), + &bufs.all_logits_normed, + &mut bufs.all_logits, + ); + + Ok(()) + } +} From 1620c0c5d40e66a1af7e55f8a70aa7d0ca7b92bb Mon Sep 17 00:00:00 2001 From: xiaguan <751080330@qq.com> Date: Tue, 23 Jun 2026 02:25:55 +0800 Subject: [PATCH 8/9] perf(qwen3-dflash): piecewise CUDA Graph for the verify forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Capture the verify forward's dense ops into per-segment CUDA Graphs and replay them, keeping attention eager. FlashInfer's paged-prefill attention freezes its KV-iteration count at capture time, so capturing it corrupts later tokens as the verify context grows (observed: garbage past ~token 60 once KV crosses a CTA_TILE_KV boundary). Every other op — embedding, RMSNorm, all GEMMs, SwiGLU, residual adds (~84% of the per-step launch gap) — is shape-stable in the fixed span-row layout and safe to replay. Split forward_layer_batch_paged into pre_attn / attn / post_attn so the verify path can interleave dense graph segments with eager attention; the normal batch_prefill path calls all three in sequence, behaviour unchanged. 5090, Qwen3-4B + DFlash-b16, greedy c1: 250.9 -> 274.3 tok/s (+9.3% from the graph alone; +15.7% over the pre-graph 237 baseline), matching vLLM dflash (278). Losslessness gate passes (only bf16 tie-flips). Co-Authored-By: Claude Opus 4.8 (1M context) --- openinfer-qwen3-4b/src/prefill.rs | 54 +++++- openinfer-qwen3-4b/src/verify_graph.rs | 236 +++++++++++++++++++------ 2 files changed, 237 insertions(+), 53 deletions(-) diff --git a/openinfer-qwen3-4b/src/prefill.rs b/openinfer-qwen3-4b/src/prefill.rs index f3417f5e..0d348247 100644 --- a/openinfer-qwen3-4b/src/prefill.rs +++ b/openinfer-qwen3-4b/src/prefill.rs @@ -146,10 +146,24 @@ impl Qwen3Model { lora_groups: &[DeviceLoraTokenGroup<'_>], bufs: &mut PrefillBuffers, ) -> Result<()> { - let num_heads = self.local_num_attention_heads(); - let num_kv_heads = self.local_num_key_value_heads(); - let head_dim = self.config.head_dim; + self.forward_layer_pre_attn(layer_idx, layer, hidden, lora_groups, bufs)?; + self.forward_layer_attn(layer_idx, layer, kv_buffer, layout, plan, bufs)?; + self.forward_layer_post_attn(layer_idx, layer, hidden, lora_groups, bufs)?; + Ok(()) + } + /// Pre-attention dense ops: input RMSNorm + fused QKV projections (+ LoRA). + /// Reads `hidden`; writes `bufs.normed` / `bufs.q_batch` / `bufs.k_batch` / + /// `bufs.v_batch`. Graph-safe — shapes depend only on the fixed row count, not + /// on KV length — so the verify piecewise CUDA Graph captures it. + pub(crate) fn forward_layer_pre_attn( + &self, + layer_idx: usize, + layer: &TransformerBlock, + hidden: &HiddenStates, + lora_groups: &[DeviceLoraTokenGroup<'_>], + bufs: &mut PrefillBuffers, + ) -> Result<()> { // 1. RMSNorm → bufs.normed ops::rms_norm_batch_into( &self.ctx, @@ -210,6 +224,26 @@ impl Qwen3Model { &mut bufs.v_batch, 0, )?; + Ok(()) + } + + /// The attention op: q/k norm + RoPE + paged KV append + paged attention. + /// This is the ONLY part of the layer whose KV iteration count tracks the + /// (growing) context length, so the verify piecewise graph keeps it EAGER — + /// capturing it would freeze the KV length at capture time (`num_iterations` + /// in FlashInfer's prefill kernel is fixed when the graph is recorded). + pub(crate) fn forward_layer_attn( + &self, + layer_idx: usize, + layer: &TransformerBlock, + kv_buffer: &cudarc::driver::CudaSlice, + layout: &openinfer_core::kv_pool::KvLayout, + plan: &PrefillPagedPlan, + bufs: &mut PrefillBuffers, + ) -> Result<()> { + let num_heads = self.local_num_attention_heads(); + let num_kv_heads = self.local_num_key_value_heads(); + let head_dim = self.config.head_dim; // 3. Paged prefill: norm+RoPE → append K/V to paged → batch attention ops::prefill_attention_paged_into( @@ -231,7 +265,21 @@ impl Qwen3Model { head_dim, self.config.rms_norm_eps, )?; + Ok(()) + } + /// Post-attention dense ops: O projection + residual + MLP + final residual + /// add. Reads `bufs.attn_output` and `hidden`; writes the layer output back + /// into `hidden` via the ping-pong buffer swap. Graph-safe (no KV-length + /// dependence) — captured into the verify piecewise CUDA Graph. + pub(crate) fn forward_layer_post_attn( + &self, + layer_idx: usize, + layer: &TransformerBlock, + hidden: &mut HiddenStates, + lora_groups: &[DeviceLoraTokenGroup<'_>], + bufs: &mut PrefillBuffers, + ) -> Result<()> { // 4. O projection → bufs.o_buf (as o_batch) ops::gemm_into( &self.ctx, diff --git a/openinfer-qwen3-4b/src/verify_graph.rs b/openinfer-qwen3-4b/src/verify_graph.rs index d7cff1f0..af6d4b39 100644 --- a/openinfer-qwen3-4b/src/verify_graph.rs +++ b/openinfer-qwen3-4b/src/verify_graph.rs @@ -9,25 +9,28 @@ //! capture moving pointers. //! //! [`VerifyGraphBuffers`] pre-allocates all of that once at the worst-case shape -//! (`max_batch * span` rows) and refills it in place each step. -//! [`Qwen3Model::batch_prefill_into`] is the buffer-reusing twin of -//! `batch_prefill`: it issues the *same* kernel sequence -//! (`forward_layer_batch_paged`, identical to the allocating path) so the verify -//! result stays bit-for-bit equivalent. This is the pre-requisite refactor for -//! capturing the verify forward into a CUDA Graph — this phase does the buffer -//! reuse only, no graph capture yet. +//! (`max_batch * span` rows) and refills it in place each step, then captures the +//! forward into a **piecewise** CUDA Graph: the dense ops (embedding, RMSNorm, +//! every GEMM, SwiGLU, residual adds — ~84% of the per-step kernel-launch gap) +//! are captured per segment and replayed, while the attention op runs EAGER +//! between segments. Attention must stay eager because FlashInfer's paged-prefill +//! kernel fixes its KV-iteration count when the graph is recorded; with the verify +//! context growing every step, a captured attention would under-read KV and +//! corrupt later tokens. The dense segments are shape-stable in the fixed +//! `span`-row layout, so one capture replays losslessly for the request's life. use anyhow::Result; use cudarc::driver::CudaSlice; +use openinfer_core::cuda_graph::CudaGraphState; use openinfer_core::kv_pool::KvLayout; use openinfer_core::ops; use openinfer_core::ops::PrefillPagedPlan; use openinfer_core::tensor::HiddenStates; use openinfer_kv_cache::KvView; +use crate::batch_decode_buffers::BATCH_BUCKETS; use crate::config::PREFILL_ATTENTION_CTA_TILE_Q; -use crate::lora::DeviceLoraTokenGroup; use crate::prefill::PrefillBuffers; use crate::weights::Qwen3Model; @@ -50,6 +53,11 @@ pub(crate) struct VerifyGraphBuffers { token_ids_d: CudaSlice, /// Paged-attention plan, refilled in place each step. plan: PrefillPagedPlan, + /// Piecewise CUDA Graphs: `graphs[bucket_idx][segment]`. Each bucket's verify + /// forward is split into `num_layers + 1` dense segments (attention runs eager + /// between them); a segment is captured once at its exact-bucket batch and + /// replayed thereafter. Empty (`CudaGraphState::new()`) until first captured. + graphs: Vec>, max_batch: usize, span: usize, } @@ -108,6 +116,15 @@ impl VerifyGraphBuffers { max_batch, max_tiles, )?, + // num_layers + 1 dense segments per bucket (attention is eager between). + graphs: BATCH_BUCKETS + .iter() + .map(|_| { + (0..model.config().num_hidden_layers + 1) + .map(|_| CudaGraphState::new()) + .collect() + }) + .collect(), max_batch, span, }) @@ -141,12 +158,13 @@ impl VerifyGraphBuffers { } impl Qwen3Model { - /// Buffer-reusing twin of [`Qwen3Model::batch_prefill`] for the DFlash verify - /// forward. Issues the identical kernel sequence - /// ([`Self::forward_layer_batch_paged`] is reused verbatim), so the all- - /// position logits and captured hidden states are bit-for-bit equal to the - /// allocating path; only the buffer *source* differs (reused vs. freshly - /// allocated). Results land in `bufs` (`all_logits()` / `captured_hidden()`). + /// Fixed-buffer, piecewise-CUDA-Graph twin of [`Qwen3Model::batch_prefill`] + /// for the DFlash verify forward. Issues the same per-op kernels as the + /// allocating path (split into `forward_layer_pre_attn` / `forward_layer_attn` + /// / `forward_layer_post_attn`), so the all-position logits and captured hidden + /// states match it; only the buffer *source* differs (reused vs. freshly + /// allocated) and the dense ops replay from a graph. Results land in `bufs` + /// (`all_logits()` / `captured_hidden()`). /// /// `capture_layer_ids` must be the strictly-increasing DFlash target layers /// whose count matches the `num_capture_layers` `bufs` was built with. @@ -199,8 +217,9 @@ impl Qwen3Model { bufs.set_rows(total_tokens); - // Embed the concatenated verify tokens directly into the fixed `hidden` - // buffer (device token staging, no host round-trip or allocation). + // --- prep: H2D staging that MUST stay outside the graph capture (CUDA + // Graph forbids host round-trips in a captured segment). The embedding + // kernel itself runs inside graph segment 0 and reads this buffer. --- let all_tokens: Vec = prompts.iter().flat_map(|p| p.iter().copied()).collect(); anyhow::ensure!( all_tokens.len() == total_tokens, @@ -209,10 +228,9 @@ impl Qwen3Model { ); let ctx = self.device_ctx(); // Stage the active tokens into the front of the fixed device buffer; the - // embedding kernel reads exactly `bufs.hidden.seq_len` (= total_tokens) - // ids from its base pointer, so the unused tail is never touched. + // embedding kernel reads exactly `total_tokens` ids from its base pointer, + // so the unused tail is never touched. ctx.stream.memcpy_htod(&all_tokens, &mut bufs.token_ids_d)?; - self.get_embeddings_batch_into(&bufs.token_ids_d, &mut bufs.hidden)?; // Refill the paged plan in place (same host math as the allocating path). let page_indices: Vec> = @@ -233,40 +251,138 @@ impl Qwen3Model { PREFILL_ATTENTION_CTA_TILE_Q, )?; - // Verify never uses LoRA — the draft/verify boundary is a plain token span. - let lora_groups: [DeviceLoraTokenGroup<'_>; 0] = []; - let mut next_capture = 0usize; - let hidden_size = self.config().hidden_size; - for (layer_idx, layer) in self.layers.iter().enumerate() { - // `forward_layer_batch_paged` ping-pongs: it writes the layer output - // into `bufs.prefill_bufs.hidden_out`, then swaps it with `hidden`. - // After the call `bufs.hidden` again holds the live residual stream - // (and `hidden_out` holds the now-free previous buffer). Both are - // pre-allocated, so the pointer swap is harmless across steps. - self.forward_layer_batch_paged( - layer_idx, - layer, - &mut bufs.hidden, - kv_buffer, - layout, - &bufs.plan, - &lora_groups, - &mut bufs.prefill_bufs, - )?; - if capture_layer_ids.get(next_capture) == Some(&layer_idx) { - ops::copy_hidden_rows_into( - ctx, - &bufs.hidden, - &mut bufs.captured_hidden, - next_capture * hidden_size, - )?; - next_capture += 1; + // --- piecewise CUDA Graph: dense ops captured per segment, attention + // EAGER between segments. FlashInfer's prefill attention freezes its KV + // iteration count at capture time (it tracks the growing context), so + // capturing it corrupts later tokens; every other op is shape-stable in + // the fixed `span`-row layout. Segments: [embed + L0.pre] [L0.attn] + // [L0.post + L1.pre] [L1.attn] ... [L_last.post + lm_head]. --- + let num_layers = self.layers.len(); + match BATCH_BUCKETS.iter().position(|&b| b == batch_size) { + Some(bidx) => { + // Take the bucket's segment graphs out so the capture closures can + // borrow `bufs` mutably; restore them after (even on error). + let mut segs = std::mem::take(&mut bufs.graphs[bidx]); + let result = (|| -> Result<()> { + segs[0].run_or_capture(ctx, || self.verify_seg_embed_pre(bufs))?; + self.verify_attn(0, kv_buffer, layout, bufs)?; + for i in 1..num_layers { + segs[i].run_or_capture(ctx, || { + self.verify_seg_post_pre(i, capture_layer_ids, bufs) + })?; + self.verify_attn(i, kv_buffer, layout, bufs)?; + } + segs[num_layers].run_or_capture(ctx, || { + self.verify_seg_post_logits(capture_layer_ids, bufs) + })?; + Ok(()) + })(); + bufs.graphs[bidx] = segs; + result?; + } + None => { + // Off-bucket batch: run the same segments eager (no capture). + self.verify_seg_embed_pre(bufs)?; + self.verify_attn(0, kv_buffer, layout, bufs)?; + for i in 1..num_layers { + self.verify_seg_post_pre(i, capture_layer_ids, bufs)?; + self.verify_attn(i, kv_buffer, layout, bufs)?; + } + self.verify_seg_post_logits(capture_layer_ids, bufs)?; } } - // All-position logits: final RMS norm into `all_logits_normed`, then the - // lm_head GEMM into `all_logits` (both pre-allocated). Mirrors - // `compute_all_position_logits` but writes into fixed buffers. + Ok(()) + } + + /// Graph segment 0: embedding (reads the staged `token_ids_d`) plus layer 0's + /// pre-attention dense ops. Verify never uses LoRA, so the LoRA group is empty. + fn verify_seg_embed_pre(&self, bufs: &mut VerifyGraphBuffers) -> Result<()> { + self.get_embeddings_batch_into(&bufs.token_ids_d, &mut bufs.hidden)?; + self.forward_layer_pre_attn( + 0, + &self.layers[0], + &bufs.hidden, + &[], + &mut bufs.prefill_bufs, + ) + } + + /// Eager attention for layer `i` — kept out of every graph (see + /// [`Self::forward_layer_attn`]). Touches only the fixed `prefill_bufs` and + /// the refilled `plan`. + fn verify_attn( + &self, + i: usize, + kv_buffer: &CudaSlice, + layout: &KvLayout, + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + self.forward_layer_attn( + i, + &self.layers[i], + kv_buffer, + layout, + &bufs.plan, + &mut bufs.prefill_bufs, + ) + } + + /// Middle graph segment `i` (`1..num_layers`): finish layer `i-1` + /// (post-attention dense + DFlash-context capture), then start layer `i` + /// (pre-attention dense). + /// + /// The ping-pong swap inside `post_attn` is graph-safe regardless of layer + /// parity: `run_or_capture` runs the closure (and thus the CPU-side swap) only + /// on the capture step, so the captured graph bakes the exact buffer pointers + /// for every op; replay just relaunches them. Each step's segment-0 embedding + /// overwrites the same baked buffer that this segment's first `post_attn` + /// reads, so no stale residual can leak across steps. The only live (eager) op + /// between segments — attention — touches just `q/k/v_batch` / `attn_output`, + /// which never participate in the swap, so it is independent of `hidden`'s + /// logical pointer. Parity only decides which physical buffer holds the final + /// hidden; that choice is baked into the last segment either way. + fn verify_seg_post_pre( + &self, + i: usize, + capture_layer_ids: &[usize], + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + let prev = i - 1; + self.forward_layer_post_attn( + prev, + &self.layers[prev], + &mut bufs.hidden, + &[], + &mut bufs.prefill_bufs, + )?; + self.verify_capture_if_needed(prev, capture_layer_ids, bufs)?; + self.forward_layer_pre_attn( + i, + &self.layers[i], + &bufs.hidden, + &[], + &mut bufs.prefill_bufs, + ) + } + + /// Final graph segment: finish the last layer (post-attention + capture), then + /// the all-position logits (final RMSNorm + lm_head GEMM) into `all_logits`. + fn verify_seg_post_logits( + &self, + capture_layer_ids: &[usize], + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + let last = self.layers.len() - 1; + self.forward_layer_post_attn( + last, + &self.layers[last], + &mut bufs.hidden, + &[], + &mut bufs.prefill_bufs, + )?; + self.verify_capture_if_needed(last, capture_layer_ids, bufs)?; + let ctx = self.device_ctx(); ops::rms_norm_batch_into( ctx, &bufs.hidden, @@ -280,7 +396,27 @@ impl Qwen3Model { &bufs.all_logits_normed, &mut bufs.all_logits, ); + Ok(()) + } + /// Copy layer `layer_idx`'s residual-stream hidden into the captured-hidden + /// buffer when that layer is a DFlash target. `capture_layer_ids` is strictly + /// increasing, so its position is the capture slot. + fn verify_capture_if_needed( + &self, + layer_idx: usize, + capture_layer_ids: &[usize], + bufs: &mut VerifyGraphBuffers, + ) -> Result<()> { + if let Some(slot) = capture_layer_ids.iter().position(|&l| l == layer_idx) { + let hidden_size = self.config().hidden_size; + ops::copy_hidden_rows_into( + self.device_ctx(), + &bufs.hidden, + &mut bufs.captured_hidden, + slot * hidden_size, + )?; + } Ok(()) } } From 4e03b0e4fad36b5a62ca3492d839f2ee6611f154 Mon Sep 17 00:00:00 2001 From: xiaguan <751080330@qq.com> Date: Tue, 23 Jun 2026 02:31:45 +0800 Subject: [PATCH 9/9] docs(qwen3-dflash): record piecewise verify CUDA Graph closing the c1 gap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The launch gap is dense-op-dominated (nsys: 84% dense GEMM/norm, 8% attention), not draft-specific — correcting the earlier "draft launch is the c1 gap" read. Full-forward capture is infeasible (FlashInfer paged-prefill attention freezes its KV-iteration count at record time; captured attention corrupts tokens past ~60 as verify context grows). The piecewise graph keeps attention eager and captures the dense segments. 5090 greedy: c1 250.9 -> 274.3 (+9.3% graph; ~vLLM 278), c8 1346 -> 1525, c16 ~flat. All batch sizes now at or above vLLM. Draft-side piecewise graph tracked as the next step. Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/index.md | 2 +- .../qwen3/dflash-speculative-decoding.md | 30 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/docs/index.md b/docs/index.md index ebf9e657..9ffd5baa 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,7 +30,7 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | `models/qwen3/roadmap.md` | Qwen3-4B roadmap (2026-06 review): line is the maturity bar; #220 RoPE OOB, batched greedy sampling (#307), mixed greedy/non-greedy sampling (#284), and pegaflow KV offload (#316) are landed; open set is zero TP coverage, zero-adapter-only LoRA gate, dropped prefix-cache observability, stale docs, and YaRN #8 follow-up. | | `models/qwen3/model-crate.md` | `openinfer-qwen3-4b` owns Qwen3 config/weights/executor/scheduler/tests/kernel plan; root sees generic `EngineHandle`; split-K retuned to `256/64`, with 4k/64 serving TPOT p50 at `6.46ms` on RTX 5090. | | `models/qwen3/prefix-cache.md` | Prefix caching on by default for Qwen3-4B: full-block kvbm radix matching at the executor, suffix-only prefill. Repeated ~1900-token prompt TTFT 141.8 → 16.3ms p50 (8.7×); warm TTFT ≈ TPOT + ~5ms setup. Includes the RoPE scalar-path corruption fix and the drain-the-stream TTFT measurement pitfall. | -| `models/qwen3/dflash-speculative-decoding.md` | DFlash speculative decoding behind `--dflash-draft-model-path`, modelled as an optimistic transaction (propose K → verify K+1 span → accept longest argmax prefix + 1 bonus → commit/roll back KV). Lossless up to bf16 tie-flips (bit-identical multi-token accepts; lm-eval gsm8k strict-match identical spec on/off). Single-stream decode 1.82× on 5070 Ti, 1.56× on 5090. Concurrent throughput fixed by batching the draft forward (5090 greedy c8 1346 / c16 1868, both beating vLLM 1240 / 1846; draft@b16 24.8→5.6 ms). c1 still trails vLLM (237 vs 278) on launch-bound draft overhead — accept measured equal (9.1% vs 8.85%, same drafter), so CUDA-Graph draft is tracked next. Proposer trait deferred to EAGLE. | +| `models/qwen3/dflash-speculative-decoding.md` | DFlash speculative decoding behind `--dflash-draft-model-path`, modelled as an optimistic transaction (propose K → verify K+1 span → accept longest argmax prefix + 1 bonus → commit/roll back KV). Lossless up to bf16 tie-flips (bit-identical multi-token accepts; lm-eval gsm8k strict-match identical spec on/off). Single-stream decode 1.82× on 5070 Ti, 1.56× on 5090. Concurrent throughput fixed by batching the draft forward, then a piecewise verify CUDA Graph (dense ops captured, attention eager) closed single-stream: 5090 greedy c1 274 ≈ vLLM 278, c8 1525 > 1240, c16 1834 ≈ 1846 — all batch sizes now ≥ vLLM. Accept measured equal (9.1% vs 8.85%, same drafter); draft-side piecewise graph tracked next. Proposer trait deferred to EAGLE. | | `models/qwen3/accuracy-gate.md` | Qwen3-4B instance of the logits golden gate (`tests/hf_golden_gate.rs`): 48 teacher-forced sequences / 816 positions vs a stored HF bf16 golden, replayed over bs=1 / batched eager / CUDA-graph. Strict guards: regret check + mean ≤ 0.06 + p99 ≤ 0.20; absolute max printed but not asserted (coverage-unstable). Methodology in `subsystems/correctness/`. | | `models/qwen3/kernels-crate.md` | Phase 1 split implemented and 5090-verified: Qwen3-4B kernel surface lives in `openinfer-kernels`; release build, test-target compile, accuracy gate, and bench snapshot pass. | | `models/qwen3/tp-design.md` | Qwen3 tensor-parallel design: `TP=2` milestone scope plus the controller/worker broadcast execution model, request identity, and coarse-grained step protocol for future TP/MoE work. | diff --git a/docs/models/qwen3/dflash-speculative-decoding.md b/docs/models/qwen3/dflash-speculative-decoding.md index a0e66cb6..fd67f33c 100644 --- a/docs/models/qwen3/dflash-speculative-decoding.md +++ b/docs/models/qwen3/dflash-speculative-decoding.md @@ -1,6 +1,6 @@ # DFlash Speculative Decoding (Qwen3-4B) -**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points; lm-eval gsm8k strict-match is identical with spec on vs off. Measured single-stream decode A/B: **1.82× on RTX 5070 Ti** (93.4 → 170.0 tok/s), **1.56× on RTX 5090** (168.9 → 263.2 tok/s). The drafter's out-of-pool footprint (weights + per-request KV/scratch) is reserved during memory profiling so the KV pool shrinks to fit instead of OOMing under load, and requests landing in the draft's final in-fill block are rejected cleanly at admission rather than crashing mid-prefill. **Concurrent throughput is now competitive after batching the draft forward.** The draft used to run a per-request **serial** `for` loop — launch-bound (a skip-attention A/B showed attention compute <2%, so 24.8 ms/step at batch 16 was almost all kernel-launch overhead), which inverted the single-stream win (c16 −59%). Batching the dense ops into one N×block pass drops draft@batch16 to **5.6 ms** and lifts 5090 greedy throughput to **c8 1346 / c16 1868 tok/s — both now beating vLLM's 1240 / 1846**. Single-stream (c1 237 vs vLLM 278) still trails, but **not from accept**: measured 9.1% vs vLLM's 8.85% with the *same* drafter rules out draft quality. The c1 gap is the draft's ~1 ms kernel-launch overhead — openinfer's spec path, unlike base decode, isn't CUDA-Graph captured. **CUDA-Graph draft is the tracked next step.** See Performance § for the A/B tables. +**TL;DR:** Qwen3-4B gains DFlash speculative decoding behind `--dflash-draft-model-path`. Speculative decode is modelled as an **optimistic transaction**: the DFlash drafter proposes K tokens, one target "verify" forward over the K+1 span confirms them, and we commit the longest argmax-matching prefix + 1 bonus token (rolling back the rest of the speculative KV). Greedy decode is **lossless up to bf16 numerical tie-flips** — the same non-determinism that already affects plain greedy decode at genuine bifurcation points; lm-eval gsm8k strict-match is identical with spec on vs off. Measured single-stream decode A/B: **1.82× on RTX 5070 Ti** (93.4 → 170.0 tok/s), **1.56× on RTX 5090** (168.9 → 263.2 tok/s). The drafter's out-of-pool footprint (weights + per-request KV/scratch) is reserved during memory profiling so the KV pool shrinks to fit instead of OOMing under load, and requests landing in the draft's final in-fill block are rejected cleanly at admission rather than crashing mid-prefill. **Concurrent throughput is now competitive after batching the draft forward.** The draft used to run a per-request **serial** `for` loop — launch-bound (a skip-attention A/B showed attention compute <2%, so 24.8 ms/step at batch 16 was almost all kernel-launch overhead), which inverted the single-stream win (c16 −59%). Batching the dense ops into one N×block pass drops draft@batch16 to **5.6 ms** and lifts 5090 greedy throughput to **c8 1346 / c16 1868 tok/s — both now beating vLLM's 1240 / 1846**. The single-stream gap is now closed by a **piecewise CUDA Graph** over the verify forward: 5090 greedy **c1 237 → 274 tok/s (+16%), matching vLLM dflash (278)**; c8 1346 → 1525 and c16 ≈ flat (no regression, both still ≥ vLLM's 1240 / 1846). Under greedy dflash the spec path ran with no CUDA Graph at all (the base-decode graph never fires), so nsys saw 1296 ms of GPU-idle launch gap — **~84% from dense-op kernel launches** (GEMM alone 48%), only 8% from attention. Capturing the *whole* forward fails: FlashInfer's paged-prefill attention freezes its KV-iteration count at capture time, so a captured attention under-reads the growing verify KV and corrupts tokens past ~60. The fix is **piecewise** — capture the dense ops (embedding, RMSNorm, every GEMM, SwiGLU, residual adds) into per-segment graphs and replay them, keeping **attention eager**. Accept is *not* the gap (9.1% vs vLLM's 8.85%, same drafter). Draft-side piecewise graph is the tracked next step. See Performance § for the A/B tables. Last touched: 2026-06 @@ -54,13 +54,13 @@ The speedup is smaller on the 5090: its higher memory bandwidth makes the baseli Under concurrent load the single-stream win inverts. Same greedy harness (`temperature=0`, sharegpt prompts, 128 out tokens), openinfer vs vLLM with the **same** DFlash-b16 drafter, RTX 5090: -| concurrency | OI plain | OI DFlash serial | **OI DFlash batched** | vLLM DFlash | -| --- | --- | --- | --- | --- | -| c1 | 170 | 245 | 237 | 278 | -| c8 | 1180 | 831 | **1346** | 1240 | -| c16 | 2277 | 1013 | **1868** | 1846 | +| concurrency | OI plain | OI DFlash serial | OI DFlash batched | **OI DFlash +graph** | vLLM DFlash | +| --- | --- | --- | --- | --- | --- | +| c1 | 170 | 245 | 237 | **274** | 278 | +| c8 | 1180 | 831 | 1346 | **1525** | 1240 | +| c16 | 2277 | 1013 | 1868 | **1834** | 1846 | -(tok/s, sharegpt out128, same-session serial-vs-batched A/B. The serial draft inverted the win — vLLM degraded gracefully while openinfer nearly halved; batching the draft restores it, c8/c16 now beat vLLM. c1 is unchanged — batch=1 has no batching win; see "Single-stream gap" below.) +(tok/s, sharegpt out128, RTX 5090, greedy. The serial draft inverted the win — vLLM degraded gracefully while openinfer nearly halved; batching the draft restored c8/c16 past vLLM. The **+graph** column adds the piecewise verify CUDA Graph (this branch): it closes c1 to vLLM and lifts c8, with c16 flat — see "Single-stream gap" below. Caveat: the serial→batched columns are a same-session A/B; the +graph column's c1 is a clean same-session A/B (251 fixed-buffer eager → 274 graph), while c8/c16 are single runs against the prior batched baseline, i.e. a no-regression check rather than a tight A/B.) #### Root cause (serial draft) and the fix (batched draft, landed) @@ -78,13 +78,19 @@ The serial draft scaled **exactly linearly** (`draft_x` 16.00 at batch 16); batc **The fix that landed: batch the whole draft forward (N requests in one pass), killing the N× launch.** `draft_logits_batched` (`dflash.rs`) runs the dense ops (rms_norm / GEMM / silu / MLP / embedding / logits) **once over an N×block buffer** — free, since cuBLAS takes any M and the ops are already row-batched (N×35 → 35 launches, no CUDA-kernel change). The varlen ops (tail concat / k·v GEMM / rope / KV-copy / attention) stay a per-request loop slicing the batched buffers at `row_offset = i·block_size` (the two DFlash-exclusive ops `dflash_qk_norm_rope_into` / `single_prefill_nhd_noncausal_into` advance the device pointer to the slice). A lane-level `DFlashBatchScratch` (sized for the largest batch bucket) replaces the per-request scratch — folding in the reservation-reclaim follow-up. Losslessness gate still passes (bf16 tie-flips only). -#### Single-stream gap (c1): launch-bound draft, not accept +#### Single-stream gap (c1): closed by a piecewise verify CUDA Graph + +c1 trailed vLLM (237 vs 278), and it is **not accept**: with the *same* drafter and spec config (vLLM `--speculative-config '{"method":"dflash",…,"num_speculative_tokens":16}'`), measured greedy accept is **9.1% (ours) vs 8.85% (vLLM)** — mean 1.29 vs 1.42 draft tokens/step, our pos-0 accept (60%) actually higher. 9% is the drafter's floor on sharegpt free-text (bimodal: structured spans accept up to 15, free text 0), shared by both engines. + +The gap is **launch exposure**. Under greedy dflash the spec path runs with *no* CUDA Graph: the base-decode graph (`execute_decode`) never fires because greedy routes through `SpeculativeDecode`, not `batch_decode`. nsys on c1 (single stream, fully serial): wall 7999 ms, GPU **busy** 6703 ms (memory-bound weight reads — irreducible, the same for vLLM), GPU **idle 1296 ms**, ~79% of it sub-3 µs kernel-launch gaps. Attributing that idle by kernel type (`gap_by_kernel.py`): **dense ops 84%** (GEMM 48%, RMSNorm 15%, embedding / KV-copy / residual-add / SwiGLU the rest), **attention only 8%**. Dense dominates because verify is 36 layers of GEMMs vs the draft's 5 — so the launch gap lives mostly in *verify*, not the draft. (This corrects the earlier read that pinned c1 on draft launch; that predated the batched draft and the per-kernel-type nsys breakdown.) + +**Why the whole forward can't go in one graph.** A first attempt captured the entire verify forward; output was correct up to ~token 60, then garbage. Root cause: FlashInfer's paged-prefill attention derives its KV-iteration count (`num_iterations`, from `kv_len`) and that loop bound is **frozen when the graph is recorded**. The verify context grows every step, so once it crosses the captured `CTA_TILE_KV` (~64) boundary the replayed attention under-reads KV. Base-decode's graph is safe only because its *decode* kernel's KV loop is purely device-driven; *prefill*'s is not. FlashInfer ships no graph-safe prefill variant; vLLM hits the same wall and keeps attention out of its piecewise cudagraph. -c1 trails vLLM (237 vs 278) and batching doesn't help it — batch=1 has nothing to batch. The cause is **not accept**: with the *same* drafter and spec config (vLLM `--speculative-config '{"method":"dflash",…,"num_speculative_tokens":16}'`), measured greedy accept is **9.1% (ours) vs 8.85% (vLLM)** — mean 1.29 vs 1.42 draft tokens/step, and our pos-0 accept (60%) is actually higher. 9% is the drafter's floor on sharegpt free-text (bimodal: structured spans accept up to 15, free text accepts 0), shared by both engines. +**The fix: piecewise graph.** Keep attention **eager**; capture only the dense ops, whose dims depend on the fixed `span` row count, never on KV length. `forward_layer_batch_paged` is split into `pre_attn` / `attn` / `post_attn`, and the verify forward becomes `num_layers+1` dense graph segments — `[embed+L0.pre] [L0.attn eager] [L0.post+L1.pre] … [L_last.post+lm_head]` — captured once per batch bucket and replayed (`verify_graph.rs`). The ping-pong residual swap sits inside the captured segments, so its pointer alternation is baked into each graph and reproduces on every replay regardless of layer parity: `run_or_capture` re-runs the CPU swap only on the capture step, and the one eager op (attention) touches just `q/k/v_batch` / `attn_output`, never the swapped `hidden`. -The real gap is in **step time**: normalized per output token, ours is 4.22 ms vs vLLM's 3.60 ms. c1 step = draft 1.55 ms + verify 7.9 ms. Verify is memory-bound (one 4B target forward, reads ~8 GB of weights) — irreducible and the same for vLLM. But the draft's 1.55 ms is **pure kernel-launch** (85 tiny 16-token kernels, ~18 µs/launch CPU enqueue, compute <2%), while vLLM's draft is CUDA-Graph captured (~0.5 ms). That ~1 ms/step is the c1 gap. +Result (5090, greedy, same-session A/B): fixed-buffer eager **250.9 → +graph 274.3 tok/s (+9.3% from the graph alone)**; 237 → 274 (+16%) over the pre-graph batched baseline, **matching vLLM's 278**. Concurrent (no-regression check): c8 1346 → 1525, c16 1868 → 1834 (both still ≥ vLLM). Losslessness gate passes (bf16 tie-flips only) — the dense ops replay bit-identically, and the eager attention is unchanged. -**Next step: CUDA-Graph the draft forward.** openinfer's whole spec path (`SpeculativeDraft` / `SpeculativeVerify`, `executor.rs`) is *not* CUDA-Graph captured, whereas base decode is (`execute_decode` split-graph cache). Capturing the draft is an architecture-level change — the per-request draft KV (`DFlashLayerCache`) needs pointer-stable buffers, the variable context length (0–16) needs a fixed/padded window, plus capture/replay wiring — predicted to drop draft@batch1 1.55→~0.3 ms (c1 → ~272, ≈ vLLM) and draft@batch16 5.6→~1.5 ms (c16 → ~2200). Tracked as its own PR after this one lands. +**Draft-side piecewise graph is the tracked next step.** The draft (5 layers, `dflash.rs`) is the other ~16% of the launch gap; it needs the same pre/attn/post split, with its variable-length contiguous KV (`DFlashLayerCache`) handled at the eager attention boundary. Tracked as its own PR after this one lands. The EAGLE proposer trait is still deferred to when EAGLE actually lands (see "no proposer trait yet" above); both the batched draft and the future CUDA-Graph draft are DFlash-internal changes behind the unchanged `DraftPlan→DraftResult` seam, so neither is thrown away by EAGLE. @@ -124,7 +130,7 @@ The fixed reservation lands exactly in the margin (+2822 MiB) and the per-token - **TP=1, primary rank only.** The DFlash lane runs on the worker thread of the primary rank; the launch path gates `!(dflash && lora)` and `tp_size == 1`. The server fails loud (`--dflash-draft-model-path` is rejected for non-Qwen3 model lines rather than silently ignored). - **Second proposer (EAGLE) → introduce the trait then.** A proposer trait wants two real methods to fix its shape; DFlash stays concrete until EAGLE lands. The batched-draft perf work (Performance §) is orthogonal — it's a DFlash-internal change behind the unchanged `DraftPlan→DraftResult` seam, so it can land before or with the trait without being redone, and EAGLE's autoregressive attention won't reuse DFlash's block attention anyway. - **File size.** `executor.rs` (~2.8k lines) and `scheduler/tests.rs` (~1.2k lines) exceed the 1k guideline; the spec arms in `execute_step_on_lane` are candidates to move into `executor/spec.rs` in a follow-up. -- **5090 validation — done** for correctness, single-stream, and concurrent throughput: `hf_golden_gate` passes (bs1 / batched / cuda-graph / tp2), the losslessness gate passes (before and after the draft batching), single-stream A/B is 1.56×, and lm-eval gsm8k parity holds (above). Concurrent throughput is **fixed** (Performance §): the serial draft inverted the win (c16 −59%), batching the draft restores **c8 1346 / c16 1868 — both beating vLLM**. Single-stream c1 (237 vs vLLM 278) still trails on launch-bound draft overhead; CUDA-Graph draft is tracked next. +- **5090 validation — done** for correctness, single-stream, and concurrent throughput: `hf_golden_gate` passes (bs1 / batched / cuda-graph / tp2), the losslessness gate passes (batched draft *and* piecewise verify graph), single-stream A/B is 1.56×, and lm-eval gsm8k parity holds (above). Concurrent throughput is **fixed** (Performance §): batching the draft beat vLLM at c8/c16, then the **piecewise verify CUDA Graph** closed c1 — **c1 274 ≈ vLLM 278, c8 1525 > 1240, c16 1834 ≈ 1846**, all batch sizes now at or above vLLM. Draft-side piecewise graph (the remaining ~16% launch gap) is tracked next. - **Frontend `seed`.** `/v1/completions` 400s on a per-request `seed` instead of accepting-and-ignoring it under greedy — surfaced while wiring lm-eval; orthogonal to spec decode, worth a small separate fix. - **Reclaim the per-request reservation.** The fixed term is dominated by the block-sized scratch billed per decode-batch slot (~6.5 MiB × 256 ≈ 1.6 GiB), most of it the transient logits buffer (`vocab × block`). The per-token term carries the context/pending scratch, which today reallocs to prompt length on the first draft and never shrinks (sized for `[0, context_len)` but only `[committed_len, +block]` is live). Both are billed conservatively because they were per-request. **Partly reclaimed by the batched-draft PR:** the per-request scratch is gone — `DFlashBatchScratch` is allocated once at lane level, sharing the transient logits/dense buffers across the whole batch instead of per request. The reservation accounting (`from_config`) still bills the old per-request upper bound (a safe over-estimate, so no OOM); tightening it to the now-smaller lane-level footprint is the remaining follow-up.