Skip to content

feat(qwen3-4b-dflash): add Qwen3-4B-DFlash draft model crate#439

Open
kitty-eu-org wants to merge 7 commits into
openinfer-project:mainfrom
fagao-ai:codex/qwen3-dflash-cache-gate
Open

feat(qwen3-4b-dflash): add Qwen3-4B-DFlash draft model crate#439
kitty-eu-org wants to merge 7 commits into
openinfer-project:mainfrom
fagao-ai:codex/qwen3-dflash-cache-gate

Conversation

@kitty-eu-org

Copy link
Copy Markdown

Description

Add openinfer-qwen3-4b-dflash crate supporting the z-lab/Qwen3-4B-DFlash-b16 draft model, laying the groundwork for future speculative decoding integration. The crate is behind its own cargo feature and does not affect existing model crates.

What's included:

  • New model crate (openinfer-qwen3-4b-dflash): config parsing, exact-key safetensor weight loading, draft forward with HF remote-code parity gate, unified DFlashDraftCache (mirrors reference DynamicCache lifecycle), draft-only batch executor/scheduler with FCFS exact-shape batching, single-instance batch buffers, bounded max_caches eviction with drop_cache, and join-on-drop scheduler shutdown aligned with EngineHandle / qwen3 drop_request.
  • Shared kernels (openinfer-kernels): non-causal dense/ragged-batch prefill attention, K-only norm+RoPE variant, fused strided_segment_copy.
  • Correctness gates (hf_golden_gate.rs, 8 tests): HF golden forward parity (uncached/unified-cache/draft-cache), batch-vs-single, executor/scheduler smoke, cache control (fail-closed / drop_cache / max_caches).
  • Docs & tools: docs/models/qwen3/dflash.md, Python accuracy scripts (golden dump, forward bench, drafter-substitution comparison).

Performance (RTX 5070 Ti, CUDA 13.3, ctx_len=2, q_len=16): bs32 reaches 75,148 draft tok/s (9.0× scaling from bs1), with two perf passes — fused strided copy (1.5× at bs32) and K-only norm+RoPE for context-K.

Out of scope: target hidden extraction, target verification, acceptance, fallback token selection, OpenAI serving — these belong to the future speculative-decoding controller.

Type of Change

  • New feature (non-breaking change which adds functionality)

Checklist

  • My code follows the style guidelines of this project (see docs/conventions/coding-style.md).
  • I have performed a self-review of my own code.
  • I have formatted my commits according to Commitizen conventions.
  • I have run the local test suite and all tests pass (see CLAUDE.md).

kitty-eu-org and others added 6 commits June 18, 2026 19:30
Aligns the DFlash executor with the rest of the project (Qwen3
BatchDecodeBuffers, Kimi/DeepSeek scratch): one allocation sized for
the worst case, narrowed per forward via set_active_shape, instead of
the unique-per-crate HashMap<(batch, q, ctx), buffer> cache that grew
a fresh GPU buffer set for every unseen (q_len, ctx_len) combo.

* DFlashExecutor now holds a single DFlashBatchBuffers, allocated once
  in load() for (max_batch_size, max_q_len, max_step_context_len).
  New DFlashExecutorOptions.max_q_len gates the q-axis capacity.
* set_active_batch(bs) -> set_active_shape(bs, q_len, ctx_len); both
  forward paths derive the shape from the requests themselves so
  callers no longer pre-set it.
* prepare_ragged_plan cache key now covers (batch_size, q_len, ctx_len)
  — with a single instance the shape can change between forwards, so
  keying only on batch_size would reuse a stale plan.
* compact_host_inputs stitches all requests on the host and uploads
  noise/target with one H2D each (was one launch per request per
  tensor), matching Qwen3's sync_paged_meta upload pattern.
* Compact NoCache paths materialize the owned output via a single
  clone_batch_output dtod instead of zeros + copy_hidden over the full
  span.

Tests/benches keep working: create_batch_buffers keeps its 3-positional
signature (now (max_batch, max_q, max_ctx)), and test struct literals
gain the max_q_len field.
…qwen3-4b

The DFlash draft scheduler leaked two resources on long-running use: the
GPU-owner thread (no JoinHandle, no shutdown) and per-request draft caches
(grow-only HashMap, each carrying full ForwardBuffers + per-layer past K/V).
Both are now bounded, mirroring qwen3-4b's EngineHandle / drop_request patterns.

