fix(qwen3): opt-in GEMM-N reduction-order pin#428
Conversation
|
I like the direction, but I think we should narrow the first step. Instead of making the numeric policy affect the whole shared GEMM path, can we first target one concrete subproblem: Qwen3 decode projection GEMMs only? Prefill should keep the current tuned cuBLASLt path; the 2.6-4.5x prefill cost is too high for a path we may eventually want to enable by default. I don’t think we need to solve every batch-dependence in one PR. A good first PR would be: one measured source, one narrow code path, batch-invariance proven by the existing gates, and decode perf within an acceptable regression window. Then split-KV chunking and other residual sources can be separate follow-ups. |
|
Thanks — you're right, the gates don't go far enough: none exercises the unified (mixed prefill+decode) step, where a decode row shares a projection GEMM with co-scheduled prefill chunks, so the under-load case isn't actually gated. That gap is on me — adding it now. I'm narrowing this PR to one source and have cross-arch measurements (Qwen3-4B on sm_89 and GH200/sm_90; commands + numbers below):
Two places I'd lean toward keeping the current approach — but both are open to discussion:
Plan — none of this has landed in the PR yet; I'll push the revised PR1 shortly: PR1 = GEMM-N pin + the two new gates + fallback observability; PR2 = split-KV + attention-path residual; decode↔unified path / cross-cuBLAS-version / uniform coverage = follow-ups. Repro (Qwen3-4B, Single GPU x86_64/sm_89): # existing gate, reproducible on the branch now:
OPENINFER_TEST_MODEL_PATH=<Qwen3-4B-base> \
cargo test --release -p openinfer-qwen3-4b --test batch_invariance_endtoend -- --nocapture
# perf / unified-step / production-envelope fallback harnesses land as PR1's gates.Measured on sm_89 (3 runs) and GH200/sm_90 — summary, with raw output in the details below:
raw gate + perf output — sm_89 and GH200 (sm_90) |
24876bd to
ad6f16a
Compare
933802e to
c90e2d9
Compare
|
Thanks, this is much clearer. Splitting split-KV out and adding the unified-step + fallback-envelope gates sounds like the right next step. |
|
One thing I want to clarify about the longer-term direction: I’d like batch-invariant behavior to become the default for most kernels when the perf cost is small enough. For paths where the cost is material, I’m fine keeping it opt-in, but I’d prefer an explicit server/config surface such as |
|
Sounds good, agreed on making this an explicit serving knob. I’ll replace the env var with a |
|
Direction looks good — the pin mechanism is careful, and the gates are solid (especially the baseline-drift guard so a Pin pass can't be vacuous). Plan is to land this as the opt-in, default-off foundation. Two things before merge: 1. Rebase onto current main. This now conflicts with #430 (profiled KV sizing), which touched the same launch-options plumbing — 2. Open a tracking issue for the workstream. #428 covers exactly one divergence source: the decode projection-GEMM N dimension. It deliberately doesn't cover the rest yet — path-selection residuals (per the flag's own help text) and the decode↔unified cross-path drift that survives the pin (your
That keeps #428 honest as PR1 (decode GEMM-N only) and gives the remaining work a place to live. |
16c8861 to
6543c3f
Compare
|
|
Description
Refs #414
A request's next-token logits depend on which other requests share its batch. This adds an opt-in NumericPolicy (default Tuned, unchanged from today) that
removes the two measured reduction-order batch dependenciesremoves the measured projection-GEMM N-dependence in the deterministic Qwen3 paths tested here. It pinstwo batch-coupled reductionsone batch-coupled reduction:Projection GEMMs: pin one cuBLASLt algo (fixed tile/stages/split-K factor) per weight shape (M,K), reused for every N, so the fp32-accumulate reduction order no longer changes with batch width. If the pinned algo cannot serve a live N, or would exceed the pinned workspace budget, the policy falls back to a per-token GEMM, which is itself batch-invariant.
Decode split-KV attention: derive the split chunk size from max_context_tokens (a process constant) instead of the live batch max, so a request's chunk count, and therefore its softmax-rescale order, depends only on its own sequence length.Behavior
bf16-in / fp32-accumulate GEMMs reduce in an order cuBLAS picks as a function of N (the cuBLASLt plan is keyed on {M,N,K})
; the split-KV chunk size was keyed on the batch's max_seq_len.Either shifts a shared row by about 1 ULP, which flips a near-tie argmax. Norm (per-row), embedding, and prefill attention (fixed cta_tile_q, no partition-KV) are already batch-invariant. Within the isolated deterministic path tested here (prefix cache and LoRA off, fixed sampling),these two are the measured reduction-order sources. Measured: down_proj split-K = 4 → 16 → 3 across N = 1 → 16 → 32~~; a 5000-token request decoded alone uses 64 chunks but 40 co-batched with an 8000-token request~~.Test Env
Tests:
45 batch-invariance gates, each with a baseline-drift guard against a vacuous pass. GEMM seam: maxΔ=0 across the N-sweep.End-to-end and the two CUDA-graph guards: prompt_a's ordered top-K (token, logprob) table is bit-identical alone vs co-batched.End-to-end prefill and decode CUDA-graph gates check ordered top-K bit-identity on the tested GEMM-N axis; the unified-step gate covers mixed prefill+decode within-path GEMM-N; the envelope gate asserts served>0 and fallback=0 through the swept Qwen3 production envelope. Verified on both archs, eager and production CUDA-graph: Single GPU (x86_64, sm_89) and GH200 (aarch64, sm_90).follow-ups
The fix is verified to make decode batch-invariant when enabled(OPENINFER_NUMERIC_POLICY=pin|pertoken); it is gated off by default, so the production Tuned path stays byte-identical to today. Promoting it to the default is a separate decision: it carries a prefill perf cost and wants the blast-radius data (in the supplementary notes) first, and is not part of delivering the fix.the split-KV chunk fix andthe gates are qwen3-specific.about freelow-overhead at decode and2.6 to 4.5× on prefill-heavy shapesabout 2.6–4.5× at the per-GEMM level; whole-prefill e2e is ~1.8–1.85× on sm_89 and ~1.9–2.1× on sm_90. the hybrid (pin + per-token fallback) fits decode-dominated serving. The trade-off and the cuBLASLt-pin-vs-custom-kernel comparison are in the supplementary notes.Supplementary notes for #414 (for reference)
Background for the scope and perf decisions in the PR. Everything here is measured. Hardware labels: Single GPU (x86_64, sm_89) and GH200 (aarch64, sm_90).
Blast radius: how often a near-tie actually flips
On a realistic workload (12 prompts of 24 tokens plus 16 decode steps), GH200:
So co-batching perturbs essentially every position by about 1 ULP, and roughly 0.5% sit close enough to a tie to flip the argmax. The flip rate is set by how many positions land in a near-tie. Treat this as an illustrative single run, an order-of-magnitude figure rather than a reproducible constant.
Alternatives measured
Two complementary measurements.
Eager per-GEMM is the un-diluted kernel ratio, each divided by its own cuBLAS, over down_proj / o_proj / qkv:
Production CUDA-graph end-to-end step latency, pin vs baseline, same harness, batch in {1,4,8,16,32,64}: about 1.0 to 1.05× for batch ≤ 32, up to 1.23× at batch 64, with fallback=0 at every tested batch. This e2e ratio is diluted toward 1.0 by common-mode host work (per-step sync, full-vocab argmax) that is identical in both arms, so it is the floor and the eager per-GEMM number is the un-diluted ceiling.
Prefill costs about 2.6 to 4.5× under the pin, which is why it is opt-in and decode-targeted.The 2.6–4.5× number is the per-GEMM ceiling; whole-prefill e2e is ~1.8–1.85× on sm_89 and ~1.9–2.1× on sm_90.
The attention side has a narrow-batch cost in the split-KV prong. We did not bench it separately. It is characterized from the same tradeoff between chunk count and SM occupancy that our production attention tuning already measured. It costs latency in one narrow regime: long context (over 4096 tokens) at small batch (about 1 to 8). Pinning the chunk size to a process constant gives a below-max-length request fewer chunks, so fewer attention CTAs and lower SM occupancy at small batch. It washes out by batch 8, and it only applies at batch ≤ 32, since above that the step uses the non-partitioned attention path, with no chunking and no cost.Caveats
Out of scope (separate batch-dependences, not addressed here)
Prefix-cache KV provenance, sampling at temperature above 0 (Philox is keyed on the batch row index), LoRA delta-kernel N-dependence.
There is also the decode attention-path selection. NonPartition versus SplitKv keys on the batch's live max_seq_len, not on policy, so a request whose context is below 1024, co-batched with one at 1024 or more, runs a different attention kernel. That is a residual batch-variance of the same ULP class, and it survives the Pin. It is self-limiting, since it goes away once the request's own context passes 1024, and it is confined to batch ≤ 32. A clean follow-up if batch-invariant decode is ever promised by default.