Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 130 additions & 25 deletions openinfer-qwen3-4b/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ impl PrefillStepItem {
self
}

/// Cap the prompt tokens this step forwards (chunked prefill). The executor
/// clamps it to the tokens actually remaining and advances `chunk_start`
/// from the request's KV position across calls, so re-issuing the same
/// request id with the same budget walks the prompt one chunk at a time.
/// The scheduler sets this from its own budget; it is also the public hook
/// for driving chunked prefill directly against the executor.
#[must_use]
pub fn with_chunk_budget(mut self, chunk_budget: usize) -> Self {
self.chunk_budget = chunk_budget;
self
}

/// Prompt tokens forwarded this step.
fn as_slice(&self) -> &[u32] {
&self.prompt_tokens[self.chunk_start..self.chunk_start + self.chunk_tokens]
Expand Down Expand Up @@ -130,6 +142,20 @@ impl DecodeStepItem {
}
}

/// Fold one chunk's partial echo prompt logprobs into the accumulator. Each
/// prompt index is produced by exactly one chunk (see
/// `build_prefill_request_results`), so every `Some` lands on a `None`.
fn fold_prompt_logprobs_chunk(
acc: &mut [Option<TokenLogprob>],
partial: Vec<Option<TokenLogprob>>,
) {
for (slot, value) in acc.iter_mut().zip(partial) {
if value.is_some() {
*slot = value;
}
}
}

fn build_prefill_request_results(
lane: &mut LocalQwen3Lane,
requests: &[PrefillStepItem],
Expand All @@ -149,30 +175,33 @@ fn build_prefill_request_results(
} else {
None
};
// Echo emits a per-chunk *partial*: this chunk's local position `local`
// (logits column `token_offset + local`) predicts prompt token
// `chunk_start + local + 1`, so it fills that slot. Index 0 has no
// predecessor; the column predicting the first generated token
// (`target == prompt_len`) is skipped. Partials are stitched across
// chunks by `merge_echo_prompt_logprobs`, surfacing the full vector on
// the final chunk.
let prompt_logprobs = if req.echo {
let mut partial = vec![None; req.prompt_tokens.len()];
if compute_prompt_logprobs {
let mut echo_logprobs = Vec::with_capacity(req.prompt_tokens.len());
echo_logprobs.push(None);
if let Some(all_logits) = all_position_logits {
for j in 1..req.prompt_tokens.len() {
let prev_pos = token_offset + j - 1;
let target_token = req.prompt_tokens[j];
echo_logprobs.push(lane.extract_prompt_logprobs(
for local in 0..req.chunk_tokens {
let target = req.chunk_start + local + 1;
if target >= req.prompt_tokens.len() {
break;
}
let target_token = req.prompt_tokens[target];
partial[target] = lane.extract_prompt_logprobs(
all_logits,
prev_pos,
token_offset + local,
target_token,
req.logprobs,
));
}
} else {
for _ in 1..req.prompt_tokens.len() {
echo_logprobs.push(None);
);
}
}
Some(echo_logprobs)
} else {
Some(vec![None; req.prompt_tokens.len()])
}
Some(partial)
} else {
None
};
Expand Down Expand Up @@ -602,6 +631,12 @@ pub struct Qwen3Executor {
/// so prefix matching itself stays enabled). Set via
/// [`Self::set_no_prefix_cache`].
l1_retention_disabled: bool,
/// Per-request accumulator for echo prompt logprobs while the prompt is
/// chunk-prefilled. Each chunk computes the logprobs for the prompt
/// positions it forwarded; the full vector is moved into the result on the
/// final chunk (see [`Self::merge_echo_prompt_logprobs`]). Empty unless an
/// echo request is mid-prefill.
echo_prompt_logprobs: HashMap<RequestId, Vec<Option<TokenLogprob>>>,
}

/// One request's in-flight CPU-tier KV prefetch.
Expand Down Expand Up @@ -656,6 +691,7 @@ impl Qwen3Executor {
saved_cursor: HashMap::new(),
prefetch: HashMap::new(),
l1_retention_disabled: false,
echo_prompt_logprobs: HashMap::new(),
})
}

Expand Down Expand Up @@ -805,6 +841,7 @@ impl Qwen3Executor {
saved_cursor: HashMap::new(),
prefetch: HashMap::new(),
l1_retention_disabled: false,
echo_prompt_logprobs: HashMap::new(),
})
}

