Skip to content

fix(qwen3): opt-in GEMM-N reduction-order pin#428

Merged
xiaguan merged 3 commits into
openinfer-project:mainfrom
FeathBow:fix/qwen3-414-batch-invariant-gemm
Jun 22, 2026
Merged

fix(qwen3): opt-in GEMM-N reduction-order pin#428
xiaguan merged 3 commits into
openinfer-project:mainfrom
FeathBow:fix/qwen3-414-batch-invariant-gemm

Conversation

@FeathBow

@FeathBow FeathBow commented Jun 20, 2026

Copy link
Copy Markdown
Contributor

Description

Refs #414

A request's next-token logits depend on which other requests share its batch. This adds an opt-in NumericPolicy (default Tuned, unchanged from today) that removes the two measured reduction-order batch dependencies removes the measured projection-GEMM N-dependence in the deterministic Qwen3 paths tested here. It pins two batch-coupled reductions one batch-coupled reduction:

Projection GEMMs: pin one cuBLASLt algo (fixed tile/stages/split-K factor) per weight shape (M,K), reused for every N, so the fp32-accumulate reduction order no longer changes with batch width. If the pinned algo cannot serve a live N, or would exceed the pinned workspace budget, the policy falls back to a per-token GEMM, which is itself batch-invariant.

Decode split-KV attention: derive the split chunk size from max_context_tokens (a process constant) instead of the live batch max, so a request's chunk count, and therefore its softmax-rescale order, depends only on its own sequence length.

Behavior

bf16-in / fp32-accumulate GEMMs reduce in an order cuBLAS picks as a function of N (the cuBLASLt plan is keyed on {M,N,K}); the split-KV chunk size was keyed on the batch's max_seq_len. Either shifts a shared row by about 1 ULP, which flips a near-tie argmax. Norm (per-row), embedding, and prefill attention (fixed cta_tile_q, no partition-KV) are already batch-invariant. Within the isolated deterministic path tested here (prefix cache and LoRA off, fixed sampling), these two are the measured reduction-order sources. Measured: down_proj split-K = 4 → 16 → 3 across N = 1 → 16 → 32~~; a 5000-token request decoded alone uses 64 chunks but 40 co-batched with an 8000-token request~~.

Test Env

Tests: 4 5 batch-invariance gates, each with a baseline-drift guard against a vacuous pass. GEMM seam: maxΔ=0 across the N-sweep. End-to-end and the two CUDA-graph guards: prompt_a's ordered top-K (token, logprob) table is bit-identical alone vs co-batched. End-to-end prefill and decode CUDA-graph gates check ordered top-K bit-identity on the tested GEMM-N axis; the unified-step gate covers mixed prefill+decode within-path GEMM-N; the envelope gate asserts served>0 and fallback=0 through the swept Qwen3 production envelope. Verified on both archs, eager and production CUDA-graph: Single GPU (x86_64, sm_89) and GH200 (aarch64, sm_90).

follow-ups

  • Opt-in, default-off (by design). The fix is verified to make decode batch-invariant when enabled(OPENINFER_NUMERIC_POLICY=pin|pertoken); it is gated off by default, so the production Tuned path stays byte-identical to today. Promoting it to the default is a separate decision: it carries a prefill perf cost and wants the blast-radius data (in the supplementary notes) first, and is not part of delivering the fix.
  • Set the policy before CUDA-graph capture. The decode graph cache key excludes the numeric policy, and numeric_policy() lazily reads OPENINFER_NUMERIC_POLICY, so the first capture of each bucket bakes whatever policy is active. Deployment is expected to set the env before startup, before any capture; the tests set it before executor construction. Switching after a capture replays the captured policy's behavior, held by convention rather than by API.
  • Global switch. NumericPolicy lives in the shared GEMM layer, so the GEMM pin applies to any model routing through launch_gemm; the split-KV chunk fix and the gates are qwen3-specific.
  • TP=1. The pin gives invariance at a fixed TP degree (fixed NCCL order), not TP-size (TP=1 vs TP=2) invariance, which is a separate problem.
  • Perf. The pin is about free low-overhead at decode and 2.6 to 4.5× on prefill-heavy shapes about 2.6–4.5× at the per-GEMM level; whole-prefill e2e is ~1.8–1.85× on sm_89 and ~1.9–2.1× on sm_90. the hybrid (pin + per-token fallback) fits decode-dominated serving. The trade-off and the cuBLASLt-pin-vs-custom-kernel comparison are in the supplementary notes.

Supplementary notes for #414 (for reference)

