feat(qwen3-4b-dflash): add Qwen3-4B-DFlash draft model crate#439
feat(qwen3-4b-dflash): add Qwen3-4B-DFlash draft model crate#439kitty-eu-org wants to merge 7 commits into
Conversation
Aligns the DFlash executor with the rest of the project (Qwen3 BatchDecodeBuffers, Kimi/DeepSeek scratch): one allocation sized for the worst case, narrowed per forward via set_active_shape, instead of the unique-per-crate HashMap<(batch, q, ctx), buffer> cache that grew a fresh GPU buffer set for every unseen (q_len, ctx_len) combo. * DFlashExecutor now holds a single DFlashBatchBuffers, allocated once in load() for (max_batch_size, max_q_len, max_step_context_len). New DFlashExecutorOptions.max_q_len gates the q-axis capacity. * set_active_batch(bs) -> set_active_shape(bs, q_len, ctx_len); both forward paths derive the shape from the requests themselves so callers no longer pre-set it. * prepare_ragged_plan cache key now covers (batch_size, q_len, ctx_len) — with a single instance the shape can change between forwards, so keying only on batch_size would reuse a stale plan. * compact_host_inputs stitches all requests on the host and uploads noise/target with one H2D each (was one launch per request per tensor), matching Qwen3's sync_paged_meta upload pattern. * Compact NoCache paths materialize the owned output via a single clone_batch_output dtod instead of zeros + copy_hidden over the full span. Tests/benches keep working: create_batch_buffers keeps its 3-positional signature (now (max_batch, max_q, max_ctx)), and test struct literals gain the max_q_len field.
…qwen3-4b The DFlash draft scheduler leaked two resources on long-running use: the GPU-owner thread (no JoinHandle, no shutdown) and per-request draft caches (grow-only HashMap, each carrying full ForwardBuffers + per-layer past K/V). Both are now bounded, mirroring qwen3-4b's EngineHandle / drop_request patterns. Scheduler shutdown: DFlashSchedulerHandle now wraps Arc<Inner> holding an Option<JoinHandle>. The last clone's Drop closes the channel (the scheduler loop drains pending requests via send_stopped) and joins the thread, matching openinfer-engine EngineHandle::Drop. Dropping the handle without an explicit shutdown no longer leaks the thread. Cache eviction: DFlashExecutorOptions gains max_caches (default 64). A new drop_cache(id) — exposed on both executor and scheduler — removes a request's cache and lets RAII free the GPU buffers. It is idempotent (a missing cache is not an error), matching qwen3's drop_request. Over-cap admission fails closed until a retired request's cache is dropped. Cleanup: remove submit_with_enqueued_ack, which sent its ack from the caller thread (not the scheduler) and only proved the message entered the channel buffer — unbounded-channel FIFO already guarantees the ordering it claimed to. The batch exact-shape validator now fully checks the first request and only shape-matches the rest, instead of re-running the full validator per request. Gate: adds dflash_cache_drop_releases_and_capacity_fails_closed covering drop_cache release + idempotency and max_caches fail-closed/reuse. HF golden deltas unchanged (mean=0.034243, p99=0.125000, max=0.500000, n=7680); 8 tests pass.
…d copy kernel The batch forward path built the ragged-attention K/V layout [ctx | noise] per request by looping memcpy_dtod over each request: 2 * batch_size copies per K/V tensor per layer. At bs=32 that is 128 launches/layer (640 per forward), and at ~5us CPU launch overhead each this dominated the bs32 latency budget. Add strided_segment_copy_kernel (csrc/shared/elementwise.cu): one launch copies an entire batch's segment (all requests' ctx rows, or all noise rows) from a contiguous source into the strided per-request destination layout. Each layer now issues 4 launches (k_ctx, k_noise, v_ctx, v_noise) instead of 2 * batch_size * 2, collapsing the bs32 per-layer count from 128 to 4. Result (RTX 5070 Ti, WSL, ctx_len=2, q_len=16): bs 8: 4.36ms -> 3.34ms (1.31x) bs16: 6.85ms -> 4.70ms (1.46x) bs32: 12.17ms -> 8.18ms (1.49x) bs1->bs32 throughput: 5.3x -> 8.1x (7.7K -> 62.6K draft tok/s) HF golden deltas unchanged (mean=0.034243, p99=0.125000, max=0.500000, n=7680); batch-vs-single stays at mean=0.000000. 8 tests pass.
The batch path's context-hidden K projection needs RMSNorm + RoPE, but it has no corresponding Q — the draft Q comes only from the noise tokens. The code reused the joint qk_norm_rope kernel with a scratch Q buffer whose result was immediately discarded. For Qwen3-4B's 16:4 GQA ratio that wasted 80% of the kernel's work (num_q_heads of every num_q_heads + num_kv_heads blocks) on a dead Q branch. Add k_norm_rope_batched_decode_cuda: same per-head RMSNorm + RoPE logic but launches only num_kv_heads blocks per token, restricted to the K tensor. Wire it into the batch context-K path. HF golden deltas unchanged (mean=0.034243, p99=0.125000, max=0.500000, n=7680); batch-vs-single stays at mean=0.000000. 8 tests pass. ctx_len=32 bs32: 9.50ms -> 9.26ms (+2.6%); benefit scales with ctx_len since the dead-Q work grows with context length.
…che-gate # Conflicts: # openinfer-core/src/ops.rs # openinfer-kernels/src/ops.rs
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a66604f9f2
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| "--warmup", | ||
| str(args.warmup), | ||
| "--iters", | ||
| str(args.iters), |
There was a problem hiding this comment.
Forward benchmark dimensions to the Rust runner
When this script is run with non-default --ctx-len or --q-len, the fixture written above uses those dimensions, but this command launches qwen3_dflash_forward_bench without forwarding them, so the Rust binary keeps its own defaults (2/16) and rejects the fixture shape instead of producing the OpenInfer side of the comparison. Please pass the same --ctx-len/--q-len values here.
Useful? React with 👍 / 👎.
The Python forward bench script generated its fixture with the caller's --ctx-len/--q-len but launched the Rust runner without forwarding them, so the runner kept its defaults (2/16) and rejected the fixture shape for any non-default dimension. Pass --ctx-len/--q-len through to the runner, and make the Rust fixture path derive ctx_len/q_len from the fixture's actual tensor shapes so the two sides agree regardless of which flags the caller repeats.
Description
Add
openinfer-qwen3-4b-dflashcrate supporting thez-lab/Qwen3-4B-DFlash-b16draft model, laying the groundwork for future speculative decoding integration. The crate is behind its own cargo feature and does not affect existing model crates.What's included:
openinfer-qwen3-4b-dflash): config parsing, exact-key safetensor weight loading, draft forward with HF remote-code parity gate, unifiedDFlashDraftCache(mirrors referenceDynamicCachelifecycle), draft-only batch executor/scheduler with FCFS exact-shape batching, single-instance batch buffers, boundedmax_cacheseviction withdrop_cache, and join-on-drop scheduler shutdown aligned withEngineHandle/ qwen3drop_request.openinfer-kernels): non-causal dense/ragged-batch prefill attention, K-only norm+RoPE variant, fusedstrided_segment_copy.hf_golden_gate.rs, 8 tests): HF golden forward parity (uncached/unified-cache/draft-cache), batch-vs-single, executor/scheduler smoke, cache control (fail-closed /drop_cache/max_caches).docs/models/qwen3/dflash.md, Python accuracy scripts (golden dump, forward bench, drafter-substitution comparison).Performance (RTX 5070 Ti, CUDA 13.3,
ctx_len=2,q_len=16): bs32 reaches 75,148 draft tok/s (9.0× scaling from bs1), with two perf passes — fused strided copy (1.5× at bs32) and K-only norm+RoPE for context-K.Out of scope: target hidden extraction, target verification, acceptance, fallback token selection, OpenAI serving — these belong to the future speculative-decoding controller.
Type of Change
Checklist
docs/conventions/coding-style.md).CLAUDE.md).