From 5920a03130bf3c96ae642fcbd7aa1d6f11770acd Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Tue, 23 Jun 2026 21:31:00 -0700 Subject: [PATCH] fix: auto-disable dynamic shapes when custom KV cache is enabled The custom KV cache op (update_cache) only supports single-token inputs, but dynamic shapes enables variable-length prefill. This combination exports successfully but crashes at inference with an opaque error. Auto-set disable_dynamic_shapes=True when use_custom_kv_cache=True, with a warning log so users know the override happened. Co-authored-by: Claude --- optimum/exporters/executorch/integrations.py | 8 +++++++ tests/models/test_modeling_common.py | 25 ++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 33cf665..e2a94a7 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -418,6 +418,14 @@ def __init__( self.config = model.config self.use_custom_kv_cache = use_custom_kv_cache self.use_custom_sdpa = use_custom_sdpa + + # update_cache op only supports single-token (decode) inputs + if use_custom_kv_cache and not disable_dynamic_shapes: + logging.warning( + "Custom KV cache requires static shapes. Automatically setting disable_dynamic_shapes=True." + ) + disable_dynamic_shapes = True + self.disable_dynamic_shapes = disable_dynamic_shapes self.metadata = save_config_to_constant_methods( model.config, diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 63017df..60e988c 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -33,6 +33,7 @@ from optimum.executorch import ExecuTorchModelForCausalLM from optimum.executorch.modeling import _FILE_PATTERN from optimum.exporters.executorch import main_export +from optimum.exporters.executorch.integrations import CausalLMExportableModule from optimum.utils.file_utils import find_files_matching_pattern from ..utils import check_causal_lm_output_quality @@ -183,3 +184,27 @@ def forward(self, x): if node.op == "call_function" and node.target == exir_ops.edge.aten.embedding.default ) ) + + def test_custom_kv_cache_auto_disables_dynamic_shapes(self): + model_id = "optimum-internal-testing/tiny-random-llama" + model = AutoModelForCausalLM.from_pretrained(model_id) + + wrapper = CausalLMExportableModule( + model, + use_custom_kv_cache=True, + disable_dynamic_shapes=False, + ) + self.assertTrue(wrapper.disable_dynamic_shapes) + self.assertFalse(wrapper.metadata.get("enable_dynamic_shape", True)) + + def test_dynamic_shapes_preserved_without_custom_kv_cache(self): + model_id = "optimum-internal-testing/tiny-random-llama" + model = AutoModelForCausalLM.from_pretrained(model_id) + + wrapper = CausalLMExportableModule( + model, + use_custom_kv_cache=False, + disable_dynamic_shapes=False, + ) + self.assertFalse(wrapper.disable_dynamic_shapes) + self.assertTrue(wrapper.metadata.get("enable_dynamic_shape", False))