Skip to content
Open

Qwen3.5 #2234

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ Specified using `--task generate`.
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B`
| `Qwen3_5ForConditionalGeneration` | Qwen3.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3.5-9B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3_5MoeForConditionalGeneration` | Qwen3.5-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3.5-35B-A3B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | | ✅︎ | ✅︎\* |
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ |
Expand Down
164 changes: 94 additions & 70 deletions examples/offline_inference/basic/chat.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
import argparse
import os

os.environ["PT_HPU_LAZY_MODE"] = "1"

from vllm import LLM, EngineArgs, SamplingParams

# Parse the command-line arguments.
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="facebook/opt-125m",
help="The model path.",
)
parser.add_argument("--tp-size", type=int, default=2, help="The number of threads.")
parser.add_argument(
"--output-tokens", type=int, default=512, help="The number of output tokens."
)
parser.add_argument(
"--max-model-length", type=int, default=16384, help="Max model length."
)
parser.add_argument("--enable-ep", action="store_true", help="Enable EP for MOE models")
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-p", type=float, default=0.95)
parser.add_argument("--enable-thinking", action="store_true", help="Enable think mode for inference")

Check failure on line 29 in examples/offline_inference/basic/chat.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

examples/offline_inference/basic/chat.py:29:89: E501 Line too long (101 > 88)
# Add example params
parser.add_argument("--chat-template-path", type=str)
args = parser.parse_args()

os.environ["VLLM_SKIP_WARMUP"] = "true"
os.environ["HABANA_VISIBLE_DEVICES"] = "ALL"
os.environ["PT_HPU_ENABLE_LAZY_COLLECTIVES"] = "true"
os.environ["PT_HPU_WEIGHT_SHARING"] = "0"


def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
# Add example params
parser.add_argument("--chat-template-path", type=str)

return parser


def main(args: dict):
# Pop arguments not used by LLM
max_tokens = args.pop("max_tokens")
temperature = args.pop("temperature")
top_p = args.pop("top_p")
top_k = args.pop("top_k")
chat_template_path = args.pop("chat_template_path")

# Create an LLM
llm = LLM(**args)

# Create sampling params object
sampling_params = llm.get_default_sampling_params()
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if temperature is not None:
sampling_params.temperature = temperature
if top_p is not None:
sampling_params.top_p = top_p
if top_k is not None:
sampling_params.top_k = top_k
if __name__ == "__main__":
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
messages = []
for idx in range(len(prompts)):
conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompts[idx]
},
]
messages.append(conversation)
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.output_tokens,
)
chat_template_path = args.chat_template_path
model = args.model
if args.tp_size == 1:
llm = LLM(
model=model,
tokenizer=model,
trust_remote_code=True,
dtype="bfloat16",
max_model_len=args.max_model_length,
)
else:
llm = LLM(
model=model,
tokenizer=model,
tensor_parallel_size=args.tp_size,
distributed_executor_backend="mp",
trust_remote_code=True,
max_model_len=args.max_model_length,
enable_expert_parallel=args.enable_ep,
dtype="bfloat16",
)

def print_outputs(outputs):
print("\nGenerated Outputs:\n" + "-" * 80)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
for idx in range(len(outputs)):
prompt = prompts[idx]
generated_text = outputs[idx].outputs[0].text
print(f"Prompt: {prompt!r}\n")
print(f"Generated text: {generated_text!r}")
print("-" * 80)

print("=" * 80)

# In this script, we demonstrate how to pass input to the chat method:
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
print_outputs(outputs)

# You can run batch inference with llm.chat API
conversations = [conversation for _ in range(10)]

# We turn on tqdm progress bar to verify it's indeed running batch inference
outputs = llm.chat(conversations, sampling_params, use_tqdm=True)
print_outputs(outputs)

# A chat template can be optionally supplied.
# If not, the model will use its default chat template.
if chat_template_path is not None:
with open(chat_template_path) as f:
chat_template = f.read()

outputs = llm.chat(
conversations,
messages,
sampling_params,
use_tqdm=False,
chat_template=chat_template,
chat_template_kwargs={"enable_thinking": args.enable_thinking},
)


if __name__ == "__main__":
parser = create_parser()
args: dict = vars(parser.parse_args())
main(args)
else:
outputs = llm.chat(
messages,
sampling_params,
chat_template_kwargs={"enable_thinking": args.enable_thinking}
)
print_outputs(outputs)
78 changes: 78 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,80 @@ def run_qwen3_omni_moe(questions: list[str], modality: str) -> ModelRequestData:
)


