From 4cd44e4583a43bb54b131b256a27ca71a991f6e2 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:59:42 +0200 Subject: [PATCH 01/39] add a base model class --- .../audioflamingo3/modeling_audioflamingo3.py | 255 +++++++++------ .../audioflamingo3/modular_audioflamingo3.py | 229 +++++++++----- .../models/glmasr/modeling_glmasr.py | 202 +++++++++--- .../models/glmasr/modular_glmasr.py | 28 +- .../granite_speech/modeling_granite_speech.py | 227 ++++++++------ .../musicflamingo/modeling_musicflamingo.py | 260 +++++++++------ .../musicflamingo/modular_musicflamingo.py | 98 ++---- .../qwen2_audio/modeling_qwen2_audio.py | 189 ++++++++--- .../vibevoice_asr/modeling_vibevoice_asr.py | 219 +++++++++---- .../vibevoice_asr/modular_vibevoice_asr.py | 210 ++++++++++--- .../models/voxtral/modeling_voxtral.py | 191 ++++++++--- .../models/voxtral/modular_voxtral.py | 192 +++++++++--- .../modeling_voxtral_realtime.py | 296 ++++++++++-------- .../modular_voxtral_realtime.py | 286 ++++++++++------- 14 files changed, 1937 insertions(+), 945 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index 1fbbc733c308..8e22270ae5fd 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -21,6 +21,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass import torch from torch import nn @@ -31,13 +32,13 @@ from ...masking_utils import create_bidirectional_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig @@ -256,6 +257,44 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel): _supports_sdpa = True +@dataclass +class AudioFlamingo3ModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for AudioFlamingo3 causal language model (or autoregressive) outputs. + """ +) +class AudioFlamingo3CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The audio model from AudioFlamingo3 without any head or projection on top. @@ -403,19 +442,16 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _tp_plan = None - _pp_plan = None - +class AudioFlamingo3Model(AudioFlamingo3PreTrainedModel): def __init__(self, config): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config) # Initialize weights and apply final processing @@ -427,18 +463,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -451,11 +475,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -484,77 +504,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | AudioFlamingo3ModelOutputWithPast: r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/audio-flamingo-3-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversations = [ - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> {"type": "text", "text": "Transcribe the input speech."}, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav", - >>> }, - >>> ], - >>> } - >>> ], - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?", - >>> }, - >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"}, - >>> ], - >>> } - >>> ], - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversations, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device) - - >>> outputs = model.generate(**inputs, max_new_tokens=500) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."] - ```""" - + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output @@ -564,17 +524,117 @@ def forward( audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return AudioFlamingo3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + """ +) +class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _tp_plan = None + _pp_plan = None + _keep_in_fp32_modules_strict = None + + def __init__(self, config): + super().__init__(config) + self.model = AudioFlamingo3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, input_features_mask, **kwargs) -> tuple | BaseModelOutputWithPooling: + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | AudioFlamingo3CausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return AudioFlamingo3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -591,4 +651,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"] +__all__ = [ + "AudioFlamingo3ForConditionalGeneration", + "AudioFlamingo3PreTrainedModel", + "AudioFlamingo3Encoder", + "AudioFlamingo3Model", +] diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index c325bc85300e..75c766c5982c 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...masking_utils import create_bidirectional_mask -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults @@ -28,7 +30,12 @@ Qwen2AudioEncoder, Qwen2AudioPreTrainedModel, ) -from ..voxtral.modeling_voxtral import VoxtralForConditionalGeneration, VoxtralMultiModalProjector +from ..voxtral.modeling_voxtral import ( + VoxtralForConditionalGeneration, + VoxtralModel, + VoxtralModelOutputWithPast, + VoxtralMultiModalProjector, +) from ..whisper.modeling_whisper import WhisperAttention, WhisperEncoderLayer from .configuration_audioflamingo3 import AudioFlamingo3Config @@ -48,6 +55,37 @@ class AudioFlamingo3PreTrainedModel(Qwen2AudioPreTrainedModel): pass +@dataclass +class AudioFlamingo3ModelOutputWithPast(VoxtralModelOutputWithPast): + pass + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for AudioFlamingo3 causal language model (or autoregressive) outputs. + """ +) +class AudioFlamingo3CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The audio model from AudioFlamingo3 without any head or projection on top. @@ -138,17 +176,11 @@ def __init__(self, config: AudioFlamingo3Config): @auto_docstring( custom_intro=""" - The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): - _tp_plan = None - _pp_plan = None - _keep_in_fp32_modules_strict = None - - def __init__(self, config): - super().__init__(config) - +class AudioFlamingo3Model(VoxtralModel): @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -161,11 +193,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -194,77 +222,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ): r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/audio-flamingo-3-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversations = [ - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> {"type": "text", "text": "Transcribe the input speech."}, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav", - >>> }, - >>> ], - >>> } - >>> ], - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?", - >>> }, - >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"}, - >>> ], - >>> } - >>> ], - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversations, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device) - - >>> outputs = model.generate(**inputs, max_new_tokens=500) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."] - ```""" - + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output @@ -274,17 +242,103 @@ def forward( audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return AudioFlamingo3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + """ +) +class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): + _tp_plan = None + _pp_plan = None + _keep_in_fp32_modules_strict = None + + def __init__(self, config): + super().__init__(config) + self.model = AudioFlamingo3Model(config) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | AudioFlamingo3CausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return AudioFlamingo3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -301,4 +355,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"] +__all__ = [ + "AudioFlamingo3ForConditionalGeneration", + "AudioFlamingo3PreTrainedModel", + "AudioFlamingo3Encoder", + "AudioFlamingo3Model", +] diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index aff96cad3217..ddc43d4e624f 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -19,6 +19,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from typing import Optional from ...activations import ACT2FN @@ -26,14 +27,19 @@ from ...generation import GenerationMixin from ...integrations import use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, + ModelOutput, +) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_glmasr import GlmAsrConfig, GlmAsrEncoderConfig @@ -349,21 +355,30 @@ def forward(self, audio_features): return hidden_states +@dataclass +class GlmAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" - The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. + The GlmAsr model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _tp_plan = None - _pp_plan = None - +class GlmAsrModel(GlmAsrPreTrainedModel): def __init__(self, config): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = GlmAsrMultiModalProjector(config) # Initialize weights and apply final processing @@ -375,18 +390,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector." @@ -399,11 +402,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -425,6 +424,113 @@ def get_audio_features( return audio_outputs + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GlmAsrModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + ) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + return GlmAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for GlmAsr causal language model (or autoregressive) outputs. + """ +) +class GlmAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _tp_plan = None + _pp_plan = None + _keep_in_fp32_modules_strict = None + + def __init__(self, config): + super().__init__(config) + self.model = GlmAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, input_features_mask, **kwargs) -> tuple | BaseModelOutputWithPooling: + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -469,30 +575,36 @@ def forward( >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) >>> print(decoded_outputs) ```""" - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) - ) - - outputs: CausalLMOutputWithPast = self.language_model( - inputs_embeds=inputs_embeds, + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, + inputs_embeds=inputs_embeds, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return GlmAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -509,4 +621,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrPreTrainedModel"] +__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrModel", "GlmAsrPreTrainedModel"] diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index ff0b8b6062a4..38ec9dcdb071 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -29,6 +29,7 @@ from ...utils.output_capturing import capture_outputs from ..audioflamingo3.modeling_audioflamingo3 import ( AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Model, AudioFlamingo3MultiModalProjector, AudioFlamingo3PreTrainedModel, ) @@ -351,12 +352,7 @@ def __init__(self, config: GlmAsrConfig): self.linear_2 = nn.Linear(config.text_config.hidden_size * 2, config.text_config.hidden_size) -@auto_docstring( - custom_intro=""" - The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. - """ -) -class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): +class GlmAsrModel(AudioFlamingo3Model): @can_return_tuple @auto_docstring( custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector." @@ -385,6 +381,18 @@ def get_audio_features( return audio_outputs + +@auto_docstring( + custom_intro=""" + The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + self.model = GlmAsrModel(config) + self.post_init() + def forward( self, input_ids: torch.LongTensor | None = None, @@ -440,4 +448,10 @@ def forward( ) -__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrProcessor", "GlmAsrPreTrainedModel"] +__all__ = [ + "GlmAsrEncoder", + "GlmAsrForConditionalGeneration", + "GlmAsrModel", + "GlmAsrProcessor", + "GlmAsrPreTrainedModel", +] diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 0fbc1d1035bf..ebd4b91283de 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -22,7 +22,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -35,7 +35,7 @@ ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_granite_speech import GraniteSpeechConfig, GraniteSpeechEncoderConfig @@ -68,6 +68,23 @@ class GraniteSpeechCausalLMOutputWithPast(ModelOutput): attentions: tuple[torch.FloatTensor] | None = None +@dataclass +@auto_docstring( + custom_intro=""" + Base class for GraniteSpeech outputs, with hidden states and attentions. + """ +) +class GraniteSpeechModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + ### Projector class GraniteSpeechEncoderProjector(nn.Module): def __init__(self, config: GraniteSpeechConfig): @@ -261,6 +278,7 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> @auto_docstring class GraniteSpeechPreTrainedModel(PreTrainedModel): config: GraniteSpeechConfig + base_model_prefix = "model" input_modalities = ("audio", "text") _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this @@ -326,15 +344,15 @@ def forward( The Granite Speech model, which consists of an audio encoder, projector, and language model. """ ) -class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): +@auto_docstring( + custom_intro=""" + The GraniteSpeech model (CTC encoder, projector, language model), without a language modeling head. + """ +) +class GraniteSpeechModel(GraniteSpeechPreTrainedModel): def __init__(self, config: GraniteSpeechConfig): super().__init__(config) - # NOTE: It doesn't matter when we initialize from config, but we should be careful - # to make sure this does not pick up the adapter_config if in the future we use - # from_pretrained or something similar, since that should be set by the composite - # model; don't need to consider it twice - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - + self.language_model = AutoModel.from_config(config.text_config) self.encoder = GraniteSpeechCTCEncoder(config.encoder_config) self.projector = GraniteSpeechEncoderProjector(config) @@ -348,24 +366,12 @@ def __init__(self, config: GraniteSpeechConfig): self.post_init() - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - def get_input_embeddings(self): return self.language_model.get_input_embeddings() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - @can_return_tuple @auto_docstring def get_audio_features( @@ -377,6 +383,27 @@ def get_audio_features( return audio_outputs + def get_merged_audio_embeddings( + self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None + ) -> torch.Tensor: + """Merge audio features into the language embeddings at `audio_token_id` positions.""" + is_audio_index = input_ids == self.config.audio_token_id + llm_input_ids = torch.where(is_audio_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) + + special_audio_mask = is_audio_index.unsqueeze(-1) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + if input_features_mask is not None: + torch_compilable_check( + not torch.all(is_audio_index.int().sum(dim=1) != input_features_mask.int().sum(dim=1)), + "Number of audio tokens does not match number of audio features", + ) + audio_features = audio_features[input_features_mask] + + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + return inputs_embeds + + @can_return_tuple @auto_docstring def forward( self, @@ -387,28 +414,19 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, - return_dict: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **lm_kwargs, - ) -> tuple[torch.Tensor] | GraniteSpeechCausalLMOutputWithPast: + **kwargs, + ) -> tuple | GraniteSpeechModelOutputWithPast: r""" input_features_mask (`torch.Tensor`, *optional*): Mask to be applied to audio features prior to scattering into the language embeddings. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ - # TODO (@alex-jw-brooks) add an example to this docstring once models are released output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -419,21 +437,16 @@ def forward( ) if inputs_embeds is None: - # Get the base embeddings; set all audio tokens to 0 index - # to avoid out of vocabulary issues with the LLM embedding. - # Audio features will be masked into is_audio_idx indices later. is_audio_idx = input_ids == self.config.audio_token_id llm_input_ids = input_ids.clone() llm_input_ids[is_audio_idx] = 0 inputs_embeds = self.get_input_embeddings()(llm_input_ids) + audio_embeds = None if input_features is not None: if input_features.dtype != self.dtype: input_features = input_features.to(self.dtype) - # Get the audio features from the encoder / projector audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # Merge the audio features into the LLM embeddings inputs_embeds = self.get_merged_audio_embeddings( input_ids=input_ids, audio_features=audio_embeds, @@ -448,11 +461,88 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, - logits_to_keep=logits_to_keep, + **kwargs, + ) + + return GraniteSpeechModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechConfig): + super().__init__(config) + self.model = GraniteSpeechModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, **kwargs): + return self.model.get_audio_features(input_features, **kwargs) + + def get_merged_audio_embeddings(self, *args, **kwargs): + return self.model.get_merged_audio_embeddings(*args, **kwargs) + + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **lm_kwargs, + ) -> tuple[torch.Tensor] | GraniteSpeechCausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor`, *optional*): + Mask to be applied to audio features prior to scattering into the language embeddings. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, **lm_kwargs, ) - logits = outputs[0] + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -466,14 +556,13 @@ def forward( else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) if not return_dict: - output = (logits,) + outputs[1:] + output = (logits,) + tuple(v for v in (outputs.past_key_values, outputs.hidden_states, outputs.attentions) if v is not None) return (loss,) + output if loss is not None else output return GraniteSpeechCausalLMOutputWithPast( @@ -496,8 +585,7 @@ def prepare_inputs_for_generation( **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model - - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -507,55 +595,12 @@ def prepare_inputs_for_generation( **kwargs, ) - # If we're in cached decoding stage, input_features should be None because - # input ids do not contain special audio token anymore Otherwise we need - # input feature values to be passed to the model if is_first_iteration or not kwargs.get("use_cache", True): model_inputs["input_features"] = input_features return model_inputs - def get_merged_audio_embeddings( - self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Adds the audio token to the model's LLM vocabulary so that we can pass it - through the tokenizer; it's assumed that the embeddings corresponding to the - <|audio|> token will be clobbered with speech features. - - Args: - input_ids (`torch.Tensor`): - Input IDs containing one or more audio tokens. - audio_features (`torch.Tensor`): - Audio features to be masked into the language embeddings to form multimodal embeddings. - input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) - Mask to be applied to audio features prior to scattering into the language embeddings. - """ - is_audio_index = input_ids == self.config.audio_token_id - llm_input_ids = torch.where(is_audio_index, 0, input_ids) - inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] - - # Mask the audio features into the text embeddings - special_audio_mask = is_audio_index.unsqueeze(-1) - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - if input_features_mask is not None: - torch_compilable_check( - not torch.all(is_audio_index.int().sum(dim=1) != input_features_mask.int().sum(dim=1)), - "Number of audio tokens does not match number of audio features", - ) - audio_features = audio_features[input_features_mask] - - inputs_embeds = inputs_embeds.masked_scatter( - special_audio_mask, - audio_features, - ) - return inputs_embeds - def generate(self, *args, **kwargs) -> torch.LongTensor: - # This model is expected to have a lora adapter, which is only - # enabled when considering audio inputs. As such, we override generate - # to conditionally enable / disable the lora adapter based on whether - # or not any input features were provided. - + # Enable/disable LoRA adapter based on whether audio inputs are provided. input_features = kwargs.pop("input_features", None) if is_peft_available and self._hf_peft_config_loaded: if input_features is not None: @@ -565,12 +610,11 @@ def generate(self, *args, **kwargs) -> torch.LongTensor: return super().generate(*args, input_features=input_features, **kwargs) def save_pretrained(self, save_directory, *args, **kwargs): - # overwrite save_pretrained to first save the adapter if we have one + # Save the adapter first, then the base model if is_peft_available and self._hf_peft_config_loaded: adapter_name = self._get_adapter_name() self.peft_config[adapter_name].base_model_name_or_path = save_directory super().save_pretrained(save_directory, *args, **kwargs) - # Then save the base model afterwards prev_val = self._hf_peft_config_loaded self._hf_peft_config_loaded = False super().save_pretrained(save_directory, *args, **kwargs) @@ -583,5 +627,6 @@ def _get_adapter_name(self): __all__ = [ "GraniteSpeechCTCEncoder", "GraniteSpeechForConditionalGeneration", + "GraniteSpeechModel", "GraniteSpeechPreTrainedModel", ] diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index adec95bbf3e1..adeb0c89d01c 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -20,6 +20,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from math import pi from typing import Optional @@ -29,12 +30,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_musicflamingo import MusicFlamingoConfig @@ -150,6 +151,18 @@ def _init_weights(self, module): init.copy_(module.position_angles, buffer_value) +@dataclass +class MusicFlamingoModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + class MusicFlamingoMultiModalProjector(nn.Module): """ Audio adaptor (small MLP) that projects MusicFlamingoEncoder features @@ -195,19 +208,16 @@ def apply_rotary_time_emb(hidden_states, cos, sin): @auto_docstring( custom_intro=""" - The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. + The MusicFlamingo model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _tp_plan = None - _pp_plan = None - +class MusicFlamingoModel(MusicFlamingoPreTrainedModel): def __init__(self, config: MusicFlamingoConfig): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = MusicFlamingoMultiModalProjector(config) self.pos_emb = MusicFlamingoRotaryEmbedding(config) @@ -220,18 +230,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -279,65 +277,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | MusicFlamingoModelOutputWithPast: r""" - input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/music-flamingo-2601-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversation = [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", - >>> }, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", - >>> }, - >>> ], - >>> } - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversation, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device, model.dtype) - - >>> outputs = model.generate(**inputs, max_new_tokens=100) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["This track is an uplifting Eurodance-style Trance-Pop anthem..."] - ```""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features( input_features, input_features_mask, input_ids=input_ids, return_dict=True @@ -349,31 +299,22 @@ def forward( audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs - - def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): - input_features = kwargs.pop("input_features", None) - input_features_mask = kwargs.pop("input_features_mask", None) - model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) - - if is_first_iteration or not model_inputs.get("use_cache", False): - if input_features is not None: - model_inputs["input_features"] = input_features - if input_features_mask is not None: - model_inputs["input_features_mask"] = input_features_mask - - return model_inputs + return MusicFlamingoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) def _build_audio_timestamps( self, @@ -410,4 +351,139 @@ def _build_audio_timestamps( return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets -__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoPreTrainedModel"] +@dataclass +@auto_docstring( + custom_intro=""" + Base class for MusicFlamingo causal language model (or autoregressive) outputs. + """ +) +class MusicFlamingoCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. + """ +) +class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _tp_plan = None + _pp_plan = None + _keep_in_fp32_modules_strict = None + + def __init__(self, config: MusicFlamingoConfig): + super().__init__(config) + self.model = MusicFlamingoModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, input_features_mask, **kwargs) -> tuple | BaseModelOutputWithPooling: + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MusicFlamingoCausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return MusicFlamingoCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) + + def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): + input_features = kwargs.pop("input_features", None) + input_features_mask = kwargs.pop("input_features_mask", None) + + model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + if is_first_iteration or not model_inputs.get("use_cache", False): + if input_features is not None: + model_inputs["input_features"] = input_features + if input_features_mask is not None: + model_inputs["input_features_mask"] = input_features_mask + + return model_inputs + + +__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoModel", "MusicFlamingoPreTrainedModel"] diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index 7d98d0ffdeab..8fb1ead23cd1 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -14,6 +14,7 @@ # limitations under the License. import re +from dataclasses import dataclass from math import pi from huggingface_hub.dataclasses import strict @@ -29,6 +30,8 @@ from ..audioflamingo3.configuration_audioflamingo3 import AudioFlamingo3Config from ..audioflamingo3.modeling_audioflamingo3 import ( AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Model, + AudioFlamingo3ModelOutputWithPast, AudioFlamingo3PreTrainedModel, ) from ..audioflamingo3.processing_audioflamingo3 import AudioFlamingo3Processor @@ -252,15 +255,16 @@ def _init_weights(self, module): init.copy_(module.position_angles, buffer_value) -@auto_docstring( - custom_intro=""" - The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. - """ -) -class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): +@dataclass +class MusicFlamingoModelOutputWithPast(AudioFlamingo3ModelOutputWithPast): + pass + + +class MusicFlamingoModel(AudioFlamingo3Model): def __init__(self, config: MusicFlamingoConfig): super().__init__(config) self.pos_emb = MusicFlamingoRotaryEmbedding(config) + self.post_init() def _build_audio_timestamps( self, @@ -343,65 +347,13 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/music-flamingo-2601-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversation = [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", - >>> }, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", - >>> }, - >>> ], - >>> } - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversation, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device, model.dtype) - - >>> outputs = model.generate(**inputs, max_new_tokens=100) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["This track is an uplifting Eurodance-style Trance-Pop anthem..."] - ```""" + ): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features( input_features, input_features_mask, input_ids=input_ids, return_dict=True @@ -413,22 +365,40 @@ def forward( audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return MusicFlamingoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. + """ +) +class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): + def __init__(self, config: MusicFlamingoConfig): + super().__init__(config) + self.model = MusicFlamingoModel(config) + self.post_init() __all__ = [ "MusicFlamingoConfig", "MusicFlamingoProcessor", "MusicFlamingoForConditionalGeneration", + "MusicFlamingoModel", "MusicFlamingoPreTrainedModel", ] diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 442eab1edcd4..fa319bc70a64 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -25,19 +25,43 @@ from ...generation import GenerationMixin from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging, torch_compilable_check from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig logger = logging.get_logger(__name__) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen2Audio outputs, with hidden states and attentions. + """ +) +class Qwen2AudioModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask, potentially updated by the audio merging logic so that audio tokens are unmasked. + labels (`torch.LongTensor`, *optional*): + Labels, potentially re-aligned by the legacy audio merging logic. Returned so the language-modeling + head can compute the loss against the expanded sequence. + """ + + attention_mask: torch.FloatTensor | None = None + labels: torch.LongTensor | None = None + + @dataclass @auto_docstring( custom_intro=""" @@ -394,17 +418,17 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The QWEN2AUDIO model which consists of a audio backbone and a language model. + The Qwen2Audio model which consists of an audio backbone and a language model, without a language modeling head. """ ) -class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): +class Qwen2AudioModel(Qwen2AudioPreTrainedModel): def __init__(self, config: Qwen2AudioConfig): super().__init__(config) self.audio_tower = AutoModel.from_config(config.audio_config) # Usually a `Qwen2AudioEncoder` instance self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = ( self.config.text_config.pad_token_id if self.config.text_config.pad_token_id is not None else -1 @@ -428,18 +452,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def _merge_input_ids_with_audio_features( self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels ): @@ -651,7 +663,7 @@ def forward( labels: torch.LongTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple | Qwen2AudioCausalLMOutputWithPast: + ) -> tuple | Qwen2AudioModelOutputWithPast: r""" feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: @@ -659,32 +671,9 @@ def forward( - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from io import BytesIO - >>> from urllib.request import urlopen - >>> import librosa - >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration - - >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") - - >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" - >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" - >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) - - >>> inputs = processor(text=prompt, audio=audio, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Generate the caption in English: Glass is breaking." - ```""" + Labels kept in the signature for the legacy merge path that may re-align them with audio tokens. + The loss is not computed here; `Qwen2AudioForConditionalGeneration` is responsible for that. + """ target_device = self.audio_tower.device @@ -767,7 +756,115 @@ def forward( **kwargs, ) - logits = outputs.logits + return Qwen2AudioModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + attention_mask=attention_mask, + labels=labels, + ) + + +@auto_docstring( + custom_intro=""" + The QWEN2AUDIO model which consists of an audio backbone and a language model. + """ +) +class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: Qwen2AudioConfig): + super().__init__(config) + self.model = Qwen2AudioModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @property + def padding_side(self): + return self.model.padding_side + + @padding_side.setter + def padding_side(self, padding_side: str): + self.model.padding_side = padding_side + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + feature_attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen2AudioCausalLMOutputWithPast: + r""" + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration + + >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") + + >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" + >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) + + >>> inputs = processor(text=prompt, audio=audio, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Generate the caption in English: Glass is breaking." + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + feature_attention_mask=feature_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + attention_mask = outputs.attention_mask + labels = outputs.labels if outputs.labels is not None else labels loss = None if labels is not None: @@ -809,4 +906,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"] +__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder", "Qwen2AudioModel"] diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index 703bb6ca5130..e02d0f647b23 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -17,6 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn @@ -25,11 +27,11 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_vibevoice_asr import VibeVoiceAsrConfig @@ -249,26 +251,69 @@ def _init_weights(self, module): init.constant_(module.ffn_gamma, self.config.layer_scale_init_value) +@dataclass @auto_docstring( custom_intro=""" - The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + Base class for VibeVoice ASR outputs, with hidden states and attentions. """ ) -class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _tp_plan = None - _pp_plan = None +class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VibeVoice ASR causal language model outputs. + """ +) +class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores. + past_key_values (`Cache`, *optional*): + Cache instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model), + without a language modeling head. + """ +) +class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config) - - # Initialize weights and apply final processing + self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) self.post_init() + # Acoustic/semantic tokenizers are run under no_grad in `get_audio_features`; freeze + # their parameters so grad-checkpointing and training sanity checks don't flag them. + for p in self.acoustic_tokenizer_encoder.parameters(): + p.requires_grad_(False) + for p in self.semantic_tokenizer_encoder.parameters(): + p.requires_grad_(False) def get_input_embeddings(self): return self.language_model.get_input_embeddings() @@ -276,18 +321,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.") def get_audio_features( @@ -296,17 +329,15 @@ def get_audio_features( padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: + ): r""" input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`): - Input audio tensor. Audio should be sampled at 24kHz. + Input audio tensor sampled at 24kHz. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`, - but can be modified to fit the available memory. + Size of audio chunks to process at once through the tokenizers. """ - if acoustic_tokenizer_chunk_size is None: acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size else: @@ -351,7 +382,6 @@ def get_audio_features( combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents) if padding_mask is not None: - # Adjust padding mask according to tokenizer compression num_audio_tokens = torch.ceil( padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length ).to(torch.int64) @@ -374,13 +404,88 @@ def forward( padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VibeVoiceAsrModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_values is not None and input_ids is not None: + audio_embeds = self.get_audio_features( + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + ).pooler_output + + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + ) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + + return VibeVoiceAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + """ +) +class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: VibeVoiceAsrConfig): + super().__init__(config) + self.model = VibeVoiceAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + input_values: torch.FloatTensor | None = None, + padding_mask: torch.BoolTensor | None = None, + acoustic_tokenizer_chunk_size: int | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast: r""" padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to - `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory. + Size of audio chunks processed by the acoustic and semantic tokenizers. Example: @@ -390,33 +495,35 @@ def forward( >>> model_id = "microsoft/VibeVoice-ASR-HF" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto") - - >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> outputs = model.generate(**inputs) - - >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) - >>> print(decoded_outputs) ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + **kwargs, + ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_values is not None and input_ids is not None: - audio_embeds = self.get_audio_features( - input_values=input_values, - padding_mask=padding_mask, - acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, - ).pooler_output + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) - # Replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) - return self.language_model( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs + return VibeVoiceAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs): @@ -437,4 +544,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg return model_inputs -__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrPreTrainedModel"] +__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrModel", "VibeVoiceAsrPreTrainedModel"] diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index fc9c960c1033..93391c91f14c 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -11,16 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from huggingface_hub.dataclasses import strict from torch import nn from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, + ModelOutput, +) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ..audioflamingo3.modeling_audioflamingo3 import AudioFlamingo3ForConditionalGeneration from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel from ..qwen2.modeling_qwen2 import Qwen2RMSNorm from ..vibevoice_acoustic_tokenizer.modeling_vibevoice_acoustic_tokenizer import ( @@ -161,17 +168,75 @@ class VibeVoiceAsrPreTrainedModel(VibeVoiceAcousticTokenizerPreTrainedModel): _supports_sdpa = True +@dataclass @auto_docstring( custom_intro=""" - The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + Base class for VibeVoice ASR outputs, with hidden states and attentions. + """ +) +class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VibeVoice ASR causal language model outputs. + """ +) +class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores. + past_key_values (`Cache`, *optional*): + Cache instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model), + without a language modeling head. """ ) -class VibeVoiceAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): +class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config) - del self.audio_tower + self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + # Acoustic/semantic tokenizers are run under no_grad in `get_audio_features`; freeze + # their parameters so grad-checkpointing and training sanity checks don't flag them. + for p in self.acoustic_tokenizer_encoder.parameters(): + p.requires_grad_(False) + for p in self.semantic_tokenizer_encoder.parameters(): + p.requires_grad_(False) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) @can_return_tuple @auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.") @@ -184,14 +249,12 @@ def get_audio_features( ): r""" input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`): - Input audio tensor. Audio should be sampled at 24kHz. + Input audio tensor sampled at 24kHz. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`, - but can be modified to fit the available memory. + Size of audio chunks to process at once through the tokenizers. """ - if acoustic_tokenizer_chunk_size is None: acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size else: @@ -236,7 +299,6 @@ def get_audio_features( combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents) if padding_mask is not None: - # Adjust padding mask according to tokenizer compression num_audio_tokens = torch.ceil( padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length ).to(torch.int64) @@ -259,13 +321,88 @@ def forward( padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VibeVoiceAsrModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_values is not None and input_ids is not None: + audio_embeds = self.get_audio_features( + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + ).pooler_output + + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + ) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + + return VibeVoiceAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + """ +) +class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: VibeVoiceAsrConfig): + super().__init__(config) + self.model = VibeVoiceAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + input_values: torch.FloatTensor | None = None, + padding_mask: torch.BoolTensor | None = None, + acoustic_tokenizer_chunk_size: int | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast: r""" padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to - `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory. + Size of audio chunks processed by the acoustic and semantic tokenizers. Example: @@ -275,33 +412,35 @@ def forward( >>> model_id = "microsoft/VibeVoice-ASR-HF" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto") - - >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> outputs = model.generate(**inputs) - - >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) - >>> print(decoded_outputs) ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + **kwargs, + ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_values is not None and input_ids is not None: - audio_embeds = self.get_audio_features( - input_values=input_values, - padding_mask=padding_mask, - acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, - ).pooler_output + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) - # Replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) - return self.language_model( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs + return VibeVoiceAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs): @@ -325,5 +464,6 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg __all__ = [ "VibeVoiceAsrConfig", "VibeVoiceAsrForConditionalGeneration", + "VibeVoiceAsrModel", "VibeVoiceAsrPreTrainedModel", ] diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 76da78cc558f..09328ea6a28c 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -21,6 +21,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass import torch from torch import nn @@ -29,13 +30,13 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig @@ -226,6 +227,7 @@ class VoxtralPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_attention_backend = True _can_compile_fullgraph = True + _keep_in_fp32_modules_strict = ["embed_positions"] @auto_docstring( @@ -359,19 +361,61 @@ def forward(self, audio_features): return hidden_states +@dataclass @auto_docstring( custom_intro=""" - The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + Base class for Voxtral outputs, with hidden states and attentions. """ ) -class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = ["embed_positions"] +class VoxtralModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Voxtral causal language model (or autoregressive) outputs. + """ +) +class VoxtralCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of a Whisper encoder, a multi-modal projector and a LLama language model, + without a language modeling head. + """ +) +class VoxtralModel(VoxtralPreTrainedModel): def __init__(self, config): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = VoxtralMultiModalProjector(config) # Initialize weights and apply final processing @@ -383,18 +427,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -404,11 +436,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. """ audio_outputs = self.audio_tower(input_features, return_dict=True, **kwargs) audio_hidden_states = audio_outputs.last_hidden_state @@ -418,6 +446,81 @@ def get_audio_features( return audio_outputs + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return VoxtralModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of a Whisper encoder, a multi-modal projector and a LLama language model. + """ +) +class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features( + self, input_features: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple | BaseModelOutputWithPooling: + return self.model.get_audio_features(input_features, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -432,7 +535,7 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VoxtralCausalLMOutputWithPast: r""" Example: @@ -466,29 +569,35 @@ def forward( >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) - ) - - outputs: BaseModelOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return VoxtralCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, **kwargs): # Overwritten -- we should not pass input_features when we are in cached decoding stage @@ -505,4 +614,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"] +__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"] diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index c7b2c53e16d4..855962b4bf5b 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -13,6 +13,8 @@ # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn @@ -22,13 +24,13 @@ from ...modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPooling, - CausalLMOutputWithPast, + ModelOutput, ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from ..qwen2_audio.modeling_qwen2_audio import ( Qwen2AudioAttention, Qwen2AudioEncoder, @@ -52,6 +54,7 @@ class VoxtralPreTrainedModel(Qwen2AudioPreTrainedModel): _supports_attention_backend = True _can_compile_fullgraph = True _no_split_modules = None + _keep_in_fp32_modules_strict = ["embed_positions"] # TODO: @eustlb, I would really prefer to use WhisperEncoder but it's messing with modular @@ -128,19 +131,61 @@ def forward(self, audio_features): return hidden_states +@dataclass @auto_docstring( custom_intro=""" - The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + Base class for Voxtral outputs, with hidden states and attentions. """ ) -class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = ["embed_positions"] +class VoxtralModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Voxtral causal language model (or autoregressive) outputs. + """ +) +class VoxtralCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of a Whisper encoder, a multi-modal projector and a LLama language model, + without a language modeling head. + """ +) +class VoxtralModel(VoxtralPreTrainedModel): def __init__(self, config): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = VoxtralMultiModalProjector(config) # Initialize weights and apply final processing @@ -152,18 +197,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -173,11 +206,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. """ audio_outputs = self.audio_tower(input_features, return_dict=True, **kwargs) audio_hidden_states = audio_outputs.last_hidden_state @@ -187,6 +216,81 @@ def get_audio_features( return audio_outputs + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return VoxtralModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of a Whisper encoder, a multi-modal projector and a LLama language model. + """ +) +class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features( + self, input_features: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple | BaseModelOutputWithPooling: + return self.model.get_audio_features(input_features, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -201,7 +305,7 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VoxtralCausalLMOutputWithPast: r""" Example: @@ -235,29 +339,35 @@ def forward( >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) - ) - - outputs: BaseModelOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return VoxtralCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, **kwargs): # Overwritten -- we should not pass input_features when we are in cached decoding stage @@ -274,4 +384,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"] +__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"] diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index 07325b0ea559..087c9ed48700 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -42,7 +42,6 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel from .configuration_voxtral_realtime import ( VoxtralRealtimeConfig, VoxtralRealtimeEncoderConfig, @@ -118,6 +117,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast): padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None +@dataclass +class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast): + r""" + Args: + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder + that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states before they are added to the text embeddings. + """ + + encoder_past_key_values: Cache | None = None + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @dataclass class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast): r""" @@ -480,6 +497,7 @@ class VoxtralRealtimePreTrainedModel(PreTrainedModel): _supports_attention_backend = True # TODO: @eustlb, this should be enabled soon _can_compile_fullgraph = False + _keep_in_fp32_modules_strict = None @torch.no_grad() def _init_weights(self, module): @@ -820,80 +838,6 @@ def forward( ) -@auto_docstring -class VoxtralRealtimeTextForCausalLM(VoxtralRealtimeTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_gather_output"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config): - super().__init__(config) - self.model = VoxtralRealtimeTextModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM - - >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - class VoxtralRealtimeTimeEmbedding(nn.Module): """Sinusoidal Embedding for encoding time""" @@ -928,17 +872,15 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The VoxtralRealtime model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector, + a Mistral-based language model and a time embedding, without a language modeling head. """ ) -class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - +class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config) + self.audio_tower = VoxtralRealtimeEncoder(config.audio_config) + self.language_model = VoxtralRealtimeTextModel(config.text_config) self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config) self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size) @@ -951,18 +893,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -978,11 +908,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): @@ -1020,43 +946,20 @@ def forward( padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, num_delay_tokens: int | torch.Tensor = None, **kwargs: Unpack[TransformersKwargs], - ) -> VoxtralRealtimeCausalLMOutputWithPast: + ) -> tuple | VoxtralRealtimeModelOutputWithPast: r""" encoder_past_key_values (`Cache`, *optional*): - Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. num_delay_tokens (`int` or `torch.Tensor`, *optional*): - Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. - - Example: - - ```python - >>> import torch - >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor - >>> from datasets import load_dataset - - >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" - - >>> processor = AutoProcessor.from_pretrained(repo_id) - >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> audio = ds[0]["audio"]["array"] - - >>> inputs = processor(audio, return_tensors="pt") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - - >>> outputs = model.generate(**inputs) - >>> processor.batch_decode(outputs, skip_special_tokens=True) - ```""" + Number of delay tokens used when preparing inputs. + """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1066,6 +969,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_outputs = None + audio_embeds = None if input_features is not None or encoder_inputs_embeds is not None: audio_outputs = self.get_audio_features( input_features=input_features, @@ -1075,7 +980,8 @@ def forward( use_cache=use_cache, return_dict=True, ) - inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device) + audio_embeds = audio_outputs.pooler_output + inputs_embeds = inputs_embeds + audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens @@ -1094,25 +1000,140 @@ def forward( t_cond = self.time_embedding(time_tensor) t_cond = t_cond[None, ...] # broadcastable to batch size - outputs: CausalLMOutputWithPast = self.language_model( + outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, t_cond=t_cond, **kwargs, ) + + return VoxtralRealtimeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + encoder_past_key_values=audio_outputs.past_key_values + if (audio_outputs is not None and use_cache) + else None, + padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None, + audio_hidden_states=audio_embeds, + ) + + +class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralRealtimeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @property + def audio_tower(self): + return self.model.audio_tower + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + encoder_past_key_values: Cache | None = None, + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + encoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + num_delay_tokens: int | torch.Tensor = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast: + r""" + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. + num_delay_tokens (`int` or `torch.Tensor`, *optional*): + Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. + + Example: + + ```python + >>> import torch + >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset + + >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" + + >>> processor = AutoProcessor.from_pretrained(repo_id) + >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> audio = ds[0]["audio"]["array"] + + >>> inputs = processor(audio, return_tensors="pt") + >>> inputs = inputs.to(model.device, dtype=model.dtype) + + >>> outputs = model.generate(**inputs) + >>> processor.batch_decode(outputs, skip_special_tokens=True) + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + inputs_embeds=inputs_embeds, + encoder_inputs_embeds=encoder_inputs_embeds, + use_cache=use_cache, + num_delay_tokens=num_delay_tokens, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + return VoxtralRealtimeCausalLMOutputWithPast( - loss=outputs.loss, - logits=outputs.logits, + loss=loss, + logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - encoder_past_key_values=audio_outputs.past_key_values if use_cache else None, - padding_cache=audio_outputs.padding_cache if use_cache else None, + encoder_past_key_values=outputs.encoder_past_key_values, + padding_cache=outputs.padding_cache, ) def prepare_inputs_for_generation( @@ -1308,4 +1329,9 @@ def _prepare_generated_length( return generation_config -__all__ = ["VoxtralRealtimeForConditionalGeneration", "VoxtralRealtimeEncoder", "VoxtralRealtimePreTrainedModel"] +__all__ = [ + "VoxtralRealtimeForConditionalGeneration", + "VoxtralRealtimeEncoder", + "VoxtralRealtimePreTrainedModel", + "VoxtralRealtimeModel", +] diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index edad37679927..08780f9ad706 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -31,13 +31,11 @@ from ...models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, - MistralForCausalLM, MistralMLP, MistralModel, MistralRMSNorm, ) from ...models.voxtral.modeling_voxtral import ( - VoxtralForConditionalGeneration, VoxtralMultiModalProjector, VoxtralPreTrainedModel, ) @@ -116,6 +114,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast): padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None +@dataclass +class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast): + r""" + Args: + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder + that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states before they are added to the text embeddings. + """ + + encoder_past_key_values: Cache | None = None + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @dataclass class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast): r""" @@ -255,6 +271,7 @@ def forward( class VoxtralRealtimePreTrainedModel(VoxtralPreTrainedModel, PreTrainedModel): # TODO: @eustlb, this should be enabled soon _can_compile_fullgraph = False + _keep_in_fp32_modules_strict = None @torch.no_grad() def _init_weights(self, module): @@ -436,66 +453,6 @@ def __init__(self, config): self.rotary_emb = VoxtralRealtimeRotaryEmbedding(config=config) -class VoxtralRealtimeTextForCausalLM(MistralForCausalLM): - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM - - >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - class VoxtralRealtimeTimeEmbedding(nn.Module): """Sinusoidal Embedding for encoding time""" @@ -520,14 +477,29 @@ def __init__(self, config): ) -class VoxtralRealtimeForConditionalGeneration(VoxtralForConditionalGeneration, GenerationMixin): - _keep_in_fp32_modules_strict = None - +@auto_docstring( + custom_intro=""" + The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector, + a Mistral-based language model and a time embedding, without a language modeling head. + """ +) +class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel): def __init__(self, config): super().__init__(config) - self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config) + self.audio_tower = VoxtralRealtimeEncoder(config.audio_config) + self.language_model = VoxtralRealtimeTextModel(config.text_config) + self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config) self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -543,11 +515,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): @@ -585,43 +553,20 @@ def forward( padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, num_delay_tokens: int | torch.Tensor = None, **kwargs: Unpack[TransformersKwargs], - ) -> VoxtralRealtimeCausalLMOutputWithPast: + ) -> tuple | VoxtralRealtimeModelOutputWithPast: r""" encoder_past_key_values (`Cache`, *optional*): - Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. num_delay_tokens (`int` or `torch.Tensor`, *optional*): - Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. - - Example: - - ```python - >>> import torch - >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor - >>> from datasets import load_dataset - - >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" - - >>> processor = AutoProcessor.from_pretrained(repo_id) - >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> audio = ds[0]["audio"]["array"] - - >>> inputs = processor(audio, return_tensors="pt") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - - >>> outputs = model.generate(**inputs) - >>> processor.batch_decode(outputs, skip_special_tokens=True) - ```""" + Number of delay tokens used when preparing inputs. + """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -631,6 +576,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_outputs = None + audio_embeds = None if input_features is not None or encoder_inputs_embeds is not None: audio_outputs = self.get_audio_features( input_features=input_features, @@ -640,7 +587,8 @@ def forward( use_cache=use_cache, return_dict=True, ) - inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device) + audio_embeds = audio_outputs.pooler_output + inputs_embeds = inputs_embeds + audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens @@ -659,25 +607,138 @@ def forward( t_cond = self.time_embedding(time_tensor) t_cond = t_cond[None, ...] # broadcastable to batch size - outputs: CausalLMOutputWithPast = self.language_model( + outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, t_cond=t_cond, **kwargs, ) + + return VoxtralRealtimeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + encoder_past_key_values=audio_outputs.past_key_values if (audio_outputs is not None and use_cache) else None, + padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None, + audio_hidden_states=audio_embeds, + ) + + +class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralRealtimeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @property + def audio_tower(self): + return self.model.audio_tower + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + encoder_past_key_values: Cache | None = None, + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + encoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + num_delay_tokens: int | torch.Tensor = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast: + r""" + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. + num_delay_tokens (`int` or `torch.Tensor`, *optional*): + Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. + + Example: + + ```python + >>> import torch + >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset + + >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" + + >>> processor = AutoProcessor.from_pretrained(repo_id) + >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> audio = ds[0]["audio"]["array"] + + >>> inputs = processor(audio, return_tensors="pt") + >>> inputs = inputs.to(model.device, dtype=model.dtype) + + >>> outputs = model.generate(**inputs) + >>> processor.batch_decode(outputs, skip_special_tokens=True) + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + inputs_embeds=inputs_embeds, + encoder_inputs_embeds=encoder_inputs_embeds, + use_cache=use_cache, + num_delay_tokens=num_delay_tokens, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + return VoxtralRealtimeCausalLMOutputWithPast( - loss=outputs.loss, - logits=outputs.logits, + loss=loss, + logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - encoder_past_key_values=audio_outputs.past_key_values if use_cache else None, - padding_cache=audio_outputs.padding_cache if use_cache else None, + encoder_past_key_values=outputs.encoder_past_key_values, + padding_cache=outputs.padding_cache, ) def prepare_inputs_for_generation( @@ -705,7 +766,7 @@ def _prepare_model_inputs( bos_token_id: torch.Tensor | None = None, model_kwargs: dict[str, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, str | None, dict[str, torch.Tensor]]: - inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs(inputs, bos_token_id, model_kwargs) + inputs, input_name, model_kwargs = super()._prepare_model_inputs(inputs, bos_token_id, model_kwargs) input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): @@ -725,7 +786,7 @@ def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, if getattr(self, "_stream_exhausted", False): self._stream_exhausted = False return False - return GenerationMixin._has_unfinished_sequences(this_peer_finished, synced_gpus, device) + return super()._has_unfinished_sequences(this_peer_finished, synced_gpus, device) def _update_model_kwargs_for_generation( self, @@ -734,7 +795,7 @@ def _update_model_kwargs_for_generation( is_encoder_decoder: bool = False, num_new_tokens: int = 1, ): - model_kwargs = GenerationMixin._update_model_kwargs_for_generation( + model_kwargs = super()._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder, num_new_tokens ) @@ -761,7 +822,7 @@ def _prepare_cache_for_generation( batch_size: int, max_cache_length: int, ): - GenerationMixin._prepare_cache_for_generation( + super()._prepare_cache_for_generation( generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) @@ -815,7 +876,7 @@ def _prepare_generation_config( generation_config is not None and generation_config.max_new_tokens is not None ) - generation_config, model_kwargs = GenerationMixin._prepare_generation_config(generation_config, **kwargs) + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): @@ -854,7 +915,7 @@ def _prepare_generated_length( if getattr(generation_config, "_voxtral_set_max_length", False): has_default_max_length = False - generation_config = GenerationMixin._prepare_generated_length( + generation_config = super()._prepare_generated_length( generation_config, has_default_max_length, has_default_min_length, @@ -877,4 +938,5 @@ def _prepare_generated_length( "VoxtralRealtimeForConditionalGeneration", "VoxtralRealtimeEncoder", "VoxtralRealtimePreTrainedModel", + "VoxtralRealtimeModel", ] From 79b833b9bf05e92b9d800b8ff5b32c26d84373d3 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:00:04 +0200 Subject: [PATCH 02/39] ensure BC via conversion mapping --- src/transformers/conversion_mapping.py | 53 ++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 7dd2a8826448..832550e63245 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -89,6 +89,59 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], + "qwen2_audio": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "voxtral": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "voxtral_realtime": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "audioflamingo3": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "glmasr": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "musicflamingo": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "granite_speech": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^encoder", target_patterns="model.encoder"), + WeightRenaming(source_patterns=r"^projector", target_patterns="model.projector"), + ], + "vibevoice_asr": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming( + source_patterns=r"^acoustic_tokenizer_encoder", target_patterns="model.acoustic_tokenizer_encoder" + ), + WeightRenaming( + source_patterns=r"^semantic_tokenizer_encoder", target_patterns="model.semantic_tokenizer_encoder" + ), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], "llava_next": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), From 57436db1c3d833b3fc0e78d2a5e832eb0d50fc30 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:00:23 +0200 Subject: [PATCH 03/39] auto classes --- src/transformers/models/auto/modeling_auto.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index deb1153d335e..f06904b20885 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -52,7 +52,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("aria", "AriaModel"), ("aria_text", "AriaTextModel"), ("audio-spectrogram-transformer", "ASTModel"), - ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"), + ("audioflamingo3", "AudioFlamingo3Model"), ("audioflamingo3_encoder", "AudioFlamingo3Encoder"), ("autoformer", "AutoformerModel"), ("aya_vision", "AyaVisionModel"), @@ -196,7 +196,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("glm_ocr", "GlmOcrModel"), ("glm_ocr_text", "GlmOcrTextModel"), ("glm_ocr_vision", "GlmOcrVisionModel"), - ("glmasr", "GlmAsrForConditionalGeneration"), + ("glmasr", "GlmAsrModel"), ("glmasr_encoder", "GlmAsrEncoder"), ("glpn", "GLPNModel"), ("got_ocr2", "GotOcr2Model"), @@ -209,6 +209,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gpt_oss", "GptOssModel"), ("gptj", "GPTJModel"), ("granite", "GraniteModel"), + ("granite_speech", "GraniteSpeechModel"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), ("granitemoeshared", "GraniteMoeSharedModel"), @@ -308,7 +309,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mpt", "MptModel"), ("mra", "MraModel"), ("mt5", "MT5Model"), - ("musicflamingo", "MusicFlamingoForConditionalGeneration"), + ("musicflamingo", "MusicFlamingoModel"), ("musicgen", "MusicgenModel"), ("musicgen_melody", "MusicgenMelodyModel"), ("mvp", "MvpModel"), @@ -366,6 +367,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen2", "Qwen2Model"), ("qwen2_5_vl", "Qwen2_5_VLModel"), ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"), + ("qwen2_audio", "Qwen2AudioModel"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("qwen2_vl", "Qwen2VLModel"), @@ -460,7 +462,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerModel"), ("vibevoice_acoustic_tokenizer_decoder", "VibeVoiceAcousticTokenizerDecoderModel"), ("vibevoice_acoustic_tokenizer_encoder", "VibeVoiceAcousticTokenizerEncoderModel"), - ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"), + ("vibevoice_asr", "VibeVoiceAsrModel"), ("video_llama_3", "VideoLlama3Model"), ("video_llama_3_vision", "VideoLlama3VisionModel"), ("video_llava", "VideoLlavaModel"), @@ -476,9 +478,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("vits", "VitsModel"), ("vivit", "VivitModel"), ("vjepa2", "VJEPA2Model"), - ("voxtral", "VoxtralForConditionalGeneration"), + ("voxtral", "VoxtralModel"), ("voxtral_encoder", "VoxtralEncoder"), - ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"), + ("voxtral_realtime", "VoxtralRealtimeModel"), ("voxtral_realtime_encoder", "VoxtralRealtimeEncoder"), ("voxtral_realtime_text", "VoxtralRealtimeTextModel"), ("wav2vec2", "Wav2Vec2Model"), From 1bd269c437b7b1fc57f37c9fdc53f4b94ffabea8 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:00:36 +0200 Subject: [PATCH 04/39] test updates --- .../test_modeling_audioflamingo3.py | 16 +++----- tests/models/glmasr/test_modeling_glmasr.py | 16 +++----- .../test_modeling_granite_speech.py | 13 ++++--- .../test_modeling_musicflamingo.py | 22 +++++------ .../qwen2_audio/test_modeling_qwen2_audio.py | 16 +++----- .../test_modeling_vibevoice_asr.py | 38 +++++++++++++++---- tests/models/voxtral/test_modeling_voxtral.py | 16 +++----- 7 files changed, 72 insertions(+), 65 deletions(-) diff --git a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py index 7301812e7032..e436770452ad 100644 --- a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py +++ b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py @@ -190,10 +190,6 @@ def test_sdpa_can_dispatch_on_flash(self): def test_flash_attn_2_inference_equivalence_right_padding(self): pass - @unittest.skip(reason="AudioFlamingo3 has no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - def test_sdpa_can_dispatch_composite_models(self): # AF3 is audio+text composite; verify SDPA toggles propagate to submodules. if not self.has_attentions: @@ -213,19 +209,19 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - audio_attn = "sdpa" if model.audio_tower._supports_sdpa else "eager" + text_attn = "sdpa" if model.model.language_model._supports_sdpa else "eager" + audio_attn = "sdpa" if model.model.audio_tower._supports_sdpa else "eager" self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.audio_tower.config._attn_implementation == audio_attn) + self.assertTrue(model.model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.model.audio_tower.config._attn_implementation == audio_attn) # Eager model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") - self.assertTrue(model_eager.audio_tower.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.audio_tower.config._attn_implementation == "eager") for _, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ diff --git a/tests/models/glmasr/test_modeling_glmasr.py b/tests/models/glmasr/test_modeling_glmasr.py index 744e268e74c7..8637a27c5617 100644 --- a/tests/models/glmasr/test_modeling_glmasr.py +++ b/tests/models/glmasr/test_modeling_glmasr.py @@ -167,10 +167,6 @@ def test_sdpa_can_dispatch_on_flash(self): def test_flash_attn_2_inference_equivalence_right_padding(self): pass - @unittest.skip(reason="GlmAsr has no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - def test_sdpa_can_dispatch_composite_models(self): # GlmAsr is audio+text composite; verify SDPA toggles propagate to submodules. if not self.has_attentions: @@ -189,19 +185,19 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - audio_attn = "sdpa" if model.audio_tower._supports_sdpa else "eager" + text_attn = "sdpa" if model.model.language_model._supports_sdpa else "eager" + audio_attn = "sdpa" if model.model.audio_tower._supports_sdpa else "eager" self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.audio_tower.config._attn_implementation == audio_attn) + self.assertTrue(model.model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.model.audio_tower.config._attn_implementation == audio_attn) # Eager model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") - self.assertTrue(model_eager.audio_tower.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.audio_tower.config._attn_implementation == "eager") for _, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ diff --git a/tests/models/granite_speech/test_modeling_granite_speech.py b/tests/models/granite_speech/test_modeling_granite_speech.py index c5e7aa3defcd..0c71c3d5bb48 100644 --- a/tests/models/granite_speech/test_modeling_granite_speech.py +++ b/tests/models/granite_speech/test_modeling_granite_speech.py @@ -271,17 +271,17 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + text_attn = "sdpa" if model.model.language_model._supports_sdpa else "eager" # `None` as it is the requested one which will be assigned to each sub-config # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.model.language_model.config._attn_implementation == text_attn) model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.language_model.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ @@ -294,8 +294,11 @@ def test_sdpa_can_dispatch_composite_models(self): def test_eager_matches_sdpa_generate(self): pass - @unittest.skip(reason="GraniteSpeech has no separate base model without a head.") - def test_model_base_model_prefix(self): + @unittest.skip( + reason="The bundled BertEncoder in GraniteSpeechEncoderProjector rewrites LayerNorm gamma/beta ↔ weight/bias " + "on save/load, which trips `test_reverse_loading_mapping` independently of the VLM base-class refactor." + ) + def test_reverse_loading_mapping(self): pass diff --git a/tests/models/musicflamingo/test_modeling_musicflamingo.py b/tests/models/musicflamingo/test_modeling_musicflamingo.py index 8c3b0ce549c8..71f505b9a3b7 100644 --- a/tests/models/musicflamingo/test_modeling_musicflamingo.py +++ b/tests/models/musicflamingo/test_modeling_musicflamingo.py @@ -176,7 +176,7 @@ def setUp(self): def test_rotary_window_axis_resets_per_audio(self): config = self.model_tester.get_config() - pos_emb = MusicFlamingoForConditionalGeneration(config).pos_emb.to(torch_device) + pos_emb = MusicFlamingoForConditionalGeneration(config).model.pos_emb.to(torch_device) timestamps = torch.tensor( [ @@ -206,7 +206,7 @@ def test_build_audio_timestamps_reconstructs_windows_from_input_ids(self): input_ids[0, :45] = config.audio_token_id input_ids[1, :30] = config.audio_token_id - _, post_lengths = model.audio_tower._get_feat_extract_output_lengths( + _, post_lengths = model.model.audio_tower._get_feat_extract_output_lengths( input_features_mask.sum(-1).to(torch.long) ) max_post_length = int(post_lengths.max().item()) @@ -224,7 +224,7 @@ def test_build_audio_timestamps_reconstructs_windows_from_input_ids(self): ] ) - inferred = model._build_audio_timestamps(input_ids, post_lengths, max_post_length) + inferred = model.model._build_audio_timestamps(input_ids, post_lengths, max_post_length) torch.testing.assert_close(inferred, audio_timestamps) @unittest.skip( @@ -246,10 +246,6 @@ def test_sdpa_can_dispatch_on_flash(self): def test_flash_attn_2_inference_equivalence_right_padding(self): pass - @unittest.skip(reason="MusicFlamingo has no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - def test_sdpa_can_dispatch_composite_models(self): # MusicFlamingo is audio+text composite; verify SDPA toggles propagate to submodules. if not self.has_attentions: @@ -269,19 +265,19 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - audio_attn = "sdpa" if model.audio_tower._supports_sdpa else "eager" + text_attn = "sdpa" if model.model.language_model._supports_sdpa else "eager" + audio_attn = "sdpa" if model.model.audio_tower._supports_sdpa else "eager" self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.audio_tower.config._attn_implementation == audio_attn) + self.assertTrue(model.model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.model.audio_tower.config._attn_implementation == audio_attn) # Eager model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") - self.assertTrue(model_eager.audio_tower.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.audio_tower.config._attn_implementation == "eager") for _, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 4df16b9f6f4b..4c51fc7ae3bf 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -159,10 +159,6 @@ def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_dispatch_on_flash(self): pass - @unittest.skip(reason="Qwen2Audio has no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - def test_sdpa_can_dispatch_composite_models(self): # overwrite because Qwen2 is audio+text model (not vision+text) if not self.has_attentions: @@ -180,20 +176,20 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - vision_attn = "sdpa" if model.audio_tower._supports_sdpa else "eager" + text_attn = "sdpa" if model.model.language_model._supports_sdpa else "eager" + vision_attn = "sdpa" if model.model.audio_tower._supports_sdpa else "eager" # `None` as it is the requested one which will be assigned to each sub-config # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.audio_tower.config._attn_implementation == vision_attn) + self.assertTrue(model.model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.model.audio_tower.config._attn_implementation == vision_attn) model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") - self.assertTrue(model_eager.audio_tower.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.audio_tower.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py index be0ece165e36..669e2f65aef1 100644 --- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py +++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py @@ -163,14 +163,38 @@ def test_sdpa_can_dispatch_on_flash(self): def test_flash_attn_2_inference_equivalence_right_padding(self): pass - @unittest.skip(reason="VibeVoiceAsr has no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - @unittest.skip(reason="VibeVoiceAsr audio components do not use attention.") def test_get_audio_features_attentions(self): pass + @unittest.skip(reason="VibeVoiceAsr has slight randomness due to VAE sampling in get_audio_features.") + def test_forward_with_logits_to_keep(self): + pass + + @unittest.skip(reason="VibeVoiceAsr has slight randomness due to VAE sampling in get_audio_features.") + def test_generate_methods_with_logits_to_keep(self): + pass + + @unittest.skip( + reason="VibeVoiceAsr's acoustic/semantic tokenizer encoders run under torch.no_grad() and are " + "frozen by design, so they don't receive gradients — which trips the global 'all params have " + "gradients' check." + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="Same as test_training_gradient_checkpointing: tokenizer encoders are frozen." + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip( + reason="Same as test_training_gradient_checkpointing: tokenizer encoders are frozen." + ) + def test_training_gradient_checkpointing_use_reentrant_true(self): + pass + @unittest.skip(reason="VibeVoiceAsr has unique audio processing with acoustic and semantic tokenizers.") def test_get_audio_features_hidden_states(self): pass @@ -211,16 +235,16 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + text_attn = "sdpa" if model.model.language_model._supports_sdpa else "eager" self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.model.language_model.config._attn_implementation == text_attn) # Eager model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.language_model.config._attn_implementation == "eager") for _, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ diff --git a/tests/models/voxtral/test_modeling_voxtral.py b/tests/models/voxtral/test_modeling_voxtral.py index 0cff2a66779b..557b8ecf9b8c 100644 --- a/tests/models/voxtral/test_modeling_voxtral.py +++ b/tests/models/voxtral/test_modeling_voxtral.py @@ -192,10 +192,6 @@ def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self): def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): pass - @unittest.skip(reason="Voxtral has no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - def test_sdpa_can_dispatch_composite_models(self): # overwrite because Voxtral is audio+text model (not vision+text) if not self.has_attentions: @@ -213,20 +209,20 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" - vision_attn = "sdpa" if model.audio_tower._supports_sdpa else "eager" + text_attn = "sdpa" if model.model.language_model._supports_sdpa else "eager" + vision_attn = "sdpa" if model.model.audio_tower._supports_sdpa else "eager" # `None` as it is the requested one which will be assigned to each sub-config # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) - self.assertTrue(model.audio_tower.config._attn_implementation == vision_attn) + self.assertTrue(model.model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.model.audio_tower.config._attn_implementation == vision_attn) model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") - self.assertTrue(model_eager.audio_tower.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.model.audio_tower.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ From 9010366bccaf8c77024dc68deb7684b48b92a6f3 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 22 Apr 2026 18:37:29 +0200 Subject: [PATCH 05/39] ensure BC for accessing attributes --- .../audioflamingo3/modeling_audioflamingo3.py | 2 + .../audioflamingo3/modular_audioflamingo3.py | 2 + .../models/glmasr/modeling_glmasr.py | 2 + .../models/glmasr/modular_glmasr.py | 2 + .../granite_speech/modeling_granite_speech.py | 2 + .../musicflamingo/modeling_musicflamingo.py | 2 + .../musicflamingo/modular_musicflamingo.py | 2 + .../qwen2_audio/modeling_qwen2_audio.py | 2 + .../vibevoice_asr/modeling_vibevoice_asr.py | 2 + .../vibevoice_asr/modular_vibevoice_asr.py | 2 + .../models/voxtral/modeling_voxtral.py | 2 + .../models/voxtral/modular_voxtral.py | 2 + .../modeling_voxtral_realtime.py | 2 + .../modular_voxtral_realtime.py | 2 + src/transformers/utils/deprecation.py | 56 +++++++++++++++++++ 15 files changed, 84 insertions(+) diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index 8e22270ae5fd..e1192b948b7f 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -36,6 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -547,6 +548,7 @@ def forward( The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. """ ) +@forward_base_model_attrs(version="5.7") class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} _tp_plan = None diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index 75c766c5982c..7b5ce21bac37 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -24,6 +24,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..qwen2_audio.modeling_qwen2_audio import ( @@ -265,6 +266,7 @@ def forward( The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. """ ) +@forward_base_model_attrs(version="5.7") class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): _tp_plan = None _pp_plan = None diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index ddc43d4e624f..79b25ba5db65 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -37,6 +37,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -504,6 +505,7 @@ class GlmAsrCausalLMOutputWithPast(ModelOutput): The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) +@forward_base_model_attrs(version="5.7") class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} _tp_plan = None diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index 38ec9dcdb071..cbf5b1d71f42 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -25,6 +25,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..audioflamingo3.modeling_audioflamingo3 import ( @@ -387,6 +388,7 @@ def get_audio_features( The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) +@forward_base_model_attrs(version="5.7") class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index ebd4b91283de..ca0094bd1e3f 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -33,6 +33,7 @@ logging, torch_compilable_check, ) +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -473,6 +474,7 @@ def forward( ) +@forward_base_model_attrs(version="5.7") class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index adeb0c89d01c..4fe94ac6ce76 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -35,6 +35,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available +from ...utils.deprecation import forward_base_model_attrs from ..auto import AutoModel from .configuration_musicflamingo import MusicFlamingoConfig @@ -382,6 +383,7 @@ class MusicFlamingoCausalLMOutputWithPast(ModelOutput): The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. """ ) +@forward_base_model_attrs(version="5.7") class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} _tp_plan = None diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index 8fb1ead23cd1..a5dafca76f51 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -27,6 +27,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available +from ...utils.deprecation import forward_base_model_attrs from ..audioflamingo3.configuration_audioflamingo3 import AudioFlamingo3Config from ..audioflamingo3.modeling_audioflamingo3 import ( AudioFlamingo3ForConditionalGeneration, @@ -388,6 +389,7 @@ def forward( The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. """ ) +@forward_base_model_attrs(version="5.7") class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): def __init__(self, config: MusicFlamingoConfig): super().__init__(config) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index fa319bc70a64..4fe9027df026 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -771,6 +772,7 @@ def forward( The QWEN2AUDIO model which consists of an audio backbone and a language model. """ ) +@forward_base_model_attrs(version="5.7") class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index e02d0f647b23..a6d0c8320e25 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -31,6 +31,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils.deprecation import forward_base_model_attrs from ..auto import AutoModel from .configuration_vibevoice_asr import VibeVoiceAsrConfig @@ -442,6 +443,7 @@ def forward( The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. """ ) +@forward_base_model_attrs(version="5.7") class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index 93391c91f14c..b3d515ee8881 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -28,6 +28,7 @@ ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import forward_base_model_attrs from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel from ..qwen2.modeling_qwen2 import Qwen2RMSNorm from ..vibevoice_acoustic_tokenizer.modeling_vibevoice_acoustic_tokenizer import ( @@ -359,6 +360,7 @@ def forward( The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. """ ) +@forward_base_model_attrs(version="5.7") class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 09328ea6a28c..74bbb38048a3 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -34,6 +34,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -495,6 +496,7 @@ def forward( The Voxtral model, which consists of a Whisper encoder, a multi-modal projector and a LLama language model. """ ) +@forward_base_model_attrs(version="5.7") class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index 855962b4bf5b..56880107ff5b 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -28,6 +28,7 @@ ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -265,6 +266,7 @@ def forward( The Voxtral model, which consists of a Whisper encoder, a multi-modal projector and a LLama language model. """ ) +@forward_base_model_attrs(version="5.7") class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index 087c9ed48700..31b3850c591e 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -40,6 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from .configuration_voxtral_realtime import ( @@ -1023,6 +1024,7 @@ def forward( ) +@forward_base_model_attrs(version="5.7") class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index 08780f9ad706..4f12393d2e02 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -41,6 +41,7 @@ ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from .configuration_voxtral_realtime import VoxtralRealtimeEncoderConfig @@ -628,6 +629,7 @@ def forward( ) +@forward_base_model_attrs(version="5.7") class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/utils/deprecation.py b/src/transformers/utils/deprecation.py index db0e67325d78..7091bfa6759a 100644 --- a/src/transformers/utils/deprecation.py +++ b/src/transformers/utils/deprecation.py @@ -173,3 +173,59 @@ def wrapped_func(*args, **kwargs): return wrapped_func return wrapper + + +def forward_base_model_attrs(version: str): + """ + Class decorator that forwards attribute access to the base model (`self.`) + when the attribute is not found on the instance directly, and warns that direct access on the + outer class is deprecated. + + Intended for backward compatibility during refactors that move submodules from the outer + `*ForConditionalGeneration` class down to the inner base model — e.g. `model.language_model` + becoming `model.model.language_model`. + + Apply only to the outer wrapper class (the `*ForConditionalGeneration`), not to the inner + base model itself. The decorator relies on `base_model_prefix` being set on the class (which + `PreTrainedModel` subclasses always do). + + Args: + version (`str`): + The Transformers version in which direct access will be removed (e.g. `"5.7"`). + """ + + def decorator(cls): + # Resolve the inherited __getattr__ (typically nn.Module's, which looks up + # submodules/parameters/buffers) so we can delegate to it without recursing. + inherited_getattr = cls.__getattr__ + + def __getattr__(self, name): + # First, the normal nn.Module lookup (submodules, parameters, buffers). + try: + return inherited_getattr(self, name) + except AttributeError: + pass + # Only forward public attributes to the base model — private names are + # framework internals (e.g. `_is_hf_initialized`) and shouldn't warn. + if name.startswith("_"): + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + prefix = type(self).base_model_prefix + try: + base = inherited_getattr(self, prefix) + except AttributeError: + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + if hasattr(base, name): + if not is_torchdynamo_compiling(): + warnings.warn( + f"Accessing `{name}` directly on `{type(self).__name__}` is deprecated and " + f"will be removed in Transformers v{version}. Use `.{prefix}.{name}` instead.", + FutureWarning, + stacklevel=2, + ) + return getattr(base, name) + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + + cls.__getattr__ = __getattr__ + return cls + + return decorator From af429ef810ee0b0785411e85c5fcc54936053211 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:19:33 +0200 Subject: [PATCH 06/39] simplify conversion mapping --- src/transformers/conversion_mapping.py | 35 ++++---------------------- 1 file changed, 5 insertions(+), 30 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 832550e63245..c60806e70229 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -73,6 +73,11 @@ "qwen2_5_vl": "qwen2_vl", "sam3_tracker_video": "sam3_tracker", "pp_chart2table": "llava", + "voxtral": "qwen2_audio", + "voxtral_realtime": "qwen2_audio", + "audioflamingo3": "qwen2_audio", + "glmasr": "qwen2_audio", + "musicflamingo": "qwen2_audio", "gemma3n_text": "qwen3_5_text", "qwen3_5_moe_text": "qwen3_5_text", } @@ -95,36 +100,6 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], - "voxtral": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), - WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), - WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), - WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), - ], - "voxtral_realtime": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), - WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), - WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), - WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), - ], - "audioflamingo3": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), - WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), - WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), - WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), - ], - "glmasr": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), - WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), - WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), - WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), - ], - "musicflamingo": [ - WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), - WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), - WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), - WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), - ], "granite_speech": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), From 1da57e8b025334b689a4fad0c7d3a42e80c5ee5f Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 11 May 2026 12:04:57 +0200 Subject: [PATCH 07/39] convert modular --- .../audioflamingo3/modeling_audioflamingo3.py | 3 +- src/transformers/models/auto/modeling_auto.py | 2 +- .../models/glmasr/modeling_glmasr.py | 9 +- .../granite_speech/modeling_granite_speech.py | 4 +- .../modeling_granite_speech_plus.py | 232 +++++++++++------- .../musicflamingo/modeling_musicflamingo.py | 4 +- .../musicflamingo/modular_musicflamingo.py | 2 +- .../vibevoice_asr/modeling_vibevoice_asr.py | 41 +++- .../vibevoice_asr/modular_vibevoice_asr.py | 1 - .../modular_voxtral_realtime.py | 4 +- .../test_modeling_vibevoice_asr.py | 8 +- 11 files changed, 190 insertions(+), 120 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index ea6966d52bd7..bbb17b8ccee2 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -448,11 +448,10 @@ def forward(self, audio_features): """ ) class AudioFlamingo3Model(AudioFlamingo3PreTrainedModel): -class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None _supports_attention_backend = True _tp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e6f7f3523465..7bced9ad55eb 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -213,8 +213,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gpt_oss", "GptOssModel"), ("gptj", "GPTJModel"), ("granite", "GraniteModel"), - ("granite_speech", "GraniteSpeechModel"), ("granite4_vision", "Granite4VisionModel"), + ("granite_speech", "GraniteSpeechModel"), ("granite_speech", "GraniteSpeechForConditionalGeneration"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index 44cd27ba6a03..438e5c6d9d52 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -370,20 +370,15 @@ class GlmAsrModelOutputWithPast(BaseModelOutputWithPast): @auto_docstring( custom_intro=""" - The GlmAsr model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), - without a language modeling head. + The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -<<<<<<< alm-base-model-class class GlmAsrModel(GlmAsrPreTrainedModel): -======= -class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None _supports_attention_backend = True _tp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None ->>>>>>> main def __init__(self, config): super().__init__(config) self.vocab_size = config.text_config.vocab_size diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 58c20970645c..4b86aeaf48c4 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -561,7 +561,9 @@ def forward( ) if not return_dict: - output = (logits,) + tuple(v for v in (outputs.past_key_values, outputs.hidden_states, outputs.attentions) if v is not None) + output = (logits,) + tuple( + v for v in (outputs.past_key_values, outputs.hidden_states, outputs.attentions) if v is not None + ) return (loss,) + output if loss is not None else output return GraniteSpeechCausalLMOutputWithPast( diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py index 9ff819aa3b6c..dde7458ac105 100644 --- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py @@ -28,7 +28,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -41,7 +41,7 @@ ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_granite_speech_plus import GraniteSpeechPlusConfig, GraniteSpeechPlusEncoderConfig @@ -86,6 +86,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @auto_docstring class GraniteSpeechPlusPreTrainedModel(PreTrainedModel): config: GraniteSpeechPlusConfig + base_model_prefix = "model" input_modalities = ("audio", "text") _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this @@ -341,23 +342,34 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): attentions: tuple[torch.FloatTensor] | None = None +@dataclass @auto_docstring( custom_intro=""" - The Granite Speech Plus model, a Granite Speech variant whose projector consumes the concatenation of the - encoder's final hidden states with an arbitrary subset of its intermediate hidden states. + Base class for GraniteSpeechPlus outputs, with hidden states and attentions. """ ) -class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechPlusPreTrainedModel, GenerationMixin): +class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Granite Speech model, which consists of an audio encoder, projector, and language model. + """ +) +class GraniteSpeechPlusModel(GraniteSpeechPlusPreTrainedModel): _supports_attention_backend = True def __init__(self, config: GraniteSpeechPlusConfig): super().__init__(config) - # NOTE: It doesn't matter when we initialize from config, but we should be careful - # to make sure this does not pick up the adapter_config if in the future we use - # from_pretrained or something similar, since that should be set by the composite - # model; don't need to consider it twice - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - + self.language_model = AutoModel.from_config(config.text_config) self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config) self.projector = GraniteSpeechPlusEncoderProjector(config) @@ -371,24 +383,12 @@ def __init__(self, config: GraniteSpeechPlusConfig): self.post_init() - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - def get_input_embeddings(self): return self.language_model.get_input_embeddings() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - @can_return_tuple @auto_docstring def get_audio_features( @@ -400,6 +400,27 @@ def get_audio_features( return audio_outputs + def get_merged_audio_embeddings( + self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None + ) -> torch.Tensor: + """Merge audio features into the language embeddings at `audio_token_id` positions.""" + is_audio_index = input_ids == self.config.audio_token_id + llm_input_ids = torch.where(is_audio_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) + + special_audio_mask = is_audio_index.unsqueeze(-1) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + if input_features_mask is not None: + torch_compilable_check( + not torch.all(is_audio_index.int().sum(dim=1) != input_features_mask.int().sum(dim=1)), + "Number of audio tokens does not match number of audio features", + ) + audio_features = audio_features[input_features_mask] + + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + return inputs_embeds + + @can_return_tuple @auto_docstring def forward( self, @@ -410,28 +431,19 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, - return_dict: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **lm_kwargs, - ) -> tuple[torch.Tensor] | GraniteSpeechPlusCausalLMOutputWithPast: + **kwargs, + ) -> tuple | GraniteSpeechPlusModelOutputWithPast: r""" input_features_mask (`torch.Tensor`, *optional*): Mask to be applied to audio features prior to scattering into the language embeddings. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ - # TODO (@alex-jw-brooks) add an example to this docstring once models are released output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -442,21 +454,16 @@ def forward( ) if inputs_embeds is None: - # Get the base embeddings; set all audio tokens to 0 index - # to avoid out of vocabulary issues with the LLM embedding. - # Audio features will be masked into is_audio_idx indices later. is_audio_idx = input_ids == self.config.audio_token_id llm_input_ids = input_ids.clone() llm_input_ids[is_audio_idx] = 0 inputs_embeds = self.get_input_embeddings()(llm_input_ids) + audio_embeds = None if input_features is not None: if input_features.dtype != self.dtype: input_features = input_features.to(self.dtype) - # Get the audio features from the encoder / projector audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # Merge the audio features into the LLM embeddings inputs_embeds = self.get_merged_audio_embeddings( input_ids=input_ids, audio_features=audio_embeds, @@ -471,11 +478,94 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, - logits_to_keep=logits_to_keep, + **kwargs, + ) + + return GraniteSpeechPlusModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Granite Speech Plus model, a Granite Speech variant whose projector consumes the concatenation of the + encoder's final hidden states with an arbitrary subset of its intermediate hidden states. + """ +) +class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechPlusPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechPlusConfig): + super().__init__(config) + self.model = GraniteSpeechPlusModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, **kwargs): + return self.model.get_audio_features(input_features, **kwargs) + + def get_merged_audio_embeddings(self, *args, **kwargs): + return self.model.get_merged_audio_embeddings(*args, **kwargs) + + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **lm_kwargs, + ) -> tuple[torch.Tensor] | GraniteSpeechPlusCausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor`, *optional*): + Mask to be applied to audio features prior to scattering into the language embeddings. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, **lm_kwargs, ) - logits = outputs[0] + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -489,14 +579,15 @@ def forward( else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) if not return_dict: - output = (logits,) + outputs[1:] + output = (logits,) + tuple( + v for v in (outputs.past_key_values, outputs.hidden_states, outputs.attentions) if v is not None + ) return (loss,) + output if loss is not None else output return GraniteSpeechPlusCausalLMOutputWithPast( @@ -519,8 +610,7 @@ def prepare_inputs_for_generation( **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model - - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -530,55 +620,12 @@ def prepare_inputs_for_generation( **kwargs, ) - # If we're in cached decoding stage, input_features should be None because - # input ids do not contain special audio token anymore Otherwise we need - # input feature values to be passed to the model if is_first_iteration or not kwargs.get("use_cache", True): model_inputs["input_features"] = input_features return model_inputs - def get_merged_audio_embeddings( - self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Adds the audio token to the model's LLM vocabulary so that we can pass it - through the tokenizer; it's assumed that the embeddings corresponding to the - <|audio|> token will be clobbered with speech features. - - Args: - input_ids (`torch.Tensor`): - Input IDs containing one or more audio tokens. - audio_features (`torch.Tensor`): - Audio features to be masked into the language embeddings to form multimodal embeddings. - input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) - Mask to be applied to audio features prior to scattering into the language embeddings. - """ - is_audio_index = input_ids == self.config.audio_token_id - llm_input_ids = torch.where(is_audio_index, 0, input_ids) - inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] - - # Mask the audio features into the text embeddings - special_audio_mask = is_audio_index.unsqueeze(-1) - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - if input_features_mask is not None: - torch_compilable_check( - not torch.all(is_audio_index.int().sum(dim=1) != input_features_mask.int().sum(dim=1)), - "Number of audio tokens does not match number of audio features", - ) - audio_features = audio_features[input_features_mask] - - inputs_embeds = inputs_embeds.masked_scatter( - special_audio_mask, - audio_features, - ) - return inputs_embeds - def generate(self, *args, **kwargs) -> torch.LongTensor: - # This model is expected to have a lora adapter, which is only - # enabled when considering audio inputs. As such, we override generate - # to conditionally enable / disable the lora adapter based on whether - # or not any input features were provided. - + # Enable/disable LoRA adapter based on whether audio inputs are provided. input_features = kwargs.pop("input_features", None) if is_peft_available and self._hf_peft_config_loaded: if input_features is not None: @@ -588,12 +635,11 @@ def generate(self, *args, **kwargs) -> torch.LongTensor: return super().generate(*args, input_features=input_features, **kwargs) def save_pretrained(self, save_directory, *args, **kwargs): - # overwrite save_pretrained to first save the adapter if we have one + # Save the adapter first, then the base model if is_peft_available and self._hf_peft_config_loaded: adapter_name = self._get_adapter_name() self.peft_config[adapter_name].base_model_name_or_path = save_directory super().save_pretrained(save_directory, *args, **kwargs) - # Then save the base model afterwards prev_val = self._hf_peft_config_loaded self._hf_peft_config_loaded = False super().save_pretrained(save_directory, *args, **kwargs) diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index 6d27eaffeed5..5896f626449a 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -213,11 +213,11 @@ def apply_rotary_time_emb(hidden_states, cos, sin): without a language modeling head. """ ) -class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None +class MusicFlamingoModel(MusicFlamingoPreTrainedModel): _supports_attention_backend = True _tp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config: MusicFlamingoConfig): super().__init__(config) diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index a5dafca76f51..d8515745208d 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -23,7 +23,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index b5bbba1c2a9d..0a412957819b 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -258,11 +258,34 @@ def _init_weights(self, module): Base class for VibeVoice ASR outputs, with hidden states and attentions. """ ) -class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _supports_attention_backend = True - _tp_plan = None - _pp_plan = None +class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VibeVoice ASR causal language model outputs. + """ +) +class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores. + past_key_values (`Cache`, *optional*): + Cache instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None @@ -279,6 +302,8 @@ class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, Generati """ ) class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): + _supports_attention_backend = True + def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) @@ -383,6 +408,12 @@ def forward( acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | VibeVoiceAsrModelOutputWithPast: + r""" + padding_mask (): + + acoustic_tokenizer_chunk_size (): + + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index b60dc41ef430..97a58ad79a21 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -23,7 +23,6 @@ from ...modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPooling, - CausalLMOutputWithPast, ModelOutput, ) from ...processing_utils import Unpack diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index 4f12393d2e02..d82c6417cb20 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -623,7 +623,9 @@ def forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - encoder_past_key_values=audio_outputs.past_key_values if (audio_outputs is not None and use_cache) else None, + encoder_past_key_values=audio_outputs.past_key_values + if (audio_outputs is not None and use_cache) + else None, padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None, audio_hidden_states=audio_embeds, ) diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py index 669e2f65aef1..6fcff39e6513 100644 --- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py +++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py @@ -183,15 +183,11 @@ def test_generate_methods_with_logits_to_keep(self): def test_training_gradient_checkpointing(self): pass - @unittest.skip( - reason="Same as test_training_gradient_checkpointing: tokenizer encoders are frozen." - ) + @unittest.skip(reason="Same as test_training_gradient_checkpointing: tokenizer encoders are frozen.") def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip( - reason="Same as test_training_gradient_checkpointing: tokenizer encoders are frozen." - ) + @unittest.skip(reason="Same as test_training_gradient_checkpointing: tokenizer encoders are frozen.") def test_training_gradient_checkpointing_use_reentrant_true(self): pass From 748140214dc2b041007c76628e9a02457b5ba9ff Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 11 May 2026 15:33:34 +0200 Subject: [PATCH 08/39] convert modular --- .../vibevoice_asr/modeling_vibevoice_asr.py | 252 ++++++++----- .../modeling_voxtral_realtime.py | 331 +++++++++--------- 2 files changed, 333 insertions(+), 250 deletions(-) diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index b66dd15b2cb1..0a412957819b 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -17,6 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn @@ -25,17 +27,12 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - torch_compilable_check, -) -from ..auto import AutoModel, AutoModelForCausalLM +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils.deprecation import forward_base_model_attrs +from ..auto import AutoModel from .configuration_vibevoice_asr import VibeVoiceAsrConfig @@ -255,27 +252,71 @@ def _init_weights(self, module): init.constant_(module.ffn_gamma, self.config.layer_scale_init_value) +@dataclass @auto_docstring( custom_intro=""" - The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + Base class for VibeVoice ASR outputs, with hidden states and attentions. """ ) -class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None +class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VibeVoice ASR causal language model outputs. + """ +) +class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores. + past_key_values (`Cache`, *optional*): + Cache instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model), + without a language modeling head. + """ +) +class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): _supports_attention_backend = True - _tp_plan = None - _pp_plan = None def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config) - - # Initialize weights and apply final processing + self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) self.post_init() + # Acoustic/semantic tokenizers are run under no_grad in `get_audio_features`; freeze + # their parameters so grad-checkpointing and training sanity checks don't flag them. + for p in self.acoustic_tokenizer_encoder.parameters(): + p.requires_grad_(False) + for p in self.semantic_tokenizer_encoder.parameters(): + p.requires_grad_(False) def get_input_embeddings(self): return self.language_model.get_input_embeddings() @@ -283,18 +324,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.") def get_audio_features( @@ -303,17 +332,15 @@ def get_audio_features( padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: + ): r""" input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`): - Input audio tensor. Audio should be sampled at 24kHz. + Input audio tensor sampled at 24kHz. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`, - but can be modified to fit the available memory. + Size of audio chunks to process at once through the tokenizers. """ - if acoustic_tokenizer_chunk_size is None: acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size else: @@ -358,7 +385,6 @@ def get_audio_features( combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents) if padding_mask is not None: - # Adjust padding mask according to tokenizer compression num_audio_tokens = torch.ceil( padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length ).to(torch.int64) @@ -369,29 +395,86 @@ def get_audio_features( return BaseModelOutputWithPooling(last_hidden_state=acoustic_latents, pooler_output=combined_features) - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + input_values: torch.FloatTensor | None = None, + padding_mask: torch.BoolTensor | None = None, + acoustic_tokenizer_chunk_size: int | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VibeVoiceAsrModelOutputWithPast: + r""" + padding_mask (): + + acoustic_tokenizer_chunk_size (): + """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_values is not None and input_ids is not None: + audio_embeds = self.get_audio_features( + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + ).pooler_output + + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + + return VibeVoiceAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, ) - return special_audio_mask + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + """ +) +@forward_base_model_attrs(version="5.7") +class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: VibeVoiceAsrConfig): + super().__init__(config) + self.model = VibeVoiceAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) @can_return_tuple @auto_docstring @@ -404,14 +487,15 @@ def forward( input_values: torch.FloatTensor | None = None, padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast: r""" padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to - `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory. + Size of audio chunks processed by the acoustic and semantic tokenizers. Example: @@ -421,33 +505,35 @@ def forward( >>> model_id = "microsoft/VibeVoice-ASR-HF" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto") - - >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> outputs = model.generate(**inputs) - - >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) - >>> print(decoded_outputs) ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + **kwargs, + ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_values is not None and input_ids is not None: - audio_embeds = self.get_audio_features( - input_values=input_values, - padding_mask=padding_mask, - acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, - ).pooler_output + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) - # Replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) - return self.language_model( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs + return VibeVoiceAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs): @@ -468,4 +554,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg return model_inputs -__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrPreTrainedModel"] +__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrModel", "VibeVoiceAsrPreTrainedModel"] diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index dbecd9a6f530..31b3850c591e 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -39,17 +39,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - logging, - torch_compilable_check, -) +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel from .configuration_voxtral_realtime import ( VoxtralRealtimeConfig, VoxtralRealtimeEncoderConfig, @@ -125,6 +118,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast): padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None +@dataclass +class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast): + r""" + Args: + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder + that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states before they are added to the text embeddings. + """ + + encoder_past_key_values: Cache | None = None + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @dataclass class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast): r""" @@ -487,6 +498,7 @@ class VoxtralRealtimePreTrainedModel(PreTrainedModel): _supports_attention_backend = True # TODO: @eustlb, this should be enabled soon _can_compile_fullgraph = False + _keep_in_fp32_modules_strict = None @torch.no_grad() def _init_weights(self, module): @@ -827,80 +839,6 @@ def forward( ) -@auto_docstring -class VoxtralRealtimeTextForCausalLM(VoxtralRealtimeTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_gather_output"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config): - super().__init__(config) - self.model = VoxtralRealtimeTextModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM - - >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - class VoxtralRealtimeTimeEmbedding(nn.Module): """Sinusoidal Embedding for encoding time""" @@ -935,17 +873,15 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The VoxtralRealtime model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector, + a Mistral-based language model and a time embedding, without a language modeling head. """ ) -class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - +class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config) + self.audio_tower = VoxtralRealtimeEncoder(config.audio_config) + self.language_model = VoxtralRealtimeTextModel(config.text_config) self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config) self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size) @@ -958,18 +894,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -985,11 +909,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): @@ -1014,30 +934,6 @@ def get_audio_features( return audio_outputs - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - @can_return_tuple @auto_docstring def forward( @@ -1051,43 +947,20 @@ def forward( padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, num_delay_tokens: int | torch.Tensor = None, **kwargs: Unpack[TransformersKwargs], - ) -> VoxtralRealtimeCausalLMOutputWithPast: + ) -> tuple | VoxtralRealtimeModelOutputWithPast: r""" encoder_past_key_values (`Cache`, *optional*): - Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. num_delay_tokens (`int` or `torch.Tensor`, *optional*): - Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. - - Example: - - ```python - >>> import torch - >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor - >>> from datasets import load_dataset - - >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" - - >>> processor = AutoProcessor.from_pretrained(repo_id) - >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> audio = ds[0]["audio"]["array"] - - >>> inputs = processor(audio, return_tensors="pt") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - - >>> outputs = model.generate(**inputs) - >>> processor.batch_decode(outputs, skip_special_tokens=True) - ```""" + Number of delay tokens used when preparing inputs. + """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1097,6 +970,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_outputs = None + audio_embeds = None if input_features is not None or encoder_inputs_embeds is not None: audio_outputs = self.get_audio_features( input_features=input_features, @@ -1106,7 +981,8 @@ def forward( use_cache=use_cache, return_dict=True, ) - inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device) + audio_embeds = audio_outputs.pooler_output + inputs_embeds = inputs_embeds + audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens @@ -1125,25 +1001,141 @@ def forward( t_cond = self.time_embedding(time_tensor) t_cond = t_cond[None, ...] # broadcastable to batch size - outputs: CausalLMOutputWithPast = self.language_model( + outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, t_cond=t_cond, **kwargs, ) + + return VoxtralRealtimeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + encoder_past_key_values=audio_outputs.past_key_values + if (audio_outputs is not None and use_cache) + else None, + padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None, + audio_hidden_states=audio_embeds, + ) + + +@forward_base_model_attrs(version="5.7") +class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralRealtimeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @property + def audio_tower(self): + return self.model.audio_tower + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + encoder_past_key_values: Cache | None = None, + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + encoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + num_delay_tokens: int | torch.Tensor = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast: + r""" + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. + num_delay_tokens (`int` or `torch.Tensor`, *optional*): + Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. + + Example: + + ```python + >>> import torch + >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset + + >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" + + >>> processor = AutoProcessor.from_pretrained(repo_id) + >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> audio = ds[0]["audio"]["array"] + + >>> inputs = processor(audio, return_tensors="pt") + >>> inputs = inputs.to(model.device, dtype=model.dtype) + + >>> outputs = model.generate(**inputs) + >>> processor.batch_decode(outputs, skip_special_tokens=True) + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + inputs_embeds=inputs_embeds, + encoder_inputs_embeds=encoder_inputs_embeds, + use_cache=use_cache, + num_delay_tokens=num_delay_tokens, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + return VoxtralRealtimeCausalLMOutputWithPast( - loss=outputs.loss, - logits=outputs.logits, + loss=loss, + logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - encoder_past_key_values=audio_outputs.past_key_values if use_cache else None, - padding_cache=audio_outputs.padding_cache if use_cache else None, + encoder_past_key_values=outputs.encoder_past_key_values, + padding_cache=outputs.padding_cache, ) def prepare_inputs_for_generation( @@ -1339,4 +1331,9 @@ def _prepare_generated_length( return generation_config -__all__ = ["VoxtralRealtimeForConditionalGeneration", "VoxtralRealtimeEncoder", "VoxtralRealtimePreTrainedModel"] +__all__ = [ + "VoxtralRealtimeForConditionalGeneration", + "VoxtralRealtimeEncoder", + "VoxtralRealtimePreTrainedModel", + "VoxtralRealtimeModel", +] From cf7c5f1f9c026cafdb13cbc35752a2cf0c64be32 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 11 May 2026 16:02:41 +0200 Subject: [PATCH 09/39] apply to voxtral --- .../models/voxtral/modeling_voxtral.py | 157 ++++++++++++----- .../models/voxtral/modular_voxtral.py | 158 +++++++++++++----- 2 files changed, 239 insertions(+), 76 deletions(-) diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 54466321b79e..15260a14936a 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -21,6 +21,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass import torch from torch import nn @@ -33,9 +34,10 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig @@ -359,22 +361,33 @@ def forward(self, audio_features): return hidden_states +@dataclass @auto_docstring( custom_intro=""" - The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + Base class for Voxtral outputs, with hidden states and attentions. """ ) -class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = ["embed_positions"] +class VoxtralModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model, + without a language modeling head. + """ +) +class VoxtralModel(VoxtralPreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = VoxtralMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): @@ -383,18 +396,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -442,6 +443,81 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs: BaseModelOutputWithPast = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return VoxtralModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +@forward_base_model_attrs(version="5.7") +class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -456,7 +532,7 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | CausalLMOutputWithPast: r""" Example: @@ -490,29 +566,34 @@ def forward( >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: BaseModelOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) def prepare_inputs_for_generation(self, *args, **kwargs): # Overwritten -- we should not pass input_features when we are in cached decoding stage @@ -529,4 +610,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"] +__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"] diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index 02e8e2806a0f..31c7193d71f8 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -13,6 +13,8 @@ # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn @@ -26,9 +28,10 @@ ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from ..qwen2_audio.modeling_qwen2_audio import ( Qwen2AudioAttention, Qwen2AudioEncoder, @@ -128,22 +131,33 @@ def forward(self, audio_features): return hidden_states +@dataclass @auto_docstring( custom_intro=""" - The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + Base class for Voxtral outputs, with hidden states and attentions. """ ) -class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = ["embed_positions"] +class VoxtralModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model, + without a language modeling head. + """ +) +class VoxtralModel(VoxtralPreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = VoxtralMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): @@ -152,18 +166,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -211,6 +213,81 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs: BaseModelOutputWithPast = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return VoxtralModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +@forward_base_model_attrs(version="5.7") +class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -225,7 +302,7 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | CausalLMOutputWithPast: r""" Example: @@ -259,29 +336,34 @@ def forward( >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: BaseModelOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) def prepare_inputs_for_generation(self, *args, **kwargs): # Overwritten -- we should not pass input_features when we are in cached decoding stage @@ -298,4 +380,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"] +__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"] From 83799ce63c4b73be32d7cbcaf8fe1c6447336a24 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 11 May 2026 16:02:45 +0200 Subject: [PATCH 10/39] convert modular --- .../audioflamingo3/modeling_audioflamingo3.py | 256 +++++++++++------- .../models/glmasr/modeling_glmasr.py | 200 +++++++++++--- .../vibevoice_asr/modeling_vibevoice_asr.py | 6 - 3 files changed, 317 insertions(+), 145 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index 6f18fcc437ad..f4b8f79bca3c 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -21,6 +21,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass import torch from torch import nn @@ -31,13 +32,14 @@ from ...masking_utils import create_bidirectional_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig @@ -256,6 +258,42 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel): _supports_sdpa = True +@dataclass +class AudioFlamingo3ModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for AudioFlamingo3 causal language model (or autoregressive) outputs. + """ +) +class AudioFlamingo3CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The audio model from AudioFlamingo3 without any head or projection on top. @@ -403,23 +441,21 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None +class AudioFlamingo3Model(AudioFlamingo3PreTrainedModel): _supports_attention_backend = True _tp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): @@ -428,18 +464,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -452,11 +476,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -509,77 +529,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | AudioFlamingo3ModelOutputWithPast: r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/audio-flamingo-3-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversations = [ - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> {"type": "text", "text": "Transcribe the input speech."}, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav", - >>> }, - >>> ], - >>> } - >>> ], - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?", - >>> }, - >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"}, - >>> ], - >>> } - >>> ], - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversations, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device) - - >>> outputs = model.generate(**inputs, max_new_tokens=500) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."] - ```""" - + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output @@ -589,17 +549,118 @@ def forward( ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return AudioFlamingo3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + """ +) +@forward_base_model_attrs(version="5.7") +class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = None + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _tp_plan = None + _pp_plan = None + + def __init__(self, config): + super().__init__(config) + self.model = AudioFlamingo3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | AudioFlamingo3CausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return AudioFlamingo3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -616,4 +677,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"] +__all__ = [ + "AudioFlamingo3ForConditionalGeneration", + "AudioFlamingo3PreTrainedModel", + "AudioFlamingo3Encoder", + "AudioFlamingo3Model", +] diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index f2c68e56df71..fa4a50e97555 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -19,6 +19,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from typing import Optional from ...activations import ACT2FN @@ -26,14 +27,20 @@ from ...generation import GenerationMixin from ...integrations import use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, + ModelOutput, +) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_glmasr import GlmAsrConfig, GlmAsrEncoderConfig @@ -349,25 +356,32 @@ def forward(self, audio_features): return hidden_states +@dataclass +class GlmAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None +class GlmAsrModel(GlmAsrPreTrainedModel): _supports_attention_backend = True _tp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = GlmAsrMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): @@ -376,18 +390,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector." @@ -400,11 +402,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -450,6 +448,114 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GlmAsrModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + return GlmAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for GlmAsr causal language model (or autoregressive) outputs. + """ +) +class GlmAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +@forward_base_model_attrs(version="5.7") +class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = None + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _tp_plan = None + _pp_plan = None + + def __init__(self, config): + super().__init__(config) + self.model = GlmAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -494,30 +600,36 @@ def forward( >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) >>> print(decoded_outputs) ```""" - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: CausalLMOutputWithPast = self.language_model( - inputs_embeds=inputs_embeds, + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, + inputs_embeds=inputs_embeds, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return GlmAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -534,4 +646,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrPreTrainedModel"] +__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrModel", "GlmAsrPreTrainedModel"] diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index 0a412957819b..a80989570671 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -408,12 +408,6 @@ def forward( acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | VibeVoiceAsrModelOutputWithPast: - r""" - padding_mask (): - - acoustic_tokenizer_chunk_size (): - - """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) From 464fae5c9e66eafa0f7aec8f6b4bff16330107b8 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 11 May 2026 16:28:35 +0200 Subject: [PATCH 11/39] remove test_model_base_model_prefix overwrite --- tests/alm_tester.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/alm_tester.py b/tests/alm_tester.py index c34d4d45524c..c9026aa8960e 100644 --- a/tests/alm_tester.py +++ b/tests/alm_tester.py @@ -153,11 +153,6 @@ class ALMModelTest(MultiModalModelTest): - `pipeline_model_mapping`: Override if not using default from model_tester """ - # TODO: @eustlb, remove this once #45534 is merged - @unittest.skip("Audio-LMs have no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - def test_mismatching_num_audio_tokens(self): """ Tests that ALMs throw an error with explicit message saying what is wrong From 744567c54fa701d1c1dd77fd79c930a280136ad6 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 11 May 2026 17:46:47 +0200 Subject: [PATCH 12/39] make --- .../configuration_audioflamingo3.py | 1 + .../models/glmasr/configuration_glmasr.py | 1 + .../configuration_musicflamingo.py | 1 + .../musicflamingo/modeling_musicflamingo.py | 195 ++++++++++++++---- .../qwen2_audio/configuration_qwen2_audio.py | 1 + .../configuration_vibevoice_asr.py | 1 + .../vibevoice_asr/modeling_vibevoice_asr.py | 6 + .../vibevoice_asr/modular_vibevoice_asr.py | 1 + .../models/voxtral/configuration_voxtral.py | 1 + .../configuration_voxtral_realtime.py | 1 + tests/alm_tester.py | 1 - 11 files changed, 164 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py index 096e263d856d..cd81ef805205 100644 --- a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py @@ -100,6 +100,7 @@ class AudioFlamingo3Config(PreTrainedConfig): audio_token_id: int = 151669 projector_hidden_act: str = "gelu" projector_bias: bool = True + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/glmasr/configuration_glmasr.py b/src/transformers/models/glmasr/configuration_glmasr.py index c3d320bb1db4..c89379ead3e7 100644 --- a/src/transformers/models/glmasr/configuration_glmasr.py +++ b/src/transformers/models/glmasr/configuration_glmasr.py @@ -101,6 +101,7 @@ class GlmAsrConfig(PreTrainedConfig): text_config: dict | PreTrainedConfig | None = None audio_token_id: int = 59260 projector_hidden_act: str = "gelu" + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/musicflamingo/configuration_musicflamingo.py b/src/transformers/models/musicflamingo/configuration_musicflamingo.py index 7eff8861558a..a733f73004c5 100644 --- a/src/transformers/models/musicflamingo/configuration_musicflamingo.py +++ b/src/transformers/models/musicflamingo/configuration_musicflamingo.py @@ -66,6 +66,7 @@ class MusicFlamingoConfig(PreTrainedConfig): audio_token_id: int = 151669 projector_hidden_act: str = "gelu" projector_bias: bool = True + tie_word_embeddings: bool = False audio_bos_token_id: int = 151670 audio_eos_token_id: int = 151671 diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index a9e05470662d..5ceed90170ed 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -20,6 +20,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from math import pi from typing import Optional @@ -29,12 +30,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_musicflamingo import MusicFlamingoConfig @@ -150,6 +151,16 @@ def _init_weights(self, module): init.copy_(module.position_angles, buffer_value) +@dataclass +class MusicFlamingoModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + class MusicFlamingoMultiModalProjector(nn.Module): """ Audio adaptor (small MLP) that projects MusicFlamingoEncoder features @@ -173,6 +184,134 @@ def forward(self, audio_features): return hidden_states +@auto_docstring( + custom_intro=""" + The MusicFlamingo model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. + """ +) +class MusicFlamingoModel(MusicFlamingoPreTrainedModel): + _supports_attention_backend = True + _tp_plan = None + _pp_plan = None + _keep_in_fp32_modules_strict = None + + def __init__(self, config): + super().__init__(config) + self.audio_tower = AutoModel.from_config(config.audio_config) + self.language_model = AutoModel.from_config(config.text_config) + self.multi_modal_projector = MusicFlamingoMultiModalProjector(config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring( + custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." + ) + def get_audio_features( + self, + input_features: torch.FloatTensor, + input_features_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + input_features (`torch.FloatTensor`): + Float values of mel features extracted from the raw speech waveform. + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padded feature indices. + """ + + audio_output = self.audio_tower( + input_features, input_features_mask=input_features_mask, return_dict=True, **kwargs + ) + audio_embeds = self.multi_modal_projector(audio_output.last_hidden_state) + + # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling + input_lengths = input_features_mask.sum(-1).to(torch.long) + _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_lengths) + valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None] + audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)] + + return audio_output + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_audio_mask = special_audio_mask.all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + + n_audio_tokens = special_audio_mask.sum() + n_audio_features = audio_features.shape[0] + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + return special_audio_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MusicFlamingoModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + return MusicFlamingoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + def rotate_half(x): x = x.reshape(*x.shape[:-1], -1, 2) x1, x2 = x.unbind(dim=-1) @@ -200,38 +339,28 @@ def apply_rotary_time_emb(hidden_states, cos, sin): ) class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None - _supports_attention_backend = True + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} _tp_plan = None _pp_plan = None def __init__(self, config: MusicFlamingoConfig): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.multi_modal_projector = MusicFlamingoMultiModalProjector(config) + self.model = MusicFlamingoModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.pos_emb = MusicFlamingoRotaryEmbedding(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.language_model.get_input_embeddings() + return self.model.get_input_embeddings() def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) + self.model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() + def get_output_embeddings(self) -> nn.Module: + return self.lm_head def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() + self.lm_head = new_embeddings @can_return_tuple @auto_docstring( @@ -269,30 +398,6 @@ def get_audio_features( return audio_output - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - @can_return_tuple @auto_docstring def forward( @@ -442,4 +547,4 @@ def _build_audio_timestamps( return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets -__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoPreTrainedModel"] +__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoModel", "MusicFlamingoPreTrainedModel"] diff --git a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py index 6aec9eace900..749f24123bb4 100644 --- a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py @@ -98,6 +98,7 @@ class Qwen2AudioConfig(PreTrainedConfig): audio_config: dict | PreTrainedConfig | None = None text_config: dict | PreTrainedConfig | None = None audio_token_index: int = 151646 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py index a673a5845871..4d56a948eda1 100644 --- a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py @@ -75,6 +75,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig): audio_bos_token_id: int = 151646 audio_eos_token_id: int = 151647 acoustic_tokenizer_chunk_size: int = 1440000 + tie_word_embeddings: bool = False def __post_init__(self, **kwargs): if isinstance(self.acoustic_tokenizer_encoder_config, dict): diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index a80989570671..0a412957819b 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -408,6 +408,12 @@ def forward( acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | VibeVoiceAsrModelOutputWithPast: + r""" + padding_mask (): + + acoustic_tokenizer_chunk_size (): + + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index 97a58ad79a21..a0dbcb158268 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -89,6 +89,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig): audio_bos_token_id: int = 151646 audio_eos_token_id: int = 151647 acoustic_tokenizer_chunk_size: int = 1440000 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.acoustic_tokenizer_encoder_config, dict): diff --git a/src/transformers/models/voxtral/configuration_voxtral.py b/src/transformers/models/voxtral/configuration_voxtral.py index 2ecbedfc1a9e..b476d80dd976 100644 --- a/src/transformers/models/voxtral/configuration_voxtral.py +++ b/src/transformers/models/voxtral/configuration_voxtral.py @@ -110,6 +110,7 @@ class VoxtralConfig(PreTrainedConfig): text_config: dict | PreTrainedConfig | None = None audio_token_id: int | None = None projector_hidden_act: str = "gelu" + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py index b0227b418771..a1593dfbcdca 100644 --- a/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py @@ -170,6 +170,7 @@ class VoxtralRealtimeConfig(PreTrainedConfig): audio_length_per_tok: int = 8 default_num_delay_tokens: int = 6 downsample_factor: int = 4 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/tests/alm_tester.py b/tests/alm_tester.py index c9026aa8960e..96d527aef5aa 100644 --- a/tests/alm_tester.py +++ b/tests/alm_tester.py @@ -13,7 +13,6 @@ # limitations under the License. import copy -import unittest from inspect import signature from .multimodal_tester import MultiModalModelTest, MultiModalModelTester From 687b693a4f4428f1a536d1fc678e6ddf98666d21 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 15:03:26 +0200 Subject: [PATCH 13/39] make --- docs/source/en/model_doc/hyperclovax.md | 2 +- .../configuration_granite_speech.py | 1 + .../granite_speech/modeling_granite_speech.py | 297 ++++++++++------- .../configuration_granite_speech_plus.py | 1 + .../modeling_granite_speech_plus.py | 298 +++++++++++------- .../modular_granite_speech_plus.py | 2 + .../musicflamingo/modeling_musicflamingo.py | 296 ++++++++--------- .../musicflamingo/modular_musicflamingo.py | 107 +++---- .../configuration_vibevoice_asr.py | 2 +- .../vibevoice_asr/modeling_vibevoice_asr.py | 8 +- .../vibevoice_asr/modular_vibevoice_asr.py | 6 + 11 files changed, 553 insertions(+), 467 deletions(-) diff --git a/docs/source/en/model_doc/hyperclovax.md b/docs/source/en/model_doc/hyperclovax.md index 2725fbbc090f..0a26d0bde2b3 100644 --- a/docs/source/en/model_doc/hyperclovax.md +++ b/docs/source/en/model_doc/hyperclovax.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2025-07-21 and added to Hugging Face Transformers on 2026-05-06.* +*This model was released on 2025-07-21 and added to Hugging Face Transformers on 2026-05-08.*
diff --git a/src/transformers/models/granite_speech/configuration_granite_speech.py b/src/transformers/models/granite_speech/configuration_granite_speech.py index e5532b3bf880..bf5eb93235d1 100644 --- a/src/transformers/models/granite_speech/configuration_granite_speech.py +++ b/src/transformers/models/granite_speech/configuration_granite_speech.py @@ -126,6 +126,7 @@ class GraniteSpeechConfig(PreTrainedConfig): has_lora_adapter: bool = True downsample_rate: int = 5 window_size: int = 15 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.text_config, dict): diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index d7ba0c94c950..07aabbde3970 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -22,7 +22,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -33,18 +33,34 @@ logging, torch_compilable_check, ) +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_granite_speech import GraniteSpeechConfig, GraniteSpeechEncoderConfig logger = logging.get_logger(__name__) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Granite Speech outputs, with hidden states and attentions. + """ +) +class GraniteSpeechModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" - Base class for LlavaNext causal language model (or autoregressive) outputs. + Base class for Granite Speech causal language model (or autoregressive) outputs. """ ) @dataclass @@ -59,6 +75,8 @@ class GraniteSpeechCausalLMOutputWithPast(ModelOutput): Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. """ loss: torch.FloatTensor | None = None @@ -66,6 +84,7 @@ class GraniteSpeechCausalLMOutputWithPast(ModelOutput): past_key_values: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None ### Projector @@ -262,6 +281,7 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> class GraniteSpeechPreTrainedModel(PreTrainedModel): config: GraniteSpeechConfig input_modalities = ("audio", "text") + base_model_prefix = "model" _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this _supports_sdpa = True @@ -323,50 +343,25 @@ def forward( @auto_docstring( custom_intro=""" - The Granite Speech model, which consists of an audio encoder, projector, and language model. + The Granite Speech model, which consists of an audio encoder, projector, and language model, + without a language modeling head. """ ) -class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): +class GraniteSpeechModel(GraniteSpeechPreTrainedModel): _supports_attention_backend = True def __init__(self, config: GraniteSpeechConfig): super().__init__(config) - # NOTE: It doesn't matter when we initialize from config, but we should be careful - # to make sure this does not pick up the adapter_config if in the future we use - # from_pretrained or something similar, since that should be set by the composite - # model; don't need to consider it twice - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.encoder = GraniteSpeechCTCEncoder(config.encoder_config) self.projector = GraniteSpeechEncoderProjector(config) - - if config.has_lora_adapter and not is_peft_available(): - logger.warning( - "Config indicates that a lora adapter should be present, but " - "peft is not installed; this will cause the model to perform " - "incorrectly when audio inputs are provided. Please install " - "peft and reload the model!" - ) - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - def get_input_embeddings(self): return self.language_model.get_input_embeddings() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) @can_return_tuple @auto_docstring @@ -379,6 +374,61 @@ def get_audio_features( return audio_outputs + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_audio_mask = special_audio_mask.all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + + n_audio_tokens = special_audio_mask.sum() + n_audio_features = audio_features.shape[0] + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + return special_audio_mask + + def get_merged_audio_embeddings( + self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Adds the audio token to the model's LLM vocabulary so that we can pass it + through the tokenizer; it's assumed that the embeddings corresponding to the + <|audio|> token will be clobbered with speech features. + + Args: + input_ids (`torch.Tensor`): + Input IDs containing one or more audio tokens. + audio_features (`torch.Tensor`): + Audio features to be masked into the language embeddings to form multimodal embeddings. + input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) + Mask to be applied to audio features prior to scattering into the language embeddings. + """ + is_audio_index = input_ids == self.config.audio_token_id + llm_input_ids = torch.where(is_audio_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] + + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + if input_features_mask is not None: + audio_features = audio_features[input_features_mask] + + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + return inputs_embeds + + @can_return_tuple @auto_docstring def forward( self, @@ -389,29 +439,13 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **lm_kwargs, - ) -> tuple[torch.Tensor] | GraniteSpeechCausalLMOutputWithPast: + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GraniteSpeechModelOutputWithPast: r""" input_features_mask (`torch.Tensor`, *optional*): Mask to be applied to audio features prior to scattering into the language embeddings. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ - # TODO (@alex-jw-brooks) add an example to this docstring once models are released - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -429,6 +463,7 @@ def forward( llm_input_ids[is_audio_idx] = 0 inputs_embeds = self.get_input_embeddings()(llm_input_ids) + audio_embeds = None if input_features is not None: if input_features.dtype != self.dtype: input_features = input_features.to(self.dtype) @@ -448,13 +483,97 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - logits_to_keep=logits_to_keep, - **lm_kwargs, + **kwargs, ) - logits = outputs[0] + + return GraniteSpeechModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Granite Speech model, which consists of an audio encoder, projector, and language model. + """ +) +@forward_base_model_attrs(version="5.7") +class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): + _supports_attention_backend = True + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechConfig): + super().__init__(config) + self.model = GraniteSpeechModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + if config.has_lora_adapter and not is_peft_available(): + logger.warning( + "Config indicates that a lora adapter should be present, but " + "peft is not installed; this will cause the model to perform " + "incorrectly when audio inputs are provided. Please install " + "peft and reload the model!" + ) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs, + ) -> tuple | GraniteSpeechCausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor`, *optional*): + Mask to be applied to audio features prior to scattering into the language embeddings. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + # TODO (@alex-jw-brooks) add an example to this docstring once models are released + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -474,16 +593,13 @@ def forward( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return GraniteSpeechCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation( @@ -499,7 +615,7 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -516,60 +632,6 @@ def prepare_inputs_for_generation( model_inputs["input_features"] = input_features return model_inputs - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - - def get_merged_audio_embeddings( - self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Adds the audio token to the model's LLM vocabulary so that we can pass it - through the tokenizer; it's assumed that the embeddings corresponding to the - <|audio|> token will be clobbered with speech features. - - Args: - input_ids (`torch.Tensor`): - Input IDs containing one or more audio tokens. - audio_features (`torch.Tensor`): - Audio features to be masked into the language embeddings to form multimodal embeddings. - input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) - Mask to be applied to audio features prior to scattering into the language embeddings. - """ - is_audio_index = input_ids == self.config.audio_token_id - llm_input_ids = torch.where(is_audio_index, 0, input_ids) - inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] - - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - if input_features_mask is not None: - audio_features = audio_features[input_features_mask] - - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) - return inputs_embeds - def generate(self, *args, **kwargs) -> torch.LongTensor: # This model is expected to have a lora adapter, which is only # enabled when considering audio inputs. As such, we override generate @@ -603,5 +665,6 @@ def _get_adapter_name(self): __all__ = [ "GraniteSpeechCTCEncoder", "GraniteSpeechForConditionalGeneration", + "GraniteSpeechModel", "GraniteSpeechPreTrainedModel", ] diff --git a/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py index 1eec538091a4..f7d47a976e79 100644 --- a/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py @@ -137,6 +137,7 @@ class GraniteSpeechPlusConfig(PreTrainedConfig): has_lora_adapter: bool = True downsample_rate: int = 5 window_size: int = 15 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.text_config, dict): diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py index d16fb52290ca..7e0a46ff85cf 100644 --- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py @@ -28,7 +28,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -39,9 +39,10 @@ logging, torch_compilable_check, ) +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_granite_speech_plus import GraniteSpeechPlusConfig, GraniteSpeechPlusEncoderConfig @@ -87,6 +88,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GraniteSpeechPlusPreTrainedModel(PreTrainedModel): config: GraniteSpeechPlusConfig input_modalities = ("audio", "text") + base_model_prefix = "model" _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this _supports_sdpa = True @@ -315,9 +317,24 @@ def forward( return BaseModelOutputWithPooling(last_hidden_state=hidden_states) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Granite Speech outputs, with hidden states and attentions. + """ +) +class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" - Base class for LlavaNext causal language model (or autoregressive) outputs. + Base class for Granite Speech causal language model (or autoregressive) outputs. """ ) @dataclass @@ -332,6 +349,8 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. """ loss: torch.FloatTensor | None = None @@ -339,55 +358,30 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): past_key_values: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None @auto_docstring( custom_intro=""" - The Granite Speech Plus model, a Granite Speech variant whose projector consumes the concatenation of the - encoder's final hidden states with an arbitrary subset of its intermediate hidden states. + The Granite Speech model, which consists of an audio encoder, projector, and language model, + without a language modeling head. """ ) -class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechPlusPreTrainedModel, GenerationMixin): +class GraniteSpeechPlusModel(GraniteSpeechPlusPreTrainedModel): _supports_attention_backend = True def __init__(self, config: GraniteSpeechPlusConfig): super().__init__(config) - # NOTE: It doesn't matter when we initialize from config, but we should be careful - # to make sure this does not pick up the adapter_config if in the future we use - # from_pretrained or something similar, since that should be set by the composite - # model; don't need to consider it twice - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config) self.projector = GraniteSpeechPlusEncoderProjector(config) - - if config.has_lora_adapter and not is_peft_available(): - logger.warning( - "Config indicates that a lora adapter should be present, but " - "peft is not installed; this will cause the model to perform " - "incorrectly when audio inputs are provided. Please install " - "peft and reload the model!" - ) - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - def get_input_embeddings(self): return self.language_model.get_input_embeddings() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) @can_return_tuple @auto_docstring @@ -400,6 +394,61 @@ def get_audio_features( return audio_outputs + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_audio_mask = special_audio_mask.all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + + n_audio_tokens = special_audio_mask.sum() + n_audio_features = audio_features.shape[0] + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + return special_audio_mask + + def get_merged_audio_embeddings( + self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Adds the audio token to the model's LLM vocabulary so that we can pass it + through the tokenizer; it's assumed that the embeddings corresponding to the + <|audio|> token will be clobbered with speech features. + + Args: + input_ids (`torch.Tensor`): + Input IDs containing one or more audio tokens. + audio_features (`torch.Tensor`): + Audio features to be masked into the language embeddings to form multimodal embeddings. + input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) + Mask to be applied to audio features prior to scattering into the language embeddings. + """ + is_audio_index = input_ids == self.config.audio_token_id + llm_input_ids = torch.where(is_audio_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] + + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + if input_features_mask is not None: + audio_features = audio_features[input_features_mask] + + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + return inputs_embeds + + @can_return_tuple @auto_docstring def forward( self, @@ -410,29 +459,13 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **lm_kwargs, - ) -> tuple[torch.Tensor] | GraniteSpeechPlusCausalLMOutputWithPast: + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GraniteSpeechPlusModelOutputWithPast: r""" input_features_mask (`torch.Tensor`, *optional*): Mask to be applied to audio features prior to scattering into the language embeddings. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ - # TODO (@alex-jw-brooks) add an example to this docstring once models are released - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -450,6 +483,7 @@ def forward( llm_input_ids[is_audio_idx] = 0 inputs_embeds = self.get_input_embeddings()(llm_input_ids) + audio_embeds = None if input_features is not None: if input_features.dtype != self.dtype: input_features = input_features.to(self.dtype) @@ -469,13 +503,98 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - logits_to_keep=logits_to_keep, - **lm_kwargs, + **kwargs, + ) + + return GraniteSpeechPlusModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Granite Speech Plus model, a Granite Speech variant whose projector consumes the concatenation of the + encoder's final hidden states with an arbitrary subset of its intermediate hidden states. + """ +) +@forward_base_model_attrs(version="5.7") +class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechPlusPreTrainedModel, GenerationMixin): + _supports_attention_backend = True + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechPlusConfig): + super().__init__(config) + self.model = GraniteSpeechPlusModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + if config.has_lora_adapter and not is_peft_available(): + logger.warning( + "Config indicates that a lora adapter should be present, but " + "peft is not installed; this will cause the model to perform " + "incorrectly when audio inputs are provided. Please install " + "peft and reload the model!" + ) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs, + ) -> tuple | GraniteSpeechPlusCausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor`, *optional*): + Mask to be applied to audio features prior to scattering into the language embeddings. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + # TODO (@alex-jw-brooks) add an example to this docstring once models are released + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, ) - logits = outputs[0] + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -495,16 +614,13 @@ def forward( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return GraniteSpeechPlusCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation( @@ -520,7 +636,7 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -537,60 +653,6 @@ def prepare_inputs_for_generation( model_inputs["input_features"] = input_features return model_inputs - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - - def get_merged_audio_embeddings( - self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Adds the audio token to the model's LLM vocabulary so that we can pass it - through the tokenizer; it's assumed that the embeddings corresponding to the - <|audio|> token will be clobbered with speech features. - - Args: - input_ids (`torch.Tensor`): - Input IDs containing one or more audio tokens. - audio_features (`torch.Tensor`): - Audio features to be masked into the language embeddings to form multimodal embeddings. - input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) - Mask to be applied to audio features prior to scattering into the language embeddings. - """ - is_audio_index = input_ids == self.config.audio_token_id - llm_input_ids = torch.where(is_audio_index, 0, input_ids) - inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] - - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - if input_features_mask is not None: - audio_features = audio_features[input_features_mask] - - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) - return inputs_embeds - def generate(self, *args, **kwargs) -> torch.LongTensor: # This model is expected to have a lora adapter, which is only # enabled when considering audio inputs. As such, we override generate diff --git a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py index aa4f966246d2..5d1d03024ba7 100644 --- a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py @@ -21,6 +21,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..granite_speech.configuration_granite_speech import GraniteSpeechConfig, GraniteSpeechEncoderConfig @@ -160,6 +161,7 @@ def forward( encoder's final hidden states with an arbitrary subset of its intermediate hidden states. """ ) +@forward_base_model_attrs(version="5.7") class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechForConditionalGeneration): ... diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index 5ceed90170ed..9425836ee183 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -30,11 +30,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ..auto import AutoModel from .configuration_musicflamingo import MusicFlamingoConfig @@ -184,6 +185,26 @@ def forward(self, audio_features): return hidden_states +def rotate_half(x): + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def apply_rotary_time_emb(hidden_states, cos, sin): + original_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float64) + cos = cos.to(hidden_states) + sin = sin.to(hidden_states) + rot_dim = cos.shape[-1] + + passthrough = hidden_states[..., rot_dim:] + rotated = hidden_states[..., :rot_dim] + rotated = (rotated * cos) + (rotate_half(rotated) * sin) + return torch.cat((rotated, passthrough), dim=-1).to(original_dtype) + + @auto_docstring( custom_intro=""" The MusicFlamingo model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), @@ -196,11 +217,12 @@ class MusicFlamingoModel(MusicFlamingoPreTrainedModel): _pp_plan = None _keep_in_fp32_modules_strict = None - def __init__(self, config): + def __init__(self, config: MusicFlamingoConfig): super().__init__(config) self.audio_tower = AutoModel.from_config(config.audio_config) self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = MusicFlamingoMultiModalProjector(config) + self.pos_emb = MusicFlamingoRotaryEmbedding(config) self.post_init() def get_input_embeddings(self): @@ -217,23 +239,29 @@ def get_audio_features( self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor, + input_ids: torch.LongTensor, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" - input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Token ids containing the audio token ID placeholders, for reconstructing rotary time embedding timestamps. """ - audio_output = self.audio_tower( - input_features, input_features_mask=input_features_mask, return_dict=True, **kwargs + input_features, + input_features_mask=input_features_mask, + return_dict=True, + **kwargs, ) - audio_embeds = self.multi_modal_projector(audio_output.last_hidden_state) + hidden_states = audio_output.last_hidden_state + _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_features_mask.sum(-1).to(torch.long)) + audio_timestamps = self._build_audio_timestamps(input_ids, post_lengths, hidden_states.shape[-2]) + cos, sin = self.pos_emb(audio_timestamps.to(hidden_states.device), seq_len=hidden_states.shape[-2]) + hidden_states = apply_rotary_time_emb(hidden_states, cos, sin) + audio_embeds = self.multi_modal_projector(hidden_states) # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling - input_lengths = input_features_mask.sum(-1).to(torch.long) - _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_lengths) valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None] audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)] @@ -286,7 +314,9 @@ def forward( audio_embeds = None if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + audio_embeds = self.get_audio_features( + input_features, input_features_mask, input_ids=input_ids, return_dict=True + ).pooler_output # replace text-audio token placeholders with audio embeddings special_audio_mask = self.get_placeholder_mask( @@ -311,25 +341,72 @@ def forward( audio_hidden_states=audio_embeds, ) + def _build_audio_timestamps( + self, + input_ids: torch.LongTensor, + post_lengths: torch.LongTensor, + max_post_length: int, + ) -> torch.FloatTensor: + audio_token_mask = input_ids == self.config.audio_token_id + diff = torch.diff(torch.nn.functional.pad(audio_token_mask.int(), (1, 1), value=0), dim=1) + _, starts = torch.where(diff == 1) + _, ends = torch.where(diff == -1) + sample_lengths = (ends - starts).to(torch.long) -def rotate_half(x): - x = x.reshape(*x.shape[:-1], -1, 2) - x1, x2 = x.unbind(dim=-1) - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) + n_audio_tokens = audio_token_mask.sum() + n_audio_features = post_lengths.sum() + torch_compilable_check( + n_audio_tokens == n_audio_features, + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + # Account for 4x downsampling in audio encoder (conv2 and avg pooling) + audio_embed_frame_step = self.config.audio_frame_step * 4 + frame_offsets = ( + torch.arange(max_post_length, device=post_lengths.device, dtype=torch.float32) * audio_embed_frame_step + ) -def apply_rotary_time_emb(hidden_states, cos, sin): - original_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float64) - cos = cos.to(hidden_states) - sin = sin.to(hidden_states) - rot_dim = cos.shape[-1] + # Map each encoder output row to its audio sample using token counts + cumsum_post = torch.cat([torch.zeros(1, device=post_lengths.device), torch.cumsum(post_lengths, dim=0)[:-1]]) + cumsum_samples = torch.cumsum(sample_lengths, dim=0) + sample_indices = torch.searchsorted(cumsum_samples, cumsum_post, right=True) + + # Compute window index within each sample (0, 1, 2, ... then reset for next sample) + sample_start_rows = torch.searchsorted( + sample_indices, torch.arange(sample_lengths.shape[0], device=post_lengths.device) + ) + window_indices = ( + torch.arange(post_lengths.shape[0], device=post_lengths.device) - sample_start_rows[sample_indices] + ) + + # Compute timestamps + return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets - passthrough = hidden_states[..., rot_dim:] - rotated = hidden_states[..., :rot_dim] - rotated = (rotated * cos) + (rotate_half(rotated) * sin) - return torch.cat((rotated, passthrough), dim=-1).to(original_dtype) + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for MusicFlamingo causal language model (or autoregressive) outputs. + """ +) +class MusicFlamingoCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None @auto_docstring( @@ -337,6 +414,7 @@ def apply_rotary_time_emb(hidden_states, cos, sin): The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. """ ) +@forward_base_model_attrs(version="5.7") class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -347,7 +425,6 @@ def __init__(self, config: MusicFlamingoConfig): super().__init__(config) self.model = MusicFlamingoModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) - self.pos_emb = MusicFlamingoRotaryEmbedding(config) self.post_init() def get_input_embeddings(self): @@ -362,41 +439,8 @@ def get_output_embeddings(self) -> nn.Module: def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - @can_return_tuple - @auto_docstring( - custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." - ) - def get_audio_features( - self, - input_features: torch.FloatTensor, - input_features_mask: torch.Tensor, - input_ids: torch.LongTensor, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: - r""" - input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): - Mask to avoid performing attention on padded feature indices. - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Token ids containing the audio token ID placeholders, for reconstructing rotary time embedding timestamps. - """ - audio_output = self.audio_tower( - input_features, - input_features_mask=input_features_mask, - return_dict=True, - **kwargs, - ) - hidden_states = audio_output.last_hidden_state - _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_features_mask.sum(-1).to(torch.long)) - audio_timestamps = self._build_audio_timestamps(input_ids, post_lengths, hidden_states.shape[-2]) - cos, sin = self.pos_emb(audio_timestamps.to(hidden_states.device), seq_len=hidden_states.shape[-2]) - hidden_states = apply_rotary_time_emb(hidden_states, cos, sin) - audio_embeds = self.multi_modal_projector(hidden_states) - - # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling - valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None] - audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)] - - return audio_output + def get_audio_features(self, input_features, input_features_mask, input_ids, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, input_ids, **kwargs) @can_return_tuple @auto_docstring @@ -413,83 +457,52 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | MusicFlamingoCausalLMOutputWithPast: r""" - input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Labels for computing the masked language modeling loss. Example: ```python >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor - >>> model_id = "nvidia/music-flamingo-2601-hf" + >>> model_id = "nvidia/audio-flamingo-3-hf" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversation = [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", - >>> }, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", - >>> }, - >>> ], - >>> } - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversation, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device, model.dtype) - - >>> outputs = model.generate(**inputs, max_new_tokens=100) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["This track is an uplifting Eurodance-style Trance-Pop anthem..."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features( - input_features, input_features_mask, input_ids=input_ids, return_dict=True - ).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: CausalLMOutputWithPast = self.language_model( - inputs_embeds=inputs_embeds, + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, + inputs_embeds=inputs_embeds, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return MusicFlamingoCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -505,46 +518,5 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs - def _build_audio_timestamps( - self, - input_ids: torch.LongTensor, - post_lengths: torch.LongTensor, - max_post_length: int, - ) -> torch.FloatTensor: - audio_token_mask = input_ids == self.config.audio_token_id - diff = torch.diff(torch.nn.functional.pad(audio_token_mask.int(), (1, 1), value=0), dim=1) - _, starts = torch.where(diff == 1) - _, ends = torch.where(diff == -1) - sample_lengths = (ends - starts).to(torch.long) - - n_audio_tokens = audio_token_mask.sum() - n_audio_features = post_lengths.sum() - torch_compilable_check( - n_audio_tokens == n_audio_features, - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - - # Account for 4x downsampling in audio encoder (conv2 and avg pooling) - audio_embed_frame_step = self.config.audio_frame_step * 4 - frame_offsets = ( - torch.arange(max_post_length, device=post_lengths.device, dtype=torch.float32) * audio_embed_frame_step - ) - - # Map each encoder output row to its audio sample using token counts - cumsum_post = torch.cat([torch.zeros(1, device=post_lengths.device), torch.cumsum(post_lengths, dim=0)[:-1]]) - cumsum_samples = torch.cumsum(sample_lengths, dim=0) - sample_indices = torch.searchsorted(cumsum_samples, cumsum_post, right=True) - - # Compute window index within each sample (0, 1, 2, ... then reset for next sample) - sample_start_rows = torch.searchsorted( - sample_indices, torch.arange(sample_lengths.shape[0], device=post_lengths.device) - ) - window_indices = ( - torch.arange(post_lengths.shape[0], device=post_lengths.device) - sample_start_rows[sample_indices] - ) - - # Compute timestamps - return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets - __all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoModel", "MusicFlamingoPreTrainedModel"] diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index e16ae28f6c68..b8e1a999daea 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -14,6 +14,7 @@ # limitations under the License. import re +from dataclasses import dataclass from math import pi from huggingface_hub.dataclasses import strict @@ -22,13 +23,16 @@ from ... import initialization as init from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ..audioflamingo3.configuration_audioflamingo3 import AudioFlamingo3Config from ..audioflamingo3.modeling_audioflamingo3 import ( AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Model, + AudioFlamingo3ModelOutputWithPast, AudioFlamingo3PreTrainedModel, ) from ..audioflamingo3.processing_audioflamingo3 import AudioFlamingo3Processor @@ -76,6 +80,7 @@ class MusicFlamingoConfig(AudioFlamingo3Config): audio_eos_token_id: int = 151671 audio_frame_step: float = 0.01 rope_parameters: dict | None = None + tie_word_embeddings: bool = False def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): @@ -252,12 +257,12 @@ def _init_weights(self, module): init.copy_(module.position_angles, buffer_value) -@auto_docstring( - custom_intro=""" - The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. - """ -) -class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): +@dataclass +class MusicFlamingoModelOutputWithPast(AudioFlamingo3ModelOutputWithPast): + pass + + +class MusicFlamingoModel(AudioFlamingo3Model): def __init__(self, config: MusicFlamingoConfig): super().__init__(config) self.pos_emb = MusicFlamingoRotaryEmbedding(config) @@ -350,65 +355,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ): r""" - input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/music-flamingo-2601-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversation = [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", - >>> }, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", - >>> }, - >>> ], - >>> } - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversation, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device, model.dtype) - - >>> outputs = model.generate(**inputs, max_new_tokens=100) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["This track is an uplifting Eurodance-style Trance-Pop anthem..."] - ```""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features( input_features, input_features_mask, input_ids=input_ids, return_dict=True @@ -420,22 +377,44 @@ def forward( ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return MusicFlamingoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. + """ +) +@forward_base_model_attrs(version="5.7") +class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): + def __init__(self, config: MusicFlamingoConfig): + super().__init__(config) + self.model = MusicFlamingoModel(config) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, input_ids, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, input_ids, **kwargs) __all__ = [ "MusicFlamingoConfig", "MusicFlamingoProcessor", "MusicFlamingoForConditionalGeneration", + "MusicFlamingoModel", "MusicFlamingoPreTrainedModel", ] diff --git a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py index 4d56a948eda1..11856b20b851 100644 --- a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py @@ -75,7 +75,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig): audio_bos_token_id: int = 151646 audio_eos_token_id: int = 151647 acoustic_tokenizer_chunk_size: int = 1440000 - tie_word_embeddings: bool = False + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.acoustic_tokenizer_encoder_config, dict): diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index 0a412957819b..5d9c6745e045 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -409,10 +409,10 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> tuple | VibeVoiceAsrModelOutputWithPast: r""" - padding_mask (): - - acoustic_tokenizer_chunk_size (): - + padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing operations on padding feature indices. + acoustic_tokenizer_chunk_size (`int`, *optional*): + Size of audio chunks processed by the acoustic and semantic tokenizers. """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index a0dbcb158268..99ffb7b77e2b 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -325,6 +325,12 @@ def forward( acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | VibeVoiceAsrModelOutputWithPast: + r""" + padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing operations on padding feature indices. + acoustic_tokenizer_chunk_size (`int`, *optional*): + Size of audio chunks processed by the acoustic and semantic tokenizers. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) From ff8b9f54c7eb2e956343b15759587770bb6f8ad0 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 15:37:16 +0200 Subject: [PATCH 14/39] XXXModel class in doc --- docs/source/en/model_doc/audioflamingo3.md | 5 +++++ docs/source/en/model_doc/glmasr.md | 5 +++++ docs/source/en/model_doc/granite_speech.md | 5 +++++ docs/source/en/model_doc/granite_speech_plus.md | 5 +++++ docs/source/en/model_doc/musicflamingo.md | 5 +++++ docs/source/en/model_doc/qwen2_audio.md | 5 +++++ docs/source/en/model_doc/vibevoice_asr.md | 5 +++++ docs/source/en/model_doc/voxtral.md | 5 +++++ docs/source/en/model_doc/voxtral_realtime.md | 5 +++++ 9 files changed, 45 insertions(+) diff --git a/docs/source/en/model_doc/audioflamingo3.md b/docs/source/en/model_doc/audioflamingo3.md index 0249480b3317..95cf1d9caa84 100644 --- a/docs/source/en/model_doc/audioflamingo3.md +++ b/docs/source/en/model_doc/audioflamingo3.md @@ -403,6 +403,11 @@ are forwarded, so you can tweak padding or tensor formats just like when calling [[autodoc]] AudioFlamingo3Encoder - forward +## AudioFlamingo3Model + +[[autodoc]] AudioFlamingo3Model + - forward + ## AudioFlamingo3ForConditionalGeneration [[autodoc]] AudioFlamingo3ForConditionalGeneration diff --git a/docs/source/en/model_doc/glmasr.md b/docs/source/en/model_doc/glmasr.md index a9acd132d8eb..ad7dda6757a6 100644 --- a/docs/source/en/model_doc/glmasr.md +++ b/docs/source/en/model_doc/glmasr.md @@ -231,6 +231,11 @@ assert decoded_outputs == EXPECTED_OUTPUT [[autodoc]] GlmAsrEncoder - forward +## GlmAsrModel + +[[autodoc]] GlmAsrModel + - forward + ## GlmAsrForConditionalGeneration [[autodoc]] GlmAsrForConditionalGeneration diff --git a/docs/source/en/model_doc/granite_speech.md b/docs/source/en/model_doc/granite_speech.md index 5115292a84d8..aedfb5eea61a 100644 --- a/docs/source/en/model_doc/granite_speech.md +++ b/docs/source/en/model_doc/granite_speech.md @@ -163,6 +163,11 @@ for i, transcription in enumerate(transcriptions): [[autodoc]] GraniteSpeechFeatureExtractor +## GraniteSpeechModel + +[[autodoc]] GraniteSpeechModel + - forward + ## GraniteSpeechForConditionalGeneration [[autodoc]] GraniteSpeechForConditionalGeneration diff --git a/docs/source/en/model_doc/granite_speech_plus.md b/docs/source/en/model_doc/granite_speech_plus.md index 22cf57935c12..810c0330fda6 100644 --- a/docs/source/en/model_doc/granite_speech_plus.md +++ b/docs/source/en/model_doc/granite_speech_plus.md @@ -143,6 +143,11 @@ for k in range(NUM_SEGMENTS): [[autodoc]] GraniteSpeechPlusEncoderConfig +## GraniteSpeechPlusModel + +[[autodoc]] GraniteSpeechPlusModel + - forward + ## GraniteSpeechPlusForConditionalGeneration [[autodoc]] GraniteSpeechPlusForConditionalGeneration diff --git a/docs/source/en/model_doc/musicflamingo.md b/docs/source/en/model_doc/musicflamingo.md index 52cab097b3b0..bdb3fe3437c6 100644 --- a/docs/source/en/model_doc/musicflamingo.md +++ b/docs/source/en/model_doc/musicflamingo.md @@ -287,6 +287,11 @@ loss.backward() [[autodoc]] MusicFlamingoProcessor +## MusicFlamingoModel + +[[autodoc]] MusicFlamingoModel + - forward + ## MusicFlamingoForConditionalGeneration [[autodoc]] MusicFlamingoForConditionalGeneration diff --git a/docs/source/en/model_doc/qwen2_audio.md b/docs/source/en/model_doc/qwen2_audio.md index 2cfd00b8e808..785b4bf5ca06 100644 --- a/docs/source/en/model_doc/qwen2_audio.md +++ b/docs/source/en/model_doc/qwen2_audio.md @@ -251,6 +251,11 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_ [[autodoc]] Qwen2AudioEncoder - forward +## Qwen2AudioModel + +[[autodoc]] Qwen2AudioModel + - forward + ## Qwen2AudioForConditionalGeneration [[autodoc]] Qwen2AudioForConditionalGeneration diff --git a/docs/source/en/model_doc/vibevoice_asr.md b/docs/source/en/model_doc/vibevoice_asr.md index f28485e6cb9e..8c29de227240 100644 --- a/docs/source/en/model_doc/vibevoice_asr.md +++ b/docs/source/en/model_doc/vibevoice_asr.md @@ -452,6 +452,11 @@ print(transcription) - apply_transcription_request - decode +## VibeVoiceAsrModel + +[[autodoc]] VibeVoiceAsrModel + - forward + ## VibeVoiceAsrForConditionalGeneration [[autodoc]] VibeVoiceAsrForConditionalGeneration diff --git a/docs/source/en/model_doc/voxtral.md b/docs/source/en/model_doc/voxtral.md index 0f520f547584..adb2c3706ffb 100644 --- a/docs/source/en/model_doc/voxtral.md +++ b/docs/source/en/model_doc/voxtral.md @@ -352,6 +352,11 @@ This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb) [[autodoc]] VoxtralEncoder - forward +## VoxtralModel + +[[autodoc]] VoxtralModel + - forward + ## VoxtralForConditionalGeneration [[autodoc]] VoxtralForConditionalGeneration diff --git a/docs/source/en/model_doc/voxtral_realtime.md b/docs/source/en/model_doc/voxtral_realtime.md index 49b274ea6659..97fe5fecff55 100644 --- a/docs/source/en/model_doc/voxtral_realtime.md +++ b/docs/source/en/model_doc/voxtral_realtime.md @@ -182,6 +182,11 @@ This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb) [[autodoc]] VoxtralRealtimeEncoder - forward +## VoxtralRealtimeModel + +[[autodoc]] VoxtralRealtimeModel + - forward + ## VoxtralRealtimeForConditionalGeneration [[autodoc]] VoxtralRealtimeForConditionalGeneration From 753b75576a66987f0c90d4022873074458aab2b2 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 15:37:33 +0200 Subject: [PATCH 15/39] add GraniteSpeechPlusModel --- src/transformers/models/auto/modeling_auto.py | 2 +- .../modeling_granite_speech_plus.py | 339 +++++++++--------- .../modular_granite_speech_plus.py | 5 + 3 files changed, 176 insertions(+), 170 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7bced9ad55eb..d636414897b3 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -215,7 +215,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("granite", "GraniteModel"), ("granite4_vision", "Granite4VisionModel"), ("granite_speech", "GraniteSpeechModel"), - ("granite_speech", "GraniteSpeechForConditionalGeneration"), + ("granite_speech_plus", "GraniteSpeechPlusModel"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), ("granitemoeshared", "GraniteMoeSharedModel"), diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py index 7e0a46ff85cf..361b4678e823 100644 --- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py @@ -107,6 +107,175 @@ def _init_weights(self, module: nn.Module): init.copy_(module.attention_dists, attention_dists) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Granite Speech outputs, with hidden states and attentions. + """ +) +class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Granite Speech model, which consists of an audio encoder, projector, and language model, + without a language modeling head. + """ +) +class GraniteSpeechPlusModel(GraniteSpeechPlusPreTrainedModel): + _supports_attention_backend = True + + def __init__(self, config: GraniteSpeechPlusConfig): + super().__init__(config) + self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechPlusEncoderProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def get_audio_features( + self, input_features: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple | BaseModelOutputWithPooling: + audio_outputs = self.encoder(input_features, return_dict=True, **kwargs) + projected_embeds = self.projector(audio_outputs.last_hidden_state) + audio_outputs.pooler_output = projected_embeds + + return audio_outputs + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_audio_mask = special_audio_mask.all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + + n_audio_tokens = special_audio_mask.sum() + n_audio_features = audio_features.shape[0] + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + return special_audio_mask + + def get_merged_audio_embeddings( + self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Adds the audio token to the model's LLM vocabulary so that we can pass it + through the tokenizer; it's assumed that the embeddings corresponding to the + <|audio|> token will be clobbered with speech features. + + Args: + input_ids (`torch.Tensor`): + Input IDs containing one or more audio tokens. + audio_features (`torch.Tensor`): + Audio features to be masked into the language embeddings to form multimodal embeddings. + input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) + Mask to be applied to audio features prior to scattering into the language embeddings. + """ + is_audio_index = input_ids == self.config.audio_token_id + llm_input_ids = torch.where(is_audio_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] + + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + if input_features_mask is not None: + audio_features = audio_features[input_features_mask] + + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + return inputs_embeds + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GraniteSpeechPlusModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor`, *optional*): + Mask to be applied to audio features prior to scattering into the language embeddings. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_features is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_features and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + # Get the base embeddings; set all audio tokens to 0 index + # to avoid out of vocabulary issues with the LLM embedding. + # Audio features will be masked into is_audio_idx indices later. + is_audio_idx = input_ids == self.config.audio_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[is_audio_idx] = 0 + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + audio_embeds = None + if input_features is not None: + if input_features.dtype != self.dtype: + input_features = input_features.to(self.dtype) + # Get the audio features from the encoder / projector + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # Merge the audio features into the LLM embeddings + inputs_embeds = self.get_merged_audio_embeddings( + input_ids=input_ids, + audio_features=audio_embeds, + input_features_mask=input_features_mask, + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return GraniteSpeechPlusModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + ### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git class GraniteSpeechPlusConformerFeedForward(nn.Module): """Feedforward module for conformer encoder blocks.""" @@ -317,21 +486,6 @@ def forward( return BaseModelOutputWithPooling(last_hidden_state=hidden_states) -@dataclass -@auto_docstring( - custom_intro=""" - Base class for Granite Speech outputs, with hidden states and attentions. - """ -) -class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): - r""" - audio_hidden_states (`torch.FloatTensor`, *optional*): - Projected audio hidden states. - """ - - audio_hidden_states: torch.FloatTensor | None = None - - @auto_docstring( custom_intro=""" Base class for Granite Speech causal language model (or autoregressive) outputs. @@ -361,160 +515,6 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): audio_hidden_states: torch.FloatTensor | None = None -@auto_docstring( - custom_intro=""" - The Granite Speech model, which consists of an audio encoder, projector, and language model, - without a language modeling head. - """ -) -class GraniteSpeechPlusModel(GraniteSpeechPlusPreTrainedModel): - _supports_attention_backend = True - - def __init__(self, config: GraniteSpeechPlusConfig): - super().__init__(config) - self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config) - self.projector = GraniteSpeechPlusEncoderProjector(config) - self.language_model = AutoModel.from_config(config.text_config) - self.post_init() - - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - @can_return_tuple - @auto_docstring - def get_audio_features( - self, input_features: torch.Tensor, **kwargs: Unpack[TransformersKwargs] - ) -> tuple | BaseModelOutputWithPooling: - audio_outputs = self.encoder(input_features, return_dict=True, **kwargs) - projected_embeds = self.projector(audio_outputs.last_hidden_state) - audio_outputs.pooler_output = projected_embeds - - return audio_outputs - - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - - def get_merged_audio_embeddings( - self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Adds the audio token to the model's LLM vocabulary so that we can pass it - through the tokenizer; it's assumed that the embeddings corresponding to the - <|audio|> token will be clobbered with speech features. - - Args: - input_ids (`torch.Tensor`): - Input IDs containing one or more audio tokens. - audio_features (`torch.Tensor`): - Audio features to be masked into the language embeddings to form multimodal embeddings. - input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) - Mask to be applied to audio features prior to scattering into the language embeddings. - """ - is_audio_index = input_ids == self.config.audio_token_id - llm_input_ids = torch.where(is_audio_index, 0, input_ids) - inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] - - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - if input_features_mask is not None: - audio_features = audio_features[input_features_mask] - - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) - return inputs_embeds - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - input_features: torch.FloatTensor | None = None, - input_features_mask: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | GraniteSpeechPlusModelOutputWithPast: - r""" - input_features_mask (`torch.Tensor`, *optional*): - Mask to be applied to audio features prior to scattering into the language embeddings. - """ - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if input_features is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_features and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - # Get the base embeddings; set all audio tokens to 0 index - # to avoid out of vocabulary issues with the LLM embedding. - # Audio features will be masked into is_audio_idx indices later. - is_audio_idx = input_ids == self.config.audio_token_id - llm_input_ids = input_ids.clone() - llm_input_ids[is_audio_idx] = 0 - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - audio_embeds = None - if input_features is not None: - if input_features.dtype != self.dtype: - input_features = input_features.to(self.dtype) - # Get the audio features from the encoder / projector - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # Merge the audio features into the LLM embeddings - inputs_embeds = self.get_merged_audio_embeddings( - input_ids=input_ids, - audio_features=audio_embeds, - input_features_mask=input_features_mask, - ) - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - - return GraniteSpeechPlusModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - audio_hidden_states=audio_embeds, - ) - - @auto_docstring( custom_intro=""" The Granite Speech Plus model, a Granite Speech variant whose projector consumes the concatenation of the @@ -684,6 +684,7 @@ def _get_adapter_name(self): __all__ = [ + "GraniteSpeechPlusModel", "GraniteSpeechPlusCTCEncoder", "GraniteSpeechPlusForConditionalGeneration", "GraniteSpeechPlusPreTrainedModel", diff --git a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py index 5d1d03024ba7..bd4527c04d87 100644 --- a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py @@ -28,6 +28,7 @@ from ..granite_speech.modeling_granite_speech import ( GraniteSpeechCTCEncoder, GraniteSpeechForConditionalGeneration, + GraniteSpeechModel, GraniteSpeechPreTrainedModel, ) @@ -123,6 +124,9 @@ def __post_init__(self, **kwargs): class GraniteSpeechPlusPreTrainedModel(GraniteSpeechPreTrainedModel): ... +class GraniteSpeechPlusModel(GraniteSpeechModel): ... + + class GraniteSpeechPlusCTCEncoder(GraniteSpeechCTCEncoder): @merge_with_config_defaults @capture_outputs @@ -168,6 +172,7 @@ class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechForConditionalGener __all__ = [ "GraniteSpeechPlusConfig", "GraniteSpeechPlusEncoderConfig", + "GraniteSpeechPlusModel", "GraniteSpeechPlusCTCEncoder", "GraniteSpeechPlusForConditionalGeneration", "GraniteSpeechPlusPreTrainedModel", From ecce4405188fd98e7ebca2fa6191695493a9cc84 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 15:37:46 +0200 Subject: [PATCH 16/39] tests now have a base_model_class --- tests/models/audioflamingo3/test_modeling_audioflamingo3.py | 2 ++ tests/models/glmasr/test_modeling_glmasr.py | 2 ++ tests/models/granite_speech/test_modeling_granite_speech.py | 2 ++ .../granite_speech_plus/test_modeling_granite_speech_plus.py | 2 ++ tests/models/musicflamingo/test_modeling_musicflamingo.py | 2 ++ tests/models/qwen2_audio/test_modeling_qwen2_audio.py | 2 ++ tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py | 3 ++- tests/models/voxtral/test_modeling_voxtral.py | 2 ++ .../models/voxtral_realtime/test_modeling_voxtral_realtime.py | 2 ++ 9 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py index 9629fe3ba086..fef42c3ac32f 100644 --- a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py +++ b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py @@ -22,6 +22,7 @@ AudioFlamingo3Config, AudioFlamingo3EncoderConfig, AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Model, AutoProcessor, Qwen2Config, is_torch_available, @@ -42,6 +43,7 @@ class AudioFlamingo3ModelTester(ALMModelTester): config_class = AudioFlamingo3Config + base_model_class = AudioFlamingo3Model conditional_generation_class = AudioFlamingo3ForConditionalGeneration text_config_class = Qwen2Config audio_config_class = AudioFlamingo3EncoderConfig diff --git a/tests/models/glmasr/test_modeling_glmasr.py b/tests/models/glmasr/test_modeling_glmasr.py index b19e91a61209..f7dfc1a98c29 100644 --- a/tests/models/glmasr/test_modeling_glmasr.py +++ b/tests/models/glmasr/test_modeling_glmasr.py @@ -19,6 +19,7 @@ AutoProcessor, GlmAsrConfig, GlmAsrForConditionalGeneration, + GlmAsrModel, LlamaConfig, is_torch_available, ) @@ -39,6 +40,7 @@ class GlmAsrModelTester(ALMModelTester): config_class = GlmAsrConfig + base_model_class = GlmAsrModel conditional_generation_class = GlmAsrForConditionalGeneration text_config_class = LlamaConfig audio_config_class = GlmAsrEncoderConfig diff --git a/tests/models/granite_speech/test_modeling_granite_speech.py b/tests/models/granite_speech/test_modeling_granite_speech.py index e4ecebbcb0ee..613790dbaa5a 100644 --- a/tests/models/granite_speech/test_modeling_granite_speech.py +++ b/tests/models/granite_speech/test_modeling_granite_speech.py @@ -23,6 +23,7 @@ GraniteSpeechConfig, GraniteSpeechEncoderConfig, GraniteSpeechForConditionalGeneration, + GraniteSpeechModel, ) from transformers.testing_utils import ( cleanup, @@ -49,6 +50,7 @@ class GraniteSpeechModelTester(ALMModelTester): config_class = GraniteSpeechConfig + base_model_class = GraniteSpeechModel conditional_generation_class = GraniteSpeechForConditionalGeneration text_config_class = GraniteConfig audio_config_class = GraniteSpeechEncoderConfig diff --git a/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py b/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py index 21f1d997efb4..04372995ffc3 100644 --- a/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py +++ b/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py @@ -20,6 +20,7 @@ GraniteSpeechPlusConfig, GraniteSpeechPlusEncoderConfig, GraniteSpeechPlusForConditionalGeneration, + GraniteSpeechPlusModel, ) from transformers.testing_utils import cleanup, require_torch, slow, torch_device from transformers.utils import is_datasets_available, is_torch_available @@ -43,6 +44,7 @@ class GraniteSpeechPlusForConditionalGenerationModelTester(GraniteSpeechModelTes """ config_class = GraniteSpeechPlusConfig + base_model_class = GraniteSpeechPlusModel conditional_generation_class = GraniteSpeechPlusForConditionalGeneration audio_config_class = GraniteSpeechPlusEncoderConfig diff --git a/tests/models/musicflamingo/test_modeling_musicflamingo.py b/tests/models/musicflamingo/test_modeling_musicflamingo.py index 2615af219ff5..43473dec0cca 100644 --- a/tests/models/musicflamingo/test_modeling_musicflamingo.py +++ b/tests/models/musicflamingo/test_modeling_musicflamingo.py @@ -24,6 +24,7 @@ AutoProcessor, MusicFlamingoConfig, MusicFlamingoForConditionalGeneration, + MusicFlamingoModel, Qwen2Config, is_torch_available, ) @@ -51,6 +52,7 @@ class MusicFlamingoModelTester(ALMModelTester): """ config_class = MusicFlamingoConfig + base_model_class = MusicFlamingoModel conditional_generation_class = MusicFlamingoForConditionalGeneration text_config_class = Qwen2Config audio_config_class = AudioFlamingo3EncoderConfig diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 1557217fdd63..8cc0fcab05c9 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -24,6 +24,7 @@ Qwen2AudioConfig, Qwen2AudioEncoderConfig, Qwen2AudioForConditionalGeneration, + Qwen2AudioModel, Qwen2Config, is_torch_available, ) @@ -43,6 +44,7 @@ class Qwen2AudioModelTester(ALMModelTester): config_class = Qwen2AudioConfig + base_model_class = Qwen2AudioModel conditional_generation_class = Qwen2AudioForConditionalGeneration text_config_class = Qwen2Config audio_config_class = Qwen2AudioEncoderConfig diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py index fc8bb11568ea..ba4d9666c527 100644 --- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py +++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py @@ -22,6 +22,7 @@ from transformers import ( VibeVoiceAsrConfig, VibeVoiceAsrForConditionalGeneration, + VibeVoiceAsrModel, is_datasets_available, is_torch_available, ) @@ -133,7 +134,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class VibeVoiceAsrForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (VibeVoiceAsrForConditionalGeneration,) if is_torch_available() else () + all_model_classes = (VibeVoiceAsrModel, VibeVoiceAsrForConditionalGeneration) if is_torch_available() else () pipeline_model_mapping = ( {"audio-text-to-text": VibeVoiceAsrForConditionalGeneration} if is_torch_available() else {} ) diff --git a/tests/models/voxtral/test_modeling_voxtral.py b/tests/models/voxtral/test_modeling_voxtral.py index 4f0c604ce05f..deeed284bc5d 100644 --- a/tests/models/voxtral/test_modeling_voxtral.py +++ b/tests/models/voxtral/test_modeling_voxtral.py @@ -21,6 +21,7 @@ VoxtralConfig, VoxtralEncoderConfig, VoxtralForConditionalGeneration, + VoxtralModel, is_torch_available, ) from transformers.testing_utils import ( @@ -40,6 +41,7 @@ class VoxtralModelTester(ALMModelTester): config_class = VoxtralConfig + base_model_class = VoxtralModel conditional_generation_class = VoxtralForConditionalGeneration text_config_class = LlamaConfig audio_config_class = VoxtralEncoderConfig diff --git a/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py b/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py index 150d7a894104..420b794cd8ee 100644 --- a/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py +++ b/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py @@ -20,6 +20,7 @@ AutoProcessor, VoxtralRealtimeConfig, VoxtralRealtimeForConditionalGeneration, + VoxtralRealtimeModel, is_datasets_available, is_torch_available, ) @@ -48,6 +49,7 @@ class VoxtralRealtimeModelTester(ALMModelTester): config_class = VoxtralRealtimeConfig + base_model_class = VoxtralRealtimeModel conditional_generation_class = VoxtralRealtimeForConditionalGeneration text_config_class = VoxtralRealtimeTextConfig audio_config_class = VoxtralRealtimeEncoderConfig From 1f5e199531b1286241cee67d7192a3809a542cea Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 15:46:31 +0200 Subject: [PATCH 17/39] fix qwen2audio --- src/transformers/models/qwen2_audio/modeling_qwen2_audio.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index f8079c133d28..0f43db3dbd57 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -44,6 +44,7 @@ Base class for Qwen2Audio outputs, with hidden states and attentions. """ ) +@dataclass class Qwen2AudioModelOutputWithPast(BaseModelOutputWithPast): r""" past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): From 6a1166b896ce4e6a3cad7e4fee308dbef3fe4f0a Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 16:11:03 +0200 Subject: [PATCH 18/39] class level base mappings --- src/transformers/conversion_mapping.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index a120e15ac355..9b6bb9bce54b 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -109,6 +109,11 @@ "LlavaOnevisionModel": "LlavaModel", "FuyuModel": "LlavaModel", "MllamaModel": "LlavaModel", + "VoxtralModel": "Qwen2AudioModel", + "VoxtralRealtimeModel": "Qwen2AudioModel", + "AudioFlamingo3Model": "Qwen2AudioModel", + "GlmAsrModel": "Qwen2AudioModel", + "MusicFlamingoModel": "Qwen2AudioModel", "MaskFormerDetrDecoder": "DetrModel", "Qwen2_5_VLForConditionalGeneration": "Qwen2VLForConditionalGeneration", # ViT-style vision models (old HuggingFace checkpoint format → new modular format) @@ -414,6 +419,9 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], + "Qwen2AudioModel": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"), + ], "granite_speech": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), From 933f8c6949deb742715fb7ad6ee5cef523004d02 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 17:21:00 +0200 Subject: [PATCH 19/39] fix test_reverse_loading_mapping --- src/transformers/conversion_mapping.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 9b6bb9bce54b..8900fc1aece0 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -428,6 +428,9 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^encoder", target_patterns="model.encoder"), WeightRenaming(source_patterns=r"^projector", target_patterns="model.projector"), ], + "GraniteSpeechModel": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"), + ], "vibevoice_asr": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), @@ -439,6 +442,9 @@ def _build_checkpoint_conversion_mapping(): ), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], + "VibeVoiceAsrModel": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"), + ], "llava_next": [ WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"), From 5383313df55d5b8198da60f4f74a28013ec98abf Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 17:48:12 +0200 Subject: [PATCH 20/39] fix --- .../models/granite_speech/modeling_granite_speech.py | 1 + .../models/granite_speech_plus/modeling_granite_speech_plus.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 07aabbde3970..f5647b1779ea 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -535,6 +535,7 @@ def set_output_embeddings(self, new_embeddings): def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) + @can_return_tuple @auto_docstring def forward( self, diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py index 361b4678e823..d04dc5de9d99 100644 --- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py @@ -556,6 +556,7 @@ def set_output_embeddings(self, new_embeddings): def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) + @can_return_tuple @auto_docstring def forward( self, From 44ab9cda6e251ed1794aab59899723bcbd8a4741 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 12 May 2026 18:44:31 +0200 Subject: [PATCH 21/39] fix --- tests/models/musicflamingo/test_modeling_musicflamingo.py | 2 +- tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/musicflamingo/test_modeling_musicflamingo.py b/tests/models/musicflamingo/test_modeling_musicflamingo.py index 43473dec0cca..d52c84eac6f9 100644 --- a/tests/models/musicflamingo/test_modeling_musicflamingo.py +++ b/tests/models/musicflamingo/test_modeling_musicflamingo.py @@ -151,7 +151,7 @@ def test_build_audio_timestamps_reconstructs_windows_from_input_ids(self): ] ) - inferred = model._build_audio_timestamps(input_ids, post_lengths, max_post_length) + inferred = model.model._build_audio_timestamps(input_ids, post_lengths, max_post_length) torch.testing.assert_close(inferred, audio_timestamps) @unittest.skip( diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py index ba4d9666c527..b2db4ee88212 100644 --- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py +++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py @@ -139,6 +139,9 @@ class VibeVoiceAsrForConditionalGenerationModelTest(ModelTesterMixin, Generation {"audio-text-to-text": VibeVoiceAsrForConditionalGeneration} if is_torch_available() else {} ) _is_composite = True + # Acoustic/semantic tokenizers run under torch.no_grad() in get_audio_features, + # so their params never receive grads — the mixin's force-unfreeze can't change that. + test_all_params_have_gradient = False def setUp(self): self.model_tester = VibeVoiceAsrModelTester(self) From 13656cc0141ca9306e840cc16628e9e6c391b511 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 13 May 2026 16:49:10 +0200 Subject: [PATCH 22/39] nit --- .../models/voxtral_realtime/modeling_voxtral_realtime.py | 2 +- .../models/voxtral_realtime/modular_voxtral_realtime.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index ae7edb28544e..8545cdb35b2c 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -982,7 +982,7 @@ def forward( return_dict=True, ) audio_embeds = audio_outputs.pooler_output - inputs_embeds = inputs_embeds + audio_embeds.to(inputs_embeds.device) + inputs_embeds += audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index d82c6417cb20..218baafb4d0a 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -589,7 +589,7 @@ def forward( return_dict=True, ) audio_embeds = audio_outputs.pooler_output - inputs_embeds = inputs_embeds + audio_embeds.to(inputs_embeds.device) + inputs_embeds += audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens From f8f99f03387fed6d216226cab5226fb5fd1ee8b2 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 13 May 2026 16:57:00 +0200 Subject: [PATCH 23/39] fix --- src/transformers/conversion_mapping.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index b0ce4780aab6..84817b01ec11 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -91,6 +91,7 @@ "audioflamingo3": "qwen2_audio", "glmasr": "qwen2_audio", "musicflamingo": "qwen2_audio", + "granite_speech_plus": "granite_speech", "gemma3n_text": "qwen3_5_text", "qwen3_5_moe_text": "qwen3_5_text", "llava_next_video": "llava_next", @@ -114,6 +115,7 @@ "AudioFlamingo3Model": "Qwen2AudioModel", "GlmAsrModel": "Qwen2AudioModel", "MusicFlamingoModel": "Qwen2AudioModel", + "GraniteSpeechPlusModel": "GraniteSpeechModel", "MaskFormerDetrDecoder": "DetrModel", "Qwen2_5_VLForConditionalGeneration": "Qwen2VLForConditionalGeneration", # ViT-style vision models (old HuggingFace checkpoint format → new modular format) From 5184f47cf716a4e4b2197d0262af0e9bb57eddfd Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 13 May 2026 17:01:01 +0200 Subject: [PATCH 24/39] deprec in 5.7 -> 5.15 --- .../models/audioflamingo3/modeling_audioflamingo3.py | 2 +- .../models/audioflamingo3/modular_audioflamingo3.py | 2 +- src/transformers/models/glmasr/modeling_glmasr.py | 2 +- src/transformers/models/glmasr/modular_glmasr.py | 2 +- .../models/granite_speech/modeling_granite_speech.py | 2 +- .../models/granite_speech_plus/modeling_granite_speech_plus.py | 2 +- .../models/granite_speech_plus/modular_granite_speech_plus.py | 2 +- src/transformers/models/musicflamingo/modeling_musicflamingo.py | 2 +- src/transformers/models/musicflamingo/modular_musicflamingo.py | 2 +- src/transformers/models/qwen2_audio/modeling_qwen2_audio.py | 2 +- src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py | 2 +- src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py | 2 +- src/transformers/models/voxtral/modeling_voxtral.py | 2 +- src/transformers/models/voxtral/modular_voxtral.py | 2 +- .../models/voxtral_realtime/modeling_voxtral_realtime.py | 2 +- .../models/voxtral_realtime/modular_voxtral_realtime.py | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index 3771bcfc30f0..5ed22f081913 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -572,7 +572,7 @@ def forward( The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index f823bf929321..dbdf48fc6041 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -274,7 +274,7 @@ def forward( The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): _tp_plan = None _pp_plan = None diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index 2701de757e22..5d7b64e91bb6 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -528,7 +528,7 @@ class GlmAsrCausalLMOutputWithPast(ModelOutput): The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index d836d89b5625..e0cf51806ebb 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -395,7 +395,7 @@ def get_audio_features( The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index f5647b1779ea..5ab849be7893 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -500,7 +500,7 @@ def forward( The Granite Speech model, which consists of an audio encoder, projector, and language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): _supports_attention_backend = True _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py index d04dc5de9d99..56bc714fceac 100644 --- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py @@ -521,7 +521,7 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): encoder's final hidden states with an arbitrary subset of its intermediate hidden states. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechPlusPreTrainedModel, GenerationMixin): _supports_attention_backend = True _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py index bd4527c04d87..9587dd033fd0 100644 --- a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py @@ -165,7 +165,7 @@ def forward( encoder's final hidden states with an arbitrary subset of its intermediate hidden states. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechForConditionalGeneration): ... diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index 3bacd6112acc..aaf4e8894f0c 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -414,7 +414,7 @@ class MusicFlamingoCausalLMOutputWithPast(ModelOutput): The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index b8e1a999daea..157d363de3a7 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -400,7 +400,7 @@ def forward( The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): def __init__(self, config: MusicFlamingoConfig): super().__init__(config) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index c1aed1ef9135..71a5e70ac674 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -773,7 +773,7 @@ def forward( The QWEN2AUDIO model which consists of an audio backbone and a language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index 7f3ffab8d6de..88ad8b4cb609 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -451,7 +451,7 @@ def forward( The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index 30a688782e05..45dcb37b753c 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -368,7 +368,7 @@ def forward( The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 68de447a3b38..1b5286547179 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -492,7 +492,7 @@ def forward( The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = ["embed_positions"] _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index 31c7193d71f8..6162f8eafabb 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -262,7 +262,7 @@ def forward( The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. """ ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = ["embed_positions"] _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index 8545cdb35b2c..2ef91367e918 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -1024,7 +1024,7 @@ def forward( ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index 218baafb4d0a..ce9db0b1e7eb 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -631,7 +631,7 @@ def forward( ) -@forward_base_model_attrs(version="5.7") +@forward_base_model_attrs(version="5.15") class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} From b272b0777355032995657e02971b3a74118a222b Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 13 May 2026 18:35:25 +0200 Subject: [PATCH 25/39] fix flaky test --- tests/models/glmasr/test_modeling_glmasr.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/glmasr/test_modeling_glmasr.py b/tests/models/glmasr/test_modeling_glmasr.py index f7dfc1a98c29..1df06e819636 100644 --- a/tests/models/glmasr/test_modeling_glmasr.py +++ b/tests/models/glmasr/test_modeling_glmasr.py @@ -50,6 +50,12 @@ def __init__(self, parent, **kwargs): kwargs.setdefault("head_dim", 8) super().__init__(parent, **kwargs) + def create_audio_mask(self): + # Deterministic full-length mask: the base default randomizes lengths in [1, feat_seq_length], + # and short samples collapse to 0 audio tokens after conv2 (s=2) + merge_factor=4, breaking + # test_mismatching_num_audio_tokens (a 0-contribution sample makes "duplicate audio" a no-op). + return torch.ones([self.batch_size, self.feat_seq_length], dtype=torch.bool).to(torch_device) + def get_audio_embeds_mask(self, audio_mask): # conv1 (s=1) preserves length; conv2 (s=2, k=3, p=1) halves; merge_factor=4 post-projector. audio_lengths = audio_mask.sum(-1) From 71c4cb513accfe89713248a31f46b36c2a2ce906 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 13 May 2026 19:21:58 +0200 Subject: [PATCH 26/39] fix flaky test --- tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py index b2db4ee88212..138d1fa653f6 100644 --- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py +++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py @@ -189,6 +189,10 @@ def test_model_outputs_equivalence(self): def test_left_padding_compatibility(self): pass + @unittest.skip(reason="VibeVoiceAsr has slight randomness due to VAE sampling.") + def test_forward_with_logits_to_keep(self): + pass + def test_sdpa_can_dispatch_composite_models(self): # VibeVoiceAsr is audio+text composite; but audio components do not use attention for model_class in self.all_model_classes: From 59c99d5475aca33cccc68c83f91535b2f82c90f1 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 14 May 2026 11:33:43 +0200 Subject: [PATCH 27/39] fix flaky test --- tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py index 138d1fa653f6..ad2225582b34 100644 --- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py +++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py @@ -193,6 +193,10 @@ def test_left_padding_compatibility(self): def test_forward_with_logits_to_keep(self): pass + @unittest.skip(reason="VibeVoiceAsr has slight randomness due to VAE sampling.") + def test_generate_methods_with_logits_to_keep(self): + pass + def test_sdpa_can_dispatch_composite_models(self): # VibeVoiceAsr is audio+text composite; but audio components do not use attention for model_class in self.all_model_classes: From 54e11e47074bb9038b6140786e981644745c4890 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 18 May 2026 11:56:27 +0200 Subject: [PATCH 28/39] remove forward_base_model_attrs --- .../audioflamingo3/modular_audioflamingo3.py | 2 - .../models/glmasr/modular_glmasr.py | 2 - .../musicflamingo/modular_musicflamingo.py | 2 - src/transformers/utils/deprecation.py | 56 ------------------- 4 files changed, 62 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index dbdf48fc6041..8a452e54da07 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -24,7 +24,6 @@ from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..qwen2_audio.modeling_qwen2_audio import ( @@ -274,7 +273,6 @@ def forward( The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. """ ) -@forward_base_model_attrs(version="5.15") class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): _tp_plan = None _pp_plan = None diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index e0cf51806ebb..669f4507bdd0 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -25,7 +25,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available, logging -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..audioflamingo3.modeling_audioflamingo3 import ( @@ -395,7 +394,6 @@ def get_audio_features( The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -@forward_base_model_attrs(version="5.15") class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index 157d363de3a7..99fcba64cbf3 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -27,7 +27,6 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check -from ...utils.deprecation import forward_base_model_attrs from ..audioflamingo3.configuration_audioflamingo3 import AudioFlamingo3Config from ..audioflamingo3.modeling_audioflamingo3 import ( AudioFlamingo3ForConditionalGeneration, @@ -400,7 +399,6 @@ def forward( The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. """ ) -@forward_base_model_attrs(version="5.15") class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): def __init__(self, config: MusicFlamingoConfig): super().__init__(config) diff --git a/src/transformers/utils/deprecation.py b/src/transformers/utils/deprecation.py index 7091bfa6759a..db0e67325d78 100644 --- a/src/transformers/utils/deprecation.py +++ b/src/transformers/utils/deprecation.py @@ -173,59 +173,3 @@ def wrapped_func(*args, **kwargs): return wrapped_func return wrapper - - -def forward_base_model_attrs(version: str): - """ - Class decorator that forwards attribute access to the base model (`self.`) - when the attribute is not found on the instance directly, and warns that direct access on the - outer class is deprecated. - - Intended for backward compatibility during refactors that move submodules from the outer - `*ForConditionalGeneration` class down to the inner base model — e.g. `model.language_model` - becoming `model.model.language_model`. - - Apply only to the outer wrapper class (the `*ForConditionalGeneration`), not to the inner - base model itself. The decorator relies on `base_model_prefix` being set on the class (which - `PreTrainedModel` subclasses always do). - - Args: - version (`str`): - The Transformers version in which direct access will be removed (e.g. `"5.7"`). - """ - - def decorator(cls): - # Resolve the inherited __getattr__ (typically nn.Module's, which looks up - # submodules/parameters/buffers) so we can delegate to it without recursing. - inherited_getattr = cls.__getattr__ - - def __getattr__(self, name): - # First, the normal nn.Module lookup (submodules, parameters, buffers). - try: - return inherited_getattr(self, name) - except AttributeError: - pass - # Only forward public attributes to the base model — private names are - # framework internals (e.g. `_is_hf_initialized`) and shouldn't warn. - if name.startswith("_"): - raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") - prefix = type(self).base_model_prefix - try: - base = inherited_getattr(self, prefix) - except AttributeError: - raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") - if hasattr(base, name): - if not is_torchdynamo_compiling(): - warnings.warn( - f"Accessing `{name}` directly on `{type(self).__name__}` is deprecated and " - f"will be removed in Transformers v{version}. Use `.{prefix}.{name}` instead.", - FutureWarning, - stacklevel=2, - ) - return getattr(base, name) - raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") - - cls.__getattr__ = __getattr__ - return cls - - return decorator From 241f6f8cf1d4f2e1e905cde55053269cb09f91a9 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 18 May 2026 12:07:38 +0200 Subject: [PATCH 29/39] _supports_attention_backend in PretrainedModel --- .../models/audioflamingo3/modular_audioflamingo3.py | 3 +-- src/transformers/models/glmasr/modular_glmasr.py | 2 -- .../models/granite_speech/modeling_granite_speech.py | 4 +--- .../models/vibevoice_asr/modular_vibevoice_asr.py | 3 +-- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index 8a452e54da07..bbf51b421b11 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -52,7 +52,7 @@ class AudioFlamingo3EncoderLayer(WhisperEncoderLayer): class AudioFlamingo3PreTrainedModel(Qwen2AudioPreTrainedModel): - pass + _supports_attention_backend = True @dataclass @@ -181,7 +181,6 @@ def __init__(self, config: AudioFlamingo3Config): """ ) class AudioFlamingo3Model(VoxtralModel): - _supports_attention_backend = True _tp_plan = None _pp_plan = None _keep_in_fp32_modules_strict = None diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index 669f4507bdd0..2ee71006d782 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -358,8 +358,6 @@ def __init__(self, config: GlmAsrConfig): """ ) class GlmAsrModel(AudioFlamingo3Model): - _supports_attention_backend = True - @can_return_tuple @auto_docstring( custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector." diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 5ab849be7893..d0eef000fb72 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -285,6 +285,7 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel): _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this _supports_sdpa = True + _supports_attention_backend = True @torch.no_grad() def _init_weights(self, module: nn.Module): @@ -348,8 +349,6 @@ def forward( """ ) class GraniteSpeechModel(GraniteSpeechPreTrainedModel): - _supports_attention_backend = True - def __init__(self, config: GraniteSpeechConfig): super().__init__(config) self.encoder = GraniteSpeechCTCEncoder(config.encoder_config) @@ -502,7 +501,6 @@ def forward( ) @forward_base_model_attrs(version="5.15") class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): - _supports_attention_backend = True _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: GraniteSpeechConfig): diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index 45dcb37b753c..ff9bcdb6a450 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -167,6 +167,7 @@ class VibeVoiceAsrPreTrainedModel(VibeVoiceAcousticTokenizerPreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True @dataclass @@ -219,8 +220,6 @@ class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): """ ) class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): - _supports_attention_backend = True - def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) From d75fca8283627973ac9c67411ea4eabfe2fb7c2c Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 18 May 2026 12:59:02 +0200 Subject: [PATCH 30/39] removed redundant get/set_input_embeddings --- .../audioflamingo3/modeling_audioflamingo3.py | 16 +------ .../models/glmasr/modeling_glmasr.py | 16 +------ .../granite_speech/modeling_granite_speech.py | 14 ------ .../modeling_granite_speech_plus.py | 48 +++++++------------ .../modular_granite_speech_plus.py | 2 - .../musicflamingo/modeling_musicflamingo.py | 16 +------ .../qwen2_audio/modeling_qwen2_audio.py | 14 ------ .../vibevoice_asr/modeling_vibevoice_asr.py | 17 +------ .../vibevoice_asr/modular_vibevoice_asr.py | 14 ------ .../models/voxtral/modeling_voxtral.py | 14 ------ .../models/voxtral/modular_voxtral.py | 14 ------ .../modeling_voxtral_realtime.py | 18 ------- .../modular_voxtral_realtime.py | 18 ------- 13 files changed, 20 insertions(+), 201 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index 5ed22f081913..23bf2425a221 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -36,7 +36,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -256,6 +255,7 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True @dataclass @@ -446,7 +446,6 @@ def forward(self, audio_features): """ ) class AudioFlamingo3Model(AudioFlamingo3PreTrainedModel): - _supports_attention_backend = True _tp_plan = None _pp_plan = None _keep_in_fp32_modules_strict = None @@ -572,7 +571,6 @@ def forward( The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. """ ) -@forward_base_model_attrs(version="5.15") class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -585,18 +583,6 @@ def __init__(self, config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, input_features, input_features_mask, **kwargs): return self.model.get_audio_features(input_features, input_features_mask, **kwargs) diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index 5d7b64e91bb6..6ed4e3acc8fe 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -37,7 +37,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available, torch_compilable_check -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -291,6 +290,7 @@ class GlmAsrPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True # TODO: @eustlb, this is what WhisperEncoder should look like @@ -372,7 +372,6 @@ class GlmAsrModelOutputWithPast(BaseModelOutputWithPast): """ ) class GlmAsrModel(GlmAsrPreTrainedModel): - _supports_attention_backend = True _tp_plan = None _pp_plan = None _keep_in_fp32_modules_strict = None @@ -528,7 +527,6 @@ class GlmAsrCausalLMOutputWithPast(ModelOutput): The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -@forward_base_model_attrs(version="5.15") class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -541,18 +539,6 @@ def __init__(self, config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, input_features, input_features_mask, **kwargs): return self.model.get_audio_features(input_features, input_features_mask, **kwargs) diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index d0eef000fb72..ca9c4cddfd08 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -33,7 +33,6 @@ logging, torch_compilable_check, ) -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -499,7 +498,6 @@ def forward( The Granite Speech model, which consists of an audio encoder, projector, and language model. """ ) -@forward_base_model_attrs(version="5.15") class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -518,18 +516,6 @@ def __init__(self, config: GraniteSpeechConfig): self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py index 56bc714fceac..cde1c300f364 100644 --- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py @@ -39,7 +39,6 @@ logging, torch_compilable_check, ) -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -49,6 +48,21 @@ logger = logging.get_logger(__name__) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Granite Speech outputs, with hidden states and attentions. + """ +) +class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + ### Projector class GraniteSpeechPlusEncoderProjector(nn.Module): def __init__(self, config: GraniteSpeechPlusConfig): @@ -92,6 +106,7 @@ class GraniteSpeechPlusPreTrainedModel(PreTrainedModel): _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this _supports_sdpa = True + _supports_attention_backend = True @torch.no_grad() def _init_weights(self, module: nn.Module): @@ -107,21 +122,6 @@ def _init_weights(self, module: nn.Module): init.copy_(module.attention_dists, attention_dists) -@dataclass -@auto_docstring( - custom_intro=""" - Base class for Granite Speech outputs, with hidden states and attentions. - """ -) -class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): - r""" - audio_hidden_states (`torch.FloatTensor`, *optional*): - Projected audio hidden states. - """ - - audio_hidden_states: torch.FloatTensor | None = None - - @auto_docstring( custom_intro=""" The Granite Speech model, which consists of an audio encoder, projector, and language model, @@ -129,8 +129,6 @@ class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): """ ) class GraniteSpeechPlusModel(GraniteSpeechPlusPreTrainedModel): - _supports_attention_backend = True - def __init__(self, config: GraniteSpeechPlusConfig): super().__init__(config) self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config) @@ -521,9 +519,7 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): encoder's final hidden states with an arbitrary subset of its intermediate hidden states. """ ) -@forward_base_model_attrs(version="5.15") class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechPlusPreTrainedModel, GenerationMixin): - _supports_attention_backend = True _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: GraniteSpeechPlusConfig): @@ -541,18 +537,6 @@ def __init__(self, config: GraniteSpeechPlusConfig): self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) diff --git a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py index 9587dd033fd0..f5865379fba1 100644 --- a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py @@ -21,7 +21,6 @@ from ...modeling_outputs import BaseModelOutputWithPooling from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..granite_speech.configuration_granite_speech import GraniteSpeechConfig, GraniteSpeechEncoderConfig @@ -165,7 +164,6 @@ def forward( encoder's final hidden states with an arbitrary subset of its intermediate hidden states. """ ) -@forward_base_model_attrs(version="5.15") class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechForConditionalGeneration): ... diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index aaf4e8894f0c..91ac86689ef0 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -35,7 +35,6 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check -from ...utils.deprecation import forward_base_model_attrs from ..auto import AutoModel from .configuration_musicflamingo import MusicFlamingoConfig @@ -143,6 +142,7 @@ class MusicFlamingoPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True @torch.no_grad() def _init_weights(self, module): @@ -212,7 +212,6 @@ def apply_rotary_time_emb(hidden_states, cos, sin): """ ) class MusicFlamingoModel(MusicFlamingoPreTrainedModel): - _supports_attention_backend = True _tp_plan = None _pp_plan = None _keep_in_fp32_modules_strict = None @@ -414,7 +413,6 @@ class MusicFlamingoCausalLMOutputWithPast(ModelOutput): The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. """ ) -@forward_base_model_attrs(version="5.15") class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -427,18 +425,6 @@ def __init__(self, config: MusicFlamingoConfig): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, input_features, input_features_mask, input_ids, **kwargs): return self.model.get_audio_features(input_features, input_features_mask, input_ids, **kwargs) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 71a5e70ac674..a69594941ca1 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -29,7 +29,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging, torch_compilable_check -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -773,7 +772,6 @@ def forward( The QWEN2AUDIO model which consists of an audio backbone and a language model. """ ) -@forward_base_model_attrs(version="5.15") class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -783,18 +781,6 @@ def __init__(self, config: Qwen2AudioConfig): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - @property def padding_side(self): return self.model.padding_side diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index 88ad8b4cb609..1b9e68244855 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -31,7 +31,6 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling -from ...utils.deprecation import forward_base_model_attrs from ..auto import AutoModel from .configuration_vibevoice_asr import VibeVoiceAsrConfig @@ -244,6 +243,7 @@ class VibeVoiceAsrPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True def _init_weights(self, module): super()._init_weights(module) @@ -302,8 +302,6 @@ class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): """ ) class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): - _supports_attention_backend = True - def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) @@ -451,7 +449,6 @@ def forward( The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. """ ) -@forward_base_model_attrs(version="5.15") class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -461,18 +458,6 @@ def __init__(self, config: VibeVoiceAsrConfig): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index ff9bcdb6a450..55315bf51a49 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -27,7 +27,6 @@ ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.deprecation import forward_base_model_attrs from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel from ..qwen2.modeling_qwen2 import Qwen2RMSNorm from ..vibevoice_acoustic_tokenizer.modeling_vibevoice_acoustic_tokenizer import ( @@ -367,7 +366,6 @@ def forward( The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. """ ) -@forward_base_model_attrs(version="5.15") class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -377,18 +375,6 @@ def __init__(self, config: VibeVoiceAsrConfig): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 1b5286547179..5903170ef44a 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -34,7 +34,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -492,7 +491,6 @@ def forward( The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. """ ) -@forward_base_model_attrs(version="5.15") class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = ["embed_positions"] _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -503,18 +501,6 @@ def __init__(self, config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index 6162f8eafabb..3ecaa4cd6175 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -28,7 +28,6 @@ ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..auto import AutoModel @@ -262,7 +261,6 @@ def forward( The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. """ ) -@forward_base_model_attrs(version="5.15") class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = ["embed_positions"] _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -273,18 +271,6 @@ def __init__(self, config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index 2ef91367e918..a513ca0e29c2 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -40,7 +40,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from .configuration_voxtral_realtime import ( @@ -1024,7 +1023,6 @@ def forward( ) -@forward_base_model_attrs(version="5.15") class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -1034,25 +1032,9 @@ def __init__(self, config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) - @property - def audio_tower(self): - return self.model.audio_tower - @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index ce9db0b1e7eb..a45d50bdb222 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -41,7 +41,6 @@ ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from .configuration_voxtral_realtime import VoxtralRealtimeEncoderConfig @@ -631,7 +630,6 @@ def forward( ) -@forward_base_model_attrs(version="5.15") class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} @@ -641,25 +639,9 @@ def __init__(self, config): self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self) -> nn.Module: - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - def get_audio_features(self, *args, **kwargs): return self.model.get_audio_features(*args, **kwargs) - @property - def audio_tower(self): - return self.model.audio_tower - @can_return_tuple @auto_docstring def forward( From 9cb259b90a4463da2c273cab70e0a48b14eaac5b Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 18 May 2026 14:43:58 +0200 Subject: [PATCH 31/39] remove requires_grad_ handling --- .../models/vibevoice_asr/modeling_vibevoice_asr.py | 6 ------ .../models/vibevoice_asr/modular_vibevoice_asr.py | 6 ------ 2 files changed, 12 deletions(-) diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index 1b9e68244855..b7d2902aa6c1 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -309,12 +309,6 @@ def __init__(self, config: VibeVoiceAsrConfig): self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) self.language_model = AutoModel.from_config(config.text_config) self.post_init() - # Acoustic/semantic tokenizers are run under no_grad in `get_audio_features`; freeze - # their parameters so grad-checkpointing and training sanity checks don't flag them. - for p in self.acoustic_tokenizer_encoder.parameters(): - p.requires_grad_(False) - for p in self.semantic_tokenizer_encoder.parameters(): - p.requires_grad_(False) def get_input_embeddings(self): return self.language_model.get_input_embeddings() diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index 55315bf51a49..c664823c039a 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -226,12 +226,6 @@ def __init__(self, config: VibeVoiceAsrConfig): self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) self.language_model = AutoModel.from_config(config.text_config) self.post_init() - # Acoustic/semantic tokenizers are run under no_grad in `get_audio_features`; freeze - # their parameters so grad-checkpointing and training sanity checks don't flag them. - for p in self.acoustic_tokenizer_encoder.parameters(): - p.requires_grad_(False) - for p in self.semantic_tokenizer_encoder.parameters(): - p.requires_grad_(False) def get_input_embeddings(self): return self.language_model.get_input_embeddings() From c8b7788ec661e80f119e71d8ed8649d972f0612a Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 18 May 2026 14:52:24 +0200 Subject: [PATCH 32/39] make fix-repo --- docs/source/en/model_doc/pe_audio.md | 2 +- docs/source/en/model_doc/pe_audio_video.md | 2 +- docs/source/en/model_doc/pe_video.md | 2 +- .../modeling_granite_speech_plus.py | 30 +++++++++---------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/source/en/model_doc/pe_audio.md b/docs/source/en/model_doc/pe_audio.md index 21e0f8938b9e..6b2210f26d4e 100644 --- a/docs/source/en/model_doc/pe_audio.md +++ b/docs/source/en/model_doc/pe_audio.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.* +*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.* # PE Audio diff --git a/docs/source/en/model_doc/pe_audio_video.md b/docs/source/en/model_doc/pe_audio_video.md index b091383413d1..cd2377d35781 100644 --- a/docs/source/en/model_doc/pe_audio_video.md +++ b/docs/source/en/model_doc/pe_audio_video.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.* +*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.* # PE Audio Video diff --git a/docs/source/en/model_doc/pe_video.md b/docs/source/en/model_doc/pe_video.md index 872273f78597..48cb691792f7 100644 --- a/docs/source/en/model_doc/pe_video.md +++ b/docs/source/en/model_doc/pe_video.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.* +*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.* # PE Video diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py index cde1c300f364..b9750ac9d581 100644 --- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py @@ -48,21 +48,6 @@ logger = logging.get_logger(__name__) -@dataclass -@auto_docstring( - custom_intro=""" - Base class for Granite Speech outputs, with hidden states and attentions. - """ -) -class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): - r""" - audio_hidden_states (`torch.FloatTensor`, *optional*): - Projected audio hidden states. - """ - - audio_hidden_states: torch.FloatTensor | None = None - - ### Projector class GraniteSpeechPlusEncoderProjector(nn.Module): def __init__(self, config: GraniteSpeechPlusConfig): @@ -122,6 +107,21 @@ def _init_weights(self, module: nn.Module): init.copy_(module.attention_dists, attention_dists) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Granite Speech outputs, with hidden states and attentions. + """ +) +class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The Granite Speech model, which consists of an audio encoder, projector, and language model, From 8c51b43af93ff574bc44d67561427da558691523 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 18 May 2026 15:09:43 +0200 Subject: [PATCH 33/39] fix voxtral realtime --- .../models/voxtral_realtime/modeling_voxtral_realtime.py | 2 +- .../models/voxtral_realtime/modular_voxtral_realtime.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index a513ca0e29c2..aebd752d2aa8 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -1149,7 +1149,7 @@ def _prepare_model_inputs( input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): - model_kwargs["encoder_inputs_embeds"] = self.audio_tower.embedder(model_kwargs.pop("input_features")) + model_kwargs["encoder_inputs_embeds"] = self.model.audio_tower.embedder(model_kwargs.pop("input_features")) elif isinstance(input_features, GeneratorType): input_features_generator = model_kwargs.pop("input_features") diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index a45d50bdb222..81bf2259d726 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -756,7 +756,7 @@ def _prepare_model_inputs( input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): - model_kwargs["encoder_inputs_embeds"] = self.audio_tower.embedder(model_kwargs.pop("input_features")) + model_kwargs["encoder_inputs_embeds"] = self.model.audio_tower.embedder(model_kwargs.pop("input_features")) elif isinstance(input_features, GeneratorType): input_features_generator = model_kwargs.pop("input_features") From 26329fcdeed413ef667a96c5d68d3982750384af Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 18 May 2026 17:25:18 +0200 Subject: [PATCH 34/39] update vibevoice_asr test --- tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py index ad2225582b34..dc01897a07ff 100644 --- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py +++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py @@ -209,16 +209,17 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + language_model_sdpa = model_sdpa.base_model.language_model + text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager" self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(language_model_sdpa.config._attn_implementation == text_attn) # Eager model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.base_model.language_model.config._attn_implementation == "eager") for _, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ From 5519f41a1a9d2ad63c94ed30f1307719683bf586 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 18 May 2026 18:49:35 +0200 Subject: [PATCH 35/39] update test --- tests/models/musicflamingo/test_modeling_musicflamingo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/musicflamingo/test_modeling_musicflamingo.py b/tests/models/musicflamingo/test_modeling_musicflamingo.py index d52c84eac6f9..f9d48f3e0640 100644 --- a/tests/models/musicflamingo/test_modeling_musicflamingo.py +++ b/tests/models/musicflamingo/test_modeling_musicflamingo.py @@ -103,7 +103,7 @@ class MusicFlamingoForConditionalGenerationModelTest(ALMModelTest, unittest.Test def test_rotary_window_axis_resets_per_audio(self): config = self.model_tester.get_config() - pos_emb = MusicFlamingoForConditionalGeneration(config).pos_emb.to(torch_device) + pos_emb = MusicFlamingoForConditionalGeneration(config).model.pos_emb.to(torch_device) timestamps = torch.tensor( [ @@ -133,7 +133,7 @@ def test_build_audio_timestamps_reconstructs_windows_from_input_ids(self): input_ids[0, :45] = config.audio_token_id input_ids[1, :30] = config.audio_token_id - _, post_lengths = model.audio_tower._get_feat_extract_output_lengths( + _, post_lengths = model.model.audio_tower._get_feat_extract_output_lengths( input_features_mask.sum(-1).to(torch.long) ) max_post_length = int(post_lengths.max().item()) From c4ac7407e6e19fa415c4a194cb6c2444ed8c5ee0 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 20 May 2026 10:43:46 +0200 Subject: [PATCH 36/39] Update src/transformers/models/audioflamingo3/modular_audioflamingo3.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/audioflamingo3/modular_audioflamingo3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index bbf51b421b11..d03ed66df90d 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -275,7 +275,6 @@ def forward( class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): _tp_plan = None _pp_plan = None - _keep_in_fp32_modules_strict = None def __init__(self, config): super().__init__(config) From daad5ecf2e498bf4b5470a4cfec29caded02cdb5 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 20 May 2026 10:43:54 +0200 Subject: [PATCH 37/39] Update src/transformers/models/audioflamingo3/modular_audioflamingo3.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/audioflamingo3/modular_audioflamingo3.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index d03ed66df90d..825675e55d01 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -273,8 +273,6 @@ def forward( """ ) class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): - _tp_plan = None - _pp_plan = None def __init__(self, config): super().__init__(config) From b09349dbc979edf2bdfc5db79f8dba943b261e67 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 20 May 2026 10:44:06 +0200 Subject: [PATCH 38/39] Update src/transformers/models/glmasr/modeling_glmasr.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/glmasr/modeling_glmasr.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index 6ed4e3acc8fe..7b5511dd5dbd 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -528,10 +528,7 @@ class GlmAsrCausalLMOutputWithPast(ModelOutput): """ ) class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} - _tp_plan = None - _pp_plan = None def __init__(self, config): super().__init__(config) From 6f5be638f8955a572d2b82fca8b51d79c0484279 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 20 May 2026 11:40:38 +0200 Subject: [PATCH 39/39] make fix-repo --- .../models/audioflamingo3/modular_audioflamingo3.py | 1 - src/transformers/models/musicflamingo/modular_musicflamingo.py | 1 + src/transformers/models/parakeet/processing_parakeet.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index 5b36b7cecf87..86be1ffa376f 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -274,7 +274,6 @@ def forward( """ ) class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): - def __init__(self, config): super().__init__(config) self.model = AudioFlamingo3Model(config) diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index 0ea0f39c3146..325845ad4da9 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from math import pi from huggingface_hub.dataclasses import strict diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index 10563490345d..d7a4fb4ed457 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -44,7 +44,7 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False): @auto_docstring class ParakeetProcessor(ProcessorMixin): def __init__(self, feature_extractor, tokenizer, blank_token=""): - """ + r""" blank_token (`str`, *optional*, defaults to `""`): Blank token for TDT decoding. """