Expand Down Expand Up @@ -990,13 +1027,10 @@ impl Qwen3Executor {
.expect("inserted above");
req.chunk_start = rkv.kv_position();
let remaining = req.prompt_tokens.len() - req.chunk_start;
// Echo must produce all-position logits in a single forward, so it is
// exempt from chunking (the scheduler never splits echo requests).
req.chunk_tokens = if req.echo {
remaining
} else {
remaining.min(req.chunk_budget)
};
// Echo chunks like everything else: each chunk computes all-position
// logits for its own slice and the executor stitches the prompt
// logprobs across chunks (see `merge_echo_prompt_logprobs`).
req.chunk_tokens = remaining.min(req.chunk_budget);
assert!(
req.chunk_tokens > 0,
"zero-token prefill chunk for {:?} (budget {})",
Expand All @@ -1022,6 +1056,23 @@ impl Qwen3Executor {
}
}

/// Accumulate each chunk's partial echo prompt logprobs, surfacing the full
/// vector on the final chunk and nothing on earlier ones. No-op for non-echo
/// results (their `prompt_logprobs` is `None`).
fn merge_echo_prompt_logprobs(&mut self, result: &mut PrefillRequestResult) {
let Some(partial) = result.prompt_logprobs.take() else {
return;
};
let acc = self
.echo_prompt_logprobs
.entry(result.request_id)
.or_insert_with(|| vec![None; partial.len()]);
fold_prompt_logprobs_chunk(acc, partial);
if result.completed {
result.prompt_logprobs = self.echo_prompt_logprobs.remove(&result.request_id);
}
}

// ── KV-offload LOAD (async CPU-tier prefetch) ──────────────────────
// The trait-facing prefetch hooks (`begin_kv_prefetch`,
// `drain_ready_prefetch`, `wait_ready_prefetch`, `has_pending_prefetch`)
Expand Down Expand Up @@ -1193,6 +1244,9 @@ impl ModelExecutor for Qwen3Executor {
}
}
self.saved_cursor.remove(&request_id);
// Drop any half-accumulated echo prompt logprobs for a request that
// disconnected mid-prefill, so the buffer doesn't leak.
self.echo_prompt_logprobs.remove(&request_id);
Ok(())
}

Expand Down Expand Up @@ -1345,7 +1399,7 @@ impl ModelExecutor for Qwen3Executor {
let outcome = self.run_step(&step)?;

// 4. Apply prefill
let result = match outcome {
let mut result = match outcome {
WorkerStepOutcome::Prefill(result) => result,
other => {
return Err(anyhow::anyhow!(
Expand All @@ -1354,6 +1408,11 @@ impl ModelExecutor for Qwen3Executor {
));
}
};
// Stitch each echo request's per-chunk prompt logprobs; the full vector
// surfaces on the final chunk, earlier chunks emit nothing.
for req_result in &mut result.requests {
self.merge_echo_prompt_logprobs(req_result);
}
for req_result in &result.requests {
self.apply_prefill_result(req_result)?;
}
Expand Down Expand Up @@ -1458,7 +1517,7 @@ impl ModelExecutor for Qwen3Executor {
let outcome = self.run_step(&step)?;

// 4. Apply both prefill and decode
let result = match outcome {
let mut result = match outcome {
WorkerStepOutcome::Unified(result) => result,
other => {
return Err(anyhow::anyhow!(
Expand All @@ -1467,6 +1526,9 @@ impl ModelExecutor for Qwen3Executor {
));
}
};
for req_result in &mut result.prefill_requests {
self.merge_echo_prompt_logprobs(req_result);
}
for req_result in &result.prefill_requests {
self.apply_prefill_result(req_result)?;
}
Expand Down Expand Up @@ -1676,9 +1738,52 @@ impl ModelExecutor for Qwen3Executor {

#[cfg(test)]
mod tests {
use super::ensure_lora_capacity;
use super::{ensure_lora_capacity, fold_prompt_logprobs_chunk};
use openinfer_core::engine::TokenLogprob;
use std::collections::HashSet;

fn lp(logprob: f32) -> TokenLogprob {
TokenLogprob {
logprob,
top_logprobs: Vec::new(),
}
}

// Echo chunking computes prompt token `k`'s logprob from the chunk that
// forwarded position `k - 1`. Folding each chunk's partial must reconstruct
// exactly the same vector a single whole-prompt pass would produce: index 0
// stays None (no predecessor), every other index is filled once, and a
// boundary index whose predecessor is the previous chunk's last position is
// produced by that previous chunk — not dropped or double-written.
#[test]
fn folding_echo_chunks_reconstructs_the_full_prompt_logprobs() {
// 5-token prompt, chunk budget 2: positions [0,1], [2,3], [4].
// Chunk [0,2) fills indices 1,2; chunk [2,4) fills 3,4; chunk [4,5)
// fills nothing in-prompt (position 4 predicts the first *generated*
// token). prompt_logprobs has indices 0..5.
let mut acc = vec![None; 5];

let chunk0 = vec![None, Some(lp(-1.0)), Some(lp(-2.0)), None, None];
fold_prompt_logprobs_chunk(&mut acc, chunk0);
assert_eq!(acc[0], None, "index 0 has no predecessor");
assert_eq!(acc[1].as_ref().map(|l| l.logprob), Some(-1.0));
assert_eq!(acc[2].as_ref().map(|l| l.logprob), Some(-2.0));
assert_eq!(acc[3], None, "index 3's predecessor is in the next chunk");

let chunk1 = vec![None, None, None, Some(lp(-3.0)), Some(lp(-4.0))];
fold_prompt_logprobs_chunk(&mut acc, chunk1);

let chunk2 = vec![None, None, None, None, None];
fold_prompt_logprobs_chunk(&mut acc, chunk2);

let got: Vec<Option<f32>> = acc.iter().map(|l| l.as_ref().map(|l| l.logprob)).collect();
assert_eq!(
got,
vec![None, Some(-1.0), Some(-2.0), Some(-3.0), Some(-4.0)],
"every prompt index is filled exactly once, index 0 stays None"
);
}

#[test]
fn lora_capacity_rejects_new_adapter_at_limit() {
let loaded = HashSet::from(["adapter-a".to_string()]);
Expand Down
Loading