Background for the scope and perf decisions in the PR. Everything here is measured. Hardware labels: Single GPU (x86_64, sm_89) and GH200 (aarch64, sm_90).

Blast radius: how often a near-tie actually flips

On a realistic workload (12 prompts of 24 tokens plus 16 decode steps), GH200:

  • positions whose logits moved when co-batched: 204 / 204
  • near-ties (top1 vs top2 logit gap below 0.20 nat): 7 / 204
  • argmax flips: 1 / 204, that is 1 of the 7 near-ties

So co-batching perturbs essentially every position by about 1 ULP, and roughly 0.5% sit close enough to a tie to flip the argmax. The flip rate is set by how many positions land in a near-tie. Treat this as an illustrative single run, an order-of-magnitude figure rather than a reproducible constant.

Alternatives measured

Two complementary measurements.

Eager per-GEMM is the un-diluted kernel ratio, each divided by its own cuBLAS, over down_proj / o_proj / qkv:

  • cuBLASLt pin at decode (N ≤ 64): about 1.0× on both archs (sm_89 0.96 to 1.14, sm_90 1.00 to 1.02 at N ≤ 32; the worst small-batch point measured is 1.35× at o_proj N=64).
  • The two open-source batch-invariant reference kernels, TML's batch_invariant_ops and the llm_reproducibility (TBIK) repo, both verified batch-invariant, hit a decode cliff: 2 to 4× on sm_89 (N=1 spikes to about 9×) and 3 to 32× on GH200 at N ≤ 64. A persistent kernel wastes its grid at small N and recovers only at prefill.
  • The per-token oracle runs 5 to 50×, useful only as the hybrid's fallback leg.
methods_chart

Production CUDA-graph end-to-end step latency, pin vs baseline, same harness, batch in {1,4,8,16,32,64}: about 1.0 to 1.05× for batch ≤ 32, up to 1.23× at batch 64, with fallback=0 at every tested batch. This e2e ratio is diluted toward 1.0 by common-mode host work (per-step sync, full-vocab argmax) that is identical in both arms, so it is the floor and the eager per-GEMM number is the un-diluted ceiling.

Prefill costs about 2.6 to 4.5× under the pin, which is why it is opt-in and decode-targeted.
The 2.6–4.5× number is the per-GEMM ceiling; whole-prefill e2e is ~1.8–1.85× on sm_89 and ~1.9–2.1× on sm_90.

The attention side has a narrow-batch cost in the split-KV prong. We did not bench it separately. It is characterized from the same tradeoff between chunk count and SM occupancy that our production attention tuning already measured. It costs latency in one narrow regime: long context (over 4096 tokens) at small batch (about 1 to 8). Pinning the chunk size to a process constant gives a below-max-length request fewer chunks, so fewer attention CTAs and lower SM occupancy at small batch. It washes out by batch 8, and it only applies at batch ≤ 32, since above that the step uses the non-partitioned attention path, with no chunking and no cost.

Caveats

  • The graph correctness guard covers decode buckets up to 256 (bit-identical, zero fallback). The production end-to-end throughput sweep only runs to batch 64, so throughput at the large buckets is unmeasured; a fallback there would cost throughput, not correctness.
  • Determinism is stable across restarts within a fixed cuBLAS and driver version. Cross-version is untested.
  • All tests force the prefix cache off, to isolate reduction order. With the cache on, a cache-hit versus cold-prefill request still drifts by bf16 ULPs, which is separate and filed separately.
  • CUDA-graph replay is validated bit-identical on both archs. The pin rebuilds its transient B/C cuBLASLt layouts per call, so under capture a layout is freed before replay; this held on both archs (cuBLASLt appears to snapshot the layout at capture). Persisting the layouts per bucket is a hardening follow-up.

Out of scope (separate batch-dependences, not addressed here)

Prefix-cache KV provenance, sampling at temperature above 0 (Philox is keyed on the batch row index), LoRA delta-kernel N-dependence.

There is also the decode attention-path selection. NonPartition versus SplitKv keys on the batch's live max_seq_len, not on policy, so a request whose context is below 1024, co-batched with one at 1024 or more, runs a different attention kernel. That is a residual batch-variance of the same ULP class, and it survives the Pin. It is self-limiting, since it goes away once the request's own context passes 1024, and it is confined to batch ≤ 32. A clean follow-up if batch-invariant decode is ever promised by default.

@xiaguan xiaguan self-requested a review June 20, 2026 17:18
@xiaguan

xiaguan commented Jun 20, 2026

Copy link
Copy Markdown
Collaborator

I like the direction, but I think we should narrow the first step.

