Skip to content

[Feature] VLM DFlash Training: Multi-Model Support for Qwen3-VL / Qwen3.5 / Qwen3.6#585

Open
zyk42 wants to merge 1 commit into
sgl-project:mainfrom
zyk42:feature/vlm-multi-model-dflash
Open

[Feature] VLM DFlash Training: Multi-Model Support for Qwen3-VL / Qwen3.5 / Qwen3.6#585
zyk42 wants to merge 1 commit into
sgl-project:mainfrom
zyk42:feature/vlm-multi-model-dflash

Conversation

@zyk42

@zyk42 zyk42 commented Jun 18, 2026

Copy link
Copy Markdown

Model support

Model Architecture model_type transformers
Qwen3-VL-8B-Instruct Dense VLM, 36 layers qwen3_vl >= 5.7.0
Qwen3-VL-30B-A3B-Thinking MoE VLM (128 experts), 48 layers qwen3_vl_moe >= 5.7.0
Qwen3.5-9B Dense VLM + GDN, 32 layers qwen3_5 >= 5.7.0
Qwen3.5-35B-A3B MoE VLM (128 experts) + GDN, 40 layers qwen3_5_moe >= 5.7.0
Qwen3.6-35B-A3B MoE VLM (128 experts) + GDN, 40 layers qwen3_5_moe >= 5.7.0

Description

This PR extends SpecForge DFlash training to support multiple VLM model families with HF backend. Key changes:

  1. Extended VLM model typesQWEN3_VL_MODEL_TYPES now includes qwen3_5_moe and qwen3_5 in addition to qwen3_vl and qwen3_vl_moe.

  2. Auto VLM embedding detection — Automatically sets --embedding-key=model.language_model.embed_tokens.weight for all VLM model types (previously required manual specification for qwen3_vl).

  3. Qwen3.5/3.6 HF loading — Added Qwen3_5MoeForConditionalGeneration import path with graceful fallback for transformers < 5.7.0.

  4. Partial rotation RoPE — Qwen3.5/3.6 uses partial_rotary_factor=0.25 (only 64 out of 256 dims get RoPE). Updated apply_rotary_pos_emb in the DFlash draft model to handle rotary_dim < head_dim.

  5. Draft configs — New config files for all supported models with correct target_layer_ids (starting from layer 3+ for deepstack VLMs), mrope_section, and partial_rotary_factor.

  6. transformers 5.7.0 compatibility — Added mm_token_type_ids generation for Qwen3-VL models (required by transformers >= 5.7.0). Generates token type IDs from input_ids (image tokens → 1, video tokens → 2) and passes them to both forward() and get_rope_index().

Training results (Qwen3-VL-30B-A3B, GUI Agent task)

Configuration

Item Value
Target model Qwen3-VL-30B-A3B-Thinking (MoE, 128 experts, 48 layers)
Draft config 5-layer dense decoder, block_size=8, target_layer_ids=[5,14,24,34,44]
Dataset 278K greedy regen (target model self-generated GUI Agent dialogue data)
Training hardware 8x H20Z 144GB, HF backend, 8-GPU DP
Training duration ~5h/epoch × 5 epochs = ~25h
max_length 4096
transformers >= 5.7.0

Best result (278K, 5-layer, block_size=8)

Metric Value
Final accept length 3.52
Inference speedup +35.8%
Inference environment 4x RTX 5090, TP=4, SGLang 0.5.12
Training speed ~1.5 it/s (HF DP=8, after warmup)

Data scaling results

Dataset size Draft config Accept Length Speedup Note
2K 5-layer block8 ~1.2 N/A Severe overfitting
10K 5-layer block8 ~2.0 -15% Below break-even
10K 5-layer block16 ~1.7 -35% block16 worse at low accept rate
100K 8-layer block8 ~2.7 +3% Just at break-even
278K 5-layer block8 3.52 +35.8% Best cost-effectiveness
278K 8-layer block8 3.66 +34.8% Higher τ but slower draft, net similar

Break-even analysis (4x RTX 5090, TP=4)

  • Accept length < 2.5: negative (draft overhead > saved decode time)
  • Accept length 2.5~3.0: break-even to slightly positive
  • Accept length > 3.0: clear speedup (measured 3.52 → +35.8%)

Overfitting behavior (100K data)

Checkpoint Accept Length Speedup Note
epoch_2 (15K step) 2.67 +2.5% Normal
epoch_3 (20K step) 2.70 +3.0% Best
epoch_6 (40K step) 2.51 -3.4% Degrading
epoch_10 (60K step) 2.32 -10% Overfitting

