Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,9 @@ def forward_chunked_prefill(
decode_batch_size = 0
decode_seq_len = 0
decode_hidden_size = 0
query = query.squeeze(0)
key = key.squeeze(0)
value = value.squeeze(0)
if attn_metadata.num_prefills > 0:
attn_data = self.preprocess_forward(
query[:attn_metadata.num_prefill_tokens],
Expand Down
5 changes: 0 additions & 5 deletions vllm/distributed/device_communicators/hpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
class HpuCommunicator(DeviceCommunicatorBase):

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not suggest to do it.

dist.all_reduce(input_, group=self.device_group)
return input_

Expand All @@ -33,7 +29,6 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
dtype=input_.dtype,
device=input_.device)
# All-gather.
htorch.core.mark_step()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not safe to remove it. Sometimes it will cause hang or the accuracy issue.

dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
Expand Down
154 changes: 131 additions & 23 deletions vllm/model_executor/models/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import (MiniMaxText01RMSNormTP,
RMSNorm)
from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -123,7 +122,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states.to(torch.float32))
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
final_hidden_states = final_hidden_states
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
Expand Down Expand Up @@ -172,7 +170,7 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

'''
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
Expand All @@ -182,6 +180,29 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
'''

self.q_proj = ReplicatedLinear(
hidden_size,
self.head_dim * self.total_num_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.k_proj = ReplicatedLinear(
hidden_size,
self.head_dim * self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.k_proj",
)
self.v_proj = ColumnParallelLinear(
hidden_size,
self.head_dim * self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.v_proj",
)

self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
Expand Down Expand Up @@ -209,26 +230,113 @@ def __init__(
prefix=f"{prefix}.attn",
)

self.q_norm = MiniMaxText01RMSNormTP(self.head_dim *
self.total_num_heads,
eps=rms_norm_eps)
self.k_norm = MiniMaxText01RMSNormTP(self.head_dim *
self.total_num_kv_heads,
eps=rms_norm_eps)
self.q_norm = RMSNorm(self.head_dim * self.total_num_heads,
eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim * self.total_num_kv_heads,
eps=rms_norm_eps)

self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
tp_size = self.tp_size
tp_rank = self.tp_rank
orig_shape = hidden_states.shape
bs = orig_shape[0] if len(orig_shape) == 3 else 1
seq = orig_shape[-2]
x = hidden_states.reshape(-1, self.hidden_size)
x_fp8 = torch.ops.hpu.cast_to_fp8_v2(x, 1.0 / self.q_proj.input_scale,
False, False,
torch.float8_e4m3fn)[0]
qweight_slice = self.q_proj.weight.size(1) // tp_size
kweight_slice = self.k_proj.weight.size(1) // tp_size

if bs * seq > 256:
W_Q = self.q_proj.weight.transpose(0, 1)[ \
tp_rank * qweight_slice : (tp_rank + 1) * qweight_slice, :]
W_K = self.k_proj.weight.transpose(0, 1)[ \
tp_rank * kweight_slice : (tp_rank + 1) * kweight_slice, :]
S_Q = self.q_proj.weight_scale[ \
tp_rank * qweight_slice : (tp_rank + 1) * qweight_slice]
S_K = self.k_proj.weight_scale[ \
tp_rank * kweight_slice : (tp_rank + 1) * kweight_slice]
else:
W_Q = self.q_proj.weight.transpose(0, 1)
W_K = self.k_proj.weight.transpose(0, 1)
S_Q = self.q_proj.weight_scale
S_K = self.k_proj.weight_scale
W_V = self.v_proj.weight.transpose(0, 1)
S_V = self.v_proj.weight_scale

q = torch.ops.hpu.fp8_gemm_v2(A=x_fp8,
trans_A=False,
B=W_Q.contiguous(),
trans_B=True,
D=None,
out_dtype=torch.bfloat16,
A_scale_inv=self.q_proj.input_scale,
B_scale_inv=S_Q,
bias=None,
accumulate=False).reshape(bs, seq, -1)
k = torch.ops.hpu.fp8_gemm_v2(A=x_fp8,
trans_A=False,
B=W_K.contiguous(),
trans_B=True,
D=None,
out_dtype=torch.bfloat16,
A_scale_inv=self.q_proj.input_scale,
B_scale_inv=S_K,
bias=None,
accumulate=False).reshape(bs, seq, -1)
v = torch.ops.hpu.fp8_gemm_v2(A=x_fp8,
trans_A=False,
B=W_V.contiguous(),
trans_B=True,
D=None,
out_dtype=torch.bfloat16,
A_scale_inv=self.q_proj.input_scale,
B_scale_inv=S_V,
bias=None,
accumulate=False).reshape(bs, seq, -1)

if bs * seq > 256:
qnorm_slice = self.head_dim * self.total_num_heads // tp_size
knorm_slice = self.head_dim * self.total_num_kv_heads // tp_size
qnorm_weight = self.q_norm.weight[ \
tp_rank * qnorm_slice : (tp_rank + 1) * qnorm_slice]
knorm_weight = self.k_norm.weight[ \
tp_rank * knorm_slice : (tp_rank + 1) * knorm_slice]
orig_dtype = q.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
q_var = q.pow(2).mean(dim=-1).unsqueeze(1)
k_var = k.pow(2).mean(dim=-1).unsqueeze(1)
if tp_size > 1:
qk_var = torch.cat([q_var, k_var], dim=1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_size
q_var, k_var = qk_var.chunk(2, dim=1)
q = q * torch.rsqrt(
q_var.transpose(-1, -2) +
self.q_norm.variance_epsilon) * qnorm_weight
k = k * torch.rsqrt(
k_var.transpose(-1, -2) +
self.k_norm.variance_epsilon) * knorm_weight
q = q.to(orig_dtype)
k = k.to(orig_dtype)
else:
q = self.q_norm(q)
k = self.k_norm(k)
q = q.reshape(bs, seq, tp_size, -1).permute(2, 0, 1, 3)[tp_rank]
k = k.reshape(bs, seq, tp_size, -1).permute(2, 0, 1, 3)[tp_rank]

q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
return output.reshape(orig_shape)


class MiniMaxM2DecoderLayer(nn.Module):
Expand Down Expand Up @@ -396,12 +504,12 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
stacked_params_mapping = []
# (param_name, shard_name, shard_id)
# ("qkv_proj", "q_proj", "q"),
# ("qkv_proj", "k_proj", "k"),
# ("qkv_proj", "v_proj", "v"),
# ]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
Expand Down