Skip to content

fix(qwen3): opt-in split-KV chunk-count reduction-order pin#438

Open
FeathBow wants to merge 1 commit into
openinfer-project:mainfrom
FeathBow:fix/qwen3-435-splitkv-chunk
Open

fix(qwen3): opt-in split-KV chunk-count reduction-order pin#438
FeathBow wants to merge 1 commit into
openinfer-project:mainfrom
FeathBow:fix/qwen3-435-splitkv-chunk

Conversation

@FeathBow

Copy link
Copy Markdown
Contributor

Description

Refs #414, #435

Follow-up to #428. A request's split-KV decode output depends on which other requests share its batch: the split chunk size is derived from the live batch max_seq_len, so co-batching a request with a longer one changes that request's chunk count and therefore the online-softmax rescale order over its split-KV partials. Under the opt-in NumericPolicy (default Tuned, unchanged from today), split_chunk_size keys the chunk size on max_context_tokens, a process constant, so a request's chunk count, and therefore its rescale order, depends only on its own sequence length.

It covers one measured source from #435 (the split-KV chunk count). It does not address the other sources, such as the batch-dependent choice of decode attention kernel (NonPartition vs SplitKv) or decode/unified routing.

Behavior

bf16 split-KV decode accumulates per-chunk partials and merges them with an online-softmax rescale whose order is the chunk count. Keyed on the live batch max_seq_len, that order can perturb the bf16 result enough to flip a near-tie argmax. Keying the chunk size on the model context limit makes the count request-local; for Qwen3-4B it resolves to ceil(max_position_embeddings / 64) = 640 tokens, the finest split the 64-chunk workspace allows. The default path (--batch-invariant off, Tuned) is unchanged.

Test Env

A new gate, batch_invariance_decode_splitkv_graph, prefills request A alone so its KV is fixed, then co-batches it with a shorter then a longer B so that only the decode max_seq_len moves: A's Tuned chunk count goes 64 -> 40 across the SplitKv graph's capture and replay. A runtime counter (split_kv_counters) lets the gate assert the chunk size directly: it drifts under Tuned (79 -> 126) and is constant under Pin/PerToken (640); A's top-K follows (Tuned drifts, the baseline-drift guard against a vacuous pass; Pin and PerToken bit-identical). Existing no-regression gates also pass. Verified on Single GPU (x86_64, sm_89) and GH200 (aarch64, sm_90).

OPENINFER_TEST_MODEL_PATH=<Qwen3-4B-base> cargo test --release -p openinfer-qwen3-4b --test batch_invariance_decode_splitkv_graph -- --nocapture
[baseline ] path=SplitKv capture@max_seq=5001->replay@max_seq=8001 A_prefill=isolated(batch1) decode_GEMM_N=2(fixed) runtime_chunk_tokens(C=79,R=126) decode_topk_eq=false chunk_eq=false
[pin      ] ... runtime_chunk_tokens(C=640,R=640) decode_topk_eq=true chunk_eq=true
[per_token] ... runtime_chunk_tokens(C=640,R=640) decode_topk_eq=true chunk_eq=true
RESULT chunk_eq baseline=false pin=true per_token=true | decode_topk_eq baseline=false pin=true per_token=true | first_token_eq baseline=true pin=true per_token=true

Decode perf

vllm bench serve, random dataset, output-len 128, median TPOT. To separate this PR from #428's GEMM-N pin, three builds: Tuned (no flag), PR1-only (--batch-invariant with the chunk pin reverted to max_seq_len, so #428's GEMM-N pin alone), and PR2-Pin (both). Isolated chunk-pin cost is PR2-Pin / PR1-only - 1:

arch bs ctx256 ctx512 ctx1024 ctx2048 ctx4096
sm_89 1 +2.3% +5.9% +6.8% +8.4% +8.1%
sm_89 8 +2.9% +6.6% +2.4% -1.0% -0.8%
sm_89 32 +0.1% -0.3% -0.1% +0.5% -0.4%
sm_90 1 +11.4% +23.3% +24.8% +24.6% +31.4%
sm_90 8 +12.6% +17.5% +22.3% +22.3% +16.7%
sm_90 32 +7.3% +18.5% -7.6% +4.3% -2.1%

The cost is the fixed split forgoing SplitKv's batch-aware SM fill, a larger fraction on GH200 because its base decode is faster (bs=1 about 4.2 to 4.4ms vs sm_89 9.8 to 10.7ms across these contexts). It washes out by bs=8 on sm_89; on GH200 it persists into bs=8. At bs=32 the result is dominated by #428's GEMM-N pin and the isolated chunk delta is noisy: near-zero on sm_89, mixed on sm_90, and much smaller than the full --batch-invariant cost.

Raw median TPOT (ms)
sm_89 (x86_64)        ctx256  ctx512  ctx1024 ctx2048 ctx4096
tuned   bs1            9.82    9.94    9.97    10.29   10.68
pr1only bs1            10.54   10.47   10.58   10.83   11.18
pr2pin  bs1            10.78   11.09   11.30   11.74   12.09
tuned   bs8            11.13   11.73   13.84   16.96   25.95
pr1only bs8            11.81   12.57   15.93   20.32   33.15
pr2pin  bs8            12.15   13.40   16.31   20.11   32.89
tuned   bs32           16.78   19.21   27.80   44.76   127.50
pr1only bs32           20.24   23.22   36.39   62.06   201.36
pr2pin  bs32           20.27   23.15   36.36   62.39   200.59

sm_90 (GH200)         ctx256  ctx512  ctx1024 ctx2048 ctx4096
tuned   bs1            4.20    4.23    4.29    4.33    4.42
pr1only bs1            4.02    4.04    4.11    4.15    4.23
pr2pin  bs1            4.48    4.98    5.13    5.17    5.56
tuned   bs8            4.59    4.70    5.02    6.16    8.89
pr1only bs8            4.51    4.80    5.48    7.07    10.83
pr2pin  bs8            5.08    5.64    6.70    8.65    12.64
tuned   bs32           5.59    5.87    8.58    13.24   23.30
pr1only bs32           5.90    6.07    10.23   18.30   33.48
pr2pin  bs32           6.33    7.19    9.45    19.09   32.79

Checklist

  • My code follows the style guidelines of this project (see docs/conventions/coding-style.md).
  • I have performed a self-review of my own code.
  • I have formatted my commits according to Commitizen conventions.
  • I have run the local test suite and all tests pass (see CLAUDE.md).

@FeathBow FeathBow force-pushed the fix/qwen3-435-splitkv-chunk branch from afbdb9d to 24a5bc1 Compare June 22, 2026 21:41

@xiaguan xiaguan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix is clean and the gate test is well isolated (baseline-drift guard, fixed GEMM-N=2, isolated A prefill). One change before merge.

Keep the split-KV runtime counters out of the production decode path. SPLIT_KV_STEPS / SPLIT_KV_LAST_CHUNK_TOKENS, the fetch_add/store in sync_split_kv_meta, split_kv_counters/reset_split_kv_counters, and their runtime:: re-export are test observability sitting in the hot path — please drop them.

The test doesn't need them. A's prefill is isolated (same first_token) and GEMM-N is fixed at 2, so the only capture-vs-replay variable is the decode max_seq_len → A's split chunk count. That makes the existing top-K assertions (!tuned_tk drift, pin/per_token bit-identical) a closed proof on their own. For the chunk-count guard, make split_chunk_size_for pub (it's pure arithmetic — legitimate API, not scaffolding) and assert it directly:

// Tuned keys the chunk size on batch max_seq_len, so it differs C-vs-R
assert_ne!(split_chunk_size_for(A_LEN + 1), split_chunk_size_for(B_LEN + 1));

Net: zero test scaffolding in the default build, same proof strength.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants