Custom topk-logprobs kernel + Remove a redundant Device to Host logits copy.#434
Conversation
Move the DDTree draft top-K + log-prob extraction off the CPU and onto the GPU, operating directly on the draft logits' device buffer instead of D2H-ing the full [vocab x n_positions] logits and running the OpenMP heap top-K. New split-K kernel (draft_topk_cuda.cu) with register-resident per-thread top-K and online-logsumexp; ~10x faster kernel, ~10% end-to-end tok/s. Also add a GPU verify-argmax shortcut: read the in-graph batched per-node argmax (sg_.argmax_tokens) to skip the verify-step vocab x N logits D2H + CPU argmax, validating each row and falling back to CPU argmax on any bad index. Both paths fall back to the existing CPU code and are runtime-toggleable: - DFLASH_GPU_DRAFT_TOPK=0 disables the GPU top-K - DFLASH_GPU_VERIFY_ARGMAX=0 forces the legacy CPU verify argmax The GPU top-K is compiled in for CUDA builds (DFLASH27B_HAVE_DRAFT_TOPK_CUDA). Applied in both the library (qwen35_dflash_target.cpp) and the standalone end-to-end harness (test_dflash.cpp), which keep separate copies of the decode/verify loop. Adds test_draft_topk_cuda.cpp checking the kernel bit-for-bit against the CPU reference, registered in CMake. bench_llm.py: fix tokenizer (Qwen3.6-27B) and dataset names (openai/...).
There was a problem hiding this comment.
2 issues found across 7 files
Prompt for AI agents (unresolved issues)
Check if these issues are valid — if so, understand the root cause of each and fix them. If appropriate, use sub-agents to investigate and fix each issue separately.
<file name="server/src/common/geometric_draft_topk_cuda.h">
<violation number="1" location="server/src/common/geometric_draft_topk_cuda.h:28">
P2: Failure-signaling GPU top-k API should be marked `[[nodiscard]]` so callers cannot silently ignore fallback-critical errors.</violation>
</file>
<file name="server/src/qwen35/qwen35_dflash_target.cpp">
<violation number="1" location="server/src/qwen35/qwen35_dflash_target.cpp:308">
P2: GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.</violation>
</file>
Reply with feedback, questions, or to request a fix.
Re-trigger cubic
| // d_logits: device pointer to row-major [n_positions][vocab] f32 logits (the | ||
| // position stride is `vocab` floats — pass an offset pointer to skip | ||
| // leading positions). out_* are HOST buffers of size n_positions*K. | ||
| bool geometric_extract_draft_topk_cuda(const void * d_logits, |
There was a problem hiding this comment.
P2: Failure-signaling GPU top-k API should be marked [[nodiscard]] so callers cannot silently ignore fallback-critical errors.
Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At server/src/common/geometric_draft_topk_cuda.h, line 28:
<comment>Failure-signaling GPU top-k API should be marked `[[nodiscard]]` so callers cannot silently ignore fallback-critical errors.</comment>
<file context>
@@ -0,0 +1,34 @@
+// d_logits: device pointer to row-major [n_positions][vocab] f32 logits (the
+// position stride is `vocab` floats — pass an offset pointer to skip
+// leading positions). out_* are HOST buffers of size n_positions*K.
+bool geometric_extract_draft_topk_cuda(const void * d_logits,
+ int n_positions, int vocab, int K,
+ float * out_log_probs,
</file context>
| for (int i = 0; i < N_actual; i++) { | ||
| if (posterior_out[i] < 0 || posterior_out[i] >= vocab) { ok = false; break; } | ||
| } | ||
| if (ok) return true; // fast path; otherwise fall through to CPU argmax |
There was a problem hiding this comment.
P2: GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.
Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At server/src/qwen35/qwen35_dflash_target.cpp, line 308:
<comment>GPU verify-argmax fast path returns early after only an in-range bounds check, skipping CPU verification that would catch silent in-range argmax mismatches.</comment>
<file context>
@@ -278,13 +279,38 @@ bool Qwen35DFlashTarget::verify_tree(
+ for (int i = 0; i < N_actual; i++) {
+ if (posterior_out[i] < 0 || posterior_out[i] >= vocab) { ok = false; break; }
+ }
+ if (ok) return true; // fast path; otherwise fall through to CPU argmax
+ }
+
</file context>
|
Thanks for the work here. The CUDA top-k path looks promising and the standalone correctness test passes on lucebox/RTX 3090. One delivery/structure request: the new runtime flags should be documented outside the PR text, ideally in the README or a small tuning/debug section. That would make it clear which paths are default-on, how to disable them, and which flags are only for debugging/profiling. Right now this knowledge mostly lives in the PR description and code comments, so it will be easy |
|
Thanks for the feedback @davide221. I'll document the flags in the README.md in a new section. |
|
@davide221 I've update the |
|
lgtm |
What does the PR do?
The extraction of top-k vocab indices from the logit scores of the draft model currently happens on the cpu. This involves a costly transfer of the entire logit scores from the GPU to CPU i.e. Device to Host. In this PR we move the top-k computation to the GPU entirely.
This PR also adds a verify-argmax fastpath that skips the
logits transfer from Device to Host and re-uses the argmax indices stored as an attribute. It can be toggled via DFLASH_GPU_VERIFY_ARGMAX. It re-uses the same CPU fallback on any out-of-range index.
A comment stated that previous builds had issues with GPU based verify-argmax, this PR adds a check to detect failures and resort to using the prior CPU method for running the top-k computation in this scenario.
The GPU kernel and verify-argmax is enabled by default but can be disabled via the
DFLASH_GPU_DRAFT_TOPKandDFLASH_GPU_VERIFY_ARGMAXflags.Files edited
bench_llm.pyto use the right tokenizer + correct the naming conventions of datasets being used.geometric_draft_topk_cuda.cu/h: Creates a new cuda kernel used by the draft model to identify the topk (where k is in the range 1-8) vocab indices from the logit scores.qwen35_dflash_target.cpp: Invokes the new.cukernel and falls back to the cpu execution of topk incase there are errors.test_dflash.cpp: Invokes the.cukernel if theDFLASH_GPU_DRAFT_TOPKflag is enabled. Uses the fast-path verification of draft tokens ifDFLASH_GPU_VERIFY_ARGMAX=1, if it is set to 2 it identifies/reports any mismatches between cpu and gpu based verification.test_draft_topk_cuda.cpp: Correctness tests for the cuda kernel by comparing it against the previously used cpu function.CMakeLists.txt: Register the new cuda kernel.Results
Results in an ~27% increase in tok/s. (105.63-82.80)/(82.8)
Numbers can be reproduced by running bench_llm.py after re-building with the changes.
The benchmark was run on an RTX 3090.
Results before
Results after
Reproducing Results
Baseline:
DFLASH_GPU_DRAFT_TOPK=0 DFLASH_GPU_VERIFY_ARGMAX=0 python server/scripts/bench_llm.py --bench HumanEvalLatest Results:
DFLASH_GPU_DRAFT_TOPK=1 DFLASH_GPU_VERIFY_ARGMAX=1 python server/scripts/bench_llm.py --bench HumanEval