Instead of making the numeric policy affect the whole shared GEMM path, can we first target one concrete subproblem: Qwen3 decode projection GEMMs only? Prefill should keep the current tuned cuBLASLt path; the 2.6-4.5x prefill cost is too high for a path we may eventually want to enable by default.

I don’t think we need to solve every batch-dependence in one PR. A good first PR would be: one measured source, one narrow code path, batch-invariance proven by the existing gates, and decode perf within an acceptable regression window. Then split-KV chunking and other residual sources can be separate follow-ups.

@xiaguan xiaguan self-assigned this Jun 20, 2026
@FeathBow

Copy link
Copy Markdown
Contributor Author

Thanks — you're right, the gates don't go far enough: none exercises the unified (mixed prefill+decode) step, where a decode row shares a projection GEMM with co-scheduled prefill chunks, so the under-load case isn't actually gated. That gap is on me — adding it now.

I'm narrowing this PR to one source and have cross-arch measurements (Qwen3-4B on sm_89 and GH200/sm_90; commands + numbers below):

  • split-KV → its own follow-up. I'm carving this PR down to GEMM-N only; split-KV chunk pinning is a separate source (rescale order, not GEMM reduction) and carves cleanly.
  • Adding the missing gates in PR1 (measured already, not yet committed): a unified-step gate (a decode row's top-K bit-identical across unified steps with different co-scheduled prefill chunk sizes, i.e. invariant to the Unified GEMM-N within that path, under the pin) + a production-envelope fallback==0 guard. Both pass on sm_89 and GH200/sm_90.
  • decode perf in window: Pin/Tuned ~1.00–1.06× (batch ≤32) and ~1.09× (batch 64) on sm_89 (3 runs); ~1.00–1.08× on GH200/sm_90 — near-free on both.
  • but two residuals survive the pin — attention-path selection (NonPartition↔SplitKv) and the decode-vs-unified path (~0.11–0.18 nat, sm_89; the within-path GEMM-N gates pass on both archs, but these cross-path/kernel-selection residuals remain).

Two places I'd lean toward keeping the current approach — but both are open to discussion:

  • On scoping — I'd lean toward keeping the policy process-global/default-off for PR1. In OpenInfer serving one process loads one model (detect_model_type → one launch), and this matches the process-level deterministic-mode pattern used elsewhere. Threading a per-call policy through the shared GEMM stack would be a larger signature change without improving the property being tested, so I'd rather document the coverage gap (typed_ops bypasses the seam) instead — but if you feel strongly about scoping it to qwen3, happy to discuss.
  • On prefill — I don't think it can stay tuned if the goal is qwen3: batched decode is not batch-invariant — a request's tokens depend on its batch-mates #414. Leaving prefill tuned would reopen the original prefill N-dependence: 414's own repro is a prefill flip ([9707] 1→27, "the divergence happens at prefill"), and with chunked prefill a prompt's chunks land in unified steps where prefill+decode share one projection GEMM. The cost is e2e ~1.8–1.85× (sm_89) / ~1.9–2.1× (sm_90) — the 2.6–4.5× is the un-diluted per-GEMM ratio (sm_89 2.6–2.7× / sm_90 up to 4.5×), not e2e — and since it's default-off that cost is only paid when enabled. Open to your view here too.

Plan — none of this has landed in the PR yet; I'll push the revised PR1 shortly: PR1 = GEMM-N pin + the two new gates + fallback observability; PR2 = split-KV + attention-path residual; decode↔unified path / cross-cuBLAS-version / uniform coverage = follow-ups.

Repro (Qwen3-4B, Single GPU x86_64/sm_89):

# existing gate, reproducible on the branch now:
OPENINFER_TEST_MODEL_PATH=<Qwen3-4B-base> \
  cargo test --release -p openinfer-qwen3-4b --test batch_invariance_endtoend -- --nocapture
# perf / unified-step / production-envelope fallback harnesses land as PR1's gates.

Measured on sm_89 (3 runs) and GH200/sm_90 — summary, with raw output in the details below:

  • gates green on both archs: unified within-path bit-eq, production-envelope served>0/fb=0 through N=1279, carve-safe (decode_gemm_graph/endtoend still pass)
  • decode near-free on both; prefill e2e ~1.8–1.85× (sm_89) / ~1.9–2.1× (sm_90)
  • residuals (sm_89, the tracked follow-ups): decode↔unified path top-K maxΔ 0.107 nat; attention-path (NonPartition↔SplitKv) maxΔ 0.180 nat
raw gate + perf output — sm_89 and GH200 (sm_90)
Single GPU (sm_89)
  unified   Pin N=101/201/513/1024: served=253 fb=0 bit-eq-vs-N101=ok ; Tuned drifts ; result ok (1 passed, 1 ignored)
  envelope  Unified N=101/201/513/1024/1279 + pure-Decode bs=256: served=253 fallback=0 ; result ok
  decode_gemm_graph ok ; endtoend ok ; kernels batch_invariance_gemm ok
  perf decode  (ctx=512, Pin/Tuned, 3 runs): bs1-32 <=1.06x ; bs64 1.075 / 1.090 / 1.110x
  perf prefill (Pin/Tuned, best-of-5): len1024 ~1.80x ; 2048 ~1.85x ; 4096 ~1.83x (len512 baseline-noisy)

GH200 (sm_90)
  unified   Pin N=101/201/513/1024: served=253 fb=0 bit-eq-vs-N101=ok ; Tuned drifts ; result ok (1 passed, 1 ignored)
  envelope  Unified N=101/201/513/1024/1279 + pure-Decode bs=256: served=253 fallback=0 ; result ok
  decode_gemm_graph ok ; endtoend ok ; kernels batch_invariance_gemm ok
  perf decode  (ctx=512, Pin/Tuned): bs1 1.01 / bs8 1.01 / bs16 1.00 / bs32 1.01 / bs64 1.08x
  perf prefill (Pin/Tuned, best-of-5): len512 2.07 / 1024 2.07 / 2048 1.97 / 4096 1.89x

@FeathBow FeathBow force-pushed the fix/qwen3-414-batch-invariant-gemm branch 2 times, most recently from 24876bd to ad6f16a Compare June 21, 2026 01:00
@FeathBow FeathBow changed the title fix(qwen3): opt-in batch-invariant decode path fix(qwen3): opt-in GEMM-N reduction-order pin Jun 21, 2026
@FeathBow FeathBow marked this pull request as draft June 21, 2026 01:16
@FeathBow FeathBow force-pushed the fix/qwen3-414-batch-invariant-gemm branch 3 times, most recently from 933802e to c90e2d9 Compare June 21, 2026 01:48
@FeathBow FeathBow marked this pull request as ready for review June 21, 2026 01:48
@xiaguan

xiaguan commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Thanks, this is much clearer. Splitting split-KV out and adding the unified-step + fallback-envelope gates sounds like the right next step.

@xiaguan

xiaguan commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

One thing I want to clarify about the longer-term direction: I’d like batch-invariant behavior to become the default for most kernels when the perf cost is small enough. For paths where the cost is material, I’m fine keeping it opt-in, but I’d prefer an explicit server/config surface such as --batch-invariant rather than an environment variable. Env vars make important serving behavior implicit and harder to reason about in production.

@FeathBow

Copy link
Copy Markdown
Contributor Author

Sounds good, agreed on making this an explicit serving knob.

I’ll replace the env var with a --batch-invariant flag. The server will only pass a bool; Qwen3 will map it to NumericPolicy before executor construction / CUDA graph capture and log the resolved policy. PR1 stays default-off.

@xiaguan

xiaguan commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Direction looks good — the pin mechanism is careful, and the gates are solid (especially the baseline-drift guard so a Pin pass can't be vacuous). Plan is to land this as the opt-in, default-off foundation. Two things before merge:

1. Rebase onto current main. This now conflicts with #430 (profiled KV sizing), which touched the same launch-options plumbing — Qwen3LaunchOptions, the start_engine_* signatures, and config.rs Args. Both PRs add a field to the same structs, so it's a mechanical but real conflict.

2. Open a tracking issue for the workstream. #428 covers exactly one divergence source: the decode projection-GEMM N dimension. It deliberately doesn't cover the rest yet — path-selection residuals (per the flag's own help text) and the decode↔unified cross-path drift that survives the pin (your #[ignore]'d test). Let's give the whole thing one home:

  • remaining sources: prefill-side GEMM invariance, path-selection residuals, attention / sampling / all-reduce as needed;
  • the real bar: an end-to-end test that is bit-identical under concurrent / batched load — the same request must produce the same tokens regardless of what it gets batched with — verified at the output level, not just per-kernel.

That keeps #428 honest as PR1 (decode GEMM-N only) and gives the remaining work a place to live.

@FeathBow

FeathBow commented Jun 21, 2026

Copy link
Copy Markdown
Contributor Author

One gate is failing after the rebase. I’ll look into it first. done.

@xiaguan xiaguan merged commit 39a5935 into openinfer-project:main Jun 22, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants