Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions openinfer-qwen3-4b/src/batch_decode_buffers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use cudarc::driver::CudaSlice;

use openinfer_core::cuda_graph::CudaGraphState;
use openinfer_core::tensor::{DeviceContext, HiddenStates};
use openinfer_kernels::ops::{NumericPolicy, numeric_policy};
use openinfer_kv_cache::KvView;

/// Bucket sizes for CUDA Graph capture. Actual batch is padded to the nearest bucket.
Expand Down Expand Up @@ -34,6 +35,11 @@ const SPLIT_KV_CHUNK_TOKENS: usize = 64;
const SPLIT_KV_MAX_CHUNKS_PER_REQUEST: usize = 64;
const SPLIT_KV_MAX_BATCH_SIZE: usize = 32;

/// Chunk size bounding a `basis`-token request to `SPLIT_KV_MAX_CHUNKS_PER_REQUEST` chunks.
pub fn split_chunk_size_for(basis: usize) -> usize {
SPLIT_KV_CHUNK_TOKENS.max(basis.div_ceil(SPLIT_KV_MAX_CHUNKS_PER_REQUEST))
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum DecodeAttentionPath {
NonPartition,
Expand Down Expand Up @@ -110,6 +116,9 @@ pub(crate) struct BatchDecodeBuffers {
pub(crate) split_tmp_s: CudaSlice<f32>,
pub(crate) split_padded_slots: usize,
max_seq_len: usize,
/// Model context limit (`max_position_embeddings`) — the `Pin` split-KV chunk basis (see
/// `split_chunk_size`).
max_context_tokens: usize,

/// Padding page index for bucket CUDA Graph. Padding slots point here.
padding_page_id: i32,
Expand Down Expand Up @@ -137,6 +146,7 @@ impl BatchDecodeBuffers {
max_total_pages: usize,
padding_page_id: i32,
num_qo_heads: usize,
max_context_tokens: usize,
) -> Result<Self> {
let bs = max_batch_size;
// The split-KV path is gated on padded_bs <= SPLIT_KV_MAX_BATCH_SIZE,
Expand Down Expand Up @@ -178,6 +188,7 @@ impl BatchDecodeBuffers {
split_tmp_s: ctx.stream.alloc_zeros(max_split_slots * num_qo_heads)?,
split_padded_slots: 0,
max_seq_len: 0,
max_context_tokens,
padding_page_id,
graphs: BATCH_BUCKETS
.iter()
Expand Down Expand Up @@ -263,6 +274,17 @@ impl BatchDecodeBuffers {
Ok(())
}

/// The chunk count sets the online-softmax rescale order, hence the bf16 result. `Pin`/`PerToken`
/// key the chunk-size basis on the `max_context_tokens` constant, so the count does not vary with
/// the batch.
fn split_chunk_size(&self) -> usize {
let basis = match numeric_policy() {
NumericPolicy::Tuned => self.max_seq_len,
NumericPolicy::Pin | NumericPolicy::PerToken => self.max_context_tokens,
};
split_chunk_size_for(basis)
}

fn sync_split_kv_meta(
&mut self,
ctx: &DeviceContext,
Expand All @@ -274,8 +296,7 @@ impl BatchDecodeBuffers {
if padded_bs > SPLIT_KV_MAX_BATCH_SIZE {
return Ok(());
}
let split_chunk_size =
SPLIT_KV_CHUNK_TOKENS.max(self.max_seq_len.div_ceil(SPLIT_KV_MAX_CHUNKS_PER_REQUEST));
let split_chunk_size = self.split_chunk_size();
let split_padded_slots = padded_bs * SPLIT_KV_MAX_CHUNKS_PER_REQUEST;
let mut split_request_indices = Vec::with_capacity(split_padded_slots);
let mut split_kv_tile_indices = Vec::with_capacity(split_padded_slots);
Expand All @@ -285,7 +306,12 @@ impl BatchDecodeBuffers {

for (request_idx, kv) in kv_views.iter().enumerate() {
let chunks = kv.seq_len().div_ceil(split_chunk_size).max(1);
debug_assert!(chunks <= SPLIT_KV_MAX_CHUNKS_PER_REQUEST);
anyhow::ensure!(
chunks <= SPLIT_KV_MAX_CHUNKS_PER_REQUEST,
"split-KV chunk count {chunks} exceeds workspace bound {SPLIT_KV_MAX_CHUNKS_PER_REQUEST} \
(seq_len={}, split_chunk_size={split_chunk_size}) — context limit misconfigured",
kv.seq_len()
);
for chunk_idx in 0..chunks {
split_request_indices.push(request_idx as i32);
split_kv_tile_indices.push(chunk_idx as i32);
Expand Down
1 change: 1 addition & 0 deletions openinfer-qwen3-4b/src/batch_decode_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub fn trace_decode_kernel_calls(
kv_mgr.pool().total_blocks(),
kv_mgr.pool().padding_block_id(),
model.local_num_attention_heads(),
model.config().max_position_embeddings,
)?;
let token_ids = vec![0_u32; batch_size];
let views: Vec<_> = rkvs.iter().map(|r| r.decode_view()).collect();
Expand Down
6 changes: 6 additions & 0 deletions openinfer-qwen3-4b/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,11 @@ fn tune_decode_gemm_algos(model: &Qwen3Model) -> Result<()> {
gemm_lt_pin_warmup(intermediate, hidden)?;
gemm_lt_pin_warmup(hidden, intermediate)?;
gemm_lt_pin_warmup(vocab, hidden)?;
let max_context = model.config().max_position_embeddings;
log::info!(
"Qwen3 split-KV decode chunk pinned: {} tokens (max_context_tokens={max_context})",
crate::batch_decode_buffers::split_chunk_size_for(max_context)
);
return Ok(());
}

Expand Down Expand Up @@ -2088,6 +2093,7 @@ impl LocalQwen3Lane {
total_blocks,
padding_block_id,
model.local_num_attention_heads(),
model.config().max_position_embeddings,
)?;
let sample_scratch = openinfer_sample::SampleScratch::new(
model.device_ctx(),
Expand Down
1 change: 1 addition & 0 deletions openinfer-qwen3-4b/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ impl Default for Qwen3OffloadOptions {
/// This is the production phase boundary used by the Qwen3 scheduler and by
/// model-local benchmarks. The root server should use `start_engine` instead.
pub mod runtime {
pub use crate::batch_decode_buffers::split_chunk_size_for;
pub use crate::executor::{
DecodePlan, DecodeRequestResult, DecodeResult, DecodeStepItem, PrefillPlan,
PrefillRequestResult, PrefillResult, PrefillStepItem, Qwen3Executor, RequestId,
Expand Down
1 change: 1 addition & 0 deletions openinfer-qwen3-4b/src/weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,7 @@ impl Qwen3Model {
profile_blocks,
0,
self.local_num_attention_heads(),
self.config.max_position_embeddings,
)
.context("Qwen3 memory profile decode buffer alloc failed")?;
record_peak()?;
Expand Down
190 changes: 190 additions & 0 deletions openinfer-qwen3-4b/tests/batch_invariance_decode_splitkv_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
//! Gate: SplitKv decode batch-invariance. Co-batching A with a longer B moves A's chunk count;
//! Tuned drifts, Pin/PerToken stay bit-identical.

use openinfer_core::sampler::SamplingParams;
use openinfer_kernels::ops::{NumericPolicy, set_numeric_policy};
use openinfer_qwen3_4b::runtime::{
DecodePlan, DecodeStepItem, PrefillPlan, PrefillStepItem, Qwen3Executor, RequestId,
split_chunk_size_for,
};

const LOGPROBS: usize = 64;
const MAX_OUTPUT_TOKENS: usize = 4;
// A long enough that A's Tuned chunk size exceeds the 64-token floor (ceil(A_LEN/64) > 64); B
// longer than A so the decode max_seq_len (= max KV length in the batch) rises from A_LEN to
// B_LEN, changing A's Tuned chunk count. Batch 2 <= 32 selects SplitKv for both calls (bucket 2).
const A_LEN: usize = 5000;
const B_LEN: usize = 8000;
const B_SHORT: usize = 100; // << A_LEN, so call C's decode max_seq_len = A_LEN

fn model_path_or_skip() -> Option<String> {
let Ok(p) = std::env::var("OPENINFER_TEST_MODEL_PATH") else {
eprintln!(
"skipping qwen3 batch_invariance_decode_splitkv_graph: set OPENINFER_TEST_MODEL_PATH to Qwen3-4B-base"
);
return None;
};
Some(p)
}

fn pitem(id: RequestId, prompt: Vec<u32>) -> PrefillStepItem {
PrefillStepItem::new(
id,
prompt,
MAX_OUTPUT_TOKENS,
SamplingParams::default(),
LOGPROBS,
false,
)
}

fn filler(len: usize, stride: u32) -> Vec<u32> {
(0..len as u32)
.map(|i| 1000 + (i * stride) % 50000)
.collect()
}

/// Run A fixed, then decode A co-batched with B; returns A's `(prefill first_token, decode top-K)`.
fn a_decode_cobatched_with(
ex: &mut Qwen3Executor,
a_prompt: &[u32],
b_prompt: &[u32],
) -> (u32, Vec<(u32, f32)>) {
let id_a = RequestId::new(1);
let id_b = RequestId::new(2);
// Prefill A alone (batch 1); A's KV is identical regardless of B's length.
let pr_a = ex
.execute_prefill(PrefillPlan {
sample_seed: 0,
requests: &[pitem(id_a, a_prompt.to_vec())],
echo: false,
})
.expect("prefill A");
let a_first = pr_a.requests[0].first_token;
let pr_b = ex
.execute_prefill(PrefillPlan {
sample_seed: 0,
requests: &[pitem(id_b, b_prompt.to_vec())],
echo: false,
})
.expect("prefill B");
// Decode A+B together (batch 2, A row 0); B's KV length sets the batch max_seq_len.
let ditems = vec![
DecodeStepItem::new(id_a, a_first, SamplingParams::default(), LOGPROBS),
DecodeStepItem::new(
id_b,
pr_b.requests[0].first_token,
SamplingParams::default(),
LOGPROBS,
),
];
let dr = ex
.execute_decode(DecodePlan {
sample_seed: 0,
requests: &ditems,
})
.expect("decode");
let topk = dr.requests[0]
.logprob
.as_ref()
.expect("logprobs requested but none returned")
.top_logprobs
.clone();
ex.drop_request(id_a).expect("drop A");
ex.drop_request(id_b).expect("drop B");
(a_first, topk)
}

/// Fresh executor with `policy` active before first decode; returns A's
/// `(first_token_eq, decode_topk_eq)` across capture vs replay of the same SplitKv graph.
fn run_policy(policy: NumericPolicy, model_path: &str) -> (bool, bool) {
set_numeric_policy(policy);
let mut ex = Qwen3Executor::from_runtime(model_path, true, &[0]).expect("build executor");
ex.set_prefix_cache_enabled(false);
let pl = match policy {
NumericPolicy::Tuned => "baseline ",
NumericPolicy::Pin => "pin ",
NumericPolicy::PerToken => "per_token",
};
let a = filler(A_LEN, 7);
let b_short = filler(B_SHORT, 11);
let b_long = filler(B_LEN, 13);

let (ft_c, tk_c) = a_decode_cobatched_with(&mut ex, &a, &b_short); // capture @ max_seq=A_LEN
let (ft_r, tk_r) = a_decode_cobatched_with(&mut ex, &a, &b_long); // replay @ max_seq=B_LEN
let ft_eq = ft_c == ft_r;
let tk_eq = tk_c == tk_r;
eprintln!(
"batch_invariance_decode_splitkv_graph [{pl}]: path=SplitKv phase=capture@max_seq={}->replay@max_seq={} \
A_prefill=isolated(batch1) decode_GEMM_N=2(fixed) \
first_token_eq={ft_eq} decode_topk_eq={tk_eq} lp0(C={:.6},R={:.6})",
A_LEN + 1,
B_LEN + 1,
tk_c[0].1,
tk_r[0].1
);
(ft_eq, tk_eq)
}

#[test]
fn batch_invariance_decode_splitkv_graph() {
let Some(model_path) = model_path_or_skip() else {
return;
};
eprintln!(
"batch_invariance_decode_splitkv_graph: A_LEN={A_LEN} B_LEN={B_LEN} cuda_graph=true, A's prefill isolated \
(batch 1): SplitKv graph captured at max_seq=A_LEN, replayed at max_seq=B_LEN (same bucket 2). \
Only A's Tuned split chunk size varies; Pin fixes it; A's prefill + decode GEMM N=2 are fixed."
);

let tuned_chunk_c = split_chunk_size_for(A_LEN + 1);
let tuned_chunk_r = split_chunk_size_for(B_LEN + 1);
assert_ne!(
tuned_chunk_c, tuned_chunk_r,
"Tuned chunk-size arithmetic did not drift; the split-KV control would be vacuous"
);

let (tuned_ft, tuned_tk) = run_policy(NumericPolicy::Tuned, &model_path);
let (pin_ft, pin_tk) = run_policy(NumericPolicy::Pin, &model_path);
let (pertoken_ft, pertoken_tk) = run_policy(NumericPolicy::PerToken, &model_path);

eprintln!(
"batch_invariance_decode_splitkv_graph: RESULT Tuned_chunk_tokens(C={tuned_chunk_c},R={tuned_chunk_r}) \
| decode_topk_eq baseline={tuned_tk} pin={pin_tk} per_token={pertoken_tk} \
| first_token_eq baseline={tuned_ft} pin={pin_ft} per_token={pertoken_ft}"
);

// A's prefill is isolated, so first_token must match C-vs-R under every policy (else isolation broke).
assert!(
tuned_ft,
"baseline: A's prefill first_token differs C-vs-R, isolation broke"
);
assert!(
pin_ft,
"pin: A's prefill first_token differs C-vs-R, isolation broke"
);
assert!(
pertoken_ft,
"per_token: A's prefill first_token differs C-vs-R, isolation broke"
);

assert!(
!tuned_tk,
"baseline: A's decode top-K did not drift despite the Tuned chunk-size change; isolation suspect"
);

assert!(
pin_tk,
"pin: A's decode top-K changed across the SplitKv graph replay (capture@A_LEN, replay@B_LEN), \
with prefill isolated, so (graph,SplitKv) is not batch-invariant under pin"
);
assert!(
pertoken_tk,
"per_token: A's decode top-K changed across the SplitKv graph replay; harness/fix bug"
);

eprintln!(
"batch_invariance_decode_splitkv_graph: PASS with A's prefill isolated; Tuned chunk-size \
arithmetic drifts and A's top-K follows, while Pin and PerToken replay bit-identically."
);
}
4 changes: 2 additions & 2 deletions openinfer-server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ pub(crate) struct Args {
#[arg(long, default_value_t = 20)]
pub decode_sm_pct: u32,

/// Enable Qwen3 projection-GEMM batch-invariant pinning. Off by default;
/// does not cover path-selection residuals. Qwen3-only.
/// Enable Qwen3 projection-GEMM and split-KV chunk-count batch-invariant
/// pinning. Off by default; does not cover path-selection residuals. Qwen3-only.
#[arg(long, default_value_t = false)]
pub batch_invariant: bool,
}
Expand Down