diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 679bd3f..b9f831e 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -449,13 +449,21 @@ def _prepare_export_inputs(self): and not (self.use_custom_kv_cache and self.use_custom_sdpa) ) - if not self.disable_dynamic_shapes and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache: + allow_dynamic_shapes_for_hybrid = getattr(getattr(self.model, "device", None), "type", None) == "cuda" + if allow_dynamic_shapes_for_hybrid: + logging.info("Enabling dynamic shapes for CUDA with hybrid cache.") + + if not self.disable_dynamic_shapes and ( + not is_using_hybrid_cache_wo_custom_sdpa_kv_cache or allow_dynamic_shapes_for_hybrid + ): # Prepare inputs with dynamic shapes seq_length = 3 # Sequence length > 1 to avoid specialization issue example_input_ids = torch.zeros((1, seq_length), dtype=torch.long, device=self.model.device) example_cache_position = torch.arange(seq_length, dtype=torch.long, device=self.model.device) max_seq_len = self.metadata.get("get_max_seq_len") - sliding_window = self.metadata.get("sliding_window", float("inf")) + sliding_window = getattr(self.config, "sliding_window", None) + if sliding_window is None: + sliding_window = self.metadata.get("sliding_window", float("inf")) max_dim = min(max_seq_len, sliding_window) - 1 seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim) dynamic_shapes = { diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py index 9f3b38e..6080362 100644 --- a/optimum/exporters/executorch/tasks/causal_lm.py +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -50,7 +50,7 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl CausalLMExportableModule: An instance of `CausalLMExportableModule` for exporting and lowering to ExecuTorch. """ - device = "cpu" + device = kwargs.get("device", "cpu") batch_size = 1 dtype = kwargs.get("dtype", "float32") disable_dynamic_shapes = kwargs.get("disable_dynamic_shapes", False)