diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index b1947fa7f072..d6e2181aede3 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -569,6 +569,8 @@ Specified using `--task generate`. | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` +| `Qwen3_5ForConditionalGeneration` | Qwen3.5 | T + IE+ + VE+ | `Qwen/Qwen3.5-9B-Instruct`, etc. | ✅︎ | ✅︎ | +| `Qwen3_5MoeForConditionalGeneration` | Qwen3.5-MOE | T + IE+ + VE+ | `Qwen/Qwen3.5-35B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | | `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + IE+ + VE+ | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + IE+ + VE+ | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | | ✅︎ | ✅︎\* | | `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py index d078c517d00e..7f4e15249201 100644 --- a/examples/offline_inference/basic/chat.py +++ b/examples/offline_inference/basic/chat.py @@ -1,80 +1,102 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +import argparse +import os + +os.environ["PT_HPU_LAZY_MODE"] = "1" + +from vllm import LLM, EngineArgs, SamplingParams + +# Parse the command-line arguments. +parser = argparse.ArgumentParser() +parser.add_argument( + "--model", + type=str, + default="facebook/opt-125m", + help="The model path.", +) +parser.add_argument("--tp-size", type=int, default=2, help="The number of threads.") +parser.add_argument( + "--output-tokens", type=int, default=512, help="The number of output tokens." +) +parser.add_argument( + "--max-model-length", type=int, default=16384, help="Max model length." +) +parser.add_argument("--enable-ep", action="store_true", help="Enable EP for MOE models") +parser.add_argument("--temperature", type=float, default=0.8) +parser.add_argument("--top-p", type=float, default=0.95) +parser.add_argument("--enable-thinking", action="store_true", help="Enable think mode for inference") +# Add example params +parser.add_argument("--chat-template-path", type=str) +args = parser.parse_args() + +os.environ["VLLM_SKIP_WARMUP"] = "true" +os.environ["HABANA_VISIBLE_DEVICES"] = "ALL" +os.environ["PT_HPU_ENABLE_LAZY_COLLECTIVES"] = "true" +os.environ["PT_HPU_WEIGHT_SHARING"] = "0" -def create_parser(): - parser = FlexibleArgumentParser() - # Add engine args - EngineArgs.add_cli_args(parser) - parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") - # Add sampling params - sampling_group = parser.add_argument_group("Sampling parameters") - sampling_group.add_argument("--max-tokens", type=int) - sampling_group.add_argument("--temperature", type=float) - sampling_group.add_argument("--top-p", type=float) - sampling_group.add_argument("--top-k", type=int) - # Add example params - parser.add_argument("--chat-template-path", type=str) - - return parser - - -def main(args: dict): - # Pop arguments not used by LLM - max_tokens = args.pop("max_tokens") - temperature = args.pop("temperature") - top_p = args.pop("top_p") - top_k = args.pop("top_k") - chat_template_path = args.pop("chat_template_path") - - # Create an LLM - llm = LLM(**args) - - # Create sampling params object - sampling_params = llm.get_default_sampling_params() - if max_tokens is not None: - sampling_params.max_tokens = max_tokens - if temperature is not None: - sampling_params.temperature = temperature - if top_p is not None: - sampling_params.top_p = top_p - if top_k is not None: - sampling_params.top_k = top_k +if __name__ == "__main__": + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + messages = [] + for idx in range(len(prompts)): + conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompts[idx] + }, + ] + messages.append(conversation) + # Create a sampling params object. + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.output_tokens, + ) + chat_template_path = args.chat_template_path + model = args.model + if args.tp_size == 1: + llm = LLM( + model=model, + tokenizer=model, + trust_remote_code=True, + dtype="bfloat16", + max_model_len=args.max_model_length, + ) + else: + llm = LLM( + model=model, + tokenizer=model, + tensor_parallel_size=args.tp_size, + distributed_executor_backend="mp", + trust_remote_code=True, + max_model_len=args.max_model_length, + enable_expert_parallel=args.enable_ep, + dtype="bfloat16", + ) def print_outputs(outputs): print("\nGenerated Outputs:\n" + "-" * 80) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text + for idx in range(len(outputs)): + prompt = prompts[idx] + generated_text = outputs[idx].outputs[0].text print(f"Prompt: {prompt!r}\n") print(f"Generated text: {generated_text!r}") print("-" * 80) print("=" * 80) - # In this script, we demonstrate how to pass input to the chat method: - conversation = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hello! How can I assist you today?"}, - { - "role": "user", - "content": "Write an essay about the importance of higher education.", - }, - ] - outputs = llm.chat(conversation, sampling_params, use_tqdm=False) - print_outputs(outputs) - - # You can run batch inference with llm.chat API - conversations = [conversation for _ in range(10)] - - # We turn on tqdm progress bar to verify it's indeed running batch inference - outputs = llm.chat(conversations, sampling_params, use_tqdm=True) - print_outputs(outputs) - # A chat template can be optionally supplied. # If not, the model will use its default chat template. if chat_template_path is not None: @@ -82,14 +104,16 @@ def print_outputs(outputs): chat_template = f.read() outputs = llm.chat( - conversations, + messages, sampling_params, use_tqdm=False, chat_template=chat_template, + chat_template_kwargs={"enable_thinking": args.enable_thinking}, ) - - -if __name__ == "__main__": - parser = create_parser() - args: dict = vars(parser.parse_args()) - main(args) + else: + outputs = llm.chat( + messages, + sampling_params, + chat_template_kwargs={"enable_thinking": args.enable_thinking} + ) + print_outputs(outputs) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index aaac656f1ad0..5d8004463ad3 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1435,6 +1435,80 @@ def run_qwen3_omni_moe(questions: list[str], modality: str) -> ModelRequestData: ) +# Qwen3.5 Dense +def run_qwen3_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Qwen/Qwen3.5-4B-Base" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + f"<|vision_start|>{placeholder}<|vision_end|>" + f"{question}" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# Qwen3.5 MoE +def run_qwen3_5_moe(questions: list[str], modality: str) -> ModelRequestData: + model_name = "/data/Qwen3.5-397B-A17B-FP8-G2" + + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + max_num_seqs=5, + enable_expert_parallel=True, + trust_remote_code=True, + tensor_parallel_size=8, + distributed_executor_backend="mp", + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + f"<|vision_start|>{placeholder}<|vision_end|>" + f"{question}" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1516,6 +1590,8 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "qwen3_vl": run_qwen3_vl, "qwen3_vl_moe": run_qwen3_vl_moe, "qwen3_omni_moe": run_qwen3_omni_moe, + "qwen3_5": run_qwen3_5, + "qwen3_5_moe": run_qwen3_5_moe, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, "tarsier": run_tarsier, @@ -1526,6 +1602,8 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "glm4_5v", "qwen3_vl", "qwen3_vl_moe", + "qwen3_5", + "qwen3_5_moe", ] diff --git a/tests/models/registry.py b/tests/models/registry.py index f4180da268ff..11c25276ee8e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -421,6 +421,26 @@ def check_available_online( "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501 max_model_len=4096, min_transformers_version="4.57"), + "Qwen3_5ForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen3.5-9B-Instruct", + max_model_len=4096, + min_transformers_version="5.1.0", + ), + "Qwen3_5MoeForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen3.5-35B-A3B-Instruct", + max_model_len=4096, + min_transformers_version="5.1.0", + ), + "Qwen3_5MTP": _HfExamplesInfo( + "Qwen/Qwen3.5-9B-Instruct", + speculative_model="Qwen/Qwen3.5-9B-Instruct", + min_transformers_version="5.1.0", + ), + "Qwen3_5MoeMTP": _HfExamplesInfo( + "Qwen/Qwen3.5-35B-A3B-Instruct", + speculative_model="Qwen/Qwen3.5-35B-A3B-Instruct", + min_transformers_version="5.1.0", + ), "Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-Omni-30B-A3B-Instruct", # noqa: E501 max_model_len=4096, # noqa: E501 min_transformers_version="4.57"), # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index 394867537e09..70d90c9e64a7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1299,8 +1299,8 @@ def get_num_layers_by_block_type( if attn_type_list: return sum(t == 1 for t in attn_type_list[start:end]) - # Hybrid model Qwen3Next - layer_types_value = getattr(self.hf_config, "layer_types", None) + # Hybrid model Qwen3Next Qwen3.5 Series + layer_types_value = getattr(self.hf_text_config, "layer_types", None) if layer_types_value is not None: if getattr(block_type, "value", block_type) == "attention": return sum(t == "full_attention" @@ -2575,6 +2575,16 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: "n_predict": n_predict, "architectures": ["Glm4MoeMTPModel"] }) + if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"): + is_moe = hf_config.model_type == "qwen3_5_moe" + hf_config.model_type = "qwen3_5_mtp" + n_predict = getattr(hf_config, "mtp_num_hidden_layers", None) + hf_config.update( + { + "n_predict": n_predict, + "architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"], + } + ) return hf_config diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index b6d1a02aa434..0c670d8954d8 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -321,7 +321,8 @@ def forward_cuda( return self.forward_native(x, residual) -class RMSNormGated(nn.Module): +@CustomOp.register("rms_normgated") +class RMSNormGated(CustomOp): def __init__( self, @@ -347,20 +348,30 @@ def __init__( def reset_parameters(self): torch.nn.init.ones_(self.weight) - def forward(self, hidden_states, gate=None): + def forward_native(self, hidden_states, gate=None): """ If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) """ - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) # Norm before gate hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - hidden_states = self.weight * hidden_states.to(input_dtype) - hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + hidden_states = self.weight.to(hidden_states.dtype) * hidden_states + if gate is not None: + hidden_states = hidden_states * F.silu(gate.to(hidden_states.dtype)) - return hidden_states.to(input_dtype) + return hidden_states + + def forward_hpu(self, hidden_states, gate=None): + from vllm_hpu_extension.kernels import rms_norm + HPUFusedRMSNorm = rms_norm() + + hidden_states = HPUFusedRMSNorm.apply(hidden_states, + self.weight.to(hidden_states.dtype), + self.eps) + if gate is not None: + hidden_states = hidden_states * F.silu(gate.to(hidden_states.dtype)) + return hidden_states class MiniMaxText01RMSNormTP(CustomOp): diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index a524e1340580..2281a582e543 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -6,10 +6,12 @@ import torch +from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec class MambaBase(AttentionLayerBase): @@ -43,3 +45,29 @@ def mamba_type(self) -> str: def get_attn_backend(self) -> type["AttentionBackend"]: """Get the attention backend class for this Mamba layer.""" pass + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: + if ( + vllm_config.speculative_config is not None + and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"] + and vllm_config.model_config.hf_config.model_type + not in ["qwen3_next", "qwen3_5", "qwen3_5_moe"] + ): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet." + ) + mamba_block_size = vllm_config.cache_config.mamba_block_size + page_size_padded = vllm_config.cache_config.mamba_page_size_padded + return MambaSpec( + shapes=self.get_state_shape(), + dtypes=self.get_state_dtype(), + block_size=mamba_block_size, + page_size_padded=page_size_padded, + mamba_type=self.mamba_type, + mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode, + num_speculative_blocks=( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ), + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index a6c1af91de42..42da62e759b0 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Union +from collections.abc import Callable +from dataclasses import dataclass +from typing import Union, TypeAlias import torch @@ -200,3 +202,89 @@ def gated_delta_net_state_shape( temporal_state_shape = (divide(num_v_heads, tp_world_size), head_k_dim, head_v_dim) return conv_state_shape, temporal_state_shape + +@dataclass +class MambaCopySpec: + """ + Data class specifying the memory-copy parameters for Mamba states used for + prefix caching in align mode. + + Attributes: + start_addr (int): Starting address for the memory copy operation. + num_elements (int): Number of elements to copy from the starting address. + """ + + start_addr: int + num_elements: int + +MambaStateCopyFunc: TypeAlias = Callable[ + [torch.Tensor, list[int], int, int], MambaCopySpec +] +""" +Type alias for a function that computes a MambaCopySpec for copying state slices. +Parameters: + state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states). + block_ids: list[int] - the list of block indices for the state to copy. + cur_block_idx: int - current block index within `block_ids` to copy from. + num_accepted_tokens: int - number of accepted tokens used to compute the copy offset. + Range: 1 .. 1 + num_speculative_tokens (inclusive). +""" + +def get_conv_copy_spec( + state: torch.Tensor, + block_ids: list[int], + cur_block_idx: int, + num_accepted_tokens: int, +) -> MambaCopySpec: + """Return a MambaCopySpec for copying a convolutional state slice.""" + src_block_id = block_ids[cur_block_idx] + src_state = state[src_block_id, num_accepted_tokens - 1 :] + return MambaCopySpec( + start_addr=src_state.data_ptr(), num_elements=src_state.numel() + ) + + +def get_temporal_copy_spec( + state: torch.Tensor, + block_ids: list[int], + cur_block_idx: int, + num_accepted_tokens: int, +) -> MambaCopySpec: + """Return a MambaCopySpec for copying a temporal state slice.""" + src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1] + src_state = state[src_block_id] + return MambaCopySpec( + start_addr=src_state.data_ptr(), num_elements=src_state.numel() + ) + +get_full_copy_spec = get_temporal_copy_spec + +class MambaStateCopyFuncCalculator: + @classmethod + def linear_attention_state_copy_func(cls): + return (get_temporal_copy_spec,) + + @classmethod + def mamba1_state_copy_func(cls): + return (get_conv_copy_spec, get_temporal_copy_spec) + + @classmethod + def mamba2_state_copy_func(cls): + return get_conv_copy_spec, get_temporal_copy_spec + + @classmethod + def short_conv_state_copy_func(cls): + return (get_conv_copy_spec,) + + @classmethod + def gated_delta_net_state_copy_func(cls): + return (get_conv_copy_spec, get_temporal_copy_spec) + + @classmethod + def kda_state_copy_func(cls): + return ( + get_conv_copy_spec, + get_conv_copy_spec, + get_conv_copy_spec, + get_temporal_copy_spec, + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/ops/torch_gated_delta_relu.py b/vllm/model_executor/layers/mamba/ops/torch_gated_delta_relu.py new file mode 100644 index 000000000000..1ee51566d7c0 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/torch_gated_delta_relu.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/huggingface/transformers/blob/v4.57-release/src/transformers/models/qwen3_next/modeling_qwen3_next.py + +from typing import Optional + +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + + +is_hpu = current_platform.is_hpu() + +if is_hpu: + import habana_frameworks.torch as htorch + import habana_frameworks.torch.core as htcore + + +def torch_chunk_gated_delta_rule_opt( + query, + key, + value, + g, + beta, + eye_constant, + chunk_size=64, + inv_loop=12, + initial_state=None, + output_final_state=True, + use_qk_l2norm_in_kernel=True, +): + ssm_dtype = g.dtype + if use_qk_l2norm_in_kernel: + head_dim = query.size(-1) + inv_scale = head_dim**-0.5 + query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale + key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous() + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + tot_len = sequence_length + pad_size + scale = 1 / (query.shape[-1]**0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.ones(chunk_size, + chunk_size, + dtype=value.dtype, + device=query.device).tril(-1) + + # chunk decay + g = g.cumsum(dim=-1) + g_exp = g.exp().to(value.dtype) + decay_mask = ((g.unsqueeze(-1) - + g.unsqueeze(-2)).exp().to(value.dtype)).tril() + + attn = torch.matmul(k_beta, + key.transpose(-1, -2).contiguous()) * \ + decay_mask * mask + eye_constant + inv_attn = torch.zeros_like(attn) + eye_constant + htcore.mark_step() + for _ in range(inv_loop): + prod = torch.matmul(attn, inv_attn) + err = prod * mask + update = torch.matmul(inv_attn, err) + inv_attn.sub_(update) + htcore.mark_step() + attn = inv_attn + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g_exp.unsqueeze(-1)) + last_recurrent_state = (torch.zeros(batch_size, num_heads, k_head_dim, + v_head_dim).to(value) if initial_state + is None else initial_state.to(value)) + mask = torch.tril(torch.ones(chunk_size, + chunk_size, + dtype=value.dtype, + device=value.device), + diagonal=0) + attn = torch.matmul(query, key.transpose(-1, -2).contiguous()) * decay_mask * mask + qg = query * g_exp[..., None] + delta_g_exp = (g[:, :, :, -1, None] - g).exp()[..., None].to(value.dtype) + k_term = key * delta_g_exp + + num_chunks = tot_len // chunk_size + k_eye = torch.eye(k_head_dim, dtype=value.dtype, device=value.device) + k_eye = k_eye.view(1, 1, 1, k_head_dim, k_head_dim) + + alpha = g_exp[:, :, :, -1, None, None] + B = k_term.transpose(-1, -2).contiguous() + K = k_cumdecay + V = value + Q = qg + A = attn + + M = alpha * k_eye - torch.matmul(B, K) + N = torch.matmul(B, V) + C = Q - torch.matmul(A, K) + core_attn_out = torch.matmul(A, V) + + # for each chunk + htcore.mark_step() + for i in range(num_chunks): + core_attn_out[:, :, i].add_(torch.matmul(C[:, :, i], last_recurrent_state)) + last_recurrent_state = torch.matmul(M[:, :, i], last_recurrent_state) + N[:, :, i] + htcore.mark_step() + + if not output_final_state: + last_recurrent_state = None + else: + last_recurrent_state = last_recurrent_state.to(ssm_dtype) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], + core_attn_out.shape[1], -1, + core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2) + return core_attn_out, last_recurrent_state + + +def torch_recurrent_gated_delta_rule_opt( + query, + key, + value, + g, + beta, + recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, +): + ssm_dtype = g.dtype + if use_qk_l2norm_in_kernel: + head_dim = query.size(-1) + inv_scale = head_dim**-0.5 + query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale + key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous() + for x in (query, key, value, beta, g) + ] + + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1]**0.5) + query = query * scale + + recurrent_state = recurrent_state.to(value.dtype) + + q_t = query.squeeze(-2) + k_t = key.squeeze(-2) + v_t = value.squeeze(-2) + g_t = g.squeeze(-1).exp().to(value.dtype).unsqueeze(-1).unsqueeze(-1) + + recurrent_state = recurrent_state * g_t + kv_mem = torch.matmul(k_t.unsqueeze(-2), recurrent_state).squeeze(-2) + delta = (v_t - kv_mem) * beta + recurrent_state.add_(k_t.unsqueeze(-1) * delta.unsqueeze(-2)) + core_attn_out = torch.matmul(q_t.unsqueeze(-2), recurrent_state) + + if not output_final_state: + recurrent_state = None + else: + recurrent_state = recurrent_state.to(ssm_dtype) + core_attn_out = core_attn_out.transpose(1, 2) + return core_attn_out, recurrent_state + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + eye_constant, + chunk_size=64, + initial_state=None, + output_final_state=True, + use_qk_l2norm_in_kernel=True, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + head_dim = query.size(-1) + inv_scale = head_dim**-0.5 + query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale + key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + tot_len = sequence_length + pad_size + scale = 1 / (query.shape[-1]**0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + g_exp = g.exp() + decay_mask = ((g.unsqueeze(-1) - + g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((torch.matmul(k_beta.contiguous(), + key.transpose(-1, -2).contiguous())) * + decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].contiguous() + sub = attn[..., :i, :] + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)[..., :i] + attn = attn + eye_constant + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g_exp.unsqueeze(-1)) + last_recurrent_state = (torch.zeros(batch_size, num_heads, k_head_dim, + v_head_dim).to(value) if initial_state + is None else initial_state.to(value)) + core_attn_out = torch.zeros_like(value) + mask = torch.tril(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=0) + mask = mask.view(1, 1, 1, chunk_size, chunk_size) + attn = (query @ key.transpose(-1, -2)) * decay_mask * mask + qg = query * g_exp[..., None] + delta_g_exp = (g[:, :, :, -1, None] - g).exp()[..., None] + k_term = (key * delta_g_exp) + + # for each chunk + for i in range(0, tot_len // chunk_size): + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = value[:, :, i] - v_prime + attn_inter = qg[:, :, i] @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn[:, :, i] @ v_new + last_recurrent_state = ( + last_recurrent_state * g_exp[:, :, i, -1, None, None] + + k_term[:, :, i].transpose(-1, -2) @ v_new) + + if not output_final_state: + last_recurrent_state = None + else: + last_recurrent_state = last_recurrent_state.to(initial_dtype) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], + core_attn_out.shape[1], -1, + core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def torch_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + head_dim = query.size(-1) + inv_scale = head_dim**-0.5 + query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale + key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1]**0.5) + query = query * scale + + recurrent_state = recurrent_state.to(value) + + q_t = query.squeeze(-2) + k_t = key.squeeze(-2) + v_t = value.squeeze(-2) + g_t = g.squeeze(-1).exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta + + recurrent_state = recurrent_state * g_t + kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + recurrent_state.add_(k_t.unsqueeze(-1) * delta.unsqueeze(-2)) + core_attn_out = (recurrent_state * + q_t.unsqueeze(-1)).sum(dim=-2).unsqueeze(-2) + + if not output_final_state: + recurrent_state = None + else: + recurrent_state = recurrent_state.to(initial_dtype) + core_attn_out = core_attn_out.transpose(1, 2).to(initial_dtype) + return core_attn_out, recurrent_state diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 9c8cbf52fb63..f009175474e5 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -1449,7 +1449,7 @@ def get_input_positions_tensor( context_len=context_len, seq_len=seq_len, ) - elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]: return cls._qwen3vl_get_input_positions_tensor( input_tokens=input_tokens, hf_config=hf_config, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 350068b2c3d6..b53d06b34e7a 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -32,6 +32,22 @@ """ +def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor: + """ + A helper function to be used in the context of + [vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids][] + to provide a better error message. + """ + if is_multimodal is None: + raise ValueError( + "`embed_input_ids` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229." + ) + + return is_multimodal + + @runtime_checkable class SupportsMultiModal(Protocol): """The interface required for all multi-modal models.""" diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py new file mode 100644 index 000000000000..cfc84f10e196 --- /dev/null +++ b/vllm/model_executor/models/qwen3_5.py @@ -0,0 +1,1166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3.5 Series compatible with HuggingFace weights.""" + +import os +import typing +from collections.abc import Callable, Iterable + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN +from transformers.models.qwen3_5.configuration_qwen3_5 import ( + Qwen3_5Config, + Qwen3_5TextConfig, +) +from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import ( + Qwen3_5MoeConfig, + Qwen3_5MoeTextConfig, +) +from vllm.attention import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + divide, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import ( + GemmaRMSNorm as Qwen3_5RMSNorm, +) +from vllm.model_executor.layers.layernorm import RMSNormGated +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + mamba_v2_sharded_weight_loader, +) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFunc, + MambaStateCopyFuncCalculator, + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.torch_gated_delta_relu import ( + torch_chunk_gated_delta_rule_opt, + torch_recurrent_gated_delta_rule_opt, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + sharded_weight_loader, +) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from .interfaces import ( + HasInnerState, + IsHybrid, + MixtureOfExperts, + MultiModalEmbeddings, + SupportsLoRA, + SupportsPP, + _require_is_multimodal, +) +from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from .qwen3_next import ( + Qwen3NextAttention, + Qwen3NextDecoderLayer, + Qwen3NextGatedDeltaNet, + Qwen3NextModel, + Qwen3NextSparseMoeBlock, + QwenNextMixtureOfExperts, +) +from .qwen3_vl import ( + Qwen3_VisionTransformer, + Qwen3_VisionTransformerStaticShape, + Qwen3VLDummyInputsBuilder, + Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, + Qwen3VLProcessingInfo, +) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + _merge_multimodal_embeddings, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +is_hpu = current_platform.is_hpu() +logger = init_logger(__name__) + + +class Qwen3_5ProcessingInfo(Qwen3VLProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3_5Config) + + +class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3_5MoeConfig) + + +class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): + def __init__( + self, + vllm_config: VllmConfig, + config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, + prefix: str = "", + ) -> None: + super(Qwen3NextGatedDeltaNet, self).__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + self.conv1d_weight = None + + self.in_proj_qkv = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[self.key_dim, self.key_dim, self.value_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkv", + ) + self.in_proj_z = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.value_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_z", + ) + self.in_proj_b = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=None, + prefix=f"{prefix}.in_proj_ba", + ) + self.in_proj_a = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=None, + prefix=f"{prefix}.in_proj_a", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.tp_size, + self.tp_rank, + ) + }, + ) + + max_prefill_bs = vllm_config.scheduler_config.max_num_prefill_seqs + max_decode_bs = vllm_config.scheduler_config.max_num_seqs + + mamba_cache_bs = max_decode_bs + max(8, max_decode_bs) + if max_prefill_bs is not None: + mamba_cache_bs += max_prefill_bs + else: + mamba_cache_bs += max_decode_bs + + conv_state_shape = ( + mamba_cache_bs, + self.conv_kernel_size - 1, + divide(self.conv_dim, self.tp_size), + ) + temporal_state_shape = (mamba_cache_bs, + divide(self.num_v_heads, self.tp_size), + self.head_k_dim, self.head_v_dim) + + self.conv_state = torch.empty(conv_state_shape, + dtype=torch.float32, + device=self.conv1d.weight.device) + self.ssm_state = torch.empty(temporal_state_shape, + dtype=torch.float32, + device=self.conv1d.weight.device) + + self.chunk_size = 64 + self.eye_constant = torch.eye(self.chunk_size, + dtype=torch.bfloat16, + device=self.conv1d.weight.device) + self.inv_loop = int(os.environ.get("VLLM_GDN_INV_LOOP", 12)) + + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), + ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + ) + ) + + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + dtype=config.dtype, + ) + + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def fix_query_key_value_ordering( + self, + mixed_qkv, + z, + b, + a, + ): + raise NotImplementedError( + "Qwen3.5 Series dont need to fix query key value ordering" + ) + + def forward( + self, + hidden_states: torch.Tensor, + # output: torch.Tensor, + ): + """ + Forward pass with three parts: + 1. Input projection + 2. Core attention (custom op) + 3. Output projection + """ + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + conv_state = self.conv_state + ssm_state = self.ssm_state + + + num_tokens = hidden_states.size(0) + + # ============================================================ + # Part 1: Input Projection + # ============================================================ + mixed_qkv, _ = self.in_proj_qkv(hidden_states) + z, _ = self.in_proj_z(hidden_states) + z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim) + b, _ = self.in_proj_b(hidden_states) + a, _ = self.in_proj_a(hidden_states) + + b = b.contiguous().float() + a = a.contiguous().float() + mixed_qkv = mixed_qkv.float() + + # ============================================================ + # Part 2: Core Attention (Custom Op) + # ============================================================ + # Note: we should not use torch.empty here like other attention backends, + # see discussions in https://github.com/vllm-project/vllm/pull/28182 + mamba_cache_prefill_indices = attn_metadata.mamba_cache_prefill_indices + mamba_cache_decode_indices = attn_metadata.mamba_cache_decode_indices + + if self.conv1d_weight is None: + self.conv1d_weight = self.conv1d.weight.squeeze(1).transpose( + 0, 1).flatten().reshape(self.conv_kernel_size, + self.conv_dim // self.tp_size).float() + del self.conv1d.weight + + if attn_metadata.is_prompt: + bs, seq_len, qkv_dim = mixed_qkv.shape + conv_state_indices = attn_metadata.conv_state_indices + prefill_conv_state = torch.index_select( + mixed_qkv.reshape(-1, qkv_dim), + dim=0, + index=conv_state_indices).reshape(bs, -1, qkv_dim) + conv_state.index_copy_(dim=0, + index=mamba_cache_prefill_indices, + source=prefill_conv_state) + + mixed_qkv_with_pad = F.pad(mixed_qkv, + (0, 0, self.conv_kernel_size - 1, 0)) + for idx in range(self.conv_kernel_size): + qkv_slice = mixed_qkv_with_pad[:, idx:(idx + seq_len), :] + conv1d_weight_slice = self.conv1d_weight[idx] + qkv_conv = qkv_slice * conv1d_weight_slice + if idx == 0: + mixed_qkv_non_spec = qkv_conv + else: + mixed_qkv_non_spec.add_(qkv_conv) + + mixed_qkv_non_spec = F.silu(mixed_qkv_non_spec) + + else: + mixed_qkv_non_spec, cur_conv_state = causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d_weight, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_decode_indices, + ) + conv_state.index_copy_(0, mamba_cache_decode_indices, + cur_conv_state) + + query, key, value = torch.split( + mixed_qkv_non_spec.to(hidden_states.dtype), + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query_non_spec = query.reshape(query.shape[0], query.shape[1], -1, + self.head_k_dim) + key_non_spec = key.reshape(key.shape[0], key.shape[1], -1, + self.head_k_dim) + value_non_spec = value.reshape(value.shape[0], value.shape[1], -1, + self.head_v_dim) + + beta = b.sigmoid().to(hidden_states.dtype) + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + if self.num_v_heads // self.num_k_heads > 1: + query_non_spec = query_non_spec.repeat_interleave( + self.num_v_heads // self.num_k_heads, dim=2) + key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads // + self.num_k_heads, + dim=2) + + if attn_metadata.is_prompt: + core_attn_out, last_recurrent_state = ( + torch_chunk_gated_delta_rule_opt( + query_non_spec, + key_non_spec, + value_non_spec, + g=g, + beta=beta, + eye_constant=self.eye_constant, + chunk_size=self.chunk_size, + inv_loop=self.inv_loop, + initial_state=None, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + )) + ssm_state.index_copy_(dim=0, + index=mamba_cache_prefill_indices, + source=last_recurrent_state) + else: + recurrent_state = torch.index_select( + ssm_state, + dim=0, + index=mamba_cache_decode_indices, + ) + core_attn_out, last_recurrent_state = ( + torch_recurrent_gated_delta_rule_opt( + query_non_spec, + key_non_spec, + value_non_spec, + g=g, + beta=beta, + recurrent_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + )) + ssm_state.index_copy_( + dim=0, + index=mamba_cache_decode_indices, + source=last_recurrent_state, + ) + + # ============================================================ + # Part 3: Output Projection + # ============================================================ + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z).to(hidden_states.dtype) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], + core_attn_out.shape[1], -1) + + output, _ = self.out_proj(core_attn_out) + return output + + +class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer): + def __init__( + self, + vllm_config: VllmConfig, + layer_type: str, + prefix: str = "", + ) -> None: + super(Qwen3NextDecoderLayer, self).__init__() + + config = vllm_config.model_config.hf_text_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + speculative_config = vllm_config.speculative_config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3_5GatedDeltaNet( + vllm_config, + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f"{prefix}.linear_attn", + ) + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + # NOTE: Determine the MLP type based on the model type + # Qwen3.5 use all layers for MLP / Qwen3.5-MoE use sparse MoE blocks + if config.model_type == "qwen3_5_moe_text": + self.mlp = Qwen3NextSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + elif config.model_type == "qwen3_5_text": + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + # prefix=f"{prefix}.mlp", + ) + else: + raise ValueError(f"Invalid model_type {config.model_type}") + + self.input_layernorm = Qwen3_5RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Qwen3_5RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.dtype, + ), + ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + config.hidden_size, + dtype=config.dtype, + ), + ) + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + } +) +class Qwen3_5Model(Qwen3NextModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3NextModel, self).__init__() + + config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = ( + vllm_config.model_config.hf_text_config + ) + parallel_config = vllm_config.parallel_config + + # eplb_config = parallel_config.eplb_config + # self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + + def get_layer(prefix: str): + return Qwen3_5DecoderLayer( + vllm_config, + layer_type=config.layer_types[extract_layer_index(prefix)], + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + if get_pp_group().is_last_rank: + self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def load_fused_expert_weights( + self, + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, + ) -> bool: + param = params_dict[name] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + loaded_local_expert = True #False + for expert_id in range(num_experts): + curr_expert_weight = loaded_weight[expert_id] + success = weight_loader( + param, + curr_expert_weight, + name, + shard_id, + expert_id, +# return_success=True, + ) + if success: + loaded_local_expert = True + + return loaded_local_expert + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + is_fused_expert = False + fused_expert_params_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + num_experts = ( + self.config.num_experts if hasattr(self.config, "num_experts") else 0 + ) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + is_fused_expert = True + expert_params_mapping = fused_expert_params_mapping + + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name_mapped, self): + continue + if is_fused_expert: + # qwen3.5 no need to transpose + # loaded_weight = loaded_weight.transpose(-1, -2) + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + success_w1 = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[0], + "w1", + num_experts, + ) + success_w3 = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[1], + "w3", + num_experts, + ) + success = success_w1 and success_w3 + else: + # down_proj + success = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight, + shard_id, + num_experts, + ) + if success: + name = name_mapped + break + else: + # Skip loading extra bias for GPTQ models. + if ( + name_mapped.endswith(".bias") + or name_mapped.endswith("_bias") + ) and name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + weight_loader = param.weight_loader + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, +# return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + logger.warning_once( + f"Parameter {name} not found in params_dict, skip loading" + ) + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3_5ForCausalLMBase( + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, +): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_text_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + + scheduler_config = vllm_config.scheduler_config + # if cache_config.mamba_cache_mode == "all": + # raise NotImplementedError( + # "Qwen3.5 currently does not support 'all' prefix caching, " + # "please use '--mamba-cache-mode=align' instead" + # ) + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = Qwen3_5Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["mtp."], + ) + return loader.load_weights(weights) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + +class Qwen3_5ForCausalLM(Qwen3_5ForCausalLMBase): + pass + + +class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # set MoE hyperparameters + self.set_moe_parameters() + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +######################################################## +# Qwen3_5-Dense +######################################################## + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3_5ProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # protocols have not __init__ method, so we need to use nn.Module.__init__ + nn.Module.__init__(self) + config: Qwen3_5Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + # self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + # self.video_pruning_rate = multimodal_config.video_pruning_rate + # self.is_multimodal_pruning_enabled = ( + # multimodal_config.is_multimodal_pruning_enabled() + # ) + + # with self._mark_tower_model(vllm_config, {"image", "video"}): + if is_hpu: + qwen3_visionTransformer = Qwen3_VisionTransformerStaticShape + else: + qwen3_visionTransformer = Qwen3_VisionTransformer + self.visual = qwen3_visionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=None, + prefix=maybe_prefix(prefix, "visual"), + ) + + # with self._mark_language_model(vllm_config): + self.language_model = Qwen3_5ForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.use_deepstack = False + self.text_dim = config.text_config.hidden_size + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self._embed_text_input_ids( + input_ids, + self.language_model.embed_input_ids, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + is_multimodal = _require_is_multimodal(is_multimodal) + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + """Run forward pass for Qwen3.5. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen3VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + intermediate_tensors: Intermediate tensors from previous pipeline + stages. + inputs_embeds: Pre-computed input embeddings. + **kwargs: Additional keyword arguments including: + - pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in + LLM. `None` if no images are passed. + - pixel_values_videos: Pixel values of videos to be fed to a + model. `None` if no videos are passed. + - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in + LLM. `None` if no videos are passed. + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["mtp."], + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_text_config + tp_size = parallel_config.tensor_parallel_size + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + ) + + @classmethod + def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]: + return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func() + + +######################################################## +# Qwen3_5-MoE +######################################################## + + +class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts): + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.language_model.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.moe_layers = [] + example_moe = None + for layer in self.language_model.model.layers: + if isinstance(layer, Qwen3_5DecoderLayer) and isinstance( + layer.mlp, Qwen3NextSparseMoeBlock + ): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError( + "No Qwen3_5 layer found in the language_model.model.layers." + ) + + # Set MoE hyperparameters + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3_5MoeProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class Qwen3_5MoeForConditionalGeneration( + Qwen3_5ForConditionalGeneration, Qwen3_5_MoeMixtureOfExperts +): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # protocols have not __init__ method, so we need to use nn.Module.__init__ + nn.Module.__init__(self) + config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + # self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + # self.video_pruning_rate = multimodal_config.video_pruning_rate + # self.is_multimodal_pruning_enabled = ( + # multimodal_config.is_multimodal_pruning_enabled() + # ) + + # with self._mark_tower_model(vllm_config, {"image", "video"}): + if is_hpu: + qwen3_visionTransformer = Qwen3_VisionTransformerStaticShape + else: + qwen3_visionTransformer = Qwen3_VisionTransformer + self.visual = qwen3_visionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=None, + prefix=maybe_prefix(prefix, "visual"), + ) + + # with self._mark_language_model(vllm_config): + self.language_model = Qwen3_5MoeForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.use_deepstack = False + self.text_dim = config.text_config.hidden_size + + # set MoE hyperparameters + self.set_moe_parameters() diff --git a/vllm/model_executor/models/qwen3_5_mtp.py b/vllm/model_executor/models/qwen3_5_mtp.py new file mode 100644 index 000000000000..8bd29f352dbf --- /dev/null +++ b/vllm/model_executor/models/qwen3_5_mtp.py @@ -0,0 +1,447 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3_5 MTP model.""" + +import typing +from collections.abc import Callable, Iterable + +import torch +from torch import nn +from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig +from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import ( + Qwen3_5MoeTextConfig, +) + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5RMSNorm +from vllm.model_executor.models.qwen3_next import QwenNextMixtureOfExperts +from vllm.sequence import IntermediateTensors + +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + _require_is_multimodal, +) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + _merge_multimodal_embeddings, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) + +logger = init_logger(__name__) + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + "hidden_states": 0, + } +) +class Qwen3_5MultiTokenPredictor(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + + config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config + + self.config = config + + self.vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "mtp_num_hidden_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + + self.fc = ColumnParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fc", + ) + + self.layers = torch.nn.ModuleList( + Qwen3_5DecoderLayer( + vllm_config, + layer_type="full_attention", + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(self.num_mtp_layers) + ) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = Qwen3_5RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_fc_norm_embedding = Qwen3_5RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.embed_input_ids(input_ids) + assert hidden_states.shape[-1] == inputs_embeds.shape[-1] + inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) + hidden_states = self.pre_fc_norm_hidden(hidden_states) + hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + current_step_idx = spec_step_idx % self.num_mtp_layers + hidden_states, residual = self.layers[current_step_idx]( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_fused_expert_weights( + self, + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, + ) -> bool: + param = params_dict[name] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + loaded_local_expert = False + for expert_id in range(num_experts): + curr_expert_weight = loaded_weight[expert_id] + success = weight_loader( + param, + curr_expert_weight, + name, + shard_id, + expert_id, + return_success=True, + ) + if success: + loaded_local_expert = True + + return loaded_local_expert + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts + if hasattr(self.config, "num_experts") + else 0, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + is_fused_expert = False + fused_expert_params_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + num_experts = ( + self.config.num_experts if hasattr(self.config, "num_experts") else 0 + ) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + is_fused_expert = True + expert_params_mapping = fused_expert_params_mapping + + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name_mapped, self): + continue + if is_fused_expert: + # qwen3.5 no need to transpose + # loaded_weight = loaded_weight.transpose(-1, -2) + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + success_w1 = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[0], + "w1", + num_experts, + ) + success_w3 = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[1], + "w3", + num_experts, + ) + success = success_w1 and success_w3 + else: + # down_proj + success = self.load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight, + shard_id, + num_experts, + ) + if success: + name = name_mapped + break + else: + # Skip loading extra bias for GPTQ models. + if ( + name_mapped.endswith(".bias") + or name_mapped.endswith("_bias") + ) and name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + weight_loader = param.weight_loader + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + logger.warning_once( + f"Parameter {name} not found in params_dict, skip loading" + ) + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + "hidden_states": 0, + } +) +class Qwen3_5MTP(nn.Module, SupportsMultiModal): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_text_config + self.vllm_config = vllm_config + cache_config = vllm_config.cache_config + if cache_config.mamba_cache_mode == "all": + raise NotImplementedError( + "Qwen3_5MTP currently does not support 'all' prefix caching, " + "please use '--mamba-cache-mode=align' instead" + ) + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.model = Qwen3_5MultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") + ) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self._embed_text_input_ids( + input_ids, + self.model.embed_input_ids, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + is_multimodal = _require_is_multimodal(is_multimodal) + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ): + hidden_states = self.model( + input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def remap_weight_names(weights): + for name, weight in weights: + if name.startswith("mtp."): + name = name.replace("mtp.", "model.") + elif any(key in name for key in ["embed_tokens", "lm_head"]): + if "embed_tokens" in name: + name = name.replace("language_model.", "") + else: + continue + yield name, weight + + loader = AutoWeightsLoader(self) + return loader.load_weights(remap_weight_names(weights)) + + +class Qwen3_5MoeMTP(Qwen3_5MTP, QwenNextMixtureOfExperts): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.set_moe_parameters() diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 4e8d81b7f1b7..46b9f1579be3 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Qwen3Next model.""" +import os from collections.abc import Iterable from typing import Optional @@ -38,6 +39,8 @@ MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.torch_gated_delta_relu import ( + torch_chunk_gated_delta_rule, torch_recurrent_gated_delta_rule) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -62,177 +65,6 @@ KVCache = tuple[torch.Tensor, torch.Tensor] -def torch_chunk_gated_delta_rule( - query, - key, - value, - g, - beta, - eye_constant, - chunk_size=64, - initial_state=None, - output_final_state=True, - use_qk_l2norm_in_kernel=True, -): - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - head_dim = query.size(-1) - inv_scale = head_dim**-0.5 - query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale - key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size - if pad_size > 0: - query = F.pad(query, (0, 0, 0, pad_size)) - key = F.pad(key, (0, 0, 0, pad_size)) - value = F.pad(value, (0, 0, 0, pad_size)) - beta = F.pad(beta, (0, pad_size)) - g = F.pad(g, (0, pad_size)) - tot_len = sequence_length + pad_size - scale = 1 / (query.shape[-1]**0.5) - query = query * scale - - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - # reshape to chunks - query, key, value, k_beta, v_beta = [ - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) - for x in (query, key, value, k_beta, v_beta) - ] - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=0) - - # chunk decay - g = g.cumsum(dim=-1) - g_exp = g.exp() - decay_mask = ((g.unsqueeze(-1) - - g.unsqueeze(-2)).tril().exp().float()).tril() - attn = -((torch.matmul(k_beta.contiguous(), - key.transpose(-1, -2).contiguous())) * - decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].contiguous() - sub = attn[..., :i, :] - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)[..., :i] - attn = attn + eye_constant - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g_exp.unsqueeze(-1)) - last_recurrent_state = (torch.zeros(batch_size, num_heads, k_head_dim, - v_head_dim).to(value) if initial_state - is None else initial_state.to(value)) - core_attn_out = torch.zeros_like(value) - mask = torch.tril(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=0) - mask = mask.view(1, 1, 1, chunk_size, chunk_size) - attn = (query @ key.transpose(-1, -2)) * decay_mask * mask - qg = query * g_exp[..., None] - delta_g_exp = (g[:, :, :, -1, None] - g).exp()[..., None] - k_term = (key * delta_g_exp) - - # for each chunk - for i in range(0, tot_len // chunk_size): - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = value[:, :, i] - v_prime - attn_inter = qg[:, :, i] @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn[:, :, i] @ v_new - last_recurrent_state = ( - last_recurrent_state * g_exp[:, :, i, -1, None, None] + - k_term[:, :, i].transpose(-1, -2) @ v_new) - - if not output_final_state: - last_recurrent_state = None - else: - last_recurrent_state = last_recurrent_state.to(initial_dtype) - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], - core_attn_out.shape[1], -1, - core_attn_out.shape[-1]) - core_attn_out = core_attn_out[:, :, :sequence_length] - core_attn_out = core_attn_out.transpose(1, - 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state - - -def torch_recurrent_gated_delta_rule( - query, - key, - value, - g, - beta, - recurrent_state, - output_final_state=True, - use_qk_l2norm_in_kernel=True, -): - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - head_dim = query.size(-1) - inv_scale = head_dim**-0.5 - query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale - key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - batch_size, sequence_length, num_heads, k_head_dim = key.shape - v_head_dim = value.shape[-1] - scale = 1 / (query.shape[-1]**0.5) - query = query * scale - - recurrent_state = recurrent_state.to(value) - - if num_heads > 1: - core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, - v_head_dim).to(value) - for i in range(num_heads): - q_t = query[:, :, i] - k_t = key[:, :, i] - v_t = value[:, :, i] - g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta[:, :, i].unsqueeze(-1) - - recurrent_state = recurrent_state * g_t - kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - delta = (v_t - kv_mem) * beta_t - recurrent_state = recurrent_state + k_t.unsqueeze( - -1) * delta.unsqueeze(-2) - core_attn_out[:, :, i] = (recurrent_state * - q_t.unsqueeze(-1)).sum(dim=-2) - else: - q_t = query.squeeze(-2) - k_t = key.squeeze(-2) - v_t = value.squeeze(-2) - g_t = g.squeeze(-1).exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta - - recurrent_state = recurrent_state * g_t - kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - delta = (v_t - kv_mem) * beta_t - recurrent_state.add_(k_t.unsqueeze(-1) * delta.unsqueeze(-2)) - core_attn_out = (recurrent_state * - q_t.unsqueeze(-1)).sum(dim=-2).unsqueeze(-2) - - if not output_final_state: - recurrent_state = None - else: - recurrent_state = recurrent_state.to(initial_dtype) - core_attn_out = core_attn_out.transpose(1, - 2).contiguous().to(initial_dtype) - return core_attn_out, recurrent_state - - class Qwen3NextSparseMoeBlock(nn.Module): def __init__( @@ -271,7 +103,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, - renormalize=config.norm_topk_prob, + renormalize=getattr(config, "norm_topk_prob", True), quant_config=quant_config, prefix=f"{prefix}.experts") @@ -765,11 +597,15 @@ def __init__( prefix=f"{prefix}.o_proj", ) + if hasattr(config, "rope_theta"): + rope_theta = config.rope_theta + else: + rope_theta = config.rope_parameters.get("rope_theta", 10000) self.rotary_emb = get_rope( head_size=self.head_dim, rotary_dim=self.head_dim, max_position=config.max_position_embeddings, - base=config.rope_theta, + base=rope_theta, rope_scaling=config.rope_scaling, partial_rotary_factor=config.partial_rotary_factor, dual_chunk_attention_config=self.dual_chunk_attention_config, @@ -972,7 +808,7 @@ class Qwen3NextModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config: Qwen3NextConfig = vllm_config.model_config.hf_config + config: Qwen3NextConfig = vllm_config.model_config.hf_text_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -1054,7 +890,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=getattr(self.config, "num_experts", 0)) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1133,8 +969,52 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params +class QwenNextMixtureOfExperts(MixtureOfExperts): + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.moe_layers = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, Qwen3NextDecoderLayer) and isinstance( + layer.mlp, Qwen3NextSparseMoeBlock + ): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + # Set MoE hyperparameters + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - MixtureOfExperts, IsHybrid): + QwenNextMixtureOfExperts, IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1145,7 +1025,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_text_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config @@ -1283,7 +1163,7 @@ def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig" ) -> tuple[tuple[int, int], tuple[int, int]]: parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_config + hf_config = vllm_config.model_config.hf_text_config tp_size = parallel_config.tensor_parallel_size num_spec = (vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index b0a079a38c73..429c7196cc17 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1456,14 +1456,15 @@ def get_input_embeddings( input_ids, inputs_embeds, multimodal_embeddings, [self.config.image_token_id, self.config.video_token_id]) - if deepstack_input_embeds is None: + if deepstack_input_embeds is None and self.use_deepstack: deepstack_input_embeds = torch.zeros(inputs_embeds.size(0), inputs_embeds.size(1), self.deepstack_num_level * self.text_dim, device=inputs_embeds.device, dtype=inputs_embeds.dtype) - inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds), + if deepstack_input_embeds is not None: + inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds), dim=-1) return inputs_embeds diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e5deedb5185d..0713314e642d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -236,6 +236,14 @@ ), "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501 "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), # noqa: E501 + "Qwen3_5ForConditionalGeneration": ( + "qwen3_5", + "Qwen3_5ForConditionalGeneration", + ), + "Qwen3_5MoeForConditionalGeneration": ( + "qwen3_5", + "Qwen3_5MoeForConditionalGeneration", + ), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), @@ -257,6 +265,8 @@ "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"), + "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"), } _TRANSFORMERS_MODELS = { diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index b5135f9ebb82..f534dca463ed 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -91,8 +91,8 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: tokenizer_all_special_ids = tokenizer.all_special_ids tokenizer_all_special_tokens = tokenizer.all_special_tokens - tokenizer_all_special_tokens_extended = ( - tokenizer.all_special_tokens_extended) + # tokenizer_all_special_tokens_extended = ( + # tokenizer.all_special_tokens_extended) tokenizer_vocab = tokenizer.get_vocab() tokenizer_len = len(tokenizer) @@ -115,9 +115,9 @@ def all_special_ids(self) -> list[int]: def all_special_tokens(self) -> list[str]: return tokenizer_all_special_tokens - @property - def all_special_tokens_extended(self) -> list[str]: - return tokenizer_all_special_tokens_extended + # @property + # def all_special_tokens_extended(self) -> list[str]: + # return tokenizer_all_special_tokens_extended @property def max_token_id(self) -> int: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index e938f3bfc671..5827f7618a12 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -3,6 +3,7 @@ import copy from dataclasses import dataclass +from math import prod from typing import Optional import torch @@ -154,6 +155,36 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes +@dataclass +class MambaSpec(KVCacheSpec): + shapes: tuple[tuple[int, ...], ...] + dtypes: tuple[torch.dtype] + page_size_padded: int | None = None + mamba_type: str = "mamba2" + mamba_cache_mode: str = "none" + num_speculative_blocks: int = 0 + + @property + def page_size_bytes(self) -> int: + page_size = sum( + prod(shape) * get_dtype_size(dtype) + for (shape, dtype) in zip(self.shapes, self.dtypes) + ) + if self.page_size_padded is not None: + assert self.page_size_padded >= page_size + return self.page_size_padded + return page_size + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + if vllm_config.cache_config.mamba_cache_mode == "all": + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + elif vllm_config.cache_config.mamba_cache_mode == "align": + return self.page_size_bytes * (2 + self.num_speculative_blocks) + else: + return self.page_size_bytes * (1 + self.num_speculative_blocks) + + @dataclass class KVCacheTensor: """ diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 788caa270778..6650dc68df1b 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1761,7 +1761,10 @@ def get_model(self) -> torch.nn.Module: # fla is short for Flat Linear Attention def _is_fla_model(self): - return hasattr(self.model_config.hf_config, "linear_conv_kernel_dim") + return (hasattr(self.model_config.hf_config, "linear_conv_kernel_dim") or + (hasattr(self.model_config.hf_config, "text_config") and + hasattr(self.model_config.hf_config.text_config, + "linear_conv_kernel_dim"))) def _use_graphs(self, batch_size, seq_len, ctx_blocks=0): if self.enforce_eager: @@ -1982,8 +1985,12 @@ def _prepare_prompt( # TODO: if seq_len < conv_kernel_dim, padding token should be # masked in the prompt stage if self._is_fla_model(): + if hasattr(self.model_config.hf_config, "linear_conv_kernel_dim"): + linear_conv_kernel_dim = self.model_config.hf_config.linear_conv_kernel_dim + else: + linear_conv_kernel_dim = self.model_config.hf_config.text_config.linear_conv_kernel_dim conv_state_indices_list.append(list(range(seq_len + 1 - \ - self.model_config.hf_config.linear_conv_kernel_dim, seq_len))) + linear_conv_kernel_dim, seq_len))) token_types_ids = seq_group_metadata.token_type_ids token_types.append(token_types_ids) if token_types_ids else [] @@ -3325,7 +3332,8 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args, embed_dim = 1176 if any([ model_type in self.get_model().config.model_type - for model_type in ['qwen3_vl', "qwen3_omni"] + for model_type in ['qwen3_vl', "qwen3_omni", + 'qwen3_5', 'qwen3_5_moe'] ]): embed_dim = 1536 elif 'ernie4_5_moe_vl' in self.get_model().config.model_type: @@ -3585,7 +3593,10 @@ def _inc_preprocess(self): def add_fla_dummy_data(self, inputs) -> None: assert self._is_fla_model() - conv_dim = self.model_config.hf_config.linear_conv_kernel_dim + if hasattr(self.model_config.hf_config, "linear_conv_kernel_dim"): + conv_dim = self.model_config.hf_config.linear_conv_kernel_dim + else: + conv_dim = self.model_config.hf_config.text_config.linear_conv_kernel_dim bs, seq_len = inputs.input_tokens.shape mamba_cache_indices = list(range(bs)) mamba_cache_indices = torch.tensor(mamba_cache_indices,