[Feature] Add VP drafter training mode for DFlash#592
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the vp_drafter training objective (variable-prefix drafter) to the DFlash framework, allowing for training with variable visible prefixes. It adds a configuration file for Qwen3-8B, an online training script, and updates the core DFlash model and training script to support prefix length sampling, VP noise embedding generation, and corresponding loss calculations. Unit tests are also added to validate the new loss implementation. The review feedback highlights three key improvements: replacing torch.distributions.Categorical with torch.multinomial to prevent graph breaks under torch.compile, handling potential None values for prefix_weight_base to avoid a TypeError, and using .reshape() instead of .view() on noise_ids to safely handle non-contiguous tensors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
Great work! We'll check it soon. |
Motivation
This PR adds training support for the VP-Drafter used in D2SD (Dual Diffusion Draft Speculative Decoding). D2SD extends DFlash by using a first DFlash draft to estimate likely rejection boundaries, then training a second variable-prefix drafter to re-anchor at selected prefixes and generate alternative continuations.
The key training requirement is different from standard DFlash: the drafter must learn from variable-length visible prefixes instead of always seeing only the anchor token followed by masks. This PR implements that behavior as a DFlash training-mode branch, so the resulting model still uses the same
DFlashDraftModelarchitecture and config format.References:
Modifications
vp_draftertraining mode toOnlineDFlashModel.vp_draftermode samples a variable visible prefix length per block.scripts/train_dflash.pyto readdflash_config.training_mode/dflash_config.loss_typewhen--loss-typeis not explicitly provided.prefix_weight_basesupport for variable-prefix sampling.configs/qwen3-8b-dta.jsonexamples/run_qwen3_8b_dta_online.shChecklist