diff --git a/openinfer-qwen3-4b/src/batch_decode_buffers.rs b/openinfer-qwen3-4b/src/batch_decode_buffers.rs index 5b8c9fc7..6a0a3857 100644 --- a/openinfer-qwen3-4b/src/batch_decode_buffers.rs +++ b/openinfer-qwen3-4b/src/batch_decode_buffers.rs @@ -6,6 +6,7 @@ use cudarc::driver::CudaSlice; use openinfer_core::cuda_graph::CudaGraphState; use openinfer_core::tensor::{DeviceContext, HiddenStates}; +use openinfer_kernels::ops::{NumericPolicy, numeric_policy}; use openinfer_kv_cache::KvView; /// Bucket sizes for CUDA Graph capture. Actual batch is padded to the nearest bucket. @@ -34,6 +35,11 @@ const SPLIT_KV_CHUNK_TOKENS: usize = 64; const SPLIT_KV_MAX_CHUNKS_PER_REQUEST: usize = 64; const SPLIT_KV_MAX_BATCH_SIZE: usize = 32; +/// Chunk size bounding a `basis`-token request to `SPLIT_KV_MAX_CHUNKS_PER_REQUEST` chunks. +pub fn split_chunk_size_for(basis: usize) -> usize { + SPLIT_KV_CHUNK_TOKENS.max(basis.div_ceil(SPLIT_KV_MAX_CHUNKS_PER_REQUEST)) +} + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(crate) enum DecodeAttentionPath { NonPartition, @@ -110,6 +116,9 @@ pub(crate) struct BatchDecodeBuffers { pub(crate) split_tmp_s: CudaSlice, pub(crate) split_padded_slots: usize, max_seq_len: usize, + /// Model context limit (`max_position_embeddings`) — the `Pin` split-KV chunk basis (see + /// `split_chunk_size`). + max_context_tokens: usize, /// Padding page index for bucket CUDA Graph. Padding slots point here. padding_page_id: i32, @@ -137,6 +146,7 @@ impl BatchDecodeBuffers { max_total_pages: usize, padding_page_id: i32, num_qo_heads: usize, + max_context_tokens: usize, ) -> Result { let bs = max_batch_size; // The split-KV path is gated on padded_bs <= SPLIT_KV_MAX_BATCH_SIZE, @@ -178,6 +188,7 @@ impl BatchDecodeBuffers { split_tmp_s: ctx.stream.alloc_zeros(max_split_slots * num_qo_heads)?, split_padded_slots: 0, max_seq_len: 0, + max_context_tokens, padding_page_id, graphs: BATCH_BUCKETS .iter() @@ -263,6 +274,17 @@ impl BatchDecodeBuffers { Ok(()) } + /// The chunk count sets the online-softmax rescale order, hence the bf16 result. `Pin`/`PerToken` + /// key the chunk-size basis on the `max_context_tokens` constant, so the count does not vary with + /// the batch. + fn split_chunk_size(&self) -> usize { + let basis = match numeric_policy() { + NumericPolicy::Tuned => self.max_seq_len, + NumericPolicy::Pin | NumericPolicy::PerToken => self.max_context_tokens, + }; + split_chunk_size_for(basis) + } + fn sync_split_kv_meta( &mut self, ctx: &DeviceContext, @@ -274,8 +296,7 @@ impl BatchDecodeBuffers { if padded_bs > SPLIT_KV_MAX_BATCH_SIZE { return Ok(()); } - let split_chunk_size = - SPLIT_KV_CHUNK_TOKENS.max(self.max_seq_len.div_ceil(SPLIT_KV_MAX_CHUNKS_PER_REQUEST)); + let split_chunk_size = self.split_chunk_size(); let split_padded_slots = padded_bs * SPLIT_KV_MAX_CHUNKS_PER_REQUEST; let mut split_request_indices = Vec::with_capacity(split_padded_slots); let mut split_kv_tile_indices = Vec::with_capacity(split_padded_slots); @@ -285,7 +306,12 @@ impl BatchDecodeBuffers { for (request_idx, kv) in kv_views.iter().enumerate() { let chunks = kv.seq_len().div_ceil(split_chunk_size).max(1); - debug_assert!(chunks <= SPLIT_KV_MAX_CHUNKS_PER_REQUEST); + anyhow::ensure!( + chunks <= SPLIT_KV_MAX_CHUNKS_PER_REQUEST, + "split-KV chunk count {chunks} exceeds workspace bound {SPLIT_KV_MAX_CHUNKS_PER_REQUEST} \ + (seq_len={}, split_chunk_size={split_chunk_size}) — context limit misconfigured", + kv.seq_len() + ); for chunk_idx in 0..chunks { split_request_indices.push(request_idx as i32); split_kv_tile_indices.push(chunk_idx as i32); diff --git a/openinfer-qwen3-4b/src/batch_decode_trace.rs b/openinfer-qwen3-4b/src/batch_decode_trace.rs index 462b7b8e..22793bd9 100644 --- a/openinfer-qwen3-4b/src/batch_decode_trace.rs +++ b/openinfer-qwen3-4b/src/batch_decode_trace.rs @@ -88,6 +88,7 @@ pub fn trace_decode_kernel_calls( kv_mgr.pool().total_blocks(), kv_mgr.pool().padding_block_id(), model.local_num_attention_heads(), + model.config().max_position_embeddings, )?; let token_ids = vec![0_u32; batch_size]; let views: Vec<_> = rkvs.iter().map(|r| r.decode_view()).collect(); diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 95fe3d4b..4cc5518d 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -514,6 +514,11 @@ fn tune_decode_gemm_algos(model: &Qwen3Model) -> Result<()> { gemm_lt_pin_warmup(intermediate, hidden)?; gemm_lt_pin_warmup(hidden, intermediate)?; gemm_lt_pin_warmup(vocab, hidden)?; + let max_context = model.config().max_position_embeddings; + log::info!( + "Qwen3 split-KV decode chunk pinned: {} tokens (max_context_tokens={max_context})", + crate::batch_decode_buffers::split_chunk_size_for(max_context) + ); return Ok(()); } @@ -2088,6 +2093,7 @@ impl LocalQwen3Lane { total_blocks, padding_block_id, model.local_num_attention_heads(), + model.config().max_position_embeddings, )?; let sample_scratch = openinfer_sample::SampleScratch::new( model.device_ctx(), diff --git a/openinfer-qwen3-4b/src/lib.rs b/openinfer-qwen3-4b/src/lib.rs index e0d82010..618e6b99 100644 --- a/openinfer-qwen3-4b/src/lib.rs +++ b/openinfer-qwen3-4b/src/lib.rs @@ -117,6 +117,7 @@ impl Default for Qwen3OffloadOptions { /// This is the production phase boundary used by the Qwen3 scheduler and by /// model-local benchmarks. The root server should use `start_engine` instead. pub mod runtime { + pub use crate::batch_decode_buffers::split_chunk_size_for; pub use crate::executor::{ DecodePlan, DecodeRequestResult, DecodeResult, DecodeStepItem, PrefillPlan, PrefillRequestResult, PrefillResult, PrefillStepItem, Qwen3Executor, RequestId, diff --git a/openinfer-qwen3-4b/src/weights.rs b/openinfer-qwen3-4b/src/weights.rs index 560f5ac2..44831197 100644 --- a/openinfer-qwen3-4b/src/weights.rs +++ b/openinfer-qwen3-4b/src/weights.rs @@ -957,6 +957,7 @@ impl Qwen3Model { profile_blocks, 0, self.local_num_attention_heads(), + self.config.max_position_embeddings, ) .context("Qwen3 memory profile decode buffer alloc failed")?; record_peak()?; diff --git a/openinfer-qwen3-4b/tests/batch_invariance_decode_splitkv_graph.rs b/openinfer-qwen3-4b/tests/batch_invariance_decode_splitkv_graph.rs new file mode 100644 index 00000000..2d47b2b4 --- /dev/null +++ b/openinfer-qwen3-4b/tests/batch_invariance_decode_splitkv_graph.rs @@ -0,0 +1,190 @@ +//! Gate: SplitKv decode batch-invariance. Co-batching A with a longer B moves A's chunk count; +//! Tuned drifts, Pin/PerToken stay bit-identical. + +use openinfer_core::sampler::SamplingParams; +use openinfer_kernels::ops::{NumericPolicy, set_numeric_policy}; +use openinfer_qwen3_4b::runtime::{ + DecodePlan, DecodeStepItem, PrefillPlan, PrefillStepItem, Qwen3Executor, RequestId, + split_chunk_size_for, +}; + +const LOGPROBS: usize = 64; +const MAX_OUTPUT_TOKENS: usize = 4; +// A long enough that A's Tuned chunk size exceeds the 64-token floor (ceil(A_LEN/64) > 64); B +// longer than A so the decode max_seq_len (= max KV length in the batch) rises from A_LEN to +// B_LEN, changing A's Tuned chunk count. Batch 2 <= 32 selects SplitKv for both calls (bucket 2). +const A_LEN: usize = 5000; +const B_LEN: usize = 8000; +const B_SHORT: usize = 100; // << A_LEN, so call C's decode max_seq_len = A_LEN + +fn model_path_or_skip() -> Option { + let Ok(p) = std::env::var("OPENINFER_TEST_MODEL_PATH") else { + eprintln!( + "skipping qwen3 batch_invariance_decode_splitkv_graph: set OPENINFER_TEST_MODEL_PATH to Qwen3-4B-base" + ); + return None; + }; + Some(p) +} + +fn pitem(id: RequestId, prompt: Vec) -> PrefillStepItem { + PrefillStepItem::new( + id, + prompt, + MAX_OUTPUT_TOKENS, + SamplingParams::default(), + LOGPROBS, + false, + ) +} + +fn filler(len: usize, stride: u32) -> Vec { + (0..len as u32) + .map(|i| 1000 + (i * stride) % 50000) + .collect() +} + +/// Run A fixed, then decode A co-batched with B; returns A's `(prefill first_token, decode top-K)`. +fn a_decode_cobatched_with( + ex: &mut Qwen3Executor, + a_prompt: &[u32], + b_prompt: &[u32], +) -> (u32, Vec<(u32, f32)>) { + let id_a = RequestId::new(1); + let id_b = RequestId::new(2); + // Prefill A alone (batch 1); A's KV is identical regardless of B's length. + let pr_a = ex + .execute_prefill(PrefillPlan { + sample_seed: 0, + requests: &[pitem(id_a, a_prompt.to_vec())], + echo: false, + }) + .expect("prefill A"); + let a_first = pr_a.requests[0].first_token; + let pr_b = ex + .execute_prefill(PrefillPlan { + sample_seed: 0, + requests: &[pitem(id_b, b_prompt.to_vec())], + echo: false, + }) + .expect("prefill B"); + // Decode A+B together (batch 2, A row 0); B's KV length sets the batch max_seq_len. + let ditems = vec![ + DecodeStepItem::new(id_a, a_first, SamplingParams::default(), LOGPROBS), + DecodeStepItem::new( + id_b, + pr_b.requests[0].first_token, + SamplingParams::default(), + LOGPROBS, + ), + ]; + let dr = ex + .execute_decode(DecodePlan { + sample_seed: 0, + requests: &ditems, + }) + .expect("decode"); + let topk = dr.requests[0] + .logprob + .as_ref() + .expect("logprobs requested but none returned") + .top_logprobs + .clone(); + ex.drop_request(id_a).expect("drop A"); + ex.drop_request(id_b).expect("drop B"); + (a_first, topk) +} + +/// Fresh executor with `policy` active before first decode; returns A's +/// `(first_token_eq, decode_topk_eq)` across capture vs replay of the same SplitKv graph. +fn run_policy(policy: NumericPolicy, model_path: &str) -> (bool, bool) { + set_numeric_policy(policy); + let mut ex = Qwen3Executor::from_runtime(model_path, true, &[0]).expect("build executor"); + ex.set_prefix_cache_enabled(false); + let pl = match policy { + NumericPolicy::Tuned => "baseline ", + NumericPolicy::Pin => "pin ", + NumericPolicy::PerToken => "per_token", + }; + let a = filler(A_LEN, 7); + let b_short = filler(B_SHORT, 11); + let b_long = filler(B_LEN, 13); + + let (ft_c, tk_c) = a_decode_cobatched_with(&mut ex, &a, &b_short); // capture @ max_seq=A_LEN + let (ft_r, tk_r) = a_decode_cobatched_with(&mut ex, &a, &b_long); // replay @ max_seq=B_LEN + let ft_eq = ft_c == ft_r; + let tk_eq = tk_c == tk_r; + eprintln!( + "batch_invariance_decode_splitkv_graph [{pl}]: path=SplitKv phase=capture@max_seq={}->replay@max_seq={} \ + A_prefill=isolated(batch1) decode_GEMM_N=2(fixed) \ + first_token_eq={ft_eq} decode_topk_eq={tk_eq} lp0(C={:.6},R={:.6})", + A_LEN + 1, + B_LEN + 1, + tk_c[0].1, + tk_r[0].1 + ); + (ft_eq, tk_eq) +} + +#[test] +fn batch_invariance_decode_splitkv_graph() { + let Some(model_path) = model_path_or_skip() else { + return; + }; + eprintln!( + "batch_invariance_decode_splitkv_graph: A_LEN={A_LEN} B_LEN={B_LEN} cuda_graph=true, A's prefill isolated \ + (batch 1): SplitKv graph captured at max_seq=A_LEN, replayed at max_seq=B_LEN (same bucket 2). \ + Only A's Tuned split chunk size varies; Pin fixes it; A's prefill + decode GEMM N=2 are fixed." + ); + + let tuned_chunk_c = split_chunk_size_for(A_LEN + 1); + let tuned_chunk_r = split_chunk_size_for(B_LEN + 1); + assert_ne!( + tuned_chunk_c, tuned_chunk_r, + "Tuned chunk-size arithmetic did not drift; the split-KV control would be vacuous" + ); + + let (tuned_ft, tuned_tk) = run_policy(NumericPolicy::Tuned, &model_path); + let (pin_ft, pin_tk) = run_policy(NumericPolicy::Pin, &model_path); + let (pertoken_ft, pertoken_tk) = run_policy(NumericPolicy::PerToken, &model_path); + + eprintln!( + "batch_invariance_decode_splitkv_graph: RESULT Tuned_chunk_tokens(C={tuned_chunk_c},R={tuned_chunk_r}) \ + | decode_topk_eq baseline={tuned_tk} pin={pin_tk} per_token={pertoken_tk} \ + | first_token_eq baseline={tuned_ft} pin={pin_ft} per_token={pertoken_ft}" + ); + + // A's prefill is isolated, so first_token must match C-vs-R under every policy (else isolation broke). + assert!( + tuned_ft, + "baseline: A's prefill first_token differs C-vs-R, isolation broke" + ); + assert!( + pin_ft, + "pin: A's prefill first_token differs C-vs-R, isolation broke" + ); + assert!( + pertoken_ft, + "per_token: A's prefill first_token differs C-vs-R, isolation broke" + ); + + assert!( + !tuned_tk, + "baseline: A's decode top-K did not drift despite the Tuned chunk-size change; isolation suspect" + ); + + assert!( + pin_tk, + "pin: A's decode top-K changed across the SplitKv graph replay (capture@A_LEN, replay@B_LEN), \ + with prefill isolated, so (graph,SplitKv) is not batch-invariant under pin" + ); + assert!( + pertoken_tk, + "per_token: A's decode top-K changed across the SplitKv graph replay; harness/fix bug" + ); + + eprintln!( + "batch_invariance_decode_splitkv_graph: PASS with A's prefill isolated; Tuned chunk-size \ + arithmetic drifts and A's top-K follows, while Pin and PerToken replay bit-identically." + ); +} diff --git a/openinfer-server/src/config.rs b/openinfer-server/src/config.rs index 7ee423a6..ab74e3d3 100644 --- a/openinfer-server/src/config.rs +++ b/openinfer-server/src/config.rs @@ -119,8 +119,8 @@ pub(crate) struct Args { #[arg(long, default_value_t = 20)] pub decode_sm_pct: u32, - /// Enable Qwen3 projection-GEMM batch-invariant pinning. Off by default; - /// does not cover path-selection residuals. Qwen3-only. + /// Enable Qwen3 projection-GEMM and split-KV chunk-count batch-invariant + /// pinning. Off by default; does not cover path-selection residuals. Qwen3-only. #[arg(long, default_value_t = false)] pub batch_invariant: bool, }