Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 45 additions & 15 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,13 @@ def __init__(self, encoder_model):
self.config = encoder_model.config

def forward(self, input_ids):
return self.encoder(input_ids).last_hidden_state
# Compute attention_mask from input_ids so PAD tokens (id=0) do not
# attend to real positions. Without this, padded encoder inputs corrupt
# the hidden states at real token positions.
attention_mask = (input_ids != 0).long()
hidden = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
# Zero out PAD positions so decoder cross-attention ignores them.
return hidden * attention_mask.unsqueeze(-1).float()


class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
Expand Down Expand Up @@ -837,18 +843,25 @@ def __init__(self, model, max_static_cache_length, batch_size):
cross_attn.out_proj = layer.encoder_attn.out_proj
layer.encoder_attn = cross_attn

def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder
def forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor,
cache_position: torch.Tensor,
encoder_attention_mask: torch.Tensor | None = None,
):
# Get outputs from decoder.
# encoder_attention_mask must be passed so that models using relative
# position bias (e.g. T5) scale the bias by the real encoder length
# rather than the full padded length.
outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)
# Set the cross attention cache as initialized after the first forward pass
# This allows torch.cond to branch differently on subsequent runs
# self.cross_attention_cache_initialized.fill_(True)

# Apply linear projection (lm head) to obtain logits
logits = self.proj_out(outputs[0])
Expand Down Expand Up @@ -922,7 +935,7 @@ def _export_encoder(self, encoder_input_ids):
)
return exported_encoder

def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position, encoder_attention_mask=None):
wrapped_decoder = (
Seq2SeqLMDecoderExportableModuleWithStaticCache(
model=self.model,
Expand All @@ -935,14 +948,14 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi

if isinstance(self.model, WhisperForConditionalGeneration):
dynamic_shapes = None
export_args = (decoder_input_ids, encoder_hidden_states, cache_position)
elif isinstance(self.model, T5ForConditionalGeneration):
# Define dynamic dimension for encoder output sequence length
encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_len)
dynamic_shapes = {
"decoder_input_ids": None,
"encoder_hidden_states": {1: encoder_seq_len_dim},
"cache_position": None,
}
# T5's cross-attention causal mask slices against the static KV-cache
# size at torch.export time, which conflicts with a symbolic encoder
# dim. Fix: use fully static shapes — callers must pad encoder inputs
# to max_seq_len before running the encoder.
dynamic_shapes = None
export_args = (decoder_input_ids, encoder_hidden_states, cache_position, encoder_attention_mask)
else:
raise ValueError(
f"Unsupported model type {type(self.model)} for Seq2SeqLMExportableModule decoder export."
Expand All @@ -952,7 +965,7 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
exported_decoder = torch.export.export(
wrapped_decoder,
(decoder_input_ids, encoder_hidden_states, cache_position),
export_args,
dynamic_shapes=dynamic_shapes,
strict=True,
)
Expand Down Expand Up @@ -982,6 +995,14 @@ def export(
example_encoder_input_ids = torch.rand(
self._expected_encoder_input_shape, device=self.model.device, dtype=self.model.dtype
)
elif isinstance(self.model, T5ForConditionalGeneration):
# Use max_seq_len-sized input so the encoder output has the static
# shape expected by the T5 decoder (which uses fully static shapes
# to avoid a cross-attention mask conflict with the KV cache size).
example_encoder_input_ids = torch.zeros(
(1, self.max_seq_len), dtype=torch.long, device=self.model.device
)
example_encoder_input_ids[0, 0] = 1 # one real token, rest PAD
else:
example_encoder_input_ids = torch.ones((1, 10), dtype=torch.long, device=self.model.device)
else:
Expand All @@ -1005,10 +1026,19 @@ def export(
else torch.tensor([0], dtype=torch.long, device=self.model.device)
)

# Build encoder attention mask for the decoder export example.
# For T5 the decoder uses fully static shapes; the mask shape must match
# the encoder hidden states shape (batch, max_seq_len).
if isinstance(self.model, T5ForConditionalGeneration):
example_encoder_attention_mask = (example_encoder_input_ids != 0).long()
else:
example_encoder_attention_mask = None

self.exported_decoder = self._export_decoder(
example_decoder_input_ids,
example_encoder_hidden_states,
example_cache_position,
example_encoder_attention_mask,
)

# Skip sampler export for MPS + bfloat16 due to Metal shader compilation error
Expand Down