From a0c3d28fcb1f506a73eb614fd9244a6eba3a84d8 Mon Sep 17 00:00:00 2001 From: "Lin, Wei" Date: Tue, 3 Feb 2026 10:40:57 +0800 Subject: [PATCH 1/4] disable TP for the qk proj and norm in Minimax-M2 Signed-off-by: Youlei Yang --- vllm/model_executor/models/minimax_m2.py | 150 +++++++++++++++++++---- 1 file changed, 128 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index dfcf25adc5e0..a8b301a7d0bd 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -34,13 +34,12 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import (MiniMaxText01RMSNormTP, - RMSNorm) -from vllm.model_executor.layers.linear import (QKVParallelLinear, +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -123,7 +122,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states.to(torch.float32)) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - final_hidden_states = final_hidden_states if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -172,7 +170,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - + ''' self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -182,6 +180,29 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + ''' + + self.q_proj = ReplicatedLinear( + hidden_size, + self.head_dim * self.total_num_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.k_proj = ReplicatedLinear( + hidden_size, + self.head_dim * self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.k_proj", + ) + self.v_proj = ColumnParallelLinear( + hidden_size, + self.head_dim * self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.v_proj", + ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -209,22 +230,107 @@ def __init__( prefix=f"{prefix}.attn", ) - self.q_norm = MiniMaxText01RMSNormTP(self.head_dim * - self.total_num_heads, - eps=rms_norm_eps) - self.k_norm = MiniMaxText01RMSNormTP(self.head_dim * - self.total_num_kv_heads, - eps=rms_norm_eps) + self.q_norm = RMSNorm(self.head_dim * self.total_num_heads, + eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim * self.total_num_kv_heads, + eps=rms_norm_eps) + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = self.q_norm(q) - k = self.k_norm(k) + tp_size = self.tp_size + tp_rank = self.tp_rank + bs, seq, _ = hidden_states.shape + x = hidden_states.reshape(-1, self.hidden_size) + x_fp8 = torch.ops.hpu.cast_to_fp8_v2(x, 1.0 / self.q_proj.input_scale, + False, False, + torch.float8_e4m3fn)[0] + qweight_slice = self.q_proj.weight.size(1) // tp_size + kweight_slice = self.k_proj.weight.size(1) // tp_size + + if bs * seq > 256: + W_Q = self.q_proj.weight.transpose(0, 1)[ \ + tp_rank * qweight_slice : (tp_rank + 1) * qweight_slice, :] + W_K = self.k_proj.weight.transpose(0, 1)[ \ + tp_rank * kweight_slice : (tp_rank + 1) * kweight_slice, :] + S_Q = self.q_proj.weight_scale[ \ + tp_rank * qweight_slice : (tp_rank + 1) * qweight_slice] + S_K = self.k_proj.weight_scale[ \ + tp_rank * kweight_slice : (tp_rank + 1) * kweight_slice] + else: + W_Q = self.q_proj.weight.transpose(0, 1) + W_K = self.k_proj.weight.transpose(0, 1) + S_Q = self.q_proj.weight_scale + S_K = self.k_proj.weight_scale + W_V = self.v_proj.weight.transpose(0, 1) + S_V = self.v_proj.weight_scale + + q = torch.ops.hpu.fp8_gemm_v2(A=x_fp8, + trans_A=False, + B=W_Q.contiguous(), + trans_B=True, + D=None, + out_dtype=torch.bfloat16, + A_scale_inv=self.q_proj.input_scale, + B_scale_inv=S_Q, + bias=None, + accumulate=False).reshape(bs, seq, -1) + k = torch.ops.hpu.fp8_gemm_v2(A=x_fp8, + trans_A=False, + B=W_K.contiguous(), + trans_B=True, + D=None, + out_dtype=torch.bfloat16, + A_scale_inv=self.q_proj.input_scale, + B_scale_inv=S_K, + bias=None, + accumulate=False).reshape(bs, seq, -1) + v = torch.ops.hpu.fp8_gemm_v2(A=x_fp8, + trans_A=False, + B=W_V.contiguous(), + trans_B=True, + D=None, + out_dtype=torch.bfloat16, + A_scale_inv=self.q_proj.input_scale, + B_scale_inv=S_V, + bias=None, + accumulate=False).reshape(bs, seq, -1) + + if bs * seq > 256: + qnorm_slice = self.head_dim * self.total_num_heads // tp_size + knorm_slice = self.head_dim * self.total_num_kv_heads // tp_size + qnorm_weight = self.q_norm.weight[ \ + tp_rank * qnorm_slice : (tp_rank + 1) * qnorm_slice] + knorm_weight = self.k_norm.weight[ \ + tp_rank * knorm_slice : (tp_rank + 1) * knorm_slice] + orig_dtype = q.dtype + q = q.to(torch.float32) + k = k.to(torch.float32) + q_var = q.pow(2).mean(dim=-1).unsqueeze(1) + k_var = k.pow(2).mean(dim=-1).unsqueeze(1) + if tp_size > 1: + qk_var = torch.cat([q_var, k_var], dim=1) + qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_size + q_var, k_var = qk_var.chunk(2, dim=1) + q = q * torch.rsqrt( + q_var.transpose(-1, -2) + + self.q_norm.variance_epsilon) * qnorm_weight + k = k * torch.rsqrt( + k_var.transpose(-1, -2) + + self.k_norm.variance_epsilon) * knorm_weight + q = q.to(orig_dtype) + k = k.to(orig_dtype) + else: + q = self.q_norm(q) + k = self.k_norm(k) + q = q.reshape(bs, seq, tp_size, -1).permute(2, 0, 1, 3)[tp_rank] + k = k.reshape(bs, seq, tp_size, -1).permute(2, 0, 1, 3)[tp_rank] + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -396,12 +502,12 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] + stacked_params_mapping = [] + # (param_name, shard_name, shard_id) + # ("qkv_proj", "q_proj", "q"), + # ("qkv_proj", "k_proj", "k"), + # ("qkv_proj", "v_proj", "v"), + # ] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) From 68de8da951b1547c259fd24be44ec8ae2ac55a6c Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Tue, 3 Feb 2026 10:53:08 +0800 Subject: [PATCH 2/4] remove mark_step before allreduce and allgather Signed-off-by: Youlei Yang --- vllm/distributed/device_communicators/hpu_communicator.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py index f00f6b62bf24..7a9c672e3d42 100644 --- a/vllm/distributed/device_communicators/hpu_communicator.py +++ b/vllm/distributed/device_communicators/hpu_communicator.py @@ -15,10 +15,6 @@ class HpuCommunicator(DeviceCommunicatorBase): def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge - # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used - # (which is required for tensor parallel HPUGraph inference) - htorch.core.mark_step() dist.all_reduce(input_, group=self.device_group) return input_ @@ -33,7 +29,6 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: dtype=input_.dtype, device=input_.device) # All-gather. - htorch.core.mark_step() dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) From 1c07b1d5d0e7a23cc809cc2880610c684a65481d Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Tue, 3 Feb 2026 15:50:22 +0800 Subject: [PATCH 3/4] fix dim error Signed-off-by: Youlei Yang --- vllm/model_executor/models/minimax_m2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index a8b301a7d0bd..df5f40814737 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -245,7 +245,8 @@ def forward( ) -> torch.Tensor: tp_size = self.tp_size tp_rank = self.tp_rank - bs, seq, _ = hidden_states.shape + bs = hidden_states.shape[0] if hidden_states.dim() == 3 else 1 + seq = hidden_states.shape[-2] x = hidden_states.reshape(-1, self.hidden_size) x_fp8 = torch.ops.hpu.cast_to_fp8_v2(x, 1.0 / self.q_proj.input_scale, False, False, From 1dee034f32343427df254e2800a37bf15868b1f4 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Wed, 4 Feb 2026 10:36:35 +0800 Subject: [PATCH 4/4] fix shape mismatch error for chunked prefill Signed-off-by: Youlei Yang --- vllm/attention/backends/hpu_attn.py | 3 +++ vllm/model_executor/models/minimax_m2.py | 7 ++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 80de7b026941..9019af60aa96 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -582,6 +582,9 @@ def forward_chunked_prefill( decode_batch_size = 0 decode_seq_len = 0 decode_hidden_size = 0 + query = query.squeeze(0) + key = key.squeeze(0) + value = value.squeeze(0) if attn_metadata.num_prefills > 0: attn_data = self.preprocess_forward( query[:attn_metadata.num_prefill_tokens], diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index df5f40814737..1f533fb0b283 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -245,8 +245,9 @@ def forward( ) -> torch.Tensor: tp_size = self.tp_size tp_rank = self.tp_rank - bs = hidden_states.shape[0] if hidden_states.dim() == 3 else 1 - seq = hidden_states.shape[-2] + orig_shape = hidden_states.shape + bs = orig_shape[0] if len(orig_shape) == 3 else 1 + seq = orig_shape[-2] x = hidden_states.reshape(-1, self.hidden_size) x_fp8 = torch.ops.hpu.cast_to_fp8_v2(x, 1.0 / self.q_proj.input_scale, False, False, @@ -335,7 +336,7 @@ def forward( q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) - return output + return output.reshape(orig_shape) class MiniMaxM2DecoderLayer(nn.Module):