From 13aa3ba44fc70298339a7cba0cdbf3ec7cbc2c25 Mon Sep 17 00:00:00 2001 From: Rohan <34810284+rhn19@users.noreply.github.com> Date: Fri, 20 Feb 2026 14:22:03 -0800 Subject: [PATCH] Fix: respect --max_seq_len for sliding window models with custom kv cache + sdpa --- optimum/exporters/executorch/integrations.py | 58 +++++++++++++++----- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 33cf665..32abf95 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -464,7 +464,10 @@ def _prepare_export_inputs(self): sliding_window = getattr(self.config, "sliding_window", None) if sliding_window is None: sliding_window = self.metadata.get("sliding_window", float("inf")) - max_dim = min(max_seq_len, sliding_window) - 1 + if self.use_custom_kv_cache and self.use_custom_sdpa: + max_dim = max_seq_len - 1 + else: + max_dim = min(max_seq_len, sliding_window) - 1 seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) dynamic_shapes = { "input_ids": {1: seq_len_dim}, @@ -499,22 +502,47 @@ def export( f"Exporting using input_ids({input_ids.shape})={input_ids}, cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}, strict={strict}" ) - exportable_module = TorchExportableModuleForDecoderOnlyLM( - self.model, - ) - self._register_custom_attention(exportable_module) - - if self.use_custom_kv_cache: - from optimum.executorch.attentions.custom_kv_cache import ( - replace_with_et_custom_kv_cache, + max_seq_len = self.metadata.get("get_max_seq_len") + sliding_window = self.metadata.get("sliding_window") + use_ring_cache = self.use_custom_kv_cache and self.use_custom_sdpa and sliding_window is not None + + if use_ring_cache: + from transformers.integrations.executorch import TorchExportableModuleWithHybridCache + from optimum.executorch.attentions.custom_kv_cache import ETCustomHybridCache + + # Bypass TorchExportableModuleWithHybridCache.__init__ — it calls StaticCache which + # caps sliding layers to sliding_window via StaticSlidingWindowLayer, baking a + # <= sliding_window guard into torch.export. Instead, directly install + # ETCustomHybridCache sized to max_seq_len, then patch sliding layer max_cache_len + # so get_mask_sizes() returns max_seq_len during tracing. + exportable_module_inner = TorchExportableModuleWithHybridCache.__new__(TorchExportableModuleWithHybridCache) + torch.nn.Module.__init__(exportable_module_inner) + exportable_module_inner.model = self.model + exportable_module_inner.cache = ETCustomHybridCache( + config=self.model.config, max_batch_size=1, max_cache_len=max_seq_len, + device=self.model.device, dtype=self.model.dtype, ) + for layer in exportable_module_inner.cache.layers: + if layer.is_sliding: + layer.max_cache_len = max_seq_len + for i in range(len(exportable_module_inner.cache.kv_cache)): + exportable_module_inner.register_buffer( + f"key_cache_{i}", exportable_module_inner.cache.kv_cache[i].k_cache, persistent=False) + exportable_module_inner.register_buffer( + f"value_cache_{i}", exportable_module_inner.cache.kv_cache[i].v_cache, persistent=False) + if exportable_module_inner.cache.layers[i].is_sliding: + exportable_module_inner.register_buffer( + f"cache_positions_{i}", + exportable_module_inner.cache.kv_cache[i].cache_positions_manager.cache_positions, + persistent=False, + ) + exportable_module = TorchExportableModuleForDecoderOnlyLM.__new__(TorchExportableModuleForDecoderOnlyLM) + torch.nn.Module.__init__(exportable_module) + exportable_module.model = exportable_module_inner + else: + exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model) - replace_with_et_custom_kv_cache( - exportable_module.model, - self.model.config, - self.model.generation_config, - self.model.dtype, - ) + self._register_custom_attention(exportable_module) with torch.no_grad(): exported_program = exportable_module.export(