diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index b1947fa7f072..d6e2181aede3 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -569,6 +569,8 @@ Specified using `--task generate`.
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B`
+| `Qwen3_5ForConditionalGeneration` | Qwen3.5 | T + IE+ + VE+ | `Qwen/Qwen3.5-9B-Instruct`, etc. | ✅︎ | ✅︎ |
+| `Qwen3_5MoeForConditionalGeneration` | Qwen3.5-MOE | T + IE+ + VE+ | `Qwen/Qwen3.5-35B-A3B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + IE+ + VE+ | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + IE+ + VE+ | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | | ✅︎ | ✅︎\* |
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py
index d078c517d00e..7f4e15249201 100644
--- a/examples/offline_inference/basic/chat.py
+++ b/examples/offline_inference/basic/chat.py
@@ -1,80 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from vllm import LLM, EngineArgs
-from vllm.utils import FlexibleArgumentParser
+import argparse
+import os
+
+os.environ["PT_HPU_LAZY_MODE"] = "1"
+
+from vllm import LLM, EngineArgs, SamplingParams
+
+# Parse the command-line arguments.
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--model",
+ type=str,
+ default="facebook/opt-125m",
+ help="The model path.",
+)
+parser.add_argument("--tp-size", type=int, default=2, help="The number of threads.")
+parser.add_argument(
+ "--output-tokens", type=int, default=512, help="The number of output tokens."
+)
+parser.add_argument(
+ "--max-model-length", type=int, default=16384, help="Max model length."
+)
+parser.add_argument("--enable-ep", action="store_true", help="Enable EP for MOE models")
+parser.add_argument("--temperature", type=float, default=0.8)
+parser.add_argument("--top-p", type=float, default=0.95)
+parser.add_argument("--enable-thinking", action="store_true", help="Enable think mode for inference")
+# Add example params
+parser.add_argument("--chat-template-path", type=str)
+args = parser.parse_args()
+
+os.environ["VLLM_SKIP_WARMUP"] = "true"
+os.environ["HABANA_VISIBLE_DEVICES"] = "ALL"
+os.environ["PT_HPU_ENABLE_LAZY_COLLECTIVES"] = "true"
+os.environ["PT_HPU_WEIGHT_SHARING"] = "0"
-def create_parser():
- parser = FlexibleArgumentParser()
- # Add engine args
- EngineArgs.add_cli_args(parser)
- parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
- # Add sampling params
- sampling_group = parser.add_argument_group("Sampling parameters")
- sampling_group.add_argument("--max-tokens", type=int)
- sampling_group.add_argument("--temperature", type=float)
- sampling_group.add_argument("--top-p", type=float)
- sampling_group.add_argument("--top-k", type=int)
- # Add example params
- parser.add_argument("--chat-template-path", type=str)
-
- return parser
-
-
-def main(args: dict):
- # Pop arguments not used by LLM
- max_tokens = args.pop("max_tokens")
- temperature = args.pop("temperature")
- top_p = args.pop("top_p")
- top_k = args.pop("top_k")
- chat_template_path = args.pop("chat_template_path")
-
- # Create an LLM
- llm = LLM(**args)
-
- # Create sampling params object
- sampling_params = llm.get_default_sampling_params()
- if max_tokens is not None:
- sampling_params.max_tokens = max_tokens
- if temperature is not None:
- sampling_params.temperature = temperature
- if top_p is not None:
- sampling_params.top_p = top_p
- if top_k is not None:
- sampling_params.top_k = top_k
+if __name__ == "__main__":
+ # Sample prompts.
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+ messages = []
+ for idx in range(len(prompts)):
+ conversation = [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant"
+ },
+ {
+ "role": "user",
+ "content": prompts[idx]
+ },
+ ]
+ messages.append(conversation)
+ # Create a sampling params object.
+ sampling_params = SamplingParams(
+ temperature=args.temperature,
+ top_p=args.top_p,
+ max_tokens=args.output_tokens,
+ )
+ chat_template_path = args.chat_template_path
+ model = args.model
+ if args.tp_size == 1:
+ llm = LLM(
+ model=model,
+ tokenizer=model,
+ trust_remote_code=True,
+ dtype="bfloat16",
+ max_model_len=args.max_model_length,
+ )
+ else:
+ llm = LLM(
+ model=model,
+ tokenizer=model,
+ tensor_parallel_size=args.tp_size,
+ distributed_executor_backend="mp",
+ trust_remote_code=True,
+ max_model_len=args.max_model_length,
+ enable_expert_parallel=args.enable_ep,
+ dtype="bfloat16",
+ )
def print_outputs(outputs):
print("\nGenerated Outputs:\n" + "-" * 80)
- for output in outputs:
- prompt = output.prompt
- generated_text = output.outputs[0].text
+ for idx in range(len(outputs)):
+ prompt = prompts[idx]
+ generated_text = outputs[idx].outputs[0].text
print(f"Prompt: {prompt!r}\n")
print(f"Generated text: {generated_text!r}")
print("-" * 80)
print("=" * 80)
- # In this script, we demonstrate how to pass input to the chat method:
- conversation = [
- {"role": "system", "content": "You are a helpful assistant"},
- {"role": "user", "content": "Hello"},
- {"role": "assistant", "content": "Hello! How can I assist you today?"},
- {
- "role": "user",
- "content": "Write an essay about the importance of higher education.",
- },
- ]
- outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
- print_outputs(outputs)
-
- # You can run batch inference with llm.chat API
- conversations = [conversation for _ in range(10)]
-
- # We turn on tqdm progress bar to verify it's indeed running batch inference
- outputs = llm.chat(conversations, sampling_params, use_tqdm=True)
- print_outputs(outputs)
-
# A chat template can be optionally supplied.
# If not, the model will use its default chat template.
if chat_template_path is not None:
@@ -82,14 +104,16 @@ def print_outputs(outputs):
chat_template = f.read()
outputs = llm.chat(
- conversations,
+ messages,
sampling_params,
use_tqdm=False,
chat_template=chat_template,
+ chat_template_kwargs={"enable_thinking": args.enable_thinking},
)
-
-
-if __name__ == "__main__":
- parser = create_parser()
- args: dict = vars(parser.parse_args())
- main(args)
+ else:
+ outputs = llm.chat(
+ messages,
+ sampling_params,
+ chat_template_kwargs={"enable_thinking": args.enable_thinking}
+ )
+ print_outputs(outputs)
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index aaac656f1ad0..5d8004463ad3 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -1435,6 +1435,80 @@ def run_qwen3_omni_moe(questions: list[str], modality: str) -> ModelRequestData:
)
+# Qwen3.5 Dense
+def run_qwen3_5(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "Qwen/Qwen3.5-4B-Base"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=5,
+ mm_processor_kwargs={
+ "min_pixels": 28 * 28,
+ "max_pixels": 1280 * 28 * 28,
+ "fps": 1,
+ },
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ if modality == "image":
+ placeholder = "<|image_pad|>"
+ elif modality == "video":
+ placeholder = "<|video_pad|>"
+
+ prompts = [
+ (
+ f"<|vision_start|>{placeholder}<|vision_end|>"
+ f"{question}"
+ )
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
+# Qwen3.5 MoE
+def run_qwen3_5_moe(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "/data/Qwen3.5-397B-A17B-FP8-G2"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=16384,
+ max_num_seqs=5,
+ enable_expert_parallel=True,
+ trust_remote_code=True,
+ tensor_parallel_size=8,
+ distributed_executor_backend="mp",
+ mm_processor_kwargs={
+ "min_pixels": 28 * 28,
+ "max_pixels": 1280 * 28 * 28,
+ "fps": 1,
+ },
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ if modality == "image":
+ placeholder = "<|image_pad|>"
+ elif modality == "video":
+ placeholder = "<|video_pad|>"
+
+ prompts = [
+ (
+ f"<|vision_start|>{placeholder}<|vision_end|>"
+ f"{question}"
+ )
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -1516,6 +1590,8 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"qwen3_vl": run_qwen3_vl,
"qwen3_vl_moe": run_qwen3_vl_moe,
"qwen3_omni_moe": run_qwen3_omni_moe,
+ "qwen3_5": run_qwen3_5,
+ "qwen3_5_moe": run_qwen3_5_moe,
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
"tarsier": run_tarsier,
@@ -1526,6 +1602,8 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"glm4_5v",
"qwen3_vl",
"qwen3_vl_moe",
+ "qwen3_5",
+ "qwen3_5_moe",
]
diff --git a/tests/models/registry.py b/tests/models/registry.py
index f4180da268ff..11c25276ee8e 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -421,6 +421,26 @@ def check_available_online(
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"),
+ "Qwen3_5ForConditionalGeneration": _HfExamplesInfo(
+ "Qwen/Qwen3.5-9B-Instruct",
+ max_model_len=4096,
+ min_transformers_version="5.1.0",
+ ),
+ "Qwen3_5MoeForConditionalGeneration": _HfExamplesInfo(
+ "Qwen/Qwen3.5-35B-A3B-Instruct",
+ max_model_len=4096,
+ min_transformers_version="5.1.0",
+ ),
+ "Qwen3_5MTP": _HfExamplesInfo(
+ "Qwen/Qwen3.5-9B-Instruct",
+ speculative_model="Qwen/Qwen3.5-9B-Instruct",
+ min_transformers_version="5.1.0",
+ ),
+ "Qwen3_5MoeMTP": _HfExamplesInfo(
+ "Qwen/Qwen3.5-35B-A3B-Instruct",
+ speculative_model="Qwen/Qwen3.5-35B-A3B-Instruct",
+ min_transformers_version="5.1.0",
+ ),
"Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-Omni-30B-A3B-Instruct", # noqa: E501
max_model_len=4096, # noqa: E501
min_transformers_version="4.57"), # noqa: E501
diff --git a/vllm/config.py b/vllm/config.py
index 394867537e09..70d90c9e64a7 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -1299,8 +1299,8 @@ def get_num_layers_by_block_type(
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])
- # Hybrid model Qwen3Next
- layer_types_value = getattr(self.hf_config, "layer_types", None)
+ # Hybrid model Qwen3Next Qwen3.5 Series
+ layer_types_value = getattr(self.hf_text_config, "layer_types", None)
if layer_types_value is not None:
if getattr(block_type, "value", block_type) == "attention":
return sum(t == "full_attention"
@@ -2575,6 +2575,16 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
"n_predict": n_predict,
"architectures": ["Glm4MoeMTPModel"]
})
+ if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
+ is_moe = hf_config.model_type == "qwen3_5_moe"
+ hf_config.model_type = "qwen3_5_mtp"
+ n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
+ hf_config.update(
+ {
+ "n_predict": n_predict,
+ "architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
+ }
+ )
return hf_config
diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py
index b6d1a02aa434..0c670d8954d8 100644
--- a/vllm/model_executor/layers/layernorm.py
+++ b/vllm/model_executor/layers/layernorm.py
@@ -321,7 +321,8 @@ def forward_cuda(
return self.forward_native(x, residual)
-class RMSNormGated(nn.Module):
+@CustomOp.register("rms_normgated")
+class RMSNormGated(CustomOp):
def __init__(
self,
@@ -347,20 +348,30 @@ def __init__(
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
- def forward(self, hidden_states, gate=None):
+ def forward_native(self, hidden_states, gate=None):
"""
If z is not None, we do norm(x) * silu(z)
if norm_before_gate, else norm(x * silu(z))
"""
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# Norm before gate
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
- hidden_states = self.weight * hidden_states.to(input_dtype)
- hidden_states = hidden_states * F.silu(gate.to(torch.float32))
+ hidden_states = self.weight.to(hidden_states.dtype) * hidden_states
+ if gate is not None:
+ hidden_states = hidden_states * F.silu(gate.to(hidden_states.dtype))
- return hidden_states.to(input_dtype)
+ return hidden_states
+
+ def forward_hpu(self, hidden_states, gate=None):
+ from vllm_hpu_extension.kernels import rms_norm
+ HPUFusedRMSNorm = rms_norm()
+
+ hidden_states = HPUFusedRMSNorm.apply(hidden_states,
+ self.weight.to(hidden_states.dtype),
+ self.eps)
+ if gate is not None:
+ hidden_states = hidden_states * F.silu(gate.to(hidden_states.dtype))
+ return hidden_states
class MiniMaxText01RMSNormTP(CustomOp):
diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py
index a524e1340580..2281a582e543 100644
--- a/vllm/model_executor/layers/mamba/abstract.py
+++ b/vllm/model_executor/layers/mamba/abstract.py
@@ -6,10 +6,12 @@
import torch
+from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
+from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
class MambaBase(AttentionLayerBase):
@@ -43,3 +45,29 @@ def mamba_type(self) -> str:
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass
+
+ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
+ if (
+ vllm_config.speculative_config is not None
+ and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
+ and vllm_config.model_config.hf_config.model_type
+ not in ["qwen3_next", "qwen3_5", "qwen3_5_moe"]
+ ):
+ raise NotImplementedError(
+ "Mamba with speculative decoding is not supported yet."
+ )
+ mamba_block_size = vllm_config.cache_config.mamba_block_size
+ page_size_padded = vllm_config.cache_config.mamba_page_size_padded
+ return MambaSpec(
+ shapes=self.get_state_shape(),
+ dtypes=self.get_state_dtype(),
+ block_size=mamba_block_size,
+ page_size_padded=page_size_padded,
+ mamba_type=self.mamba_type,
+ mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode,
+ num_speculative_blocks=(
+ vllm_config.speculative_config.num_speculative_tokens
+ if vllm_config.speculative_config
+ else 0
+ ),
+ )
\ No newline at end of file
diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py
index a6c1af91de42..42da62e759b0 100644
--- a/vllm/model_executor/layers/mamba/mamba_utils.py
+++ b/vllm/model_executor/layers/mamba/mamba_utils.py
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import Union
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Union, TypeAlias
import torch
@@ -200,3 +202,89 @@ def gated_delta_net_state_shape(
temporal_state_shape = (divide(num_v_heads,
tp_world_size), head_k_dim, head_v_dim)
return conv_state_shape, temporal_state_shape
+
+@dataclass
+class MambaCopySpec:
+ """
+ Data class specifying the memory-copy parameters for Mamba states used for
+ prefix caching in align mode.
+
+ Attributes:
+ start_addr (int): Starting address for the memory copy operation.
+ num_elements (int): Number of elements to copy from the starting address.
+ """
+
+ start_addr: int
+ num_elements: int
+
+MambaStateCopyFunc: TypeAlias = Callable[
+ [torch.Tensor, list[int], int, int], MambaCopySpec
+]
+"""
+Type alias for a function that computes a MambaCopySpec for copying state slices.
+Parameters:
+ state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
+ block_ids: list[int] - the list of block indices for the state to copy.
+ cur_block_idx: int - current block index within `block_ids` to copy from.
+ num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
+ Range: 1 .. 1 + num_speculative_tokens (inclusive).
+"""
+
+def get_conv_copy_spec(
+ state: torch.Tensor,
+ block_ids: list[int],
+ cur_block_idx: int,
+ num_accepted_tokens: int,
+) -> MambaCopySpec:
+ """Return a MambaCopySpec for copying a convolutional state slice."""
+ src_block_id = block_ids[cur_block_idx]
+ src_state = state[src_block_id, num_accepted_tokens - 1 :]
+ return MambaCopySpec(
+ start_addr=src_state.data_ptr(), num_elements=src_state.numel()
+ )
+
+
+def get_temporal_copy_spec(
+ state: torch.Tensor,
+ block_ids: list[int],
+ cur_block_idx: int,
+ num_accepted_tokens: int,
+) -> MambaCopySpec:
+ """Return a MambaCopySpec for copying a temporal state slice."""
+ src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
+ src_state = state[src_block_id]
+ return MambaCopySpec(
+ start_addr=src_state.data_ptr(), num_elements=src_state.numel()
+ )
+
+get_full_copy_spec = get_temporal_copy_spec
+
+class MambaStateCopyFuncCalculator:
+ @classmethod
+ def linear_attention_state_copy_func(cls):
+ return (get_temporal_copy_spec,)
+
+ @classmethod
+ def mamba1_state_copy_func(cls):
+ return (get_conv_copy_spec, get_temporal_copy_spec)
+
+ @classmethod
+ def mamba2_state_copy_func(cls):
+ return get_conv_copy_spec, get_temporal_copy_spec
+
+ @classmethod
+ def short_conv_state_copy_func(cls):
+ return (get_conv_copy_spec,)
+
+ @classmethod
+ def gated_delta_net_state_copy_func(cls):
+ return (get_conv_copy_spec, get_temporal_copy_spec)
+
+ @classmethod
+ def kda_state_copy_func(cls):
+ return (
+ get_conv_copy_spec,
+ get_conv_copy_spec,
+ get_conv_copy_spec,
+ get_temporal_copy_spec,
+ )
\ No newline at end of file
diff --git a/vllm/model_executor/layers/mamba/ops/torch_gated_delta_relu.py b/vllm/model_executor/layers/mamba/ops/torch_gated_delta_relu.py
new file mode 100644
index 000000000000..1ee51566d7c0
--- /dev/null
+++ b/vllm/model_executor/layers/mamba/ops/torch_gated_delta_relu.py
@@ -0,0 +1,335 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright (c) 2024, Tri Dao.
+# Adapted from https://github.com/huggingface/transformers/blob/v4.57-release/src/transformers/models/qwen3_next/modeling_qwen3_next.py
+
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+from vllm import _custom_ops as ops
+from vllm.platforms import current_platform
+
+
+is_hpu = current_platform.is_hpu()
+
+if is_hpu:
+ import habana_frameworks.torch as htorch
+ import habana_frameworks.torch.core as htcore
+
+
+def torch_chunk_gated_delta_rule_opt(
+ query,
+ key,
+ value,
+ g,
+ beta,
+ eye_constant,
+ chunk_size=64,
+ inv_loop=12,
+ initial_state=None,
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=True,
+):
+ ssm_dtype = g.dtype
+ if use_qk_l2norm_in_kernel:
+ head_dim = query.size(-1)
+ inv_scale = head_dim**-0.5
+ query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale
+ key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale
+ query, key, value, beta, g = [
+ x.transpose(1, 2).contiguous()
+ for x in (query, key, value, beta, g)
+ ]
+
+ batch_size, num_heads, sequence_length, k_head_dim = key.shape
+ v_head_dim = value.shape[-1]
+ pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
+ if pad_size > 0:
+ query = F.pad(query, (0, 0, 0, pad_size))
+ key = F.pad(key, (0, 0, 0, pad_size))
+ value = F.pad(value, (0, 0, 0, pad_size))
+ beta = F.pad(beta, (0, pad_size))
+ g = F.pad(g, (0, pad_size))
+ tot_len = sequence_length + pad_size
+ scale = 1 / (query.shape[-1]**0.5)
+ query = query * scale
+
+ v_beta = value * beta.unsqueeze(-1)
+ k_beta = key * beta.unsqueeze(-1)
+ # reshape to chunks
+ query, key, value, k_beta, v_beta = [
+ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
+ for x in (query, key, value, k_beta, v_beta)
+ ]
+ g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
+ mask = torch.ones(chunk_size,
+ chunk_size,
+ dtype=value.dtype,
+ device=query.device).tril(-1)
+
+ # chunk decay
+ g = g.cumsum(dim=-1)
+ g_exp = g.exp().to(value.dtype)
+ decay_mask = ((g.unsqueeze(-1) -
+ g.unsqueeze(-2)).exp().to(value.dtype)).tril()
+
+ attn = torch.matmul(k_beta,
+ key.transpose(-1, -2).contiguous()) * \
+ decay_mask * mask + eye_constant
+ inv_attn = torch.zeros_like(attn) + eye_constant
+ htcore.mark_step()
+ for _ in range(inv_loop):
+ prod = torch.matmul(attn, inv_attn)
+ err = prod * mask
+ update = torch.matmul(inv_attn, err)
+ inv_attn.sub_(update)
+ htcore.mark_step()
+ attn = inv_attn
+
+ value = attn @ v_beta
+ k_cumdecay = attn @ (k_beta * g_exp.unsqueeze(-1))
+ last_recurrent_state = (torch.zeros(batch_size, num_heads, k_head_dim,
+ v_head_dim).to(value) if initial_state
+ is None else initial_state.to(value))
+ mask = torch.tril(torch.ones(chunk_size,
+ chunk_size,
+ dtype=value.dtype,
+ device=value.device),
+ diagonal=0)
+ attn = torch.matmul(query, key.transpose(-1, -2).contiguous()) * decay_mask * mask
+ qg = query * g_exp[..., None]
+ delta_g_exp = (g[:, :, :, -1, None] - g).exp()[..., None].to(value.dtype)
+ k_term = key * delta_g_exp
+
+ num_chunks = tot_len // chunk_size
+ k_eye = torch.eye(k_head_dim, dtype=value.dtype, device=value.device)
+ k_eye = k_eye.view(1, 1, 1, k_head_dim, k_head_dim)
+
+ alpha = g_exp[:, :, :, -1, None, None]
+ B = k_term.transpose(-1, -2).contiguous()
+ K = k_cumdecay
+ V = value
+ Q = qg
+ A = attn
+
+ M = alpha * k_eye - torch.matmul(B, K)
+ N = torch.matmul(B, V)
+ C = Q - torch.matmul(A, K)
+ core_attn_out = torch.matmul(A, V)
+
+ # for each chunk
+ htcore.mark_step()
+ for i in range(num_chunks):
+ core_attn_out[:, :, i].add_(torch.matmul(C[:, :, i], last_recurrent_state))
+ last_recurrent_state = torch.matmul(M[:, :, i], last_recurrent_state) + N[:, :, i]
+ htcore.mark_step()
+
+ if not output_final_state:
+ last_recurrent_state = None
+ else:
+ last_recurrent_state = last_recurrent_state.to(ssm_dtype)
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
+ core_attn_out.shape[1], -1,
+ core_attn_out.shape[-1])
+ core_attn_out = core_attn_out[:, :, :sequence_length]
+ core_attn_out = core_attn_out.transpose(1, 2)
+ return core_attn_out, last_recurrent_state
+
+
+def torch_recurrent_gated_delta_rule_opt(
+ query,
+ key,
+ value,
+ g,
+ beta,
+ recurrent_state,
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=True,
+):
+ ssm_dtype = g.dtype
+ if use_qk_l2norm_in_kernel:
+ head_dim = query.size(-1)
+ inv_scale = head_dim**-0.5
+ query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale
+ key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale
+ query, key, value, beta, g = [
+ x.transpose(1, 2).contiguous()
+ for x in (query, key, value, beta, g)
+ ]
+
+ v_head_dim = value.shape[-1]
+ scale = 1 / (query.shape[-1]**0.5)
+ query = query * scale
+
+ recurrent_state = recurrent_state.to(value.dtype)
+
+ q_t = query.squeeze(-2)
+ k_t = key.squeeze(-2)
+ v_t = value.squeeze(-2)
+ g_t = g.squeeze(-1).exp().to(value.dtype).unsqueeze(-1).unsqueeze(-1)
+
+ recurrent_state = recurrent_state * g_t
+ kv_mem = torch.matmul(k_t.unsqueeze(-2), recurrent_state).squeeze(-2)
+ delta = (v_t - kv_mem) * beta
+ recurrent_state.add_(k_t.unsqueeze(-1) * delta.unsqueeze(-2))
+ core_attn_out = torch.matmul(q_t.unsqueeze(-2), recurrent_state)
+
+ if not output_final_state:
+ recurrent_state = None
+ else:
+ recurrent_state = recurrent_state.to(ssm_dtype)
+ core_attn_out = core_attn_out.transpose(1, 2)
+ return core_attn_out, recurrent_state
+
+
+def torch_chunk_gated_delta_rule(
+ query,
+ key,
+ value,
+ g,
+ beta,
+ eye_constant,
+ chunk_size=64,
+ initial_state=None,
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=True,
+):
+ initial_dtype = query.dtype
+ if use_qk_l2norm_in_kernel:
+ head_dim = query.size(-1)
+ inv_scale = head_dim**-0.5
+ query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale
+ key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale
+ query, key, value, beta, g = [
+ x.transpose(1, 2).contiguous().to(torch.float32)
+ for x in (query, key, value, beta, g)
+ ]
+
+ batch_size, num_heads, sequence_length, k_head_dim = key.shape
+ v_head_dim = value.shape[-1]
+ pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
+ if pad_size > 0:
+ query = F.pad(query, (0, 0, 0, pad_size))
+ key = F.pad(key, (0, 0, 0, pad_size))
+ value = F.pad(value, (0, 0, 0, pad_size))
+ beta = F.pad(beta, (0, pad_size))
+ g = F.pad(g, (0, pad_size))
+ tot_len = sequence_length + pad_size
+ scale = 1 / (query.shape[-1]**0.5)
+ query = query * scale
+
+ v_beta = value * beta.unsqueeze(-1)
+ k_beta = key * beta.unsqueeze(-1)
+ # reshape to chunks
+ query, key, value, k_beta, v_beta = [
+ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
+ for x in (query, key, value, k_beta, v_beta)
+ ]
+ g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
+ mask = torch.triu(torch.ones(chunk_size,
+ chunk_size,
+ dtype=torch.bool,
+ device=query.device),
+ diagonal=0)
+
+ # chunk decay
+ g = g.cumsum(dim=-1)
+ g_exp = g.exp()
+ decay_mask = ((g.unsqueeze(-1) -
+ g.unsqueeze(-2)).tril().exp().float()).tril()
+ attn = -((torch.matmul(k_beta.contiguous(),
+ key.transpose(-1, -2).contiguous())) *
+ decay_mask).masked_fill(mask, 0)
+ for i in range(1, chunk_size):
+ row = attn[..., i, :i].contiguous()
+ sub = attn[..., :i, :]
+ attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)[..., :i]
+ attn = attn + eye_constant
+ value = attn @ v_beta
+ k_cumdecay = attn @ (k_beta * g_exp.unsqueeze(-1))
+ last_recurrent_state = (torch.zeros(batch_size, num_heads, k_head_dim,
+ v_head_dim).to(value) if initial_state
+ is None else initial_state.to(value))
+ core_attn_out = torch.zeros_like(value)
+ mask = torch.tril(torch.ones(chunk_size,
+ chunk_size,
+ dtype=torch.bool,
+ device=query.device),
+ diagonal=0)
+ mask = mask.view(1, 1, 1, chunk_size, chunk_size)
+ attn = (query @ key.transpose(-1, -2)) * decay_mask * mask
+ qg = query * g_exp[..., None]
+ delta_g_exp = (g[:, :, :, -1, None] - g).exp()[..., None]
+ k_term = (key * delta_g_exp)
+
+ # for each chunk
+ for i in range(0, tot_len // chunk_size):
+ v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
+ v_new = value[:, :, i] - v_prime
+ attn_inter = qg[:, :, i] @ last_recurrent_state
+ core_attn_out[:, :, i] = attn_inter + attn[:, :, i] @ v_new
+ last_recurrent_state = (
+ last_recurrent_state * g_exp[:, :, i, -1, None, None] +
+ k_term[:, :, i].transpose(-1, -2) @ v_new)
+
+ if not output_final_state:
+ last_recurrent_state = None
+ else:
+ last_recurrent_state = last_recurrent_state.to(initial_dtype)
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
+ core_attn_out.shape[1], -1,
+ core_attn_out.shape[-1])
+ core_attn_out = core_attn_out[:, :, :sequence_length]
+ core_attn_out = core_attn_out.transpose(1, 2).to(initial_dtype)
+ return core_attn_out, last_recurrent_state
+
+
+def torch_recurrent_gated_delta_rule(
+ query,
+ key,
+ value,
+ g,
+ beta,
+ recurrent_state,
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=True,
+):
+ initial_dtype = query.dtype
+ if use_qk_l2norm_in_kernel:
+ head_dim = query.size(-1)
+ inv_scale = head_dim**-0.5
+ query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale
+ key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale
+ query, key, value, beta, g = [
+ x.transpose(1, 2).contiguous().to(torch.float32)
+ for x in (query, key, value, beta, g)
+ ]
+
+ v_head_dim = value.shape[-1]
+ scale = 1 / (query.shape[-1]**0.5)
+ query = query * scale
+
+ recurrent_state = recurrent_state.to(value)
+
+ q_t = query.squeeze(-2)
+ k_t = key.squeeze(-2)
+ v_t = value.squeeze(-2)
+ g_t = g.squeeze(-1).exp().unsqueeze(-1).unsqueeze(-1)
+ beta_t = beta
+
+ recurrent_state = recurrent_state * g_t
+ kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
+ delta = (v_t - kv_mem) * beta_t
+ recurrent_state.add_(k_t.unsqueeze(-1) * delta.unsqueeze(-2))
+ core_attn_out = (recurrent_state *
+ q_t.unsqueeze(-1)).sum(dim=-2).unsqueeze(-2)
+
+ if not output_final_state:
+ recurrent_state = None
+ else:
+ recurrent_state = recurrent_state.to(initial_dtype)
+ core_attn_out = core_attn_out.transpose(1, 2).to(initial_dtype)
+ return core_attn_out, recurrent_state
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index 9c8cbf52fb63..f009175474e5 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -1449,7 +1449,7 @@ def get_input_positions_tensor(
context_len=context_len,
seq_len=seq_len,
)
- elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
+ elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
return cls._qwen3vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index 350068b2c3d6..b53d06b34e7a 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -32,6 +32,22 @@
"""
+def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
+ """
+ A helper function to be used in the context of
+ [vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids][]
+ to provide a better error message.
+ """
+ if is_multimodal is None:
+ raise ValueError(
+ "`embed_input_ids` now requires `is_multimodal` arg, "
+ "please update your model runner according to "
+ "https://github.com/vllm-project/vllm/pull/16229."
+ )
+
+ return is_multimodal
+
+
@runtime_checkable
class SupportsMultiModal(Protocol):
"""The interface required for all multi-modal models."""
diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py
new file mode 100644
index 000000000000..cfc84f10e196
--- /dev/null
+++ b/vllm/model_executor/models/qwen3_5.py
@@ -0,0 +1,1166 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The vLLM team.
+# Copyright 2025 The Qwen Team.
+# Copyright 2025 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Qwen3.5 Series compatible with HuggingFace weights."""
+
+import os
+import typing
+from collections.abc import Callable, Iterable
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.models.qwen3_5.configuration_qwen3_5 import (
+ Qwen3_5Config,
+ Qwen3_5TextConfig,
+)
+from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
+ Qwen3_5MoeConfig,
+ Qwen3_5MoeTextConfig,
+)
+from vllm.attention import AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import (
+ CacheConfig,
+ ModelConfig,
+ SpeculativeConfig,
+ VllmConfig,
+ get_current_vllm_config,
+)
+from vllm.distributed import (
+ divide,
+ get_pp_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+)
+from vllm.forward_context import get_forward_context
+from vllm.logger import init_logger
+from vllm.model_executor.layers.layernorm import (
+ GemmaRMSNorm as Qwen3_5RMSNorm,
+)
+from vllm.model_executor.layers.layernorm import RMSNormGated
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.mamba.mamba_mixer2 import (
+ mamba_v2_sharded_weight_loader,
+)
+from vllm.model_executor.layers.mamba.mamba_utils import (
+ MambaStateCopyFunc,
+ MambaStateCopyFuncCalculator,
+ MambaStateDtypeCalculator,
+ MambaStateShapeCalculator,
+)
+from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
+ causal_conv1d_update)
+from vllm.model_executor.layers.mamba.ops.torch_gated_delta_relu import (
+ torch_chunk_gated_delta_rule_opt,
+ torch_recurrent_gated_delta_rule_opt,
+)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ sharded_weight_loader,
+)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.model_executor.utils import set_weight_attrs
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.platforms import current_platform
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import (
+ HasInnerState,
+ IsHybrid,
+ MixtureOfExperts,
+ MultiModalEmbeddings,
+ SupportsLoRA,
+ SupportsPP,
+ _require_is_multimodal,
+)
+from .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
+from .qwen3_next import (
+ Qwen3NextAttention,
+ Qwen3NextDecoderLayer,
+ Qwen3NextGatedDeltaNet,
+ Qwen3NextModel,
+ Qwen3NextSparseMoeBlock,
+ QwenNextMixtureOfExperts,
+)
+from .qwen3_vl import (
+ Qwen3_VisionTransformer,
+ Qwen3_VisionTransformerStaticShape,
+ Qwen3VLDummyInputsBuilder,
+ Qwen3VLForConditionalGeneration,
+ Qwen3VLMultiModalProcessor,
+ Qwen3VLProcessingInfo,
+)
+from .utils import (
+ AutoWeightsLoader,
+ PPMissingLayer,
+ _merge_multimodal_embeddings,
+ extract_layer_index,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+
+is_hpu = current_platform.is_hpu()
+logger = init_logger(__name__)
+
+
+class Qwen3_5ProcessingInfo(Qwen3VLProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Qwen3_5Config)
+
+
+class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Qwen3_5MoeConfig)
+
+
+class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig,
+ model_config: ModelConfig | None = None,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ speculative_config: SpeculativeConfig | None = None,
+ prefix: str = "",
+ ) -> None:
+ super(Qwen3NextGatedDeltaNet, self).__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+ self.hidden_size = config.hidden_size
+ self.num_v_heads = config.linear_num_value_heads
+ self.num_k_heads = config.linear_num_key_heads
+ self.head_k_dim = config.linear_key_head_dim
+ self.head_v_dim = config.linear_value_head_dim
+ self.key_dim = self.head_k_dim * self.num_k_heads
+ self.value_dim = self.head_v_dim * self.num_v_heads
+
+ self.conv_kernel_size = config.linear_conv_kernel_dim
+ self.layer_idx = extract_layer_index(prefix)
+ self.activation = config.hidden_act
+ self.act = ACT2FN[config.hidden_act]
+ self.layer_norm_epsilon = config.rms_norm_eps
+ self.prefix = prefix
+
+ self.config = config
+ self.model_config = model_config
+ self.cache_config = cache_config
+ self.quant_config = quant_config
+ self.speculative_config = speculative_config
+ self.num_spec = (
+ self.speculative_config.num_speculative_tokens
+ if self.speculative_config
+ else 0
+ )
+
+ # QKV
+ self.conv_dim = self.key_dim * 2 + self.value_dim
+ self.conv1d = ColumnParallelLinear(
+ input_size=self.conv_kernel_size,
+ output_size=self.conv_dim,
+ bias=False,
+ prefix=f"{prefix}.conv1d",
+ )
+ self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
+ self.conv1d_weight = None
+
+ self.in_proj_qkv = MergedColumnParallelLinear(
+ input_size=self.hidden_size,
+ output_sizes=[self.key_dim, self.key_dim, self.value_dim],
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.in_proj_qkv",
+ )
+ self.in_proj_z = ColumnParallelLinear(
+ input_size=self.hidden_size,
+ output_size=self.value_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.in_proj_z",
+ )
+ self.in_proj_b = ColumnParallelLinear(
+ input_size=self.hidden_size,
+ output_size=self.num_v_heads,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.in_proj_ba",
+ )
+ self.in_proj_a = ColumnParallelLinear(
+ input_size=self.hidden_size,
+ output_size=self.num_v_heads,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.in_proj_a",
+ )
+
+ query_key_settings = (self.key_dim, 0, False)
+ value_settings = (self.value_dim, 0, False)
+
+ delattr(self.conv1d.weight, "weight_loader")
+ set_weight_attrs(
+ self.conv1d.weight,
+ {
+ "weight_loader": mamba_v2_sharded_weight_loader(
+ [
+ query_key_settings,
+ query_key_settings,
+ value_settings,
+ ],
+ self.tp_size,
+ self.tp_rank,
+ )
+ },
+ )
+
+ max_prefill_bs = vllm_config.scheduler_config.max_num_prefill_seqs
+ max_decode_bs = vllm_config.scheduler_config.max_num_seqs
+
+ mamba_cache_bs = max_decode_bs + max(8, max_decode_bs)
+ if max_prefill_bs is not None:
+ mamba_cache_bs += max_prefill_bs
+ else:
+ mamba_cache_bs += max_decode_bs
+
+ conv_state_shape = (
+ mamba_cache_bs,
+ self.conv_kernel_size - 1,
+ divide(self.conv_dim, self.tp_size),
+ )
+ temporal_state_shape = (mamba_cache_bs,
+ divide(self.num_v_heads, self.tp_size),
+ self.head_k_dim, self.head_v_dim)
+
+ self.conv_state = torch.empty(conv_state_shape,
+ dtype=torch.float32,
+ device=self.conv1d.weight.device)
+ self.ssm_state = torch.empty(temporal_state_shape,
+ dtype=torch.float32,
+ device=self.conv1d.weight.device)
+
+ self.chunk_size = 64
+ self.eye_constant = torch.eye(self.chunk_size,
+ dtype=torch.bfloat16,
+ device=self.conv1d.weight.device)
+ self.inv_loop = int(os.environ.get("VLLM_GDN_INV_LOOP", 12))
+
+ # selective projection used to make dt, B and C input dependant
+
+ # time step projection (discretization)
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
+ self.dt_bias = nn.Parameter(
+ torch.ones(self.num_v_heads // self.tp_size),
+ )
+ self.A_log = nn.Parameter(
+ torch.empty(
+ divide(self.num_v_heads, self.tp_size),
+ )
+ )
+
+ set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
+ set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
+
+ self.norm = RMSNormGated(
+ self.head_v_dim,
+ eps=self.layer_norm_epsilon,
+ group_size=None,
+ norm_before_gate=True,
+ dtype=config.dtype,
+ )
+
+ self.out_proj = RowParallelLinear(
+ self.value_dim,
+ self.hidden_size,
+ bias=False,
+ input_is_parallel=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.out_proj",
+ )
+
+ compilation_config = get_current_vllm_config().compilation_config
+ if prefix in compilation_config.static_forward_context:
+ raise ValueError(f"Duplicate layer name: {prefix}")
+ compilation_config.static_forward_context[prefix] = self
+
+ def fix_query_key_value_ordering(
+ self,
+ mixed_qkv,
+ z,
+ b,
+ a,
+ ):
+ raise NotImplementedError(
+ "Qwen3.5 Series dont need to fix query key value ordering"
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ # output: torch.Tensor,
+ ):
+ """
+ Forward pass with three parts:
+ 1. Input projection
+ 2. Core attention (custom op)
+ 3. Output projection
+ """
+ forward_context = get_forward_context()
+ attn_metadata: AttentionMetadata = forward_context.attn_metadata
+ conv_state = self.conv_state
+ ssm_state = self.ssm_state
+
+
+ num_tokens = hidden_states.size(0)
+
+ # ============================================================
+ # Part 1: Input Projection
+ # ============================================================
+ mixed_qkv, _ = self.in_proj_qkv(hidden_states)
+ z, _ = self.in_proj_z(hidden_states)
+ z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)
+ b, _ = self.in_proj_b(hidden_states)
+ a, _ = self.in_proj_a(hidden_states)
+
+ b = b.contiguous().float()
+ a = a.contiguous().float()
+ mixed_qkv = mixed_qkv.float()
+
+ # ============================================================
+ # Part 2: Core Attention (Custom Op)
+ # ============================================================
+ # Note: we should not use torch.empty here like other attention backends,
+ # see discussions in https://github.com/vllm-project/vllm/pull/28182
+ mamba_cache_prefill_indices = attn_metadata.mamba_cache_prefill_indices
+ mamba_cache_decode_indices = attn_metadata.mamba_cache_decode_indices
+
+ if self.conv1d_weight is None:
+ self.conv1d_weight = self.conv1d.weight.squeeze(1).transpose(
+ 0, 1).flatten().reshape(self.conv_kernel_size,
+ self.conv_dim // self.tp_size).float()
+ del self.conv1d.weight
+
+ if attn_metadata.is_prompt:
+ bs, seq_len, qkv_dim = mixed_qkv.shape
+ conv_state_indices = attn_metadata.conv_state_indices
+ prefill_conv_state = torch.index_select(
+ mixed_qkv.reshape(-1, qkv_dim),
+ dim=0,
+ index=conv_state_indices).reshape(bs, -1, qkv_dim)
+ conv_state.index_copy_(dim=0,
+ index=mamba_cache_prefill_indices,
+ source=prefill_conv_state)
+
+ mixed_qkv_with_pad = F.pad(mixed_qkv,
+ (0, 0, self.conv_kernel_size - 1, 0))
+ for idx in range(self.conv_kernel_size):
+ qkv_slice = mixed_qkv_with_pad[:, idx:(idx + seq_len), :]
+ conv1d_weight_slice = self.conv1d_weight[idx]
+ qkv_conv = qkv_slice * conv1d_weight_slice
+ if idx == 0:
+ mixed_qkv_non_spec = qkv_conv
+ else:
+ mixed_qkv_non_spec.add_(qkv_conv)
+
+ mixed_qkv_non_spec = F.silu(mixed_qkv_non_spec)
+
+ else:
+ mixed_qkv_non_spec, cur_conv_state = causal_conv1d_update(
+ mixed_qkv,
+ conv_state,
+ self.conv1d_weight,
+ self.conv1d.bias,
+ self.activation,
+ conv_state_indices=mamba_cache_decode_indices,
+ )
+ conv_state.index_copy_(0, mamba_cache_decode_indices,
+ cur_conv_state)
+
+ query, key, value = torch.split(
+ mixed_qkv_non_spec.to(hidden_states.dtype),
+ [
+ self.key_dim // self.tp_size,
+ self.key_dim // self.tp_size,
+ self.value_dim // self.tp_size,
+ ],
+ dim=-1,
+ )
+ query_non_spec = query.reshape(query.shape[0], query.shape[1], -1,
+ self.head_k_dim)
+ key_non_spec = key.reshape(key.shape[0], key.shape[1], -1,
+ self.head_k_dim)
+ value_non_spec = value.reshape(value.shape[0], value.shape[1], -1,
+ self.head_v_dim)
+
+ beta = b.sigmoid().to(hidden_states.dtype)
+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
+
+ if self.num_v_heads // self.num_k_heads > 1:
+ query_non_spec = query_non_spec.repeat_interleave(
+ self.num_v_heads // self.num_k_heads, dim=2)
+ key_non_spec = key_non_spec.repeat_interleave(self.num_v_heads //
+ self.num_k_heads,
+ dim=2)
+
+ if attn_metadata.is_prompt:
+ core_attn_out, last_recurrent_state = (
+ torch_chunk_gated_delta_rule_opt(
+ query_non_spec,
+ key_non_spec,
+ value_non_spec,
+ g=g,
+ beta=beta,
+ eye_constant=self.eye_constant,
+ chunk_size=self.chunk_size,
+ inv_loop=self.inv_loop,
+ initial_state=None,
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=True,
+ ))
+ ssm_state.index_copy_(dim=0,
+ index=mamba_cache_prefill_indices,
+ source=last_recurrent_state)
+ else:
+ recurrent_state = torch.index_select(
+ ssm_state,
+ dim=0,
+ index=mamba_cache_decode_indices,
+ )
+ core_attn_out, last_recurrent_state = (
+ torch_recurrent_gated_delta_rule_opt(
+ query_non_spec,
+ key_non_spec,
+ value_non_spec,
+ g=g,
+ beta=beta,
+ recurrent_state=recurrent_state,
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=True,
+ ))
+ ssm_state.index_copy_(
+ dim=0,
+ index=mamba_cache_decode_indices,
+ source=last_recurrent_state,
+ )
+
+ # ============================================================
+ # Part 3: Output Projection
+ # ============================================================
+ z_shape_og = z.shape
+ # reshape input data into 2D tensor
+ core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
+ z = z.reshape(-1, z.shape[-1])
+ core_attn_out = self.norm(core_attn_out, z).to(hidden_states.dtype)
+ core_attn_out = core_attn_out.reshape(z_shape_og)
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
+ core_attn_out.shape[1], -1)
+
+ output, _ = self.out_proj(core_attn_out)
+ return output
+
+
+class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ layer_type: str,
+ prefix: str = "",
+ ) -> None:
+ super(Qwen3NextDecoderLayer, self).__init__()
+
+ config = vllm_config.model_config.hf_text_config
+ model_config = vllm_config.model_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ speculative_config = vllm_config.speculative_config
+
+ self.layer_type = layer_type
+ self.layer_idx = extract_layer_index(prefix)
+
+ if self.layer_type == "linear_attention":
+ self.linear_attn = Qwen3_5GatedDeltaNet(
+ vllm_config,
+ config,
+ model_config=model_config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ speculative_config=speculative_config,
+ prefix=f"{prefix}.linear_attn",
+ )
+ elif self.layer_type == "full_attention":
+ self.self_attn = Qwen3NextAttention(
+ config,
+ model_config=model_config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ )
+ else:
+ raise ValueError(f"Invalid layer_type {self.layer_type}")
+
+ # NOTE: Determine the MLP type based on the model type
+ # Qwen3.5 use all layers for MLP / Qwen3.5-MoE use sparse MoE blocks
+ if config.model_type == "qwen3_5_moe_text":
+ self.mlp = Qwen3NextSparseMoeBlock(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ elif config.model_type == "qwen3_5_text":
+ self.mlp = Qwen3NextMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ # prefix=f"{prefix}.mlp",
+ )
+ else:
+ raise ValueError(f"Invalid model_type {config.model_type}")
+
+ self.input_layernorm = Qwen3_5RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = Qwen3_5RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ self.layer_scale = getattr(config, "layer_scale", False)
+ if self.layer_scale:
+ self.attn_layer_scale = torch.nn.Parameter(
+ torch.zeros(
+ 1,
+ 1,
+ config.hidden_size,
+ dtype=config.dtype,
+ ),
+ )
+ self.ffn_layer_scale = torch.nn.Parameter(
+ torch.zeros(
+ 1,
+ 1,
+ config.hidden_size,
+ dtype=config.dtype,
+ ),
+ )
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ }
+)
+class Qwen3_5Model(Qwen3NextModel):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super(Qwen3NextModel, self).__init__()
+
+ config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = (
+ vllm_config.model_config.hf_text_config
+ )
+ parallel_config = vllm_config.parallel_config
+
+ # eplb_config = parallel_config.eplb_config
+ # self.num_redundant_experts = eplb_config.num_redundant_experts
+
+ self.config = config
+
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = VocabParallelEmbedding(
+ self.vocab_size,
+ config.hidden_size,
+ )
+
+ def get_layer(prefix: str):
+ return Qwen3_5DecoderLayer(
+ vllm_config,
+ layer_type=config.layer_types[extract_layer_index(prefix)],
+ prefix=prefix,
+ )
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
+ )
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ if get_pp_group().is_last_rank:
+ self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ def load_fused_expert_weights(
+ self,
+ name: str,
+ params_dict: dict,
+ loaded_weight: torch.Tensor,
+ shard_id: str,
+ num_experts: int,
+ ) -> bool:
+ param = params_dict[name]
+ weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
+ loaded_local_expert = True #False
+ for expert_id in range(num_experts):
+ curr_expert_weight = loaded_weight[expert_id]
+ success = weight_loader(
+ param,
+ curr_expert_weight,
+ name,
+ shard_id,
+ expert_id,
+# return_success=True,
+ )
+ if success:
+ loaded_local_expert = True
+
+ return loaded_local_expert
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ expert_params_mapping = self.get_expert_mapping()
+ is_fused_expert = False
+ fused_expert_params_mapping = [
+ ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
+ ("experts.w2_weight", "experts.down_proj", 0, "w2"),
+ ]
+ num_experts = (
+ self.config.num_experts if hasattr(self.config, "num_experts") else 0
+ )
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+
+ if name.startswith("mtp."):
+ continue
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if "experts.gate_up_proj" in name or "experts.down_proj" in name:
+ is_fused_expert = True
+ expert_params_mapping = fused_expert_params_mapping
+
+ if weight_name not in name:
+ continue
+
+ if "mlp.experts" in name:
+ continue
+
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ # name = apply_attn_prefix(name, params_dict)
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+ is_expert_weight = True
+ name_mapped = name.replace(weight_name, param_name)
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+ if is_fused_expert:
+ # qwen3.5 no need to transpose
+ # loaded_weight = loaded_weight.transpose(-1, -2)
+ if "experts.gate_up_proj" in name:
+ loaded_weight = loaded_weight.chunk(2, dim=-2)
+ success_w1 = self.load_fused_expert_weights(
+ name_mapped,
+ params_dict,
+ loaded_weight[0],
+ "w1",
+ num_experts,
+ )
+ success_w3 = self.load_fused_expert_weights(
+ name_mapped,
+ params_dict,
+ loaded_weight[1],
+ "w3",
+ num_experts,
+ )
+ success = success_w1 and success_w3
+ else:
+ # down_proj
+ success = self.load_fused_expert_weights(
+ name_mapped,
+ params_dict,
+ loaded_weight,
+ shard_id,
+ num_experts,
+ )
+ if success:
+ name = name_mapped
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if (
+ name_mapped.endswith(".bias")
+ or name_mapped.endswith("_bias")
+ ) and name_mapped not in params_dict:
+ continue
+ param = params_dict[name_mapped]
+ weight_loader = param.weight_loader
+ success = weight_loader(
+ param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+# return_success=True,
+ )
+ if success:
+ name = name_mapped
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ if name not in params_dict:
+ logger.warning_once(
+ f"Parameter {name} not found in params_dict, skip loading"
+ )
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Qwen3_5ForCausalLMBase(
+ nn.Module,
+ HasInnerState,
+ SupportsLoRA,
+ SupportsPP,
+):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ config = vllm_config.model_config.hf_text_config
+ self.vllm_config = vllm_config
+ self.model_config = vllm_config.model_config
+ cache_config = vllm_config.cache_config
+
+ scheduler_config = vllm_config.scheduler_config
+ # if cache_config.mamba_cache_mode == "all":
+ # raise NotImplementedError(
+ # "Qwen3.5 currently does not support 'all' prefix caching, "
+ # "please use '--mamba-cache-mode=align' instead"
+ # )
+ self.quant_config = vllm_config.quant_config
+
+ super().__init__()
+ self.config = config
+ self.scheduler_config = scheduler_config
+ self.model = Qwen3_5Model(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+
+ if get_pp_group().is_last_rank:
+ if config.tie_word_embeddings:
+ self.lm_head = self.model.embed_tokens
+ else:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs: object,
+ ):
+ hidden_states = self.model(
+ input_ids, positions, intermediate_tensors, inputs_embeds
+ )
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> torch.Tensor | None:
+ return self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=["mtp."],
+ )
+ return loader.load_weights(weights)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+
+class Qwen3_5ForCausalLM(Qwen3_5ForCausalLMBase):
+ pass
+
+
+class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+
+ # set MoE hyperparameters
+ self.set_moe_parameters()
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ return self.model.get_expert_mapping()
+
+
+########################################################
+# Qwen3_5-Dense
+########################################################
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Qwen3VLMultiModalProcessor,
+ info=Qwen3_5ProcessingInfo,
+ dummy_inputs=Qwen3VLDummyInputsBuilder,
+)
+class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ # protocols have not __init__ method, so we need to use nn.Module.__init__
+ nn.Module.__init__(self)
+ config: Qwen3_5Config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+ # self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
+ # self.video_pruning_rate = multimodal_config.video_pruning_rate
+ # self.is_multimodal_pruning_enabled = (
+ # multimodal_config.is_multimodal_pruning_enabled()
+ # )
+
+ # with self._mark_tower_model(vllm_config, {"image", "video"}):
+ if is_hpu:
+ qwen3_visionTransformer = Qwen3_VisionTransformerStaticShape
+ else:
+ qwen3_visionTransformer = Qwen3_VisionTransformer
+ self.visual = qwen3_visionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=None,
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+
+ # with self._mark_language_model(vllm_config):
+ self.language_model = Qwen3_5ForCausalLM(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ self.use_deepstack = False
+ self.text_dim = config.text_config.hidden_size
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: MultiModalEmbeddings | None = None,
+ *,
+ is_multimodal: torch.Tensor | None = None,
+ handle_oov_mm_token: bool = False,
+ ) -> torch.Tensor:
+ inputs_embeds = self._embed_text_input_ids(
+ input_ids,
+ self.language_model.embed_input_ids,
+ is_multimodal=is_multimodal,
+ handle_oov_mm_token=handle_oov_mm_token,
+ )
+
+ if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
+ return inputs_embeds
+
+ is_multimodal = _require_is_multimodal(is_multimodal)
+
+ inputs_embeds = _merge_multimodal_embeddings(
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=multimodal_embeddings,
+ is_multimodal=is_multimodal,
+ )
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs: object,
+ ) -> torch.Tensor | IntermediateTensors:
+ """Run forward pass for Qwen3.5.
+
+ Args:
+ input_ids: Flattened (concatenated) input_ids corresponding to a
+ batch.
+ positions: Flattened (concatenated) position ids corresponding to a
+ batch.
+ **NOTE**: If mrope is enabled (default setting for Qwen3VL
+ opensource models), the shape will be `(3, seq_len)`,
+ otherwise it will be `(seq_len,).
+ intermediate_tensors: Intermediate tensors from previous pipeline
+ stages.
+ inputs_embeds: Pre-computed input embeddings.
+ **kwargs: Additional keyword arguments including:
+ - pixel_values: Pixel values to be fed to a model.
+ `None` if no images are passed.
+ - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
+ LLM. `None` if no images are passed.
+ - pixel_values_videos: Pixel values of videos to be fed to a
+ model. `None` if no videos are passed.
+ - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
+ LLM. `None` if no videos are passed.
+ """
+
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ hidden_states = self.language_model.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=["mtp."],
+ )
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ @classmethod
+ def get_mamba_state_dtype_from_config(
+ cls,
+ vllm_config: "VllmConfig",
+ ) -> tuple[torch.dtype, torch.dtype]:
+ return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
+ vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
+ )
+
+ @classmethod
+ def get_mamba_state_shape_from_config(
+ cls, vllm_config: "VllmConfig"
+ ) -> tuple[tuple[int, int], tuple[int, int]]:
+ parallel_config = vllm_config.parallel_config
+ hf_config = vllm_config.model_config.hf_text_config
+ tp_size = parallel_config.tensor_parallel_size
+ num_spec = (
+ vllm_config.speculative_config.num_speculative_tokens
+ if vllm_config.speculative_config
+ else 0
+ )
+ return MambaStateShapeCalculator.gated_delta_net_state_shape(
+ tp_size,
+ hf_config.linear_num_key_heads,
+ hf_config.linear_num_value_heads,
+ hf_config.linear_key_head_dim,
+ hf_config.linear_value_head_dim,
+ hf_config.linear_conv_kernel_dim,
+ num_spec,
+ )
+
+ @classmethod
+ def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
+ return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()
+
+
+########################################################
+# Qwen3_5-MoE
+########################################################
+
+
+class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
+ def update_physical_experts_metadata(
+ self,
+ num_physical_experts: int,
+ num_local_physical_experts: int,
+ ) -> None:
+ assert self.num_local_physical_experts == num_local_physical_experts
+ self.num_physical_experts = num_physical_experts
+ self.num_local_physical_experts = num_local_physical_experts
+ self.num_redundant_experts = num_physical_experts - self.num_logical_experts
+ for layer in self.language_model.model.layers:
+ if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
+ moe = layer.mlp
+ moe.n_local_physical_experts = num_local_physical_experts
+ moe.n_physical_experts = num_physical_experts
+ moe.n_redundant_experts = self.num_redundant_experts
+ moe.experts.update_expert_map()
+
+ def set_moe_parameters(self):
+ self.expert_weights = []
+
+ self.moe_layers = []
+ example_moe = None
+ for layer in self.language_model.model.layers:
+ if isinstance(layer, Qwen3_5DecoderLayer) and isinstance(
+ layer.mlp, Qwen3NextSparseMoeBlock
+ ):
+ example_moe = layer.mlp
+ self.moe_layers.append(layer.mlp.experts)
+
+ if example_moe is None:
+ raise RuntimeError(
+ "No Qwen3_5 layer found in the language_model.model.layers."
+ )
+
+ # Set MoE hyperparameters
+ self.num_moe_layers = len(self.moe_layers)
+ self.num_expert_groups = 1
+ self.num_shared_experts = 0
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Qwen3VLMultiModalProcessor,
+ info=Qwen3_5MoeProcessingInfo,
+ dummy_inputs=Qwen3VLDummyInputsBuilder,
+)
+class Qwen3_5MoeForConditionalGeneration(
+ Qwen3_5ForConditionalGeneration, Qwen3_5_MoeMixtureOfExperts
+):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ # protocols have not __init__ method, so we need to use nn.Module.__init__
+ nn.Module.__init__(self)
+ config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+ # self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
+ # self.video_pruning_rate = multimodal_config.video_pruning_rate
+ # self.is_multimodal_pruning_enabled = (
+ # multimodal_config.is_multimodal_pruning_enabled()
+ # )
+
+ # with self._mark_tower_model(vllm_config, {"image", "video"}):
+ if is_hpu:
+ qwen3_visionTransformer = Qwen3_VisionTransformerStaticShape
+ else:
+ qwen3_visionTransformer = Qwen3_VisionTransformer
+ self.visual = qwen3_visionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=None,
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+
+ # with self._mark_language_model(vllm_config):
+ self.language_model = Qwen3_5MoeForCausalLM(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ self.use_deepstack = False
+ self.text_dim = config.text_config.hidden_size
+
+ # set MoE hyperparameters
+ self.set_moe_parameters()
diff --git a/vllm/model_executor/models/qwen3_5_mtp.py b/vllm/model_executor/models/qwen3_5_mtp.py
new file mode 100644
index 000000000000..8bd29f352dbf
--- /dev/null
+++ b/vllm/model_executor/models/qwen3_5_mtp.py
@@ -0,0 +1,447 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Inference-only Qwen3_5 MTP model."""
+
+import typing
+from collections.abc import Callable, Iterable
+
+import torch
+from torch import nn
+from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
+from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
+ Qwen3_5MoeTextConfig,
+)
+
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.distributed.parallel_state import get_pp_group
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.linear import ColumnParallelLinear
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5RMSNorm
+from vllm.model_executor.models.qwen3_next import QwenNextMixtureOfExperts
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import (
+ MultiModalEmbeddings,
+ SupportsMultiModal,
+ _require_is_multimodal,
+)
+from .utils import (
+ AutoWeightsLoader,
+ PPMissingLayer,
+ _merge_multimodal_embeddings,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ "hidden_states": 0,
+ }
+)
+class Qwen3_5MultiTokenPredictor(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ model_config = vllm_config.model_config
+ quant_config = vllm_config.quant_config
+
+ config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config
+
+ self.config = config
+
+ self.vocab_size = config.vocab_size
+
+ self.mtp_start_layer_idx = config.num_hidden_layers
+ self.num_mtp_layers = getattr(config, "mtp_num_hidden_layers", 1)
+
+ self.embed_tokens = VocabParallelEmbedding(
+ self.vocab_size,
+ config.hidden_size,
+ )
+
+ self.fc = ColumnParallelLinear(
+ self.config.hidden_size * 2,
+ self.config.hidden_size,
+ gather_output=True,
+ bias=False,
+ return_bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc",
+ )
+
+ self.layers = torch.nn.ModuleList(
+ Qwen3_5DecoderLayer(
+ vllm_config,
+ layer_type="full_attention",
+ prefix=f"{prefix}.layers.{idx}",
+ )
+ for idx in range(self.num_mtp_layers)
+ )
+
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.pre_fc_norm_hidden = Qwen3_5RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+ self.pre_fc_norm_embedding = Qwen3_5RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_input_ids(input_ids)
+ assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
+ inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
+ hidden_states = self.pre_fc_norm_hidden(hidden_states)
+ hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
+ hidden_states = self.fc(hidden_states)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ current_step_idx = spec_step_idx % self.num_mtp_layers
+ hidden_states, residual = self.layers[current_step_idx](
+ positions=positions,
+ hidden_states=hidden_states,
+ residual=residual,
+ )
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors(
+ {"hidden_states": hidden_states, "residual": residual}
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def load_fused_expert_weights(
+ self,
+ name: str,
+ params_dict: dict,
+ loaded_weight: torch.Tensor,
+ shard_id: str,
+ num_experts: int,
+ ) -> bool:
+ param = params_dict[name]
+ weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
+ loaded_local_expert = False
+ for expert_id in range(num_experts):
+ curr_expert_weight = loaded_weight[expert_id]
+ success = weight_loader(
+ param,
+ curr_expert_weight,
+ name,
+ shard_id,
+ expert_id,
+ return_success=True,
+ )
+ if success:
+ loaded_local_expert = True
+
+ return loaded_local_expert
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ self,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.num_experts
+ if hasattr(self.config, "num_experts")
+ else 0,
+ )
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ is_fused_expert = False
+ fused_expert_params_mapping = [
+ ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
+ ("experts.w2_weight", "experts.down_proj", 0, "w2"),
+ ]
+ num_experts = (
+ self.config.num_experts if hasattr(self.config, "num_experts") else 0
+ )
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if "experts.gate_up_proj" in name or "experts.down_proj" in name:
+ is_fused_expert = True
+ expert_params_mapping = fused_expert_params_mapping
+
+ if weight_name not in name:
+ continue
+
+ if "mlp.experts" in name:
+ continue
+
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+ is_expert_weight = True
+ name_mapped = name.replace(weight_name, param_name)
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+ if is_fused_expert:
+ # qwen3.5 no need to transpose
+ # loaded_weight = loaded_weight.transpose(-1, -2)
+ if "experts.gate_up_proj" in name:
+ loaded_weight = loaded_weight.chunk(2, dim=-2)
+ success_w1 = self.load_fused_expert_weights(
+ name_mapped,
+ params_dict,
+ loaded_weight[0],
+ "w1",
+ num_experts,
+ )
+ success_w3 = self.load_fused_expert_weights(
+ name_mapped,
+ params_dict,
+ loaded_weight[1],
+ "w3",
+ num_experts,
+ )
+ success = success_w1 and success_w3
+ else:
+ # down_proj
+ success = self.load_fused_expert_weights(
+ name_mapped,
+ params_dict,
+ loaded_weight,
+ shard_id,
+ num_experts,
+ )
+ if success:
+ name = name_mapped
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if (
+ name_mapped.endswith(".bias")
+ or name_mapped.endswith("_bias")
+ ) and name_mapped not in params_dict:
+ continue
+ param = params_dict[name_mapped]
+ weight_loader = param.weight_loader
+ success = weight_loader(
+ param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ name = name_mapped
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ if name not in params_dict:
+ logger.warning_once(
+ f"Parameter {name} not found in params_dict, skip loading"
+ )
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ "hidden_states": 0,
+ }
+)
+class Qwen3_5MTP(nn.Module, SupportsMultiModal):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": ["up_proj", "down_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ config = vllm_config.model_config.hf_text_config
+ self.vllm_config = vllm_config
+ cache_config = vllm_config.cache_config
+ if cache_config.mamba_cache_mode == "all":
+ raise NotImplementedError(
+ "Qwen3_5MTP currently does not support 'all' prefix caching, "
+ "please use '--mamba-cache-mode=align' instead"
+ )
+
+ self.quant_config = vllm_config.quant_config
+
+ super().__init__()
+ self.config = config
+ self.model = Qwen3_5MultiTokenPredictor(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
+ )
+
+ if get_pp_group().is_last_rank:
+ if config.tie_word_embeddings:
+ self.lm_head = self.model.embed_tokens
+ else:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: MultiModalEmbeddings | None = None,
+ *,
+ is_multimodal: torch.Tensor | None = None,
+ handle_oov_mm_token: bool = False,
+ ) -> torch.Tensor:
+ inputs_embeds = self._embed_text_input_ids(
+ input_ids,
+ self.model.embed_input_ids,
+ is_multimodal=is_multimodal,
+ handle_oov_mm_token=handle_oov_mm_token,
+ )
+
+ if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
+ return inputs_embeds
+
+ is_multimodal = _require_is_multimodal(is_multimodal)
+
+ inputs_embeds = _merge_multimodal_embeddings(
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=multimodal_embeddings,
+ is_multimodal=is_multimodal,
+ )
+
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs: object,
+ ):
+ hidden_states = self.model(
+ input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor | None:
+ return self.logits_processor(self.lm_head, hidden_states)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ def remap_weight_names(weights):
+ for name, weight in weights:
+ if name.startswith("mtp."):
+ name = name.replace("mtp.", "model.")
+ elif any(key in name for key in ["embed_tokens", "lm_head"]):
+ if "embed_tokens" in name:
+ name = name.replace("language_model.", "")
+ else:
+ continue
+ yield name, weight
+
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(remap_weight_names(weights))
+
+
+class Qwen3_5MoeMTP(Qwen3_5MTP, QwenNextMixtureOfExperts):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ self.set_moe_parameters()
diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py
index 4e8d81b7f1b7..46b9f1579be3 100644
--- a/vllm/model_executor/models/qwen3_next.py
+++ b/vllm/model_executor/models/qwen3_next.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
+import os
from collections.abc import Iterable
from typing import Optional
@@ -38,6 +39,8 @@
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_update)
+from vllm.model_executor.layers.mamba.ops.torch_gated_delta_relu import (
+ torch_chunk_gated_delta_rule, torch_recurrent_gated_delta_rule)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -62,177 +65,6 @@
KVCache = tuple[torch.Tensor, torch.Tensor]
-def torch_chunk_gated_delta_rule(
- query,
- key,
- value,
- g,
- beta,
- eye_constant,
- chunk_size=64,
- initial_state=None,
- output_final_state=True,
- use_qk_l2norm_in_kernel=True,
-):
- initial_dtype = query.dtype
- if use_qk_l2norm_in_kernel:
- head_dim = query.size(-1)
- inv_scale = head_dim**-0.5
- query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale
- key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale
- query, key, value, beta, g = [
- x.transpose(1, 2).contiguous().to(torch.float32)
- for x in (query, key, value, beta, g)
- ]
-
- batch_size, num_heads, sequence_length, k_head_dim = key.shape
- v_head_dim = value.shape[-1]
- pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
- if pad_size > 0:
- query = F.pad(query, (0, 0, 0, pad_size))
- key = F.pad(key, (0, 0, 0, pad_size))
- value = F.pad(value, (0, 0, 0, pad_size))
- beta = F.pad(beta, (0, pad_size))
- g = F.pad(g, (0, pad_size))
- tot_len = sequence_length + pad_size
- scale = 1 / (query.shape[-1]**0.5)
- query = query * scale
-
- v_beta = value * beta.unsqueeze(-1)
- k_beta = key * beta.unsqueeze(-1)
- # reshape to chunks
- query, key, value, k_beta, v_beta = [
- x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
- for x in (query, key, value, k_beta, v_beta)
- ]
- g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
- mask = torch.triu(torch.ones(chunk_size,
- chunk_size,
- dtype=torch.bool,
- device=query.device),
- diagonal=0)
-
- # chunk decay
- g = g.cumsum(dim=-1)
- g_exp = g.exp()
- decay_mask = ((g.unsqueeze(-1) -
- g.unsqueeze(-2)).tril().exp().float()).tril()
- attn = -((torch.matmul(k_beta.contiguous(),
- key.transpose(-1, -2).contiguous())) *
- decay_mask).masked_fill(mask, 0)
- for i in range(1, chunk_size):
- row = attn[..., i, :i].contiguous()
- sub = attn[..., :i, :]
- attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)[..., :i]
- attn = attn + eye_constant
- value = attn @ v_beta
- k_cumdecay = attn @ (k_beta * g_exp.unsqueeze(-1))
- last_recurrent_state = (torch.zeros(batch_size, num_heads, k_head_dim,
- v_head_dim).to(value) if initial_state
- is None else initial_state.to(value))
- core_attn_out = torch.zeros_like(value)
- mask = torch.tril(torch.ones(chunk_size,
- chunk_size,
- dtype=torch.bool,
- device=query.device),
- diagonal=0)
- mask = mask.view(1, 1, 1, chunk_size, chunk_size)
- attn = (query @ key.transpose(-1, -2)) * decay_mask * mask
- qg = query * g_exp[..., None]
- delta_g_exp = (g[:, :, :, -1, None] - g).exp()[..., None]
- k_term = (key * delta_g_exp)
-
- # for each chunk
- for i in range(0, tot_len // chunk_size):
- v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
- v_new = value[:, :, i] - v_prime
- attn_inter = qg[:, :, i] @ last_recurrent_state
- core_attn_out[:, :, i] = attn_inter + attn[:, :, i] @ v_new
- last_recurrent_state = (
- last_recurrent_state * g_exp[:, :, i, -1, None, None] +
- k_term[:, :, i].transpose(-1, -2) @ v_new)
-
- if not output_final_state:
- last_recurrent_state = None
- else:
- last_recurrent_state = last_recurrent_state.to(initial_dtype)
- core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
- core_attn_out.shape[1], -1,
- core_attn_out.shape[-1])
- core_attn_out = core_attn_out[:, :, :sequence_length]
- core_attn_out = core_attn_out.transpose(1,
- 2).contiguous().to(initial_dtype)
- return core_attn_out, last_recurrent_state
-
-
-def torch_recurrent_gated_delta_rule(
- query,
- key,
- value,
- g,
- beta,
- recurrent_state,
- output_final_state=True,
- use_qk_l2norm_in_kernel=True,
-):
- initial_dtype = query.dtype
- if use_qk_l2norm_in_kernel:
- head_dim = query.size(-1)
- inv_scale = head_dim**-0.5
- query = F.rms_norm(query, (head_dim, ), eps=1e-6) * inv_scale
- key = F.rms_norm(key, (head_dim, ), eps=1e-6) * inv_scale
- query, key, value, beta, g = [
- x.transpose(1, 2).contiguous().to(torch.float32)
- for x in (query, key, value, beta, g)
- ]
-
- batch_size, sequence_length, num_heads, k_head_dim = key.shape
- v_head_dim = value.shape[-1]
- scale = 1 / (query.shape[-1]**0.5)
- query = query * scale
-
- recurrent_state = recurrent_state.to(value)
-
- if num_heads > 1:
- core_attn_out = torch.zeros(batch_size, sequence_length, num_heads,
- v_head_dim).to(value)
- for i in range(num_heads):
- q_t = query[:, :, i]
- k_t = key[:, :, i]
- v_t = value[:, :, i]
- g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
- beta_t = beta[:, :, i].unsqueeze(-1)
-
- recurrent_state = recurrent_state * g_t
- kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
- delta = (v_t - kv_mem) * beta_t
- recurrent_state = recurrent_state + k_t.unsqueeze(
- -1) * delta.unsqueeze(-2)
- core_attn_out[:, :, i] = (recurrent_state *
- q_t.unsqueeze(-1)).sum(dim=-2)
- else:
- q_t = query.squeeze(-2)
- k_t = key.squeeze(-2)
- v_t = value.squeeze(-2)
- g_t = g.squeeze(-1).exp().unsqueeze(-1).unsqueeze(-1)
- beta_t = beta
-
- recurrent_state = recurrent_state * g_t
- kv_mem = (recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
- delta = (v_t - kv_mem) * beta_t
- recurrent_state.add_(k_t.unsqueeze(-1) * delta.unsqueeze(-2))
- core_attn_out = (recurrent_state *
- q_t.unsqueeze(-1)).sum(dim=-2).unsqueeze(-2)
-
- if not output_final_state:
- recurrent_state = None
- else:
- recurrent_state = recurrent_state.to(initial_dtype)
- core_attn_out = core_attn_out.transpose(1,
- 2).contiguous().to(initial_dtype)
- return core_attn_out, recurrent_state
-
-
class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(
@@ -271,7 +103,7 @@ def __init__(
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
- renormalize=config.norm_topk_prob,
+ renormalize=getattr(config, "norm_topk_prob", True),
quant_config=quant_config,
prefix=f"{prefix}.experts")
@@ -765,11 +597,15 @@ def __init__(
prefix=f"{prefix}.o_proj",
)
+ if hasattr(config, "rope_theta"):
+ rope_theta = config.rope_theta
+ else:
+ rope_theta = config.rope_parameters.get("rope_theta", 10000)
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
- base=config.rope_theta,
+ base=rope_theta,
rope_scaling=config.rope_scaling,
partial_rotary_factor=config.partial_rotary_factor,
dual_chunk_attention_config=self.dual_chunk_attention_config,
@@ -972,7 +808,7 @@ class Qwen3NextModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
- config: Qwen3NextConfig = vllm_config.model_config.hf_config
+ config: Qwen3NextConfig = vllm_config.model_config.hf_text_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
@@ -1054,7 +890,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
- num_experts=self.config.num_experts)
+ num_experts=getattr(self.config, "num_experts", 0))
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
@@ -1133,8 +969,52 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params
+class QwenNextMixtureOfExperts(MixtureOfExperts):
+ def update_physical_experts_metadata(
+ self,
+ num_physical_experts: int,
+ num_local_physical_experts: int,
+ ) -> None:
+ assert self.num_local_physical_experts == num_local_physical_experts
+ self.num_physical_experts = num_physical_experts
+ self.num_local_physical_experts = num_local_physical_experts
+ self.num_redundant_experts = num_physical_experts - self.num_logical_experts
+ for layer in self.model.layers:
+ if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
+ moe = layer.mlp
+ moe.n_local_physical_experts = num_local_physical_experts
+ moe.n_physical_experts = num_physical_experts
+ moe.n_redundant_experts = self.num_redundant_experts
+ moe.experts.update_expert_map()
+
+ def set_moe_parameters(self):
+ self.expert_weights = []
+
+ self.moe_layers = []
+ example_moe = None
+ for layer in self.model.layers:
+ if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
+ layer.mlp, Qwen3NextSparseMoeBlock
+ ):
+ example_moe = layer.mlp
+ self.moe_layers.append(layer.mlp.experts)
+
+ if example_moe is None:
+ raise RuntimeError("No Qwen3Next layer found in the model.layers.")
+
+ # Set MoE hyperparameters
+ self.num_moe_layers = len(self.moe_layers)
+ self.num_expert_groups = 1
+ self.num_shared_experts = 0
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+
class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
- MixtureOfExperts, IsHybrid):
+ QwenNextMixtureOfExperts, IsHybrid):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1145,7 +1025,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- config = vllm_config.model_config.hf_config
+ config = vllm_config.model_config.hf_text_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
@@ -1283,7 +1163,7 @@ def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, int], tuple[int, int]]:
parallel_config = vllm_config.parallel_config
- hf_config = vllm_config.model_config.hf_config
+ hf_config = vllm_config.model_config.hf_text_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0)
diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py
index b0a079a38c73..429c7196cc17 100644
--- a/vllm/model_executor/models/qwen3_vl.py
+++ b/vllm/model_executor/models/qwen3_vl.py
@@ -1456,14 +1456,15 @@ def get_input_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
- if deepstack_input_embeds is None:
+ if deepstack_input_embeds is None and self.use_deepstack:
deepstack_input_embeds = torch.zeros(inputs_embeds.size(0),
inputs_embeds.size(1),
self.deepstack_num_level *
self.text_dim,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype)
- inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds),
+ if deepstack_input_embeds is not None:
+ inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds),
dim=-1)
return inputs_embeds
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index e5deedb5185d..0713314e642d 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -236,6 +236,14 @@
),
"Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501
"Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), # noqa: E501
+ "Qwen3_5ForConditionalGeneration": (
+ "qwen3_5",
+ "Qwen3_5ForConditionalGeneration",
+ ),
+ "Qwen3_5MoeForConditionalGeneration": (
+ "qwen3_5",
+ "Qwen3_5MoeForConditionalGeneration",
+ ),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
@@ -257,6 +265,8 @@
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
+ "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
+ "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
}
_TRANSFORMERS_MODELS = {
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index b5135f9ebb82..f534dca463ed 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -91,8 +91,8 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
- tokenizer_all_special_tokens_extended = (
- tokenizer.all_special_tokens_extended)
+ # tokenizer_all_special_tokens_extended = (
+ # tokenizer.all_special_tokens_extended)
tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer)
@@ -115,9 +115,9 @@ def all_special_ids(self) -> list[int]:
def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens
- @property
- def all_special_tokens_extended(self) -> list[str]:
- return tokenizer_all_special_tokens_extended
+ # @property
+ # def all_special_tokens_extended(self) -> list[str]:
+ # return tokenizer_all_special_tokens_extended
@property
def max_token_id(self) -> int:
diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py
index e938f3bfc671..5827f7618a12 100644
--- a/vllm/v1/kv_cache_interface.py
+++ b/vllm/v1/kv_cache_interface.py
@@ -3,6 +3,7 @@
import copy
from dataclasses import dataclass
+from math import prod
from typing import Optional
import torch
@@ -154,6 +155,36 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
+@dataclass
+class MambaSpec(KVCacheSpec):
+ shapes: tuple[tuple[int, ...], ...]
+ dtypes: tuple[torch.dtype]
+ page_size_padded: int | None = None
+ mamba_type: str = "mamba2"
+ mamba_cache_mode: str = "none"
+ num_speculative_blocks: int = 0
+
+ @property
+ def page_size_bytes(self) -> int:
+ page_size = sum(
+ prod(shape) * get_dtype_size(dtype)
+ for (shape, dtype) in zip(self.shapes, self.dtypes)
+ )
+ if self.page_size_padded is not None:
+ assert self.page_size_padded >= page_size
+ return self.page_size_padded
+ return page_size
+
+ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
+ if vllm_config.cache_config.mamba_cache_mode == "all":
+ max_model_len = vllm_config.model_config.max_model_len
+ return cdiv(max_model_len, self.block_size) * self.page_size_bytes
+ elif vllm_config.cache_config.mamba_cache_mode == "align":
+ return self.page_size_bytes * (2 + self.num_speculative_blocks)
+ else:
+ return self.page_size_bytes * (1 + self.num_speculative_blocks)
+
+
@dataclass
class KVCacheTensor:
"""
diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py
index 788caa270778..6650dc68df1b 100644
--- a/vllm/worker/hpu_model_runner.py
+++ b/vllm/worker/hpu_model_runner.py
@@ -1761,7 +1761,10 @@ def get_model(self) -> torch.nn.Module:
# fla is short for Flat Linear Attention
def _is_fla_model(self):
- return hasattr(self.model_config.hf_config, "linear_conv_kernel_dim")
+ return (hasattr(self.model_config.hf_config, "linear_conv_kernel_dim") or
+ (hasattr(self.model_config.hf_config, "text_config") and
+ hasattr(self.model_config.hf_config.text_config,
+ "linear_conv_kernel_dim")))
def _use_graphs(self, batch_size, seq_len, ctx_blocks=0):
if self.enforce_eager:
@@ -1982,8 +1985,12 @@ def _prepare_prompt(
# TODO: if seq_len < conv_kernel_dim, padding token should be
# masked in the prompt stage
if self._is_fla_model():
+ if hasattr(self.model_config.hf_config, "linear_conv_kernel_dim"):
+ linear_conv_kernel_dim = self.model_config.hf_config.linear_conv_kernel_dim
+ else:
+ linear_conv_kernel_dim = self.model_config.hf_config.text_config.linear_conv_kernel_dim
conv_state_indices_list.append(list(range(seq_len + 1 - \
- self.model_config.hf_config.linear_conv_kernel_dim, seq_len)))
+ linear_conv_kernel_dim, seq_len)))
token_types_ids = seq_group_metadata.token_type_ids
token_types.append(token_types_ids) if token_types_ids else []
@@ -3325,7 +3332,8 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
embed_dim = 1176
if any([
model_type in self.get_model().config.model_type
- for model_type in ['qwen3_vl', "qwen3_omni"]
+ for model_type in ['qwen3_vl', "qwen3_omni",
+ 'qwen3_5', 'qwen3_5_moe']
]):
embed_dim = 1536
elif 'ernie4_5_moe_vl' in self.get_model().config.model_type:
@@ -3585,7 +3593,10 @@ def _inc_preprocess(self):
def add_fla_dummy_data(self, inputs) -> None:
assert self._is_fla_model()
- conv_dim = self.model_config.hf_config.linear_conv_kernel_dim
+ if hasattr(self.model_config.hf_config, "linear_conv_kernel_dim"):
+ conv_dim = self.model_config.hf_config.linear_conv_kernel_dim
+ else:
+ conv_dim = self.model_config.hf_config.text_config.linear_conv_kernel_dim
bs, seq_len = inputs.input_tokens.shape
mamba_cache_indices = list(range(bs))
mamba_cache_indices = torch.tensor(mamba_cache_indices,