Skip to content

Commit 599e75f

Browse files
authored
[ROCm] [Bugfix] Fix DeepSeek V4 Functionality and Accuracy (vllm-project#42810)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
1 parent 1c8e9c0 commit 599e75f

4 files changed

Lines changed: 88 additions & 177 deletions

File tree

vllm/model_executor/layers/mhc.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -61,31 +61,35 @@ def forward_hip(
6161
sinkhorn_repeat: int,
6262
n_splits: int = 1,
6363
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
64-
hidden_size = residual.shape[-1]
65-
if hidden_size % 256 == 0:
66-
return torch.ops.vllm.mhc_pre_aiter(
67-
residual,
68-
fn,
69-
hc_scale,
70-
hc_base,
71-
rms_eps,
72-
hc_pre_eps,
73-
hc_sinkhorn_eps,
74-
hc_post_mult_value,
75-
sinkhorn_repeat,
76-
)
77-
else:
78-
return mhc_kernels.mhc_pre_torch(
79-
residual,
80-
fn,
81-
hc_scale,
82-
hc_base,
83-
rms_eps,
84-
hc_pre_eps,
85-
hc_sinkhorn_eps,
86-
hc_post_mult_value,
87-
sinkhorn_repeat,
88-
)
64+
# TODO: Reenable aiter after we are at the aiter
65+
# version that has this bugfix
66+
# https://github.com/ROCm/aiter/commit/b639cb63bcac4672dce33a731fad042a65cb3649
67+
# It has accuracy problem at large number of tokens.
68+
# hidden_size = residual.shape[-1]
69+
# if hidden_size % 256 == 0:
70+
# return torch.ops.vllm.mhc_pre_aiter(
71+
# residual,
72+
# fn,
73+
# hc_scale,
74+
# hc_base,
75+
# rms_eps,
76+
# hc_pre_eps,
77+
# hc_sinkhorn_eps,
78+
# hc_post_mult_value,
79+
# sinkhorn_repeat,
80+
# )
81+
# else:
82+
return mhc_kernels.mhc_pre_torch(
83+
residual,
84+
fn,
85+
hc_scale,
86+
hc_base,
87+
rms_eps,
88+
hc_pre_eps,
89+
hc_sinkhorn_eps,
90+
hc_post_mult_value,
91+
sinkhorn_repeat,
92+
)
8993

9094
def forward_native(self, *args, **kwargs):
9195
raise NotImplementedError("Native implementation of mhc_pre is not available")
@@ -124,21 +128,25 @@ def forward_hip(
124128
post_layer_mix: torch.Tensor,
125129
comb_res_mix: torch.Tensor,
126130
) -> torch.Tensor:
127-
hidden_size = residual.shape[-1]
128-
if hidden_size % 256 == 0:
129-
return torch.ops.vllm.mhc_post_aiter(
130-
x,
131-
residual,
132-
post_layer_mix,
133-
comb_res_mix,
134-
)
135-
else:
136-
return mhc_kernels.mhc_post_torch(
137-
x,
138-
residual,
139-
post_layer_mix,
140-
comb_res_mix,
141-
)
131+
# TODO: Reenable aiter after we are at the aiter
132+
# version that has this bugfix
133+
# https://github.com/ROCm/aiter/commit/b639cb63bcac4672dce33a731fad042a65cb3649
134+
# It has accuracy problem at large number of tokens.
135+
# hidden_size = residual.shape[-1]
136+
# if hidden_size % 256 == 0:
137+
# return torch.ops.vllm.mhc_post_aiter(
138+
# x,
139+
# residual,
140+
# post_layer_mix,
141+
# comb_res_mix,
142+
# )
143+
# else:
144+
return mhc_kernels.mhc_post_torch(
145+
x,
146+
residual,
147+
post_layer_mix,
148+
comb_res_mix,
149+
)
142150

143151
def forward_native(self, *args, **kwargs):
144152
raise NotImplementedError("Native implementation of mhc_post is not available")

vllm/model_executor/layers/sparse_attn_indexer.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -505,27 +505,6 @@ def forward_hip(
505505
assert isinstance(q_quant, torch.Tensor), (
506506
"AMD sparse_attn_indexer expects a single FP8 q_quant tensor"
507507
)
508-
if self.skip_k_cache_insert or not rocm_aiter_ops.is_enabled():
509-
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
510-
rocm_aiter_sparse_attn_indexer_native,
511-
)
512-
513-
return rocm_aiter_sparse_attn_indexer_native(
514-
hidden_states,
515-
_encode_layer_name(self.k_cache.prefix),
516-
self.k_cache.kv_cache,
517-
q_quant,
518-
k,
519-
weights,
520-
self.quant_block_size,
521-
self.scale_fmt,
522-
self.topk_tokens,
523-
self.head_dim,
524-
self.max_model_len,
525-
self.max_total_seq_len,
526-
self.topk_indices_buffer,
527-
skip_k_cache_insert=self.skip_k_cache_insert,
528-
)
529508
if rocm_aiter_ops.is_enabled():
530509
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
531510
hidden_states,
@@ -541,5 +520,9 @@ def forward_hip(
541520
self.max_model_len,
542521
self.max_total_seq_len,
543522
self.topk_indices_buffer,
523+
skip_k_cache_insert=self.skip_k_cache_insert,
544524
)
545-
raise RuntimeError("Sparse attention indexer ROCm path could not be selected.")
525+
raise RuntimeError(
526+
"Sparse attention indexer ROCm path is only supported on AITER. "
527+
"Please enable aiter with VLLM_ROCM_USE_AITER=1"
528+
)

vllm/model_executor/models/deepseek_v4.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,8 @@ def _forward_rocm(
12771277
x, post, comb = self.hc_pre(
12781278
x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base
12791279
)
1280-
x = self.ffn_norm(x)
1280+
# ffn_norm is now folded into self.ffn.norm_gate; ffn() takes
1281+
# the pre-norm activation directly.
12811282
x = self.ffn(x, input_ids)
12821283
x = self.hc_post(x, residual, post, comb)
12831284
return x, None, None, None

vllm/v1/attention/ops/rocm_aiter_mla_sparse.py

Lines changed: 33 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,11 @@ def rocm_fp8_mqa_logits(
542542
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
543543

544544

545-
def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor:
545+
def _topk_indices_torch(
546+
logits: torch.Tensor,
547+
topk_tokens: int,
548+
row_starts: torch.Tensor | None = None,
549+
) -> torch.Tensor:
546550
k = min(topk_tokens, logits.shape[-1])
547551
values, indices = torch.topk(logits, k=k, dim=-1)
548552
indices = indices.to(torch.int32)
@@ -551,6 +555,12 @@ def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor:
551555
torch.full_like(indices, -1, dtype=torch.int32),
552556
indices,
553557
)
558+
if row_starts is not None:
559+
# Match the CUDA top_k_per_row_prefill contract: indices are local to
560+
# each row's valid [row_start, row_end) range, not columns in the
561+
# concatenated chunk logits matrix.
562+
starts = row_starts.to(dtype=torch.int32).view(-1, 1)
563+
indices = torch.where(indices < 0, indices, indices - starts)
554564
if k == topk_tokens:
555565
return indices
556566
padded = torch.full(
@@ -563,64 +573,6 @@ def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor:
563573
return padded
564574

565575

566-
# topk_tokens values with dedicated fused C++ kernel support.
567-
_TOPK_FAST_PATH_VALUES = frozenset({2048})
568-
569-
570-
def _topk_indices_prefill(
571-
logits: torch.Tensor,
572-
topk_tokens: int,
573-
topk_out: torch.Tensor,
574-
cu_seqlen_ks: torch.Tensor,
575-
cu_seqlen_ke: torch.Tensor,
576-
) -> None:
577-
"""Top-k indices for the prefill path.
578-
579-
Writes ``logits.shape[0]`` rows into ``topk_out``; caller must size the
580-
view accordingly.
581-
"""
582-
if topk_tokens in _TOPK_FAST_PATH_VALUES:
583-
torch.ops._C.top_k_per_row_prefill(
584-
logits,
585-
cu_seqlen_ks,
586-
cu_seqlen_ke,
587-
topk_out,
588-
logits.shape[0],
589-
logits.stride(0),
590-
logits.stride(1),
591-
topk_tokens,
592-
)
593-
else:
594-
topk_out.copy_(_topk_indices_torch(logits, topk_tokens))
595-
596-
597-
def _topk_indices_decode(
598-
logits: torch.Tensor,
599-
topk_tokens: int,
600-
topk_out: torch.Tensor,
601-
seq_lens: torch.Tensor,
602-
next_n: int,
603-
) -> None:
604-
"""Top-k indices for the decode path.
605-
606-
Writes ``logits.shape[0] == batch_size * next_n`` rows into ``topk_out``;
607-
caller must size the view to ``num_padded_tokens``.
608-
"""
609-
if topk_tokens in _TOPK_FAST_PATH_VALUES:
610-
torch.ops._C.top_k_per_row_decode(
611-
logits,
612-
next_n,
613-
seq_lens,
614-
topk_out,
615-
logits.shape[0],
616-
logits.stride(0),
617-
logits.stride(1),
618-
topk_tokens,
619-
)
620-
else:
621-
topk_out.copy_(_topk_indices_torch(logits, topk_tokens))
622-
623-
624576
def rocm_aiter_sparse_attn_indexer_fake(
625577
hidden_states: torch.Tensor,
626578
k_cache_prefix: LayerNameType,
@@ -635,21 +587,13 @@ def rocm_aiter_sparse_attn_indexer_fake(
635587
max_model_len: int,
636588
total_seq_lens: int,
637589
topk_indices_buffer: torch.Tensor | None,
590+
skip_k_cache_insert: bool = False,
638591
) -> torch.Tensor:
639-
# profile run
640-
# NOTE(Chen): create the max possible flattened_kv. So that
641-
# profile_run can get correct memory usage.
642-
device = hidden_states.device if k is None else k.device
643-
_flattened_kv = torch.empty(
644-
[total_seq_lens, head_dim + 4], device=device, dtype=torch.uint8
645-
)
646-
fp8_dtype = current_platform.fp8_dtype()
647-
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
648-
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
649592
return topk_indices_buffer
650593

651594

652-
def rocm_aiter_sparse_attn_indexer_native(
595+
@eager_break_during_capture
596+
def rocm_aiter_sparse_attn_indexer(
653597
hidden_states: torch.Tensor,
654598
k_cache_prefix: LayerNameType,
655599
kv_cache: torch.Tensor,
@@ -688,6 +632,7 @@ def rocm_aiter_sparse_attn_indexer_native(
688632
max_model_len,
689633
total_seq_lens,
690634
topk_indices_buffer,
635+
skip_k_cache_insert,
691636
)
692637
layer_attn_metadata = attn_metadata[k_cache_prefix]
693638
assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata)
@@ -768,12 +713,18 @@ def rocm_aiter_sparse_attn_indexer_native(
768713
topk_indices = topk_indices_buffer[
769714
chunk.token_start : chunk.token_end, :topk_tokens
770715
]
771-
_topk_indices_prefill(
716+
717+
num_rows = logits.shape[0]
718+
719+
torch.ops._C.top_k_per_row_prefill(
772720
logits,
773-
topk_tokens,
774-
topk_indices,
775721
chunk.cu_seqlen_ks,
776722
chunk.cu_seqlen_ke,
723+
topk_indices,
724+
num_rows,
725+
logits.stride(0),
726+
logits.stride(1),
727+
topk_tokens,
777728
)
778729

779730
if has_decode:
@@ -811,16 +762,18 @@ def rocm_aiter_sparse_attn_indexer_native(
811762
max_model_len=max_model_len,
812763
)
813764

814-
# Size the view to num_padded_tokens: top_k_per_row_decode writes
815-
# logits.shape[0] == num_padded_tokens rows, and the unpack below
816-
# reshapes to (batch_size, next_n, ...).
817765
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
818-
_topk_indices_decode(
766+
num_rows = logits.shape[0]
767+
768+
torch.ops._C.top_k_per_row_decode(
819769
logits,
820-
topk_tokens,
821-
topk_indices,
822-
decode_metadata.seq_lens,
823770
next_n,
771+
decode_metadata.seq_lens,
772+
topk_indices,
773+
num_rows,
774+
logits.stride(0),
775+
logits.stride(1),
776+
topk_tokens,
824777
)
825778

826779
if decode_metadata.requires_padding:
@@ -837,40 +790,6 @@ def rocm_aiter_sparse_attn_indexer_native(
837790
return topk_indices_buffer
838791

839792

840-
@eager_break_during_capture
841-
def rocm_aiter_sparse_attn_indexer(
842-
hidden_states: torch.Tensor,
843-
k_cache_prefix: LayerNameType,
844-
kv_cache: torch.Tensor,
845-
q_fp8: torch.Tensor,
846-
k: torch.Tensor,
847-
weights: torch.Tensor,
848-
quant_block_size: int,
849-
scale_fmt: str | None,
850-
topk_tokens: int,
851-
head_dim: int,
852-
max_model_len: int,
853-
total_seq_lens: int,
854-
topk_indices_buffer: torch.Tensor | None,
855-
) -> torch.Tensor:
856-
return rocm_aiter_sparse_attn_indexer_native(
857-
hidden_states,
858-
k_cache_prefix,
859-
kv_cache,
860-
q_fp8,
861-
k,
862-
weights,
863-
quant_block_size,
864-
scale_fmt,
865-
topk_tokens,
866-
head_dim,
867-
max_model_len,
868-
total_seq_lens,
869-
topk_indices_buffer,
870-
skip_k_cache_insert=False,
871-
)
872-
873-
874793
def _decode_e8m0_scales(scale: torch.Tensor) -> torch.Tensor:
875794
if scale.dtype == torch.float8_e8m0fnu:
876795
from vllm.model_executor.layers.quantization.utils.fp8_utils import (

0 commit comments

Comments
 (0)