diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 12a34d4e..83ea8bd4 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -86,6 +86,18 @@ impl PrefillStepItem { self } + /// Cap the prompt tokens this step forwards (chunked prefill). The executor + /// clamps it to the tokens actually remaining and advances `chunk_start` + /// from the request's KV position across calls, so re-issuing the same + /// request id with the same budget walks the prompt one chunk at a time. + /// The scheduler sets this from its own budget; it is also the public hook + /// for driving chunked prefill directly against the executor. + #[must_use] + pub fn with_chunk_budget(mut self, chunk_budget: usize) -> Self { + self.chunk_budget = chunk_budget; + self + } + /// Prompt tokens forwarded this step. fn as_slice(&self) -> &[u32] { &self.prompt_tokens[self.chunk_start..self.chunk_start + self.chunk_tokens] @@ -130,6 +142,20 @@ impl DecodeStepItem { } } +/// Fold one chunk's partial echo prompt logprobs into the accumulator. Each +/// prompt index is produced by exactly one chunk (see +/// `build_prefill_request_results`), so every `Some` lands on a `None`. +fn fold_prompt_logprobs_chunk( + acc: &mut [Option], + partial: Vec>, +) { + for (slot, value) in acc.iter_mut().zip(partial) { + if value.is_some() { + *slot = value; + } + } +} + fn build_prefill_request_results( lane: &mut LocalQwen3Lane, requests: &[PrefillStepItem], @@ -149,30 +175,33 @@ fn build_prefill_request_results( } else { None }; + // Echo emits a per-chunk *partial*: this chunk's local position `local` + // (logits column `token_offset + local`) predicts prompt token + // `chunk_start + local + 1`, so it fills that slot. Index 0 has no + // predecessor; the column predicting the first generated token + // (`target == prompt_len`) is skipped. Partials are stitched across + // chunks by `merge_echo_prompt_logprobs`, surfacing the full vector on + // the final chunk. let prompt_logprobs = if req.echo { + let mut partial = vec![None; req.prompt_tokens.len()]; if compute_prompt_logprobs { - let mut echo_logprobs = Vec::with_capacity(req.prompt_tokens.len()); - echo_logprobs.push(None); if let Some(all_logits) = all_position_logits { - for j in 1..req.prompt_tokens.len() { - let prev_pos = token_offset + j - 1; - let target_token = req.prompt_tokens[j]; - echo_logprobs.push(lane.extract_prompt_logprobs( + for local in 0..req.chunk_tokens { + let target = req.chunk_start + local + 1; + if target >= req.prompt_tokens.len() { + break; + } + let target_token = req.prompt_tokens[target]; + partial[target] = lane.extract_prompt_logprobs( all_logits, - prev_pos, + token_offset + local, target_token, req.logprobs, - )); - } - } else { - for _ in 1..req.prompt_tokens.len() { - echo_logprobs.push(None); + ); } } - Some(echo_logprobs) - } else { - Some(vec![None; req.prompt_tokens.len()]) } + Some(partial) } else { None }; @@ -602,6 +631,12 @@ pub struct Qwen3Executor { /// so prefix matching itself stays enabled). Set via /// [`Self::set_no_prefix_cache`]. l1_retention_disabled: bool, + /// Per-request accumulator for echo prompt logprobs while the prompt is + /// chunk-prefilled. Each chunk computes the logprobs for the prompt + /// positions it forwarded; the full vector is moved into the result on the + /// final chunk (see [`Self::merge_echo_prompt_logprobs`]). Empty unless an + /// echo request is mid-prefill. + echo_prompt_logprobs: HashMap>>, } /// One request's in-flight CPU-tier KV prefetch. @@ -656,6 +691,7 @@ impl Qwen3Executor { saved_cursor: HashMap::new(), prefetch: HashMap::new(), l1_retention_disabled: false, + echo_prompt_logprobs: HashMap::new(), }) } @@ -805,6 +841,7 @@ impl Qwen3Executor { saved_cursor: HashMap::new(), prefetch: HashMap::new(), l1_retention_disabled: false, + echo_prompt_logprobs: HashMap::new(), }) } @@ -990,13 +1027,10 @@ impl Qwen3Executor { .expect("inserted above"); req.chunk_start = rkv.kv_position(); let remaining = req.prompt_tokens.len() - req.chunk_start; - // Echo must produce all-position logits in a single forward, so it is - // exempt from chunking (the scheduler never splits echo requests). - req.chunk_tokens = if req.echo { - remaining - } else { - remaining.min(req.chunk_budget) - }; + // Echo chunks like everything else: each chunk computes all-position + // logits for its own slice and the executor stitches the prompt + // logprobs across chunks (see `merge_echo_prompt_logprobs`). + req.chunk_tokens = remaining.min(req.chunk_budget); assert!( req.chunk_tokens > 0, "zero-token prefill chunk for {:?} (budget {})", @@ -1022,6 +1056,23 @@ impl Qwen3Executor { } } + /// Accumulate each chunk's partial echo prompt logprobs, surfacing the full + /// vector on the final chunk and nothing on earlier ones. No-op for non-echo + /// results (their `prompt_logprobs` is `None`). + fn merge_echo_prompt_logprobs(&mut self, result: &mut PrefillRequestResult) { + let Some(partial) = result.prompt_logprobs.take() else { + return; + }; + let acc = self + .echo_prompt_logprobs + .entry(result.request_id) + .or_insert_with(|| vec![None; partial.len()]); + fold_prompt_logprobs_chunk(acc, partial); + if result.completed { + result.prompt_logprobs = self.echo_prompt_logprobs.remove(&result.request_id); + } + } + // ── KV-offload LOAD (async CPU-tier prefetch) ────────────────────── // The trait-facing prefetch hooks (`begin_kv_prefetch`, // `drain_ready_prefetch`, `wait_ready_prefetch`, `has_pending_prefetch`) @@ -1193,6 +1244,9 @@ impl ModelExecutor for Qwen3Executor { } } self.saved_cursor.remove(&request_id); + // Drop any half-accumulated echo prompt logprobs for a request that + // disconnected mid-prefill, so the buffer doesn't leak. + self.echo_prompt_logprobs.remove(&request_id); Ok(()) } @@ -1345,7 +1399,7 @@ impl ModelExecutor for Qwen3Executor { let outcome = self.run_step(&step)?; // 4. Apply prefill - let result = match outcome { + let mut result = match outcome { WorkerStepOutcome::Prefill(result) => result, other => { return Err(anyhow::anyhow!( @@ -1354,6 +1408,11 @@ impl ModelExecutor for Qwen3Executor { )); } }; + // Stitch each echo request's per-chunk prompt logprobs; the full vector + // surfaces on the final chunk, earlier chunks emit nothing. + for req_result in &mut result.requests { + self.merge_echo_prompt_logprobs(req_result); + } for req_result in &result.requests { self.apply_prefill_result(req_result)?; } @@ -1458,7 +1517,7 @@ impl ModelExecutor for Qwen3Executor { let outcome = self.run_step(&step)?; // 4. Apply both prefill and decode - let result = match outcome { + let mut result = match outcome { WorkerStepOutcome::Unified(result) => result, other => { return Err(anyhow::anyhow!( @@ -1467,6 +1526,9 @@ impl ModelExecutor for Qwen3Executor { )); } }; + for req_result in &mut result.prefill_requests { + self.merge_echo_prompt_logprobs(req_result); + } for req_result in &result.prefill_requests { self.apply_prefill_result(req_result)?; } @@ -1676,9 +1738,52 @@ impl ModelExecutor for Qwen3Executor { #[cfg(test)] mod tests { - use super::ensure_lora_capacity; + use super::{ensure_lora_capacity, fold_prompt_logprobs_chunk}; + use openinfer_core::engine::TokenLogprob; use std::collections::HashSet; + fn lp(logprob: f32) -> TokenLogprob { + TokenLogprob { + logprob, + top_logprobs: Vec::new(), + } + } + + // Echo chunking computes prompt token `k`'s logprob from the chunk that + // forwarded position `k - 1`. Folding each chunk's partial must reconstruct + // exactly the same vector a single whole-prompt pass would produce: index 0 + // stays None (no predecessor), every other index is filled once, and a + // boundary index whose predecessor is the previous chunk's last position is + // produced by that previous chunk — not dropped or double-written. + #[test] + fn folding_echo_chunks_reconstructs_the_full_prompt_logprobs() { + // 5-token prompt, chunk budget 2: positions [0,1], [2,3], [4]. + // Chunk [0,2) fills indices 1,2; chunk [2,4) fills 3,4; chunk [4,5) + // fills nothing in-prompt (position 4 predicts the first *generated* + // token). prompt_logprobs has indices 0..5. + let mut acc = vec![None; 5]; + + let chunk0 = vec![None, Some(lp(-1.0)), Some(lp(-2.0)), None, None]; + fold_prompt_logprobs_chunk(&mut acc, chunk0); + assert_eq!(acc[0], None, "index 0 has no predecessor"); + assert_eq!(acc[1].as_ref().map(|l| l.logprob), Some(-1.0)); + assert_eq!(acc[2].as_ref().map(|l| l.logprob), Some(-2.0)); + assert_eq!(acc[3], None, "index 3's predecessor is in the next chunk"); + + let chunk1 = vec![None, None, None, Some(lp(-3.0)), Some(lp(-4.0))]; + fold_prompt_logprobs_chunk(&mut acc, chunk1); + + let chunk2 = vec![None, None, None, None, None]; + fold_prompt_logprobs_chunk(&mut acc, chunk2); + + let got: Vec> = acc.iter().map(|l| l.as_ref().map(|l| l.logprob)).collect(); + assert_eq!( + got, + vec![None, Some(-1.0), Some(-2.0), Some(-3.0), Some(-4.0)], + "every prompt index is filled exactly once, index 0 stays None" + ); + } + #[test] fn lora_capacity_rejects_new_adapter_at_limit() { let loaded = HashSet::from(["adapter-a".to_string()]); diff --git a/openinfer-qwen3-4b/src/scheduler.rs b/openinfer-qwen3-4b/src/scheduler.rs index 42b6c23a..90ad64bc 100644 --- a/openinfer-qwen3-4b/src/scheduler.rs +++ b/openinfer-qwen3-4b/src/scheduler.rs @@ -98,35 +98,26 @@ impl PendingRequest { /// Pull the next prefill step set off the front of `prefilling`, capping the /// step's total forwarded tokens at `max_prefill_tokens`. Each taken request -/// gets its per-step chunk recorded in `step_chunk`. Echo requests need -/// logits for every prompt position in one forward, so they only run when -/// their whole remainder fits — or alone at the head of an empty step, which -/// also guarantees the queue always makes progress. +/// gets its per-step chunk recorded in `step_chunk`. Echo requests chunk like +/// any other: each chunk forwards only its own tokens and computes +/// all-position logits for that slice (`vocab × chunk`), and the executor +/// stitches the prompt logprobs across chunks, so a long echo prompt never +/// has to materialize logits for every position at once. fn take_prefill_chunks( prefilling: &mut Vec, max_prefill_tokens: usize, ) -> Vec { let mut budget = max_prefill_tokens; let mut taken: Vec = Vec::new(); - let mut i = 0; - while i < prefilling.len() && budget > 0 { - let remaining = prefilling[i].remaining_prompt_tokens(); - let chunk = if prefilling[i].echo { - if remaining > budget && !taken.is_empty() { - i += 1; - continue; - } - remaining - } else { - remaining.min(budget) - }; - let mut req = prefilling.remove(i); + while !prefilling.is_empty() && budget > 0 { + let remaining = prefilling[0].remaining_prompt_tokens(); + let chunk = remaining.min(budget); + let mut req = prefilling.remove(0); req.step_chunk = chunk; budget = budget.saturating_sub(chunk); taken.push(req); } - // Echo skips can take items out of arrival order; results come back - // sorted by request id, so the step set must be too. + // Results come back sorted by request id, so the step set must be too. taken.sort_by_key(|req| req.request_id); taken } @@ -820,9 +811,10 @@ fn prefilling_future_blocks( /// batch can eat the post-KV-pool VRAM headroom and OOM mid-serving under a /// request burst. Prompts longer than the budget are split across steps, so /// long prompts can't monopolize a step and starve running decodes. -/// Exception: echo requests need all-position logits in one forward and run -/// whole regardless of the budget — an oversized echo prompt still spikes -/// activation memory. +/// Echo requests chunk under the same budget: each chunk computes all-position +/// logits only for its own tokens and the executor stitches the prompt +/// logprobs across chunks, so echo activation scratch is bounded by one chunk +/// (`vocab × budget`) rather than the whole prompt. /// /// A unified step's duration scales with its prefill tokens, and every decode /// request in the batch stalls for the whole step — the budget bounds that @@ -945,16 +937,17 @@ fn failure_targets_for( ) -> Vec { let mut targets = Vec::new(); match plan { - self::plan::ExecutionPlan::Prefill { pending } => { + // A Prefill step can run with active decodes present (echo+logprobs is + // routed here even then), and a failure clears all active state, so the + // active requests must be failed too — same as Unified. + self::plan::ExecutionPlan::Prefill { pending } + | self::plan::ExecutionPlan::Unified { pending } => { + targets.extend(active.iter().map(active_failure_target)); targets.extend(pending.iter().map(pending_failure_target)); } self::plan::ExecutionPlan::Decode => { targets.extend(active.iter().map(active_failure_target)); } - self::plan::ExecutionPlan::Unified { pending } => { - targets.extend(active.iter().map(active_failure_target)); - targets.extend(pending.iter().map(pending_failure_target)); - } } targets } @@ -1569,39 +1562,43 @@ mod tests { } #[test] - fn echo_requests_never_split_but_run_alone_when_oversized() { + fn echo_requests_chunk_under_the_prefill_budget() { let mk_echo = |id: u64, prompt_len| { let (req, _rx) = request(prompt_len, 1); let mut pending = PendingRequest::from_scheduler_request(RequestId(id), req); pending.echo = true; pending }; - let mk = |id: u64, prompt_len| { - PendingRequest::from_scheduler_request(RequestId(id), request(prompt_len, 1).0) - }; - // Oversized echo at the head of an empty step runs whole — chunking it - // would lose the all-position logits echo needs. - let mut prefilling = vec![mk_echo(1, 64), mk(2, 16)]; + // An oversized echo prompt is split like any other: the first chunk + // fills the budget and the remainder stays queued for the next step. + let mut prefilling = vec![mk_echo(1, 64)]; let taken = take_prefill_chunks(&mut prefilling, 32); assert_eq!(taken.len(), 1); - assert_eq!(taken[0].step_chunk, 64, "echo takes its full prompt"); - assert_eq!(prefilling[0].request_id, RequestId(2)); + assert_eq!( + taken[0].step_chunk, 32, + "echo chunk is capped at the budget" + ); + assert_eq!( + taken[0].remaining_prompt_tokens(), + 64, + "prefill_pos only advances once the executor confirms the chunk, so the \ + remainder is re-queued for the next step" + ); + assert!(prefilling.is_empty()); - // An echo that doesn't fit behind earlier work is skipped, not split; - // later requests may still fill the leftover budget, and the step set - // stays sorted by request id. - let mut prefilling = vec![mk(3, 24), mk_echo(4, 16), mk(5, 8)]; + // Echo packs alongside other requests under the shared budget. + let mut prefilling = vec![mk_echo(2, 16), mk_echo(3, 8)]; let taken = take_prefill_chunks(&mut prefilling, 32); assert_eq!( taken .iter() .map(|r| (r.request_id.get(), r.step_chunk)) .collect::>(), - vec![(3, 24), (5, 8)], - "echo skipped, leftover budget goes to the next non-echo request" + vec![(2, 16), (3, 8)], + "two short echo prompts share one step" ); - assert_eq!(prefilling[0].request_id, RequestId(4)); + assert!(prefilling.is_empty()); } #[test] @@ -2262,6 +2259,46 @@ mod tests { ); } + // A Prefill step can run while decodes are active (echo+logprobs is routed + // to Prefill even then). On execution failure the scheduler clears all + // active state, so failure_targets_for must include the active requests too; + // otherwise they'd be dropped without an Error event or a drop_request, + // leaking KV and hanging the client. + #[test] + fn prefill_failure_targets_include_active_decodes() { + let (active_tx, _active_rx) = TokenSink::standalone(); + let active = vec![ActiveRequestState { + request_id: RequestId(100), + lora_adapter: None, + token_tx: active_tx, + last_token: 1, + generated_count: 1, + max_tokens: 4, + prompt_len: 16, + params: SamplingParams::default(), + logprobs: 0, + }]; + + let mut echo_lp = pending(200, true); + echo_lp.logprobs = 5; + let plan = self::plan::ExecutionPlan::Prefill { + pending: vec![echo_lp], + }; + + let ids: Vec = failure_targets_for(&active, &plan) + .iter() + .map(|t| t.request_id.get()) + .collect(); + assert!( + ids.contains(&100), + "active decode must be failed on a prefill error" + ); + assert!( + ids.contains(&200), + "the prefilling request must be failed too" + ); + } + #[test] fn active_receiver_drop_releases_request_state() { let dropped = Arc::new(Mutex::new(Vec::new())); diff --git a/openinfer-qwen3-4b/src/scheduler/plan.rs b/openinfer-qwen3-4b/src/scheduler/plan.rs index a4794f4b..efc69008 100644 --- a/openinfer-qwen3-4b/src/scheduler/plan.rs +++ b/openinfer-qwen3-4b/src/scheduler/plan.rs @@ -33,11 +33,20 @@ pub(super) enum ExecutionArtifacts { }, } +/// Whether the batch needs all-position prompt logprobs, i.e. it has an echo +/// request that also asked for logprobs. Echo alone only echoes ids back. +fn batch_needs_prompt_logprobs(pending: &[PendingRequest]) -> bool { + pending.iter().any(|req| req.echo && req.logprobs > 0) +} + pub(super) fn build_next_plan( have_active: bool, pending: Vec, ) -> Option { - if !pending.is_empty() && have_active { + // echo+logprobs needs a dedicated Prefill: the unified forward can't produce + // all-position logits. Active decodes wait those ticks; it's rare. + let needs_dedicated_prefill = batch_needs_prompt_logprobs(&pending); + if !pending.is_empty() && have_active && !needs_dedicated_prefill { Some(ExecutionPlan::Unified { pending }) } else if !pending.is_empty() { Some(ExecutionPlan::Prefill { pending }) @@ -59,10 +68,11 @@ pub(super) fn execute_plan( let scheduled_at_unix_s = openinfer_core::engine::unix_now_s(); let indices: Vec = (0..pending.len()).collect(); let requests = build_prefill_items(&pending, &indices); - let any_echo = pending.iter().any(|req| req.echo); + // `echo` here = "compute all-position logits"; only echo+logprobs + // needs them. let mut result = executor.execute_prefill(PrefillPlan { requests: &requests, - echo: any_echo, + echo: batch_needs_prompt_logprobs(&pending), sample_seed: rand::RngExt::random(rng), })?; sort_prefill_results(&mut result.requests); @@ -203,4 +213,88 @@ mod tests { "active + pending fuses prefill and decode into one unified step" ); } + + // Regression for #372: an echo+logprobs request must not be fused into a + // Unified step (the unified forward can't produce all-position prompt + // logprobs), even when decodes are active. It takes a dedicated Prefill. + #[test] + fn echo_logprobs_pending_forces_dedicated_prefill() { + let echo_logprobs = || { + let mut p = pending(); + p.echo = true; + p.logprobs = 5; + p + }; + assert!( + matches!( + build_next_plan(true, vec![echo_logprobs()]), + Some(ExecutionPlan::Prefill { pending }) if pending.len() == 1 + ), + "echo+logprobs must take a dedicated prefill, not a unified step" + ); + // A mixed batch with any echo+logprobs request also takes prefill. + assert!( + matches!( + build_next_plan(true, vec![pending(), echo_logprobs()]), + Some(ExecutionPlan::Prefill { .. }) + ), + "a batch containing an echo+logprobs request takes prefill" + ); + // echo without logprobs still fuses (no all-position logits needed). + let echo_only = || { + let mut p = pending(); + p.echo = true; + p + }; + assert!( + matches!( + build_next_plan(true, vec![echo_only()]), + Some(ExecutionPlan::Unified { .. }) + ), + "echo without logprobs does not need a dedicated prefill" + ); + } + + // All-position logits fire only for echo+logprobs: plain, echo-only, and + // logprobs-only batches must not, a mixed batch with one such request must. + #[test] + fn all_position_logits_gated_on_echo_plus_logprobs() { + let echo_logprobs = || { + let mut p = pending(); + p.echo = true; + p.logprobs = 5; + p + }; + let echo_only = || { + let mut p = pending(); + p.echo = true; + p + }; + let logprobs_only = || { + let mut p = pending(); + p.logprobs = 5; + p + }; + + assert!( + !batch_needs_prompt_logprobs(&[pending()]), + "a plain prompt needs no all-position logits" + ); + assert!( + !batch_needs_prompt_logprobs(&[echo_only()]), + "echo without logprobs only echoes ids back — no all-position logits" + ); + assert!( + !batch_needs_prompt_logprobs(&[logprobs_only()]), + "logprobs without echo only needs the sampled token's logprob" + ); + assert!( + batch_needs_prompt_logprobs(&[echo_logprobs()]), + "echo+logprobs needs all-position logits" + ); + assert!( + batch_needs_prompt_logprobs(&[echo_only(), echo_logprobs()]), + "one echo+logprobs request in the batch turns the scratch on" + ); + } } diff --git a/openinfer-qwen3-4b/tests/echo_chunked_prefill.rs b/openinfer-qwen3-4b/tests/echo_chunked_prefill.rs new file mode 100644 index 00000000..bbcf5969 --- /dev/null +++ b/openinfer-qwen3-4b/tests/echo_chunked_prefill.rs @@ -0,0 +1,310 @@ +//! Chunked-echo self-consistency gate for echo + logprobs prefill. +//! +//! Echo with `logprobs > 0` returns a logprob for every prompt token: the +//! logprob of prompt token `k` is read from the model's distribution at +//! position `k - 1`. The naive implementation materializes all-position logits +//! (`vocab × prompt_len`) in one forward, which OOMs on long prompts (#358). +//! The fix chunks the echo prefill like any other prompt: each chunk computes +//! all-position logits only for its own slice and the executor stitches the +//! per-chunk prompt logprobs back together (see `merge_echo_prompt_logprobs`). +//! +//! The numerically interesting part is the cross-chunk seam: the logprob of the +//! token at a chunk boundary comes from the *previous* chunk's last position, +//! and an off-by-one there would silently corrupt or drop one logprob per +//! boundary. There is no HF golden for within-prompt distributions, so this gate +//! uses a stronger, hardware-independent invariant instead: **the chunked echo +//! prompt logprobs must match the single-pass result** — same model, same GPU, +//! the only difference being how many tokens each forward processed. +//! +//! Crossing a seam moves later positions onto the `kv_len > q_len` attention +//! path, which drifts a few bf16 ULPs (exactly like the prefix-cache replay in +//! `hf_golden_gate`), so the match is asserted within a tight bf16 tolerance +//! rather than bit-exact. The comparison is over each *actual* prompt token's +//! logprob — the same token id in both runs — so a benign bf16 argmax tie +//! cannot trip it, while a seam off-by-one (reading the logprob from the wrong +//! position) moves it by far more than ULP noise. A realistic prompt (a golden +//! prompt plus its own teacher-forced continuation) keeps the per-position +//! distributions peaked, so that noise floor stays low. +//! +//! Requires a CUDA GPU, Qwen3-4B weights, and the HF golden token file; skips +//! cleanly when the model is absent (point `OPENINFER_TEST_MODEL_PATH` at the +//! weights to run it). + +use std::path::Path; + +use openinfer_core::engine::TokenLogprob; +use openinfer_core::sampler::SamplingParams; +use openinfer_qwen3_4b::runtime::{PrefillPlan, PrefillStepItem, Qwen3Executor, RequestId}; +use safetensors::{Dtype, SafeTensors}; + +const MODEL_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3-4B"); +const GOLDEN: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/../test_data/qwen3-4b-hf-golden.safetensors" +); + +/// Top-K logprobs requested per position — wide enough that the chunked and +/// single-pass top-K sets overlap for the seam comparison. +const LOGPROBS: usize = 16; +const MAX_OUTPUT_TOKENS: usize = 8; +/// Small budget so even a short realistic sequence is split into several +/// chunks, crossing the seam multiple times (and ending on a partial chunk). +const CHUNK_BUDGET: usize = 4; + +/// Engine-vs-engine on a peaked sequence: the only difference is the forward +/// *shape* (a 4-row chunk vs a full-width pass picks different bf16 GEMM +/// reduction orders, and seams move later positions onto the `kv_len > q_len` +/// attention path). `MEAN_TOL` stays under the HF gate's own 0.06 floor — both +/// runs sit much closer to each other than either does to HF. `MAX_TOL` allows +/// the irreducible bf16 tail on a single position (observed ~0.26 over 160 +/// positions) while staying far below a real seam off-by-one, which reads a +/// peaked token's logprob from the wrong position and moves it by nats. +const MEAN_TOL: f32 = 0.04; +const MAX_TOL: f32 = 0.50; + +fn model_path_or_skip() -> Option { + match std::env::var("OPENINFER_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => { + Some(MODEL_PATH.to_string()) + } + Err(_) => { + eprintln!( + "skipping echo_chunked_prefill: {MODEL_PATH}/config.json is missing; set OPENINFER_TEST_MODEL_PATH to run it" + ); + None + } + } +} + +fn as_i32(st: &SafeTensors, name: &str) -> (Vec, Vec) { + let t = st + .tensor(name) + .unwrap_or_else(|e| panic!("golden missing {name}: {e}")); + assert_eq!(t.dtype(), Dtype::I32, "{name} must be i32"); + let v = t + .data() + .chunks_exact(4) + .map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(); + (v, t.shape().to_vec()) +} + +/// A realistic, peaked prompt: golden sequence 0's prompt followed by its own +/// teacher-forced decode tokens. Concatenating the model's own continuation +/// keeps every next-token distribution sharp, so the actual-token logprobs are +/// stable across the two forward shapes and the tolerance can stay tight. +fn realistic_prompt() -> Vec { + let bytes = std::fs::read(GOLDEN).unwrap_or_else(|e| panic!("read {GOLDEN}: {e}")); + let st = SafeTensors::deserialize(&bytes).expect("parse golden safetensors"); + let (prompt_tokens, _) = as_i32(&st, "prompt_tokens"); + let (prompt_lens, _) = as_i32(&st, "prompt_lens"); + let (decode_tokens, dshape) = as_i32(&st, "decode_tokens"); + let decode_len = dshape[1]; + + let p0_len = prompt_lens[0] as usize; + let mut seq: Vec = prompt_tokens[..p0_len].iter().map(|&t| t as u32).collect(); + seq.extend((0..decode_len).map(|step| decode_tokens[step] as u32)); + seq +} + +fn echo_item(id: RequestId, prompt: Vec) -> PrefillStepItem { + PrefillStepItem::new( + id, + prompt, + MAX_OUTPUT_TOKENS, + SamplingParams::default(), + LOGPROBS, + true, // echo + ) +} + +/// Echo-prefill `prompt` in a single forward (budget ≥ prompt_len) and return +/// the prompt logprobs. +fn single_pass(ex: &mut Qwen3Executor, prompt: &[u32]) -> Vec> { + let id = RequestId::new(1); + let result = ex + .execute_prefill(PrefillPlan { + requests: &[echo_item(id, prompt.to_vec())], + echo: true, + sample_seed: 0, + }) + .expect("single-pass echo prefill"); + let req = &result.requests[0]; + assert!( + req.completed, + "single pass must finish the prompt in one step" + ); + let lps = req + .prompt_logprobs + .clone() + .expect("echo prefill must return prompt logprobs"); + ex.drop_request(id).expect("drop single-pass request"); + lps +} + +/// Echo-prefill `prompt` in a single chunk via the chunked plumbing +/// (`with_chunk_budget` ≥ prompt_len). Exercises the per-chunk partial build +/// and the accumulator merge for the contiguous case, where they must be the +/// identity transform — so the result has to equal [`single_pass`] bit-for-bit. +fn single_chunk_via_budget(ex: &mut Qwen3Executor, prompt: &[u32]) -> Vec> { + let id = RequestId::new(3); + let result = ex + .execute_prefill(PrefillPlan { + requests: &[echo_item(id, prompt.to_vec()).with_chunk_budget(prompt.len())], + echo: true, + sample_seed: 0, + }) + .expect("single-chunk echo prefill"); + let req = &result.requests[0]; + assert!( + req.completed, + "budget == prompt_len must finish in one chunk" + ); + let lps = req.prompt_logprobs.clone().expect("prompt logprobs"); + ex.drop_request(id).expect("drop single-chunk request"); + lps +} + +/// Echo-prefill `prompt` one `budget`-token chunk per `execute_prefill` call, +/// mirroring how the scheduler drives a long prompt across steps, and return +/// the stitched prompt logprobs from the final chunk. Asserts that only the +/// final chunk surfaces prompt logprobs. +fn chunked(ex: &mut Qwen3Executor, prompt: &[u32], budget: usize) -> Vec> { + let id = RequestId::new(2); + let mut steps = 0; + loop { + let result = ex + .execute_prefill(PrefillPlan { + requests: &[echo_item(id, prompt.to_vec()).with_chunk_budget(budget)], + echo: true, + sample_seed: 0, + }) + .expect("chunked echo prefill"); + let req = &result.requests[0]; + steps += 1; + if req.completed { + let lps = req + .prompt_logprobs + .clone() + .expect("final chunk must return the stitched prompt logprobs"); + assert!( + steps > 1, + "budget {budget} < prompt {} must take more than one chunk", + prompt.len() + ); + ex.drop_request(id).expect("drop chunked request"); + return lps; + } + assert!( + req.prompt_logprobs.is_none(), + "non-final chunk must not surface prompt logprobs (step {steps})" + ); + } +} + +#[test] +fn chunked_echo_prompt_logprobs_match_single_pass() { + let Some(model_path) = model_path_or_skip() else { + return; + }; + if !Path::new(GOLDEN).exists() { + eprintln!("skipping echo_chunked_prefill: {GOLDEN} is missing"); + return; + } + let mut ex = Qwen3Executor::from_runtime(&model_path, false, &[0]) + .unwrap_or_else(|e| panic!("build executor: {e:#}")); + // Echo bypasses the prefix cache anyway, but disable it so neither run can + // reuse the other's blocks and shrink a forward. + ex.set_prefix_cache_enabled(false); + + let prompt = realistic_prompt(); + assert!( + prompt.len() > CHUNK_BUDGET, + "need a prompt longer than one chunk to cross a seam" + ); + let reference = single_pass(&mut ex, &prompt); + + // Control: routing the whole prompt through the chunked plumbing as a single + // chunk must reproduce the single-pass result *exactly*. This isolates a + // bug in the partial-build / accumulator merge (which would show here, with + // no numerical drift to hide behind) from the unavoidable bf16 path drift + // that crossing real seams introduces below. + let one_chunk = single_chunk_via_budget(&mut ex, &prompt); + assert_eq!( + one_chunk.len(), + reference.len(), + "single-chunk plumbing changed the prompt logprobs length" + ); + for (k, (r, o)) in reference.iter().zip(&one_chunk).enumerate() { + match (r, o) { + (None, None) => {} + (Some(r), Some(o)) => { + assert_eq!( + r.logprob.to_bits(), + o.logprob.to_bits(), + "single-chunk plumbing perturbed prompt index {k}'s logprob with no forward-shape change" + ); + assert_eq!( + r.top_logprobs, o.top_logprobs, + "single-chunk plumbing perturbed prompt index {k}'s top-K" + ); + } + _ => panic!("single-chunk plumbing changed Some/None at prompt index {k}"), + } + } + + let got = chunked(&mut ex, &prompt, CHUNK_BUDGET); + + assert_eq!( + reference.len(), + prompt.len(), + "prompt logprobs has one slot per prompt token" + ); + assert_eq!(got.len(), reference.len(), "chunked length must match"); + assert!( + reference[0].is_none() && got[0].is_none(), + "the first prompt token has no predecessor, so its logprob is None in both runs" + ); + + // Compare the logprob of each *actual* prompt token — the same token id in + // both runs, so this is apples-to-apples regardless of bf16 argmax ties + // (two near-equal tokens can swap the top-1 slot between a 4-row and a + // full-width forward without either being wrong; that is why + // `hf_golden_gate` also refuses to assert exact argmax). A seam off-by-one, + // by contrast, reads the token's logprob from the wrong position and moves + // it by far more than bf16 noise — which the per-position delta below + // catches, with the worst offender pinpointed by index. + let mut deltas: Vec<(usize, f32)> = Vec::new(); + for (k, (r, g)) in reference.iter().zip(&got).enumerate().skip(1) { + let r = r + .as_ref() + .unwrap_or_else(|| panic!("single pass missing logprob at prompt index {k}")); + let g = g + .as_ref() + .unwrap_or_else(|| panic!("chunked run missing logprob at prompt index {k}")); + deltas.push((k, (r.logprob - g.logprob).abs())); + } + + let count = deltas.len() as f32; + let mean = deltas.iter().map(|&(_, d)| d).sum::() / count; + let (worst_k, max) = deltas + .iter() + .copied() + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .unwrap(); + eprintln!( + "echo_chunked_prefill: prompt_len={}, budget={CHUNK_BUDGET}, {} positions — \ + mean Δlogprob {mean:.5}, max {max:.5} @ index {worst_k}", + prompt.len(), + deltas.len() + ); + assert!( + mean <= MEAN_TOL, + "mean |Δlogprob| {mean:.5} > {MEAN_TOL} — chunked echo drifted from the single pass beyond bf16 noise" + ); + assert!( + max <= MAX_TOL, + "max |Δlogprob| {max:.5} @ index {worst_k} > {MAX_TOL} — that prompt position is materially wrong (likely a seam off-by-one)" + ); +}