fix(qwen3): opt-in split-KV chunk-count reduction-order pin#438
fix(qwen3): opt-in split-KV chunk-count reduction-order pin#438FeathBow wants to merge 1 commit into
Conversation
afbdb9d to
24a5bc1
Compare
xiaguan
left a comment
There was a problem hiding this comment.
The fix is clean and the gate test is well isolated (baseline-drift guard, fixed GEMM-N=2, isolated A prefill). One change before merge.
Keep the split-KV runtime counters out of the production decode path. SPLIT_KV_STEPS / SPLIT_KV_LAST_CHUNK_TOKENS, the fetch_add/store in sync_split_kv_meta, split_kv_counters/reset_split_kv_counters, and their runtime:: re-export are test observability sitting in the hot path — please drop them.
The test doesn't need them. A's prefill is isolated (same first_token) and GEMM-N is fixed at 2, so the only capture-vs-replay variable is the decode max_seq_len → A's split chunk count. That makes the existing top-K assertions (!tuned_tk drift, pin/per_token bit-identical) a closed proof on their own. For the chunk-count guard, make split_chunk_size_for pub (it's pure arithmetic — legitimate API, not scaffolding) and assert it directly:
// Tuned keys the chunk size on batch max_seq_len, so it differs C-vs-R
assert_ne!(split_chunk_size_for(A_LEN + 1), split_chunk_size_for(B_LEN + 1));Net: zero test scaffolding in the default build, same proof strength.
Description
Refs #414, #435
Follow-up to #428. A request's split-KV decode output depends on which other requests share its batch: the split chunk size is derived from the live batch
max_seq_len, so co-batching a request with a longer one changes that request's chunk count and therefore the online-softmax rescale order over its split-KV partials. Under the opt-inNumericPolicy(defaultTuned, unchanged from today),split_chunk_sizekeys the chunk size onmax_context_tokens, a process constant, so a request's chunk count, and therefore its rescale order, depends only on its own sequence length.It covers one measured source from #435 (the split-KV chunk count). It does not address the other sources, such as the batch-dependent choice of decode attention kernel (NonPartition vs SplitKv) or decode/unified routing.
Behavior
bf16 split-KV decode accumulates per-chunk partials and merges them with an online-softmax rescale whose order is the chunk count. Keyed on the live batch
max_seq_len, that order can perturb the bf16 result enough to flip a near-tie argmax. Keying the chunk size on the model context limit makes the count request-local; for Qwen3-4B it resolves toceil(max_position_embeddings / 64) = 640tokens, the finest split the 64-chunk workspace allows. The default path (--batch-invariantoff,Tuned) is unchanged.Test Env
A new gate,
batch_invariance_decode_splitkv_graph, prefills request A alone so its KV is fixed, then co-batches it with a shorter then a longer B so that only the decodemax_seq_lenmoves: A's Tuned chunk count goes 64 -> 40 across the SplitKv graph's capture and replay. A runtime counter (split_kv_counters) lets the gate assert the chunk size directly: it drifts under Tuned (79 -> 126) and is constant under Pin/PerToken (640); A's top-K follows (Tuned drifts, the baseline-drift guard against a vacuous pass; Pin and PerToken bit-identical). Existing no-regression gates also pass. Verified on Single GPU (x86_64, sm_89) and GH200 (aarch64, sm_90).Decode perf
vllm bench serve, random dataset, output-len 128, median TPOT. To separate this PR from #428's GEMM-N pin, three builds:Tuned(no flag),PR1-only(--batch-invariantwith the chunk pin reverted tomax_seq_len, so #428's GEMM-N pin alone), andPR2-Pin(both). Isolated chunk-pin cost isPR2-Pin / PR1-only - 1:The cost is the fixed split forgoing SplitKv's batch-aware SM fill, a larger fraction on GH200 because its base decode is faster (bs=1 about 4.2 to 4.4ms vs sm_89 9.8 to 10.7ms across these contexts). It washes out by bs=8 on sm_89; on GH200 it persists into bs=8. At bs=32 the result is dominated by #428's GEMM-N pin and the isolated chunk delta is noisy: near-zero on sm_89, mixed on sm_90, and much smaller than the full
--batch-invariantcost.Raw median TPOT (ms)
Checklist
docs/conventions/coding-style.md).CLAUDE.md).