# Qwen3.5 Dense
def run_qwen3_5(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen3.5-4B-Base"

engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
)

if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"

prompts = [
(
f"<|vision_start|>{placeholder}<|vision_end|>"
f"{question}"
)
for question in questions
]

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# Qwen3.5 MoE
def run_qwen3_5_moe(questions: list[str], modality: str) -> ModelRequestData:
model_name = "/data/Qwen3.5-397B-A17B-FP8-G2"

engine_args = EngineArgs(
model=model_name,
max_model_len=16384,
max_num_seqs=5,
enable_expert_parallel=True,
trust_remote_code=True,
tensor_parallel_size=8,
distributed_executor_backend="mp",
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
)

if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"

prompts = [
(
f"<|vision_start|>{placeholder}<|vision_end|>"
f"{question}"
)
for question in questions
]

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
Expand Down Expand Up @@ -1516,6 +1590,8 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"qwen3_vl": run_qwen3_vl,
"qwen3_vl_moe": run_qwen3_vl_moe,
"qwen3_omni_moe": run_qwen3_omni_moe,
"qwen3_5": run_qwen3_5,
"qwen3_5_moe": run_qwen3_5_moe,
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
"tarsier": run_tarsier,
Expand All @@ -1526,6 +1602,8 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"glm4_5v",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
]


Expand Down
20 changes: 20 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,26 @@ def check_available_online(
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57"),
"Qwen3_5ForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3.5-9B-Instruct",
max_model_len=4096,
min_transformers_version="5.1.0",
),
"Qwen3_5MoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3.5-35B-A3B-Instruct",
max_model_len=4096,
min_transformers_version="5.1.0",
),
"Qwen3_5MTP": _HfExamplesInfo(
"Qwen/Qwen3.5-9B-Instruct",
speculative_model="Qwen/Qwen3.5-9B-Instruct",
min_transformers_version="5.1.0",
),
"Qwen3_5MoeMTP": _HfExamplesInfo(
"Qwen/Qwen3.5-35B-A3B-Instruct",
speculative_model="Qwen/Qwen3.5-35B-A3B-Instruct",
min_transformers_version="5.1.0",
),
"Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-Omni-30B-A3B-Instruct", # noqa: E501
max_model_len=4096, # noqa: E501
min_transformers_version="4.57"), # noqa: E501
Expand Down
14 changes: 12 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,8 +1299,8 @@ def get_num_layers_by_block_type(
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])

# Hybrid model Qwen3Next
layer_types_value = getattr(self.hf_config, "layer_types", None)
# Hybrid model Qwen3Next Qwen3.5 Series
layer_types_value = getattr(self.hf_text_config, "layer_types", None)
if layer_types_value is not None:
if getattr(block_type, "value", block_type) == "attention":
return sum(t == "full_attention"
Expand Down Expand Up @@ -2575,6 +2575,16 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
"n_predict": n_predict,
"architectures": ["Glm4MoeMTPModel"]
})
if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
is_moe = hf_config.model_type == "qwen3_5_moe"
hf_config.model_type = "qwen3_5_mtp"
n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
hf_config.update(
{
"n_predict": n_predict,
"architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
}
)

return hf_config

Expand Down
25 changes: 18 additions & 7 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ def forward_cuda(
return self.forward_native(x, residual)


class RMSNormGated(nn.Module):
@CustomOp.register("rms_normgated")
class RMSNormGated(CustomOp):

def __init__(
self,
Expand All @@ -347,20 +348,30 @@ def __init__(
def reset_parameters(self):
torch.nn.init.ones_(self.weight)

def forward(self, hidden_states, gate=None):
def forward_native(self, hidden_states, gate=None):
"""
If z is not None, we do norm(x) * silu(z)
if norm_before_gate, else norm(x * silu(z))
"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# Norm before gate
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
hidden_states = self.weight.to(hidden_states.dtype) * hidden_states
if gate is not None:
hidden_states = hidden_states * F.silu(gate.to(hidden_states.dtype))

return hidden_states.to(input_dtype)
return hidden_states

def forward_hpu(self, hidden_states, gate=None):
from vllm_hpu_extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm()

hidden_states = HPUFusedRMSNorm.apply(hidden_states,
self.weight.to(hidden_states.dtype),
self.eps)
if gate is not None:
hidden_states = hidden_states * F.silu(gate.to(hidden_states.dtype))
return hidden_states


class MiniMaxText01RMSNormTP(CustomOp):
Expand Down
Loading
Loading