From 4f6172c82415c3b63e3b27487b74d441627219a6 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 19 Feb 2026 17:42:10 -0800 Subject: [PATCH 1/2] Add attention sink KV cache and SDPA support for ExecuTorch Introduce CustomRingKVCacheWithSink and ETCustomAttentionSinkCache that preserve the first sink_size tokens while using a ring buffer for the remaining window. Add get_custom_sdpa_for_attention_sink to build per-layer attention masks with sink token preservation. Wire the attention_sink parameter through replace_with_et_custom_kv_cache. Co-authored-by: Claude --- .../executorch/attentions/custom_kv_cache.py | 170 +++++++++++++++++- optimum/executorch/attentions/custom_sdpa.py | 53 ++++++ optimum/exporters/executorch/integrations.py | 1 + 3 files changed, 218 insertions(+), 6 deletions(-) diff --git a/optimum/executorch/attentions/custom_kv_cache.py b/optimum/executorch/attentions/custom_kv_cache.py index 64b7322d..6a6204b5 100644 --- a/optimum/executorch/attentions/custom_kv_cache.py +++ b/optimum/executorch/attentions/custom_kv_cache.py @@ -23,6 +23,56 @@ except ImportError: raise ImportError("ExecutorTorch is not installed. Please install it to use Custom Cache.") +try: + from executorch.examples.models.llama.source_transformation.attention_sink import ( + CachePositionsManagerWithSink, + _create_causal_mask_for_attention_sink, + ) +except ImportError: + CachePositionsManagerWithSink = None + _create_causal_mask_for_attention_sink = None + + +class CustomRingKVCacheWithSink(CustomKVCache): + """Ring buffer KV cache with attention sink — preserves first sink_size tokens.""" + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + sink_size: int, + dtype=torch.float32, + ): + assert CachePositionsManagerWithSink is not None, ( + "CachePositionsManagerWithSink not available. " + "Install ExecuTorch with attention sink support." + ) + super().__init__(max_batch_size, max_context_length, n_heads, head_dim, dtype) + self.sink_size = sink_size + self.window_size = max_context_length - sink_size + self.is_ring_buffer = True + self.cache_positions_manager = CachePositionsManagerWithSink( + max_context_length, sink_size + ) + + def update(self, input_pos, k_val, v_val): + seq_len = k_val.transpose(1, 2).size(1) + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ).unsqueeze(0) + return super().update(input_pos, k_val, v_val, indices) + + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): + return _create_causal_mask_for_attention_sink( + self.cache_positions_manager.cache_positions, + self.window_size, + self.sink_size, + start_pos, + seq_len, + ) + class ETCustomStaticCache(StaticCache): """ @@ -333,33 +383,141 @@ def get_layer_cache(self, layer_idx: int): return self.kv_cache[layer_idx] -def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): +class ETCustomAttentionSinkCache(StaticCache): + """ + KV Cache with attention sink for ExecuTorch. All layers use ring buffer + with sink token preservation. + + Sink tokens (first sink_size positions) are never evicted from cache. + Remaining positions use a ring buffer for sliding window. + """ + + def __init__( + self, + config, + max_batch_size: int, + max_cache_len: Optional[int] = None, + sink_size: int = 4, + device: Union[torch.device, str, None] = None, + dtype: torch.dtype = torch.float32, + ): + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + ) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + num_heads = getattr(config, "num_key_value_heads", None) or config.num_attention_heads + self.early_initialization( + batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device + ) + self.sink_size = sink_size + self.cache_position = None + + self.kv_cache = torch.nn.ModuleList() + for layer in self.layers: + layer_cache = CustomRingKVCacheWithSink( + max_batch_size=layer.max_batch_size, + max_context_length=layer.max_cache_len, + n_heads=layer.num_heads, + head_dim=layer.head_dim, + sink_size=sink_size, + dtype=dtype, + ) + self.kv_cache.append(layer_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert cache_kwargs is not None + cache_position = cache_kwargs.get("cache_position") + assert cache_position is not None + assert isinstance(cache_position, torch.Tensor) + self.cache_position = cache_position + + layer_cache = self.kv_cache[layer_idx] + k_out, v_out = layer_cache.update( + input_pos=cache_position, + k_val=key_states, + v_val=value_states, + ) + return k_out, v_out + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + if layer_idx is None: + layer_idx = 0 + return self.kv_cache[layer_idx].cache_positions_manager.cache_positions.max().item() + 1 + + def get_layer_cache(self, layer_idx: int): + return self.kv_cache[layer_idx] + + +def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype, attention_sink=None): """ - Replace all KV caches in the module with ETCustomStaticCache or ETCustomHybridCache. - This modifies the model in place. + Replace all KV caches in the module with ETCustomStaticCache, ETCustomHybridCache, + or ETCustomAttentionSinkCache. Args: module: The module to modify config: The model configuration + attention_sink: Optional tuple (sink_size, window_size) for attention sink mode Returns: The modified module """ - # Recursively replace KV caches - return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype) + return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype, attention_sink) -def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): +def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype, attention_sink=None): """ Helper function to recursively replace KV caches in the module. Args: module: The module to modify config: The model configuration + attention_sink: Optional tuple (sink_size, window_size) for attention sink mode Returns: The modified module """ + # Attention sink mode: replace static_cache with ETCustomAttentionSinkCache + if attention_sink is not None: + sink_size, window_size = attention_sink + cache_size = sink_size + window_size + + if hasattr(module, "static_cache"): + sink_cache = ETCustomAttentionSinkCache( + config=config, + max_batch_size=generation_config.cache_config.get("batch_size"), + max_cache_len=cache_size, + sink_size=sink_size, + device=generation_config.cache_config.get("device"), + dtype=cache_dtype, + ) + if getattr(module, "replace_cache", None) is not None: + module.replace_cache(sink_cache) + else: + module.static_cache = sink_cache + for i in range(len(sink_cache.kv_cache)): + setattr(module, f"key_cache_{i}", sink_cache.kv_cache[i].k_cache) + setattr(module, f"value_cache_{i}", sink_cache.kv_cache[i].v_cache) + module.register_buffer( + f"cache_positions_{i}", + sink_cache.kv_cache[i].cache_positions_manager.cache_positions, + persistent=False, + ) + else: + raise ValueError( + "Attention sink requires 'static_cache' attribute on module" + ) + return module + # Check if module has static_cache (TorchExportableModuleWithStaticCache) if hasattr(module, "static_cache"): assert isinstance(module.static_cache, StaticCache), f"Expected StaticCache, got {type(module.static_cache)}" diff --git a/optimum/executorch/attentions/custom_sdpa.py b/optimum/executorch/attentions/custom_sdpa.py index f857cd12..71d638f2 100644 --- a/optimum/executorch/attentions/custom_sdpa.py +++ b/optimum/executorch/attentions/custom_sdpa.py @@ -191,3 +191,56 @@ def _custom_sdpa_for_ring_kv_cache( ) return _custom_sdpa_for_ring_kv_cache + + +def get_custom_sdpa_for_attention_sink( + exportable_module: torch.nn.Module, +) -> Callable: + """Create SDPA function for attention sink models. + ALL layers use ring buffer mask with sink token preservation.""" + + from optimum.executorch.attentions.custom_kv_cache import ( + CustomRingKVCacheWithSink, + ETCustomAttentionSinkCache, + ) + + def _custom_sdpa_for_attention_sink( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], # noqa + position_ids: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + head_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, None]: + layer_idx = module.layer_idx + assert layer_idx is not None, "layer_idx is not set." + sink_cache = exportable_module.model.static_cache + assert isinstance(sink_cache, ETCustomAttentionSinkCache), ( + f"Expected ETCustomAttentionSinkCache, got {type(sink_cache)}" + ) + ring_cache = sink_cache.get_layer_cache(layer_idx) + assert isinstance(ring_cache, CustomRingKVCacheWithSink), ( + f"Expected CustomRingKVCacheWithSink, got {type(ring_cache)}" + ) + input_pos = sink_cache.cache_position[0].item() + seqlen = query.shape[2] + attention_mask = ring_cache.create_causal_mask_for_ring_buffer(input_pos, seqlen) + kwargs.update({"is_sliding": True}) + return custom_sdpa_with_start_pos_forward( + module, + query, + key, + value, + attention_mask, + position_ids, + scaling, + softcap, + head_mask, + **kwargs, + ) + + return _custom_sdpa_for_attention_sink diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 555e7a16..628f424e 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -412,6 +412,7 @@ def __init__( use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False, + attention_sink=None, ): super().__init__() self.model = model From 5f77beb8cd622038fb47e4f1b0823baeaf2dd19b Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 19 Feb 2026 18:20:25 -0800 Subject: [PATCH 2/2] Wire attention sink SDPA into CausalLMExportableModule export path Register a dedicated custom_sdpa_attention_sink attention implementation when the attention_sink option is provided, with priority over the existing ring KV cache SDPA path. Pass attention_sink through to the cache setup at export time. --- optimum/exporters/executorch/integrations.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 628f424e..fda9c0c7 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -34,7 +34,11 @@ from transformers.masking_utils import AttentionMaskInterface from transformers.modeling_utils import AttentionInterface -from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache, sdpa_mask_passthrough +from optimum.executorch.attentions.custom_sdpa import ( + get_custom_sdpa_for_attention_sink, + get_custom_sdpa_for_ring_kv_cache, + sdpa_mask_passthrough, +) from optimum.executorch.attentions.whisper_attention import WhisperCrossAttention from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods @@ -420,6 +424,7 @@ def __init__( self.use_custom_kv_cache = use_custom_kv_cache self.use_custom_sdpa = use_custom_sdpa self.disable_dynamic_shapes = disable_dynamic_shapes + self.attention_sink = attention_sink self.metadata = save_config_to_constant_methods( model.config, generation_config=getattr(model, "generation_config", None), @@ -472,16 +477,17 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module): from transformers.modeling_utils import AttentionInterface if self.use_custom_sdpa: - if self.use_custom_kv_cache: + if self.attention_sink is not None: + _custom_sdpa = get_custom_sdpa_for_attention_sink(exportable_module) + AttentionInterface.register("custom_sdpa_attention_sink", _custom_sdpa) + AttentionMaskInterface.register("custom_sdpa_attention_sink", sdpa_mask_passthrough) + exportable_module.model.model.config._attn_implementation = "custom_sdpa_attention_sink" + elif self.use_custom_kv_cache: _custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module) AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache) AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_passthrough) - # Manually set the attention implementation to custom_sdpa_ring_kv_cache - # This handles both regular sdpa and one for sliding window/local attention exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache" else: - # Manually set the attention implementation to custom_sdpa_ring_kv_cache - # This handles both regular sdpa and one for sliding window/local attention exportable_module.model.model.config._attn_implementation = "custom_sdpa" def export( @@ -507,6 +513,7 @@ def export( self.model.config, self.model.generation_config, self.model.dtype, + attention_sink=self.attention_sink, ) with torch.no_grad():