Key observations

  • Data must be target model greedy regen — human-annotated data mismatches target hidden states, accept rate drops to ~1.x
  • 278K samples is the sweet spot; <10K causes severe overfitting, 100K barely breaks even
  • System prompt must match between training and inference — mismatch causes position offset and accept rate collapse
  • target_layer_ids: Qwen3-VL must skip first 3 layers (deepstack); Qwen3.5/3.6 starts from layer 1
  • 5 epochs optimal for 278K; beyond 6 epochs overfitting begins. For 100K, 3 epochs best
  • Install flash-attn for much faster training with Qwen3-VL (vision encoder bottleneck without it)
  • block_size=8 preferred for short-output tasks (GUI Agent); block_size=16 for long generation

Usage

torchrun --nproc_per_node=8 scripts/train_dflash.py \
  --target-model-path <MODEL_PATH> \
  --target-model-backend hf \
  --draft-config-path <DRAFT_CONFIG.json> \
  --train-data-path <VLM_JSONL> \
  --output-dir <OUTPUT> \
  --is-vlm \
  --batch-size 2 \
  --max-length 4096 \
  --num-epochs 5 \
  --chat-template guiagent \
  --attention-backend flex_attention \
  --block-size 8 \
  --num-anchors 512 \
  --loss-decay-gamma 4.0

Draft config design

Target Model Config File hidden_size head_dim target_layer_ids partial_rotary_factor mrope_section
Qwen3-VL-8B (36L) qwen3-vl-8b-dflash-vlm-8layer.json 4096 128 [3,10,18,25,32] [24,20,20]
Qwen3-VL-30B-A3B (48L) qwen3-vl-30b-a3b-dflash-vlm-8layer.json 2048 128 [3,13,24,34,44] [24,20,20]
Qwen3.5-35B-A3B (40L) qwen3.5-35b-a3b-dflash-vlm-8layer.json 2048 256 [1,10,19,28,37] 0.25 [11,11,10]
Qwen3.5-9B (32L) qwen3.5-9b-dflash-vlm-8layer.json 4096 256 [1,8,15,22,29] 0.25 [11,11,10]

Key design principles:

  • model_type always qwen3_vl_text (draft is always dense decoder, regardless of target architecture)
  • target_layer_ids: 5 layers, evenly distributed. Qwen3-VL starts from layer 3 (deepstack); Qwen3.5/3.6 starts from layer 0
  • Match hidden_size, num_attention_heads, num_key_value_heads, head_dim, vocab_size, rope_theta with target
  • Qwen3.5/3.6: partial_rotary_factor=0.25 (only 64/256 dims use RoPE), rope_theta=10000000
  • Qwen3-VL: no partial rotation, rope_theta=5000000

Files changed

File Change
scripts/train_dflash.py Extend QWEN3_VL_MODEL_TYPES to include qwen3_5_moe, qwen3_5; auto-detect VLM embedding key; add Qwen3.5 HF loading
specforge/modeling/draft/dflash.py apply_rotary_pos_emb supports partial rotation (rotary_dim < head_dim)
specforge/modeling/target/dflash_target_model.py HF backend: add qwen3_5_moe / qwen3_5 loading; add mm_token_type_ids generation for transformers 5.7.0 compatibility
configs/qwen3.5-35b-a3b-dflash-vlm-8layer.json New: Draft config for Qwen3.5/3.6 (mRoPE, partial_rotary_factor=0.25, head_dim=256)
configs/qwen3.5-9b-dflash-vlm-8layer.json New: Draft config for Qwen3.5-9B dense (mRoPE, partial_rotary_factor=0.25, head_dim=256)
configs/qwen3-vl-30b-a3b-dflash-vlm-8layer.json New: Draft config for Qwen3-VL-30B-A3B-Thinking (mRoPE, head_dim=128)
configs/qwen3-vl-8b-dflash-vlm-8layer.json New: Draft config for Qwen3-VL-8B-Instruct (mRoPE, head_dim=128)

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@qibaoyuan

Copy link
Copy Markdown

Good job!

@qibaoyuan

qibaoyuan commented Jun 23, 2026

Copy link
Copy Markdown

@FrankLeeeee This PR adds DFlash training support for Qwen3-VL, Qwen3.5, and Qwen3.6 models, including HF loading, partial-RoPE support, automatic embedding detection, and transformers 5.7.0 compatibility. Validation on Qwen3-VL-30B-A3B achieved 3.52 accept length and +35.8% inference speedup.
Please help review when you have a chance. Thanks a lot!

…ckend)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@zyk42 zyk42 force-pushed the feature/vlm-multi-model-dflash branch from f3a65b4 to 9323a51 Compare June 23, 2026 07:49
@SSSSSSuger

Copy link
Copy Markdown

from specforge.core.dflash import OnlineDFlashModel, QwenVLOnlineDFlashModel ,it seems that QwenVLOnlineDFlashModel is not upload now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants