From 9323a510e09c1f05be397a02039f9337ad960804 Mon Sep 17 00:00:00 2001 From: zyk42 <2931889928@qq.com> Date: Tue, 23 Jun 2026 07:49:15 +0000 Subject: [PATCH] feat: VLM DFlash multi-model support (Qwen3-VL/Qwen3.5/Qwen3.6, HF backend) Co-Authored-By: Claude Opus 4.6 --- .../qwen3-vl-30b-a3b-dflash-vlm-8layer.json | 46 ++ configs/qwen3-vl-8b-dflash-vlm-8layer.json | 46 ++ .../qwen3.5-35b-a3b-dflash-vlm-8layer.json | 47 ++ configs/qwen3.5-9b-dflash-vlm-8layer.json | 47 ++ scripts/train_dflash.py | 452 +++++++++++--- specforge/modeling/draft/dflash.py | 102 ++- .../modeling/target/dflash_target_model.py | 581 +++++++++++++++++- 7 files changed, 1166 insertions(+), 155 deletions(-) create mode 100644 configs/qwen3-vl-30b-a3b-dflash-vlm-8layer.json create mode 100644 configs/qwen3-vl-8b-dflash-vlm-8layer.json create mode 100644 configs/qwen3.5-35b-a3b-dflash-vlm-8layer.json create mode 100644 configs/qwen3.5-9b-dflash-vlm-8layer.json diff --git a/configs/qwen3-vl-30b-a3b-dflash-vlm-8layer.json b/configs/qwen3-vl-30b-a3b-dflash-vlm-8layer.json new file mode 100644 index 000000000..c4e2fa5f0 --- /dev/null +++ b/configs/qwen3-vl-30b-a3b-dflash-vlm-8layer.json @@ -0,0 +1,46 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "block_size": 8, + "bos_token_id": 151643, + "dflash_config": { + "target_layer_ids": [3, 13, 24, 34, 44] + }, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 4, + "num_target_layers": 48, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_interleaved": true, + "mrope_section": [ + 24, + 20, + 20 + ], + "rope_type": "default" + }, + "rope_theta": 5000000, + "tie_word_embeddings": false, + "use_cache": true, + "vocab_size": 151936 +} diff --git a/configs/qwen3-vl-8b-dflash-vlm-8layer.json b/configs/qwen3-vl-8b-dflash-vlm-8layer.json new file mode 100644 index 000000000..e931fd53c --- /dev/null +++ b/configs/qwen3-vl-8b-dflash-vlm-8layer.json @@ -0,0 +1,46 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "block_size": 8, + "bos_token_id": 151643, + "dflash_config": { + "target_layer_ids": [3, 10, 18, 25, 32] + }, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 36, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_interleaved": true, + "mrope_section": [ + 24, + 20, + 20 + ], + "rope_type": "default" + }, + "rope_theta": 5000000, + "tie_word_embeddings": false, + "use_cache": true, + "vocab_size": 151936 +} diff --git a/configs/qwen3.5-35b-a3b-dflash-vlm-8layer.json b/configs/qwen3.5-35b-a3b-dflash-vlm-8layer.json new file mode 100644 index 000000000..147ab949d --- /dev/null +++ b/configs/qwen3.5-35b-a3b-dflash-vlm-8layer.json @@ -0,0 +1,47 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "block_size": 8, + "bos_token_id": 248044, + "dflash_config": { + "target_layer_ids": [1, 10, 19, 28, 37] + }, + "dtype": "bfloat16", + "eos_token_id": 248046, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 16, + "num_hidden_layers": 5, + "num_key_value_heads": 2, + "num_target_layers": 40, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_interleaved": true, + "mrope_section": [ + 11, + 11, + 10 + ], + "rope_type": "default" + }, + "rope_theta": 10000000, + "tie_word_embeddings": false, + "use_cache": true, + "vocab_size": 248320 +} diff --git a/configs/qwen3.5-9b-dflash-vlm-8layer.json b/configs/qwen3.5-9b-dflash-vlm-8layer.json new file mode 100644 index 000000000..cc0d69914 --- /dev/null +++ b/configs/qwen3.5-9b-dflash-vlm-8layer.json @@ -0,0 +1,47 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "block_size": 8, + "bos_token_id": 248044, + "dflash_config": { + "target_layer_ids": [1, 8, 15, 22, 29] + }, + "dtype": "bfloat16", + "eos_token_id": 248046, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 16, + "num_hidden_layers": 5, + "num_key_value_heads": 4, + "num_target_layers": 32, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_interleaved": true, + "mrope_section": [ + 11, + 11, + 10 + ], + "rope_type": "default" + }, + "rope_theta": 10000000, + "tie_word_embeddings": false, + "use_cache": true, + "vocab_size": 248320 +} diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index d899d95a8..9eb6b0ceb 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -3,7 +3,7 @@ """DFlash Training Script.""" import argparse -import functools +import copy import logging import math import os @@ -15,22 +15,21 @@ import torch import torch.distributed as dist from accelerate.utils import set_seed -from torch.distributed.fsdp import BackwardPrefetch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoProcessor, AutoTokenizer from datasets import load_dataset from specforge.args import SGLangBackendArgs, TrackerArgs -from specforge.core.dflash import OnlineDFlashModel +from specforge.core.dflash import OnlineDFlashModel, QwenVLOnlineDFlashModel from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders from specforge.distributed import destroy_distributed, get_dp_group, init_distributed from specforge.modeling.draft.dflash import DFlashDraftModel from specforge.modeling.target.dflash_target_model import ( DFlashTargetModel, + HFDFlashTargetModel, get_dflash_target_model, ) from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead @@ -43,6 +42,97 @@ print_with_rank, ) +QWEN3_VL_MODEL_TYPES = {"qwen3_vl", "qwen3_vl_moe", "qwen3_5_moe", "qwen3_5"} + + +def _resolve_target_num_hidden_layers(target_config) -> int: + # For VLM configs (e.g. Qwen3-VL), top-level num_hidden_layers may refer to + # vision stack depth. DFlash target layers must follow language model depth. + if hasattr(target_config, "text_config") and hasattr( + target_config.text_config, "num_hidden_layers" + ): + return target_config.text_config.num_hidden_layers + if hasattr(target_config, "num_hidden_layers"): + return target_config.num_hidden_layers + raise ValueError( + f"Cannot infer num_target_layers from config type {type(target_config)}" + ) + + +def _build_target_layer_ids( + num_target_layers: int, + num_draft_layers: int, + start_layer: int = 1, + end_layer: Optional[int] = None, +) -> list[int]: + """Build evenly spaced target layer ids.""" + if num_draft_layers <= 0: + raise ValueError("num_draft_layers must be positive.") + + if end_layer is None: + end_layer = num_target_layers - 3 + + max_layer_idx = num_target_layers - 1 + start_layer = max(0, min(start_layer, max_layer_idx)) + end_layer = max(0, min(end_layer, max_layer_idx)) + + if end_layer < start_layer: + raise ValueError( + f"Invalid layer range: start_layer={start_layer}, end_layer={end_layer}" + ) + + if num_draft_layers == 1: + midpoint = num_target_layers // 2 + return [max(start_layer, min(midpoint, end_layer))] + + span = end_layer - start_layer + return [ + int(start_layer + (i * span) / (num_draft_layers - 1)) + for i in range(num_draft_layers) + ] + + +def _resolve_draft_config(target_config): + model_type = getattr(target_config, "model_type", None) + if model_type in QWEN3_VL_MODEL_TYPES and hasattr(target_config, "text_config"): + draft_config = copy.deepcopy(target_config.text_config) + for attr_name in ( + "dflash_config", + "block_size", + "rope_scaling", + "rope_theta", + "max_position_embeddings", + ): + if not hasattr(target_config, attr_name): + continue + if not hasattr(draft_config, attr_name) or getattr( + draft_config, attr_name + ) is None: + setattr(draft_config, attr_name, getattr(target_config, attr_name)) + return draft_config + return copy.deepcopy(target_config) + + + +def _ensure_layer_types(draft_config) -> None: + if hasattr(draft_config, "layer_types") and draft_config.layer_types is not None: + return + + if not hasattr(draft_config, "num_hidden_layers"): + return + + num_hidden_layers = draft_config.num_hidden_layers + sliding_window = getattr(draft_config, "sliding_window", None) + max_window_layers = getattr(draft_config, "max_window_layers", num_hidden_layers) + if max_window_layers is None: + max_window_layers = num_hidden_layers + draft_config.layer_types = [ + "sliding_attention" + if sliding_window is not None and layer_idx >= max_window_layers + else "full_attention" + for layer_idx in range(num_hidden_layers) + ] + def parse_args(): parser = argparse.ArgumentParser(description="Train DFlash Draft Model") @@ -75,6 +165,11 @@ def parse_args(): model_group.add_argument( "--trust-remote-code", action="store_true", help="Trust remote code" ) + model_group.add_argument( + "--is-vlm", + action="store_true", + help="Whether to enable VLM training mode. If not set, will auto-detect from target model config.", + ) model_group.add_argument( "--num-anchors", type=int, @@ -86,26 +181,7 @@ def parse_args(): type=float, default=None, help="Gamma for exponential loss decay weighting (paper Eq.4). " - "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. None disables. " - "Only applies when --loss-type dflash.", - ) - model_group.add_argument( - "--loss-type", - type=str, - default="dflash", - choices=[ - "dflash", - "dpace", - "dpace-cumulative-confidence-only", - "dpace-continuation-value-only", - ], - help=("Loss variant. Use dpace for Dynamic Position-Aware Cross-Entropy."), - ) - model_group.add_argument( - "--dpace-alpha", - type=float, - default=0.5, - help="Smoothing alpha for D-PACE position weights.", + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. None disables.", ) model_group.add_argument( "--embedding-key", @@ -133,6 +209,18 @@ def parse_args(): type=int, default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8)), ) + dataset_group.add_argument( + "--min-pixels", + type=int, + default=50176, + help="Minimum image pixels for VLM processor.", + ) + dataset_group.add_argument( + "--max-pixels", + type=int, + default=802816, + help="Maximum image pixels for VLM processor.", + ) training_group = parser.add_argument_group("training") training_group.add_argument("--num-epochs", type=int, default=6) @@ -173,30 +261,81 @@ def parse_args(): return parser.parse_args() -def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: +def build_models( + args, target_config=None, is_vlm: bool = False +) -> Tuple[DFlashTargetModel, DFlashDraftModel, AutoConfig]: """Build target model (backend wrapper) and draft model.""" + device = get_local_device() + device_type = device.type + print_on_rank0( f"Loading target model from {args.target_model_path} using {args.target_model_backend} backend" ) - target_model_kwargs = {} - if args.target_model_backend == "sglang": - target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() - - device = get_local_device() - device_type = device.type + if target_config is None: + target_config = AutoConfig.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + target_model_type = getattr(target_config, "model_type", None) + + if ( + args.target_model_backend == "hf" + and is_vlm + and target_model_type == "qwen3_vl" + and args.tp_size == 1 + ): + from transformers import Qwen3VLForConditionalGeneration + + target_model = HFDFlashTargetModel( + Qwen3VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=args.trust_remote_code, + ) + .eval() + .to(device), + model_type=target_model_type, + ) + elif ( + args.target_model_backend == "hf" + and is_vlm + and target_model_type == "qwen3_vl_moe" + and args.tp_size == 1 + ): + from transformers import Qwen3VLMoeForConditionalGeneration + + target_model = HFDFlashTargetModel( + Qwen3VLMoeForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=args.trust_remote_code, + ) + .eval() + .to(device), + model_type=target_model_type, + ) + else: + target_model_kwargs = {} + if args.target_model_backend == "sglang": + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + + target_model = get_dflash_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend=args.target_model_backend, + torch_dtype=torch.bfloat16, + device=device_type if args.target_model_backend == "hf" else None, + trust_remote_code=args.trust_remote_code, + **target_model_kwargs, + ) - target_model = get_dflash_target_model( - pretrained_model_name_or_path=args.target_model_path, - backend=args.target_model_backend, - torch_dtype=torch.bfloat16, - device=device_type if args.target_model_backend == "hf" else None, - trust_remote_code=args.trust_remote_code, - **target_model_kwargs, - ) + # Resolve before draft config mutations to avoid reading modified values. + target_num_layers = _resolve_target_num_hidden_layers(target_config) if args.draft_config_path: - draft_config = AutoConfig.from_pretrained(args.draft_config_path) + draft_config = AutoConfig.from_pretrained( + args.draft_config_path, trust_remote_code=args.trust_remote_code + ) + draft_config = _resolve_draft_config(draft_config) print_on_rank0(f"Loaded draft config from {args.draft_config_path}") # Warn if command-line args differ from config if ( @@ -208,15 +347,50 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: f"command-line arg ({args.block_size}). Using checkpoint value." ) else: - target_config = AutoConfig.from_pretrained(args.target_model_path) - draft_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config = _resolve_draft_config(target_config) draft_config.num_hidden_layers = args.num_draft_layers draft_config.block_size = args.block_size - draft_config.num_target_layers = target_config.num_hidden_layers print_on_rank0("Auto-generated draft config from target model") + # Always use target language model depth for capture layer mapping. + draft_config.num_target_layers = target_num_layers + _ensure_layer_types(draft_config) + if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None: draft_config.dflash_config = {} + elif not isinstance(draft_config.dflash_config, dict): + draft_config.dflash_config = dict(draft_config.dflash_config) + + model_type = getattr(target_config, "model_type", None) + if model_type in QWEN3_VL_MODEL_TYPES: + # Keep the original evenly spaced mapping, but force the first capture + # layer to skip Qwen3-VL deepstack layers (0-2). + recommended_layer_ids = _build_target_layer_ids( + num_target_layers=target_num_layers, + num_draft_layers=draft_config.num_hidden_layers, + ) + if recommended_layer_ids and recommended_layer_ids[0] < 3: + recommended_layer_ids[0] = 3 + if "target_layer_ids" not in draft_config.dflash_config: + draft_config.dflash_config["target_layer_ids"] = recommended_layer_ids + print_on_rank0( + "Qwen3-VL detected: default target_layer_ids set to " + f"{draft_config.dflash_config['target_layer_ids']} " + "(first layer forced to 3)." + ) + elif ( + draft_config.dflash_config["target_layer_ids"] + and draft_config.dflash_config["target_layer_ids"][0] < 3 + ): + old_layer_ids = draft_config.dflash_config["target_layer_ids"] + new_layer_ids = list(old_layer_ids) + new_layer_ids[0] = 3 + draft_config.dflash_config["target_layer_ids"] = new_layer_ids + print_on_rank0( + "Qwen3-VL detected: overriding first target layer " + f"{old_layer_ids} -> {new_layer_ids} " + "to avoid deepstack train/serve mismatch." + ) draft_config._attn_implementation = args.attention_backend print_on_rank0(f"Using attention backend: {args.attention_backend}") @@ -234,10 +408,15 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters()):,}" ) - return target_model, draft_model + return target_model, draft_model, target_config -def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: +def build_dataloader( + args, + tokenizer, + is_vlm: bool = False, + processor: Optional[AutoProcessor] = None, +) -> Tuple[DataLoader, Optional[DataLoader]]: """Build train and eval dataloaders.""" import hashlib @@ -248,26 +427,58 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] f"{args.target_model_path}" ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + cache_dir = os.path.join(args.cache_dir, "processed_dataset") + dist_enabled = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if dist_enabled else 0 + world_size = dist.get_world_size() if dist_enabled else 1 train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] - train_eagle3_dataset = build_eagle3_dataset( - dataset=train_dataset, - tokenizer=tokenizer, - chat_template=args.chat_template, - max_length=args.max_length, - is_preformatted=args.is_preformatted, - cache_dir=os.path.join(args.cache_dir, "processed_dataset"), - cache_key=cache_key, - num_proc=args.build_dataset_num_proc, - ) - - min_loss_tokens = 2 * args.block_size - original_size = len(train_eagle3_dataset) - train_eagle3_dataset = train_eagle3_dataset.filter( - lambda x: x["loss_mask"].sum() >= min_loss_tokens - ) + if world_size > 1: + if rank == 0: + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, + cache_dir=cache_dir, + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) + dist.barrier() + else: + dist.barrier() + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, + cache_dir=cache_dir, + cache_key=cache_key, + # Rank 0 has finished preprocessing at this point; other ranks only need cache reads. + num_proc=1, + ) + else: + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, + cache_dir=cache_dir, + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) + print_on_rank0( - f"Filtered train dataset: {original_size} -> {len(train_eagle3_dataset)} samples" + f"Train dataset: {len(train_eagle3_dataset)} samples (filter skipped)" ) train_dataloader = prepare_dp_dataloaders( @@ -276,6 +487,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] num_workers=args.dataloader_num_workers, shuffle=True, process_group=get_dp_group(), + is_vlm=is_vlm, ) eval_dataloader = None @@ -287,6 +499,8 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] chat_template=args.chat_template, max_length=args.max_length, is_preformatted=args.is_preformatted, + is_vlm=is_vlm, + processor=processor, ) eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, @@ -294,6 +508,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] num_workers=args.dataloader_num_workers, shuffle=False, process_group=get_dp_group(), + is_vlm=is_vlm, ) return train_dataloader, eval_dataloader @@ -388,6 +603,19 @@ def main(): init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) print_with_rank("Initialized distributed") + target_config = AutoConfig.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + detected_vlm = getattr(target_config, "model_type", None) in QWEN3_VL_MODEL_TYPES + is_vlm = args.is_vlm or detected_vlm + if detected_vlm and not args.is_vlm: + print_on_rank0( + "Detected Qwen3-VL target config; enabling VLM mode automatically." + ) + print_on_rank0( + f"Detected target model_type={getattr(target_config, 'model_type', None)}, is_vlm={is_vlm}" + ) + draft_model_last_checkpoint = None ckpt_info = (0, 0) if args.resume and os.path.isdir(args.output_dir): @@ -403,7 +631,12 @@ def main(): print(f"Loading draft config from checkpoint: {checkpoint_config_path}") args.draft_config_path = checkpoint_config_path - target_model, draft_model = build_models(args) + device = get_local_device() + device_type = device.type + + target_model, draft_model, target_config = build_models( + args, target_config, is_vlm=is_vlm + ) resume_state = None if draft_model_last_checkpoint: @@ -426,7 +659,17 @@ def main(): f"step {resume_state['global_step']}" ) - tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + if is_vlm: + processor = AutoProcessor.from_pretrained( + args.target_model_path, + trust_remote_code=args.trust_remote_code, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + ) + tokenizer = processor.tokenizer + else: + processor = None + tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) if args.mask_token_id is not None: mask_token_id = args.mask_token_id @@ -442,13 +685,31 @@ def main(): draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids print_on_rank0(f"dflash_config: {draft_model.config.dflash_config}") - train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) + train_dataloader, eval_dataloader = build_dataloader( + args, + tokenizer, + is_vlm=is_vlm, + processor=processor, + ) steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) total_steps = args.num_epochs * steps_per_epoch print_on_rank0(f"Total training steps: {total_steps}") - print_on_rank0("Loading target embeddings and head...") + if ( + getattr(target_config, "model_type", None) in QWEN3_VL_MODEL_TYPES + and (args.embedding_key is None or args.lm_head_key is None) + ): + args.embedding_key = "model.language_model.embed_tokens.weight" + args.lm_head_key = "lm_head.weight" + print_on_rank0( + f"VLM model detected ({target_config.model_type}): using " + f"--embedding-key={args.embedding_key} --lm-head-key={args.lm_head_key}" + ) + print_on_rank0( + "Loading target embeddings/head with keys: " + f"embed='{args.embedding_key}', head='{args.lm_head_key}'" + ) target_components = TargetEmbeddingsAndHead.from_pretrained( args.target_model_path, embed_key=args.embedding_key, @@ -457,7 +718,11 @@ def main(): trust_remote_code=args.trust_remote_code, ) - dflash_model = OnlineDFlashModel( + dflash_model_cls = OnlineDFlashModel + if getattr(target_config, "model_type", None) in QWEN3_VL_MODEL_TYPES: + dflash_model_cls = QwenVLOnlineDFlashModel + print_on_rank0(f"Using DFlash wrapper: {dflash_model_cls.__name__}") + dflash_model = dflash_model_cls( draft_model=draft_model, target_lm_head=target_components.lm_head, target_embed_tokens=target_components.embed_tokens, @@ -466,42 +731,17 @@ def main(): attention_backend=args.attention_backend, num_anchors=args.num_anchors, loss_decay_gamma=args.loss_decay_gamma, - loss_type=args.loss_type, - dpace_alpha=args.dpace_alpha, ) - # Wrap each transformer block as its own FSDP unit so that all-gather / - # reduce-scatter overlap with compute. Without an auto_wrap_policy the - # whole model is a single FSDP unit, forcing every collective onto the - # critical path with no overlap. The block class is resolved from the - # draft model's `_no_split_modules` so this stays architecture-agnostic - # rather than hardcoding a specific decoder-layer class. - fsdp_kwargs = dict( + dflash_model = FSDP( + dflash_model, use_orig_params=True, - forward_prefetch=True, - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, - limit_all_gathers=True, mixed_precision=MixedPrecision( param_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, ) - block_names = set(getattr(draft_model, "_no_split_modules", None) or []) - block_classes = { - type(m) for m in dflash_model.modules() if type(m).__name__ in block_names - } - if block_classes: - fsdp_kwargs["auto_wrap_policy"] = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=block_classes, - ) - else: - print_with_rank( - "No _no_split_modules on draft model; falling back to single-unit " - "FSDP wrap (no compute-comm overlap)." - ) - dflash_model = FSDP(dflash_model, **fsdp_kwargs) print_with_rank("Initialized FSDP") start_epoch = ckpt_info[0] @@ -516,7 +756,7 @@ def main(): ) if resume_state is not None: - optimizer.load_state_dict(resume_state) + optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) start_epoch = resume_state["epoch"] global_step = resume_state["global_step"] del resume_state @@ -554,15 +794,33 @@ def main(): input_ids = data["input_ids"].to(device, non_blocking=True) attention_mask = data["attention_mask"].to(device, non_blocking=True) loss_mask = data["loss_mask"].to(device, non_blocking=True) + target_kwargs = {} + if is_vlm: + if "pixel_values" in data: + target_kwargs["pixel_values"] = data["pixel_values"].to(device, non_blocking=True) + if "pixel_values_videos" in data: + target_kwargs["pixel_values_videos"] = data["pixel_values_videos"].to(device, non_blocking=True) + if "image_grid_thw" in data: + target_kwargs["image_grid_thw"] = data["image_grid_thw"].to(device, non_blocking=True) + if "video_grid_thw" in data: + target_kwargs["video_grid_thw"] = data["video_grid_thw"].to(device, non_blocking=True) + if "second_per_grid_ts" in data: + target_kwargs["second_per_grid_ts"] = data["second_per_grid_ts"].to(device, non_blocking=True) target_output = target_model.generate_dflash_data( - input_ids, attention_mask, loss_mask + input_ids, attention_mask, loss_mask, **target_kwargs ) hidden_states = target_output.hidden_states.to(device, non_blocking=True) + position_ids = ( + target_output.position_ids.to(device, non_blocking=True) + if target_output.position_ids is not None + else None + ) loss, accuracy = dflash_model( input_ids=input_ids, hidden_states=hidden_states, loss_mask=loss_mask, + position_ids=position_ids, ) (loss / args.accumulation_steps).backward() diff --git a/specforge/modeling/draft/dflash.py b/specforge/modeling/draft/dflash.py index 0f750920e..3a9a41c1d 100644 --- a/specforge/modeling/draft/dflash.py +++ b/specforge/modeling/draft/dflash.py @@ -34,11 +34,79 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_len = q.size(-2) - q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) - k_embed = (k * cos) + (rotate_half(k) * sin) + rotary_dim = cos.shape[-1] + head_dim = q.shape[-1] + if rotary_dim < head_dim: + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_rot = (q_rot * cos[..., -q_len:, :]) + (rotate_half(q_rot) * sin[..., -q_len:, :]) + k_rot = (k_rot * cos) + (rotate_half(k_rot) * sin) + q_embed = torch.cat([q_rot, q_pass], dim=-1) + k_embed = torch.cat([k_rot, k_pass], dim=-1) + else: + q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed +def get_rope_scaling_value(config: Qwen3Config, key: str, default=None): + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is None: + return default + if isinstance(rope_scaling, dict): + return rope_scaling.get(key, default) + return getattr(rope_scaling, key, default) + + +class Qwen3InterleavedMultiRotaryEmbedding(Qwen3RotaryEmbedding): + """Interleaved mRoPE for Qwen3-VL style multimodal position ids.""" + + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.mrope_section = get_rope_scaling_value( + config, "mrope_section", [24, 20, 20] + ) + + def _apply_interleaved_mrope(self, freqs: torch.Tensor) -> torch.Tensor: + freqs_t = freqs[0] + for dim_idx, offset in enumerate((1, 2), start=1): + length = self.mrope_section[dim_idx] * 3 + idx_slice = slice(offset, length, 3) + freqs_t[..., idx_slice] = freqs[dim_idx, ..., idx_slice] + return freqs_t + + @torch.no_grad() + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + position_ids_expanded = position_ids[:, :, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(2, 3) + interleaved_freqs = self._apply_interleaved_mrope(freqs) + emb = torch.cat((interleaved_freqs, interleaved_freqs), dim=-1) + scaling = getattr(self, "attention_scaling", 1.0) + cos = emb.cos() * scaling + sin = emb.sin() * scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class Qwen3DFlashAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -228,7 +296,13 @@ def __init__(self, config) -> None: build_target_layer_ids(config.num_target_layers, config.num_hidden_layers), ) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen3RotaryEmbedding(config) + self.use_interleaved_mrope = bool( + get_rope_scaling_value(config, "mrope_interleaved", False) + ) + if self.use_interleaved_mrope: + self.rotary_emb = Qwen3InterleavedMultiRotaryEmbedding(config) + else: + self.rotary_emb = Qwen3RotaryEmbedding(config) self.fc = nn.Linear( len(self.target_layer_ids) * config.hidden_size, config.hidden_size, @@ -237,28 +311,6 @@ def __init__(self, config) -> None: self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.block_size = config.block_size self.mask_token_id = dflash_config.get("mask_token_id", None) - self.projector_type = dflash_config.get("projector_type", None) - self.pure_draft_prefix_len = dflash_config.get("pure_draft_prefix_len", 0) - self.shift_label = dflash_config.get("shift_label", False) - - if self.projector_type == "domino": - self.emb_dim = dflash_config["emb_dim"] - self.gru_hidden_dim = dflash_config["gru_hidden_dim"] - self.prefix_gru = nn.GRU( - input_size=config.hidden_size, - hidden_size=self.gru_hidden_dim, - num_layers=1, - batch_first=True, - bias=False, - ) - in_dim = config.hidden_size + self.gru_hidden_dim - self.embed_proj = nn.Sequential( - nn.Linear(in_dim, self.emb_dim, bias=False), - nn.SiLU(), - nn.Linear(self.emb_dim, config.vocab_size, bias=False), - ) - elif self.projector_type is not None: - raise ValueError(f"Unknown draft projector_type: {self.projector_type}") self.post_init() def forward( diff --git a/specforge/modeling/target/dflash_target_model.py b/specforge/modeling/target/dflash_target_model.py index 0df938239..affd59cf1 100644 --- a/specforge/modeling/target/dflash_target_model.py +++ b/specforge/modeling/target/dflash_target_model.py @@ -1,25 +1,120 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple +import sglang.srt.managers.mm_utils as mm_utils import torch import torch.distributed as dist import torch.nn as nn from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch -from sglang.srt.managers.scheduler import Scheduler +from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.managers.mm_utils import init_mm_embedding_cache +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + Req, + ScheduleBatch, +) +from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM from specforge.distributed import get_tp_group -from .sglang_backend import SGLangRunner +from .sglang_backend import SGLangRunner, wrap_eagle3_logits_processors_in_module + +QWEN3_VL_MODEL_TYPES = {"qwen3_vl", "qwen3_vl_moe", "qwen3_5_moe", "qwen3_5"} + + +def _patch_vlm_model_for_layer_capture(model: nn.Module, layer_ids: List[int]): + """ + Monkey-patch SGLang VLM models that lack native set_eagle3_layers_to_capture support. + This patches the inner language model to capture specified layers' hidden states, + and the outer ConditionalGeneration model to unpack the tuple and pass aux_hidden_states + to the logits_processor. + + Supports: Qwen3VLForConditionalGeneration, Qwen3_5(Moe)ForConditionalGeneration, + and their subclasses (e.g. Qwen3VLMoeForConditionalGeneration). + """ + inner_model = getattr(model, "model", None) + if inner_model is None: + return False + + if not hasattr(inner_model, "layers_to_capture"): + inner_model.layers_to_capture = [] + + inner_model.layers_to_capture = [val + 1 for val in layer_ids] + + if not hasattr(model, "capture_aux_hidden_states"): + model.capture_aux_hidden_states = True + else: + model.capture_aux_hidden_states = True + + original_forward = model.forward.__func__ if hasattr(model.forward, '__func__') else model.forward + + def _patched_forward(self, *args, **kwargs): + result = original_forward(self, *args, **kwargs) + return result + + if not hasattr(model, "_layer_capture_patched"): + original_forward_method = model.forward + + def _make_patched_forward(orig_fwd): + import functools + import types + + if isinstance(orig_fwd, types.MethodType): + orig_func = orig_fwd.__func__ + else: + orig_func = orig_fwd + + @functools.wraps(orig_func) + def patched_forward(self_model, input_ids, positions, forward_batch, + get_embedding=False, pp_proxy_tensors=None, **extra_kwargs): + if hasattr(forward_batch, 'mrope_positions') and hasattr(self_model, 'is_mrope_enabled') and self_model.is_mrope_enabled: + positions = forward_batch.mrope_positions + + from sglang.srt.managers.mm_utils import general_mm_embed_routine + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self_model.model, + multimodal_model=self_model, + positions=positions, + use_deepstack=getattr(self_model, 'use_deepstack', False), + pp_proxy_tensors=pp_proxy_tensors, + ) + + aux_hidden_states = None + if self_model.capture_aux_hidden_states and isinstance(hidden_states, tuple): + hidden_states, aux_hidden_states = hidden_states + + if self_model.pp_group.is_last_rank: + if not get_embedding: + return self_model.logits_processor( + input_ids, hidden_states, self_model.lm_head, + forward_batch, aux_hidden_states, + ) + else: + return self_model.pooler(hidden_states, forward_batch) + else: + return hidden_states + + return patched_forward + + import types + model.forward = types.MethodType(_make_patched_forward(original_forward_method), model) + model._layer_capture_patched = True + + return True @dataclass @@ -28,6 +123,7 @@ class DFlashTargetOutput: input_ids: torch.Tensor # [batch, seq_len] attention_mask: torch.Tensor # [batch, seq_len] loss_mask: torch.Tensor # [batch, seq_len] + position_ids: Optional[torch.Tensor] = None # [batch, seq_len] or [3, batch, seq_len] class DFlashTargetModel(ABC): @@ -56,6 +152,11 @@ def generate_dflash_data( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, ) -> DFlashTargetOutput: """Generate context hidden states for DFlash training.""" @@ -65,9 +166,32 @@ def set_capture_layers(self, layer_ids: List[int]) -> None: class SGLangDFlashTargetModel(DFlashTargetModel): - def __init__(self, model_runner: SGLangRunner): + def __init__(self, model_runner: SGLangRunner, hf_config=None): super().__init__() self.model_runner = model_runner + self.hf_config = hf_config + self._init_vlm_attributes() + + def _init_vlm_attributes(self): + if self.hf_config is None: + self.is_vlm = False + return + + self.is_vlm = hasattr(self.hf_config, "vision_config") + if not self.is_vlm: + return + + init_mm_embedding_cache(1024 * 1024 * 512) + self.model_type = getattr(self.hf_config, "model_type", None) + vision_config = self.hf_config.vision_config + self.spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2) + self.tokens_per_second = getattr(vision_config, "tokens_per_second", None) + self.image_token_id = getattr(self.hf_config, "image_token_id", None) + self.video_token_id = getattr(self.hf_config, "video_token_id", None) + self.vision_start_token_id = getattr( + self.hf_config, "vision_start_token_id", None + ) + self.vision_end_token_id = getattr(self.hf_config, "vision_end_token_id", None) @classmethod def from_pretrained( @@ -80,10 +204,11 @@ def from_pretrained( **kwargs, ) -> "SGLangDFlashTargetModel": tp_size = dist.get_world_size(get_tp_group()) + dtype_arg = torch_dtype if torch_dtype is not None else "auto" server_args = ServerArgs( model_path=pretrained_model_name_or_path, trust_remote_code=trust_remote_code, - dtype=torch_dtype, + dtype=dtype_arg, enable_return_hidden_states=True, # Critical for DFlash disable_cuda_graph=True, tp_size=tp_size, @@ -107,14 +232,20 @@ def from_pretrained( pp_size=1, server_args=server_args, nccl_port=None, + is_draft_worker=False, ) - return cls(model_runner) + wrap_eagle3_logits_processors_in_module( + model_runner.model, return_full_logits=False + ) + hf_config = getattr(model_config, "hf_config", None) + return cls(model_runner, hf_config=hf_config) def set_capture_layers(self, layer_ids: List[int]) -> None: super().set_capture_layers(layer_ids) if hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"): self.model_runner.model.set_eagle3_layers_to_capture(layer_ids) - print(self.model_runner.model.model.layers_to_capture) + else: + _patch_vlm_model_for_layer_capture(self.model_runner.model, layer_ids) @torch.no_grad def _extend(self, reqs): @@ -138,15 +269,14 @@ def _extend(self, reqs): batch.prepare_for_extend() if require_mlp_sync(self.model_runner.server_args): - Scheduler.prepare_mlp_sync_batch_raw( + prepare_mlp_sync_batch_raw( batch, dp_size=self.model_runner.server_args.dp_size, attn_tp_size=1, + attn_cp_size=getattr(self.model_runner.server_args, "attn_cp_size", 1), tp_group=self.model_runner.tp_group, get_idle_batch=None, disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph, - spec_algorithm=SpeculativeAlgorithm.NONE, - speculative_num_draft_tokens=None, require_mlp_tp_gather=require_mlp_tp_gather( self.model_runner.server_args ), @@ -180,37 +310,299 @@ def _extend(self, reqs): return hidden_states_list + @staticmethod + def _split_per_sample_tensor( + tensor: Optional[torch.Tensor], batch_size: int, name: str + ) -> List[Optional[torch.Tensor]]: + if tensor is None: + return [None] * batch_size + if isinstance(tensor, (list, tuple)): + return list(tensor) + if not isinstance(tensor, torch.Tensor): + return [tensor] * batch_size + if batch_size == 1: + return [tensor.squeeze(0) if tensor.dim() > 1 and tensor.shape[0] == 1 else tensor] + if tensor.dim() > 0 and tensor.shape[0] == batch_size: + return [x.squeeze(0) for x in torch.split(tensor, 1, dim=0)] + raise ValueError( + f"Cannot split {name} with shape {tuple(tensor.shape)} across batch size {batch_size}." + ) + + @staticmethod + def _normalize_grid_thw( + grid_thw: Optional[torch.Tensor], name: str + ) -> Optional[torch.Tensor]: + if grid_thw is None: + return None + if grid_thw.dim() == 1: + return grid_thw.unsqueeze(0) + if grid_thw.dim() == 2: + return grid_thw + raise ValueError( + f"{name} must be a 1D or 2D tensor, got shape {tuple(grid_thw.shape)}" + ) + + @staticmethod + def _count_mm_patches(grid_thw: Optional[torch.Tensor]) -> int: + if grid_thw is None: + return 0 + return int((grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).sum().item()) + + def _build_mm_inputs( + self, + input_id_flat: torch.Tensor, + pixel_values: Optional[torch.Tensor], + pixel_values_videos: Optional[torch.Tensor], + image_grid_thw: Optional[torch.Tensor], + video_grid_thw: Optional[torch.Tensor], + second_per_grid_ts: Optional[torch.Tensor], + ) -> Tuple[Optional[MultimodalInputs], Optional[torch.Tensor]]: + mm_items = [] + + if pixel_values is not None: + image_offsets = BaseMultimodalProcessor.get_mm_items_offset( + input_id_flat, self.image_token_id + ) + image_item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_values, + pad_value=self.image_token_id, + offsets=image_offsets, + ) + if image_grid_thw is not None: + image_item.set("image_grid_thw", image_grid_thw.cpu()) + image_item.set_pad_value() + mm_items.append(image_item) + + if pixel_values_videos is not None: + video_offsets = BaseMultimodalProcessor.get_mm_items_offset( + input_id_flat, self.video_token_id + ) + video_item = MultimodalDataItem( + modality=Modality.VIDEO, + feature=pixel_values_videos, + pad_value=self.video_token_id, + offsets=video_offsets, + ) + if video_grid_thw is not None: + video_item.set("video_grid_thw", video_grid_thw.cpu()) + if second_per_grid_ts is not None: + video_item.set("second_per_grid_ts", second_per_grid_ts.cpu()) + video_item.set_pad_value() + mm_items.append(video_item) + + if not mm_items: + return None, None + + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=input_id_flat.unsqueeze(0).cpu(), + image_grid_thw=image_grid_thw.cpu() if image_grid_thw is not None else None, + video_grid_thw=video_grid_thw.cpu() if video_grid_thw is not None else None, + second_per_grid_ts=( + second_per_grid_ts.cpu() if second_per_grid_ts is not None else None + ), + tokens_per_second=self.tokens_per_second, + ) + mm_inputs = MultimodalInputs( + mm_items=mm_items, + im_token_id=self.image_token_id, + im_start_id=self.vision_start_token_id, + im_end_id=self.vision_end_token_id, + video_token_id=self.video_token_id, + mrope_positions=( + mrope_positions.squeeze(1) if mrope_positions is not None else None + ), + mrope_position_delta=mrope_position_delta, + ) + return mm_inputs, mm_inputs.mrope_positions + @torch.no_grad() - def generate_dflash_data( + def _extend_vlm( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, - ) -> DFlashTargetOutput: - sampling_params = SamplingParams(temperature=0, max_new_tokens=1) - reqs, data_cache = [], [] + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Tuple[List[torch.Tensor], list, Optional[torch.Tensor]]: + mm_utils.embedding_cache.clear() + sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1) + reqs, data_cache, position_ids_list = [], [], [] if isinstance(input_ids, torch.Tensor): + batch_size = input_ids.shape[0] input_ids_list = torch.split(input_ids, 1, dim=0) attn_mask_list = torch.split(attention_mask, 1, dim=0) loss_mask_list = torch.split(loss_mask, 1, dim=0) + else: + batch_size = len(input_ids) + input_ids_list = input_ids + attn_mask_list = attention_mask + loss_mask_list = loss_mask + + image_grid_thw_list = self._split_per_sample_tensor( + image_grid_thw, batch_size, "image_grid_thw" + ) + video_grid_thw_list = self._split_per_sample_tensor( + video_grid_thw, batch_size, "video_grid_thw" + ) + second_per_grid_ts_list = self._split_per_sample_tensor( + second_per_grid_ts, batch_size, "second_per_grid_ts" + ) - for idx, (curr_ids, curr_attn, curr_loss) in enumerate( - zip(input_ids_list, attn_mask_list, loss_mask_list) + image_offset = 0 + video_offset = 0 + pattern = None + + for idx, ( + curr_ids, + curr_attn, + curr_loss, + curr_image_grid, + curr_video_grid, + curr_second_per_grid, + ) in enumerate( + zip( + input_ids_list, + attn_mask_list, + loss_mask_list, + image_grid_thw_list, + video_grid_thw_list, + second_per_grid_ts_list, + ) ): + curr_image_grid = self._normalize_grid_thw(curr_image_grid, "image_grid_thw") + curr_video_grid = self._normalize_grid_thw(curr_video_grid, "video_grid_thw") + + image_patches = self._count_mm_patches(curr_image_grid) + video_patches = self._count_mm_patches(curr_video_grid) + + curr_pixel_values = None + if pixel_values is not None and image_patches > 0: + curr_pixel_values = pixel_values[image_offset : image_offset + image_patches] + image_offset += image_patches + + curr_pixel_values_videos = None + if pixel_values_videos is not None and video_patches > 0: + curr_pixel_values_videos = pixel_values_videos[ + video_offset : video_offset + video_patches + ] + video_offset += video_patches + + input_id_flat = curr_ids.view(-1) + mm_inputs, position_ids = self._build_mm_inputs( + input_id_flat=input_id_flat, + pixel_values=curr_pixel_values, + pixel_values_videos=curr_pixel_values_videos, + image_grid_thw=curr_image_grid, + video_grid_thw=curr_video_grid, + second_per_grid_ts=curr_second_per_grid, + ) + input_id_list = input_id_flat.tolist() + if mm_inputs is not None: + if pattern is None: + from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + ) + + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + input_id_list = pattern.pad_input_tokens(input_id_list, mm_inputs) + req = Req( rid=str(idx), origin_input_text="", - origin_input_ids=curr_ids.view(-1).tolist(), + origin_input_ids=input_id_list, sampling_params=sampling_params, ) req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + if mm_inputs is not None: + req.multimodal_inputs = mm_inputs + data_cache.append((curr_ids, curr_attn, curr_loss)) + position_ids_list.append(position_ids) reqs.append(req) hidden_states_list = self._extend(reqs) + position_ids = None + if position_ids_list and all(pos is not None for pos in position_ids_list): + position_ids = torch.stack(position_ids_list, dim=1) + + return hidden_states_list, data_cache, position_ids + + @torch.no_grad() + def generate_dflash_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> DFlashTargetOutput: + sampling_params = SamplingParams(temperature=0, max_new_tokens=1) + reqs, data_cache = [], [] + position_ids = None + + use_multimodal = any( + item is not None + for item in ( + pixel_values, + pixel_values_videos, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + ) + ) + + if use_multimodal: + if not self.is_vlm: + raise ValueError( + "Multimodal inputs were provided to a non-VLM SGLang target model." + ) + hidden_states_list, data_cache, position_ids = self._extend_vlm( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + else: + if isinstance(input_ids, torch.Tensor): + input_ids_list = torch.split(input_ids, 1, dim=0) + attn_mask_list = torch.split(attention_mask, 1, dim=0) + loss_mask_list = torch.split(loss_mask, 1, dim=0) + + for idx, (curr_ids, curr_attn, curr_loss) in enumerate( + zip(input_ids_list, attn_mask_list, loss_mask_list) + ): + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=curr_ids.view(-1).tolist(), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + data_cache.append((curr_ids, curr_attn, curr_loss)) + reqs.append(req) + + hidden_states_list = self._extend(reqs) + # Stack back to batch hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0) input_ids = torch.cat([d[0] for d in data_cache], dim=0) @@ -222,13 +614,15 @@ def generate_dflash_data( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, + position_ids=position_ids, ) class HFDFlashTargetModel(DFlashTargetModel): - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, model_type: Optional[str] = None): super().__init__() self.model = model + self.model_type = model_type @classmethod def from_pretrained( @@ -240,20 +634,76 @@ def from_pretrained( trust_remote_code: bool = True, **kwargs, ) -> "HFDFlashTargetModel": - - target_model = AutoModelForCausalLM.from_pretrained( + hf_config = AutoConfig.from_pretrained( pretrained_model_name_or_path, - torch_dtype=torch_dtype, cache_dir=cache_dir, - output_hidden_states=True, trust_remote_code=trust_remote_code, - **kwargs, - ).eval() + ) + model_type = getattr(hf_config, "model_type", None) + + if model_type in QWEN3_VL_MODEL_TYPES: + model_cls = None + if model_type == "qwen3_vl": + try: + from transformers import Qwen3VLForConditionalGeneration + + model_cls = Qwen3VLForConditionalGeneration + except ImportError as exc: + raise ImportError( + "Qwen3VLForConditionalGeneration is unavailable. " + "Please upgrade transformers to a version with qwen3_vl support." + ) from exc + elif model_type == "qwen3_vl_moe": + try: + from transformers import Qwen3VLMoeForConditionalGeneration + + model_cls = Qwen3VLMoeForConditionalGeneration + except ImportError as exc: + raise ImportError( + "Qwen3VLMoeForConditionalGeneration is unavailable. " + "Please upgrade transformers to a version with qwen3_vl_moe support." + ) from exc + elif model_type in ("qwen3_5_moe", "qwen3_5"): + # Qwen3.5/3.6: try transformers 5.x, then AutoModel fallback + try: + from transformers import Qwen3_5MoeForConditionalGeneration + + model_cls = Qwen3_5MoeForConditionalGeneration + except ImportError: + pass + if model_cls is None: + try: + from transformers import AutoModelForImageTextToText + + model_cls = AutoModelForImageTextToText + except ImportError: + pass + if model_cls is None: + raise ImportError( + f"model_type '{model_type}' requires transformers >= 5.0 for HF backend. " + "Use --target-model-backend sglang instead, or upgrade transformers." + ) + target_model = model_cls.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + trust_remote_code=trust_remote_code, + **kwargs, + ).eval() + else: + target_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + output_hidden_states=True, + trust_remote_code=trust_remote_code, + **kwargs, + ).eval() if device: target_model = target_model.to(device) - return cls(target_model) + return cls(target_model, model_type=model_type) @torch.no_grad() def generate_dflash_data( @@ -261,13 +711,77 @@ def generate_dflash_data( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, ) -> DFlashTargetOutput: - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - use_cache=False, - ) + target_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "output_hidden_states": True, + "use_cache": False, + } + if self.model_type in QWEN3_VL_MODEL_TYPES: + target_kwargs.update( + { + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + } + ) + if hasattr(self.model, "config") and hasattr(self.model.config, "image_token_id"): + image_token_id = self.model.config.image_token_id + video_token_id = getattr(self.model.config, "video_token_id", None) + mm_token_type_ids = torch.zeros_like(input_ids) + if image_token_id is not None: + mm_token_type_ids[input_ids == image_token_id] = 1 + if video_token_id is not None: + mm_token_type_ids[input_ids == video_token_id] = 2 + target_kwargs["mm_token_type_ids"] = mm_token_type_ids + + filtered_target_kwargs = {} + for key, value in target_kwargs.items(): + if key in { + "input_ids", + "attention_mask", + "output_hidden_states", + "use_cache", + } or value is not None: + filtered_target_kwargs[key] = value + + outputs = self.model(**filtered_target_kwargs) + if outputs.hidden_states is None: + raise ValueError( + "Target model did not return hidden states. Ensure output_hidden_states=True is supported." + ) + + position_ids = None + if self.model_type in QWEN3_VL_MODEL_TYPES: + target_inner_model = getattr(self.model, "model", None) + if target_inner_model is not None and hasattr( + target_inner_model, "get_rope_index" + ): + rope_kwargs = { + "input_ids": input_ids, + "image_grid_thw": image_grid_thw, + "attention_mask": attention_mask, + } + if video_grid_thw is not None: + rope_kwargs["video_grid_thw"] = video_grid_thw + if second_per_grid_ts is not None: + rope_kwargs["second_per_grid_ts"] = second_per_grid_ts + if "mm_token_type_ids" in target_kwargs: + rope_kwargs["mm_token_type_ids"] = target_kwargs["mm_token_type_ids"] + + filtered_rope_kwargs = { + key: value for key, value in rope_kwargs.items() if value is not None + } + position_ids, _ = target_inner_model.get_rope_index( + **filtered_rope_kwargs + ) # hidden_states[0] = embedding output; hidden_states[i+1] = layer i output offset = 1 @@ -284,6 +798,7 @@ def generate_dflash_data( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, + position_ids=position_ids, )