Scheduler shutdown: DFlashSchedulerHandle now wraps Arc<Inner> holding an
Option<JoinHandle>. The last clone's Drop closes the channel (the scheduler
loop drains pending requests via send_stopped) and joins the thread, matching
openinfer-engine EngineHandle::Drop. Dropping the handle without an explicit
shutdown no longer leaks the thread.

Cache eviction: DFlashExecutorOptions gains max_caches (default 64). A new
drop_cache(id) — exposed on both executor and scheduler — removes a request's
cache and lets RAII free the GPU buffers. It is idempotent (a missing cache is
not an error), matching qwen3's drop_request. Over-cap admission fails closed
until a retired request's cache is dropped.

Cleanup: remove submit_with_enqueued_ack, which sent its ack from the caller
thread (not the scheduler) and only proved the message entered the channel
buffer — unbounded-channel FIFO already guarantees the ordering it claimed to.
The batch exact-shape validator now fully checks the first request and only
shape-matches the rest, instead of re-running the full validator per request.

Gate: adds dflash_cache_drop_releases_and_capacity_fails_closed covering
drop_cache release + idempotency and max_caches fail-closed/reuse. HF golden
deltas unchanged (mean=0.034243, p99=0.125000, max=0.500000, n=7680); 8 tests
pass.
…d copy kernel

The batch forward path built the ragged-attention K/V layout [ctx | noise]
per request by looping memcpy_dtod over each request: 2 * batch_size copies
per K/V tensor per layer. At bs=32 that is 128 launches/layer (640 per
forward), and at ~5us CPU launch overhead each this dominated the bs32
latency budget.

Add strided_segment_copy_kernel (csrc/shared/elementwise.cu): one launch
copies an entire batch's segment (all requests' ctx rows, or all noise rows)
from a contiguous source into the strided per-request destination layout.
Each layer now issues 4 launches (k_ctx, k_noise, v_ctx, v_noise) instead of
2 * batch_size * 2, collapsing the bs32 per-layer count from 128 to 4.

Result (RTX 5070 Ti, WSL, ctx_len=2, q_len=16):
  bs 8:  4.36ms -> 3.34ms (1.31x)
  bs16:  6.85ms -> 4.70ms (1.46x)
  bs32: 12.17ms -> 8.18ms (1.49x)
  bs1->bs32 throughput: 5.3x -> 8.1x (7.7K -> 62.6K draft tok/s)

HF golden deltas unchanged (mean=0.034243, p99=0.125000, max=0.500000,
n=7680); batch-vs-single stays at mean=0.000000. 8 tests pass.
The batch path's context-hidden K projection needs RMSNorm + RoPE, but it has
no corresponding Q — the draft Q comes only from the noise tokens. The code
reused the joint qk_norm_rope kernel with a scratch Q buffer whose result was
immediately discarded. For Qwen3-4B's 16:4 GQA ratio that wasted 80% of the
kernel's work (num_q_heads of every num_q_heads + num_kv_heads blocks) on a
dead Q branch.

Add k_norm_rope_batched_decode_cuda: same per-head RMSNorm + RoPE logic but
launches only num_kv_heads blocks per token, restricted to the K tensor. Wire
it into the batch context-K path.

HF golden deltas unchanged (mean=0.034243, p99=0.125000, max=0.500000,
n=7680); batch-vs-single stays at mean=0.000000. 8 tests pass.

ctx_len=32 bs32: 9.50ms -> 9.26ms (+2.6%); benefit scales with ctx_len
since the dead-Q work grows with context length.
…che-gate

# Conflicts:
#	openinfer-core/src/ops.rs
#	openinfer-kernels/src/ops.rs

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a66604f9f2

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +132 to +135
"--warmup",
str(args.warmup),
"--iters",
str(args.iters),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Forward benchmark dimensions to the Rust runner

When this script is run with non-default --ctx-len or --q-len, the fixture written above uses those dimensions, but this command launches qwen3_dflash_forward_bench without forwarding them, so the Rust binary keeps its own defaults (2/16) and rejects the fixture shape instead of producing the OpenInfer side of the comparison. Please pass the same --ctx-len/--q-len values here.

Useful? React with 👍 / 👎.

The Python forward bench script generated its fixture with the caller's
--ctx-len/--q-len but launched the Rust runner without forwarding them, so
the runner kept its defaults (2/16) and rejected the fixture shape for any
non-default dimension.

Pass --ctx-len/--q-len through to the runner, and make the Rust fixture path
derive ctx_len/q_len from the fixture's actual tensor shapes so the two sides
agree regardless of which flags the caller repeats.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant