@@ -542,7 +542,11 @@ def rocm_fp8_mqa_logits(
542542 return fp8_mqa_logits_torch (q , kv , weights , cu_seqlen_ks , cu_seqlen_ke )
543543
544544
545- def _topk_indices_torch (logits : torch .Tensor , topk_tokens : int ) -> torch .Tensor :
545+ def _topk_indices_torch (
546+ logits : torch .Tensor ,
547+ topk_tokens : int ,
548+ row_starts : torch .Tensor | None = None ,
549+ ) -> torch .Tensor :
546550 k = min (topk_tokens , logits .shape [- 1 ])
547551 values , indices = torch .topk (logits , k = k , dim = - 1 )
548552 indices = indices .to (torch .int32 )
@@ -551,6 +555,12 @@ def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor:
551555 torch .full_like (indices , - 1 , dtype = torch .int32 ),
552556 indices ,
553557 )
558+ if row_starts is not None :
559+ # Match the CUDA top_k_per_row_prefill contract: indices are local to
560+ # each row's valid [row_start, row_end) range, not columns in the
561+ # concatenated chunk logits matrix.
562+ starts = row_starts .to (dtype = torch .int32 ).view (- 1 , 1 )
563+ indices = torch .where (indices < 0 , indices , indices - starts )
554564 if k == topk_tokens :
555565 return indices
556566 padded = torch .full (
@@ -563,64 +573,6 @@ def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor:
563573 return padded
564574
565575
566- # topk_tokens values with dedicated fused C++ kernel support.
567- _TOPK_FAST_PATH_VALUES = frozenset ({2048 })
568-
569-
570- def _topk_indices_prefill (
571- logits : torch .Tensor ,
572- topk_tokens : int ,
573- topk_out : torch .Tensor ,
574- cu_seqlen_ks : torch .Tensor ,
575- cu_seqlen_ke : torch .Tensor ,
576- ) -> None :
577- """Top-k indices for the prefill path.
578-
579- Writes ``logits.shape[0]`` rows into ``topk_out``; caller must size the
580- view accordingly.
581- """
582- if topk_tokens in _TOPK_FAST_PATH_VALUES :
583- torch .ops ._C .top_k_per_row_prefill (
584- logits ,
585- cu_seqlen_ks ,
586- cu_seqlen_ke ,
587- topk_out ,
588- logits .shape [0 ],
589- logits .stride (0 ),
590- logits .stride (1 ),
591- topk_tokens ,
592- )
593- else :
594- topk_out .copy_ (_topk_indices_torch (logits , topk_tokens ))
595-
596-
597- def _topk_indices_decode (
598- logits : torch .Tensor ,
599- topk_tokens : int ,
600- topk_out : torch .Tensor ,
601- seq_lens : torch .Tensor ,
602- next_n : int ,
603- ) -> None :
604- """Top-k indices for the decode path.
605-
606- Writes ``logits.shape[0] == batch_size * next_n`` rows into ``topk_out``;
607- caller must size the view to ``num_padded_tokens``.
608- """
609- if topk_tokens in _TOPK_FAST_PATH_VALUES :
610- torch .ops ._C .top_k_per_row_decode (
611- logits ,
612- next_n ,
613- seq_lens ,
614- topk_out ,
615- logits .shape [0 ],
616- logits .stride (0 ),
617- logits .stride (1 ),
618- topk_tokens ,
619- )
620- else :
621- topk_out .copy_ (_topk_indices_torch (logits , topk_tokens ))
622-
623-
624576def rocm_aiter_sparse_attn_indexer_fake (
625577 hidden_states : torch .Tensor ,
626578 k_cache_prefix : LayerNameType ,
@@ -635,21 +587,13 @@ def rocm_aiter_sparse_attn_indexer_fake(
635587 max_model_len : int ,
636588 total_seq_lens : int ,
637589 topk_indices_buffer : torch .Tensor | None ,
590+ skip_k_cache_insert : bool = False ,
638591) -> torch .Tensor :
639- # profile run
640- # NOTE(Chen): create the max possible flattened_kv. So that
641- # profile_run can get correct memory usage.
642- device = hidden_states .device if k is None else k .device
643- _flattened_kv = torch .empty (
644- [total_seq_lens , head_dim + 4 ], device = device , dtype = torch .uint8
645- )
646- fp8_dtype = current_platform .fp8_dtype ()
647- _k_fp8 = _flattened_kv [..., :head_dim ].view (fp8_dtype ).contiguous ()
648- _k_scale = _flattened_kv [..., head_dim :].view (torch .float32 ).contiguous ()
649592 return topk_indices_buffer
650593
651594
652- def rocm_aiter_sparse_attn_indexer_native (
595+ @eager_break_during_capture
596+ def rocm_aiter_sparse_attn_indexer (
653597 hidden_states : torch .Tensor ,
654598 k_cache_prefix : LayerNameType ,
655599 kv_cache : torch .Tensor ,
@@ -688,6 +632,7 @@ def rocm_aiter_sparse_attn_indexer_native(
688632 max_model_len ,
689633 total_seq_lens ,
690634 topk_indices_buffer ,
635+ skip_k_cache_insert ,
691636 )
692637 layer_attn_metadata = attn_metadata [k_cache_prefix ]
693638 assert isinstance (layer_attn_metadata , DeepseekV32IndexerMetadata )
@@ -768,12 +713,18 @@ def rocm_aiter_sparse_attn_indexer_native(
768713 topk_indices = topk_indices_buffer [
769714 chunk .token_start : chunk .token_end , :topk_tokens
770715 ]
771- _topk_indices_prefill (
716+
717+ num_rows = logits .shape [0 ]
718+
719+ torch .ops ._C .top_k_per_row_prefill (
772720 logits ,
773- topk_tokens ,
774- topk_indices ,
775721 chunk .cu_seqlen_ks ,
776722 chunk .cu_seqlen_ke ,
723+ topk_indices ,
724+ num_rows ,
725+ logits .stride (0 ),
726+ logits .stride (1 ),
727+ topk_tokens ,
777728 )
778729
779730 if has_decode :
@@ -811,16 +762,18 @@ def rocm_aiter_sparse_attn_indexer_native(
811762 max_model_len = max_model_len ,
812763 )
813764
814- # Size the view to num_padded_tokens: top_k_per_row_decode writes
815- # logits.shape[0] == num_padded_tokens rows, and the unpack below
816- # reshapes to (batch_size, next_n, ...).
817765 topk_indices = topk_indices_buffer [:num_padded_tokens , :topk_tokens ]
818- _topk_indices_decode (
766+ num_rows = logits .shape [0 ]
767+
768+ torch .ops ._C .top_k_per_row_decode (
819769 logits ,
820- topk_tokens ,
821- topk_indices ,
822- decode_metadata .seq_lens ,
823770 next_n ,
771+ decode_metadata .seq_lens ,
772+ topk_indices ,
773+ num_rows ,
774+ logits .stride (0 ),
775+ logits .stride (1 ),
776+ topk_tokens ,
824777 )
825778
826779 if decode_metadata .requires_padding :
@@ -837,40 +790,6 @@ def rocm_aiter_sparse_attn_indexer_native(
837790 return topk_indices_buffer
838791
839792
840- @eager_break_during_capture
841- def rocm_aiter_sparse_attn_indexer (
842- hidden_states : torch .Tensor ,
843- k_cache_prefix : LayerNameType ,
844- kv_cache : torch .Tensor ,
845- q_fp8 : torch .Tensor ,
846- k : torch .Tensor ,
847- weights : torch .Tensor ,
848- quant_block_size : int ,
849- scale_fmt : str | None ,
850- topk_tokens : int ,
851- head_dim : int ,
852- max_model_len : int ,
853- total_seq_lens : int ,
854- topk_indices_buffer : torch .Tensor | None ,
855- ) -> torch .Tensor :
856- return rocm_aiter_sparse_attn_indexer_native (
857- hidden_states ,
858- k_cache_prefix ,
859- kv_cache ,
860- q_fp8 ,
861- k ,
862- weights ,
863- quant_block_size ,
864- scale_fmt ,
865- topk_tokens ,
866- head_dim ,
867- max_model_len ,
868- total_seq_lens ,
869- topk_indices_buffer ,
870- skip_k_cache_insert = False ,
871- )
872-
873-
874793def _decode_e8m0_scales (scale : torch .Tensor ) -> torch .Tensor :
875794 if scale .dtype == torch .float8_e8m0fnu :
876795 from vllm .model_executor .layers .quantization .utils .fp8_utils import (
0 commit comments