diff --git a/docs/source/en/model_doc/audioflamingo3.md b/docs/source/en/model_doc/audioflamingo3.md index 0249480b3317..95cf1d9caa84 100644 --- a/docs/source/en/model_doc/audioflamingo3.md +++ b/docs/source/en/model_doc/audioflamingo3.md @@ -403,6 +403,11 @@ are forwarded, so you can tweak padding or tensor formats just like when calling [[autodoc]] AudioFlamingo3Encoder - forward +## AudioFlamingo3Model + +[[autodoc]] AudioFlamingo3Model + - forward + ## AudioFlamingo3ForConditionalGeneration [[autodoc]] AudioFlamingo3ForConditionalGeneration diff --git a/docs/source/en/model_doc/glmasr.md b/docs/source/en/model_doc/glmasr.md index a9acd132d8eb..ad7dda6757a6 100644 --- a/docs/source/en/model_doc/glmasr.md +++ b/docs/source/en/model_doc/glmasr.md @@ -231,6 +231,11 @@ assert decoded_outputs == EXPECTED_OUTPUT [[autodoc]] GlmAsrEncoder - forward +## GlmAsrModel + +[[autodoc]] GlmAsrModel + - forward + ## GlmAsrForConditionalGeneration [[autodoc]] GlmAsrForConditionalGeneration diff --git a/docs/source/en/model_doc/granite_speech.md b/docs/source/en/model_doc/granite_speech.md index 5115292a84d8..aedfb5eea61a 100644 --- a/docs/source/en/model_doc/granite_speech.md +++ b/docs/source/en/model_doc/granite_speech.md @@ -163,6 +163,11 @@ for i, transcription in enumerate(transcriptions): [[autodoc]] GraniteSpeechFeatureExtractor +## GraniteSpeechModel + +[[autodoc]] GraniteSpeechModel + - forward + ## GraniteSpeechForConditionalGeneration [[autodoc]] GraniteSpeechForConditionalGeneration diff --git a/docs/source/en/model_doc/granite_speech_plus.md b/docs/source/en/model_doc/granite_speech_plus.md index 22cf57935c12..810c0330fda6 100644 --- a/docs/source/en/model_doc/granite_speech_plus.md +++ b/docs/source/en/model_doc/granite_speech_plus.md @@ -143,6 +143,11 @@ for k in range(NUM_SEGMENTS): [[autodoc]] GraniteSpeechPlusEncoderConfig +## GraniteSpeechPlusModel + +[[autodoc]] GraniteSpeechPlusModel + - forward + ## GraniteSpeechPlusForConditionalGeneration [[autodoc]] GraniteSpeechPlusForConditionalGeneration diff --git a/docs/source/en/model_doc/hyperclovax.md b/docs/source/en/model_doc/hyperclovax.md index 2725fbbc090f..0a26d0bde2b3 100644 --- a/docs/source/en/model_doc/hyperclovax.md +++ b/docs/source/en/model_doc/hyperclovax.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2025-07-21 and added to Hugging Face Transformers on 2026-05-06.* +*This model was released on 2025-07-21 and added to Hugging Face Transformers on 2026-05-08.*
diff --git a/docs/source/en/model_doc/musicflamingo.md b/docs/source/en/model_doc/musicflamingo.md index 52cab097b3b0..bdb3fe3437c6 100644 --- a/docs/source/en/model_doc/musicflamingo.md +++ b/docs/source/en/model_doc/musicflamingo.md @@ -287,6 +287,11 @@ loss.backward() [[autodoc]] MusicFlamingoProcessor +## MusicFlamingoModel + +[[autodoc]] MusicFlamingoModel + - forward + ## MusicFlamingoForConditionalGeneration [[autodoc]] MusicFlamingoForConditionalGeneration diff --git a/docs/source/en/model_doc/pe_audio.md b/docs/source/en/model_doc/pe_audio.md index 21e0f8938b9e..6b2210f26d4e 100644 --- a/docs/source/en/model_doc/pe_audio.md +++ b/docs/source/en/model_doc/pe_audio.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.* +*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.* # PE Audio diff --git a/docs/source/en/model_doc/pe_audio_video.md b/docs/source/en/model_doc/pe_audio_video.md index b091383413d1..cd2377d35781 100644 --- a/docs/source/en/model_doc/pe_audio_video.md +++ b/docs/source/en/model_doc/pe_audio_video.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.* +*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.* # PE Audio Video diff --git a/docs/source/en/model_doc/pe_video.md b/docs/source/en/model_doc/pe_video.md index 872273f78597..48cb691792f7 100644 --- a/docs/source/en/model_doc/pe_video.md +++ b/docs/source/en/model_doc/pe_video.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.* +*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.* # PE Video diff --git a/docs/source/en/model_doc/qwen2_audio.md b/docs/source/en/model_doc/qwen2_audio.md index 2cfd00b8e808..785b4bf5ca06 100644 --- a/docs/source/en/model_doc/qwen2_audio.md +++ b/docs/source/en/model_doc/qwen2_audio.md @@ -251,6 +251,11 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_ [[autodoc]] Qwen2AudioEncoder - forward +## Qwen2AudioModel + +[[autodoc]] Qwen2AudioModel + - forward + ## Qwen2AudioForConditionalGeneration [[autodoc]] Qwen2AudioForConditionalGeneration diff --git a/docs/source/en/model_doc/vibevoice_asr.md b/docs/source/en/model_doc/vibevoice_asr.md index f28485e6cb9e..8c29de227240 100644 --- a/docs/source/en/model_doc/vibevoice_asr.md +++ b/docs/source/en/model_doc/vibevoice_asr.md @@ -452,6 +452,11 @@ print(transcription) - apply_transcription_request - decode +## VibeVoiceAsrModel + +[[autodoc]] VibeVoiceAsrModel + - forward + ## VibeVoiceAsrForConditionalGeneration [[autodoc]] VibeVoiceAsrForConditionalGeneration diff --git a/docs/source/en/model_doc/voxtral.md b/docs/source/en/model_doc/voxtral.md index 0f520f547584..adb2c3706ffb 100644 --- a/docs/source/en/model_doc/voxtral.md +++ b/docs/source/en/model_doc/voxtral.md @@ -352,6 +352,11 @@ This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb) [[autodoc]] VoxtralEncoder - forward +## VoxtralModel + +[[autodoc]] VoxtralModel + - forward + ## VoxtralForConditionalGeneration [[autodoc]] VoxtralForConditionalGeneration diff --git a/docs/source/en/model_doc/voxtral_realtime.md b/docs/source/en/model_doc/voxtral_realtime.md index 49b274ea6659..97fe5fecff55 100644 --- a/docs/source/en/model_doc/voxtral_realtime.md +++ b/docs/source/en/model_doc/voxtral_realtime.md @@ -182,6 +182,11 @@ This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb) [[autodoc]] VoxtralRealtimeEncoder - forward +## VoxtralRealtimeModel + +[[autodoc]] VoxtralRealtimeModel + - forward + ## VoxtralRealtimeForConditionalGeneration [[autodoc]] VoxtralRealtimeForConditionalGeneration diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 858b8d8d0127..4b4a464caf05 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -87,6 +87,14 @@ "vipllava": "llava", "mistral3": "llava", "pp_chart2table": "llava", + "voxtral": "qwen2_audio", + "voxtral_realtime": "qwen2_audio", + "audioflamingo3": "qwen2_audio", + "glmasr": "qwen2_audio", + "musicflamingo": "qwen2_audio", + "granite_speech_plus": "granite_speech", + "gemma3n_text": "qwen3_5_text", + "qwen3_5_moe_text": "qwen3_5_text", "llava_next_video": "llava_next", "llava_onevision": "llava_next", # class-based mappings @@ -103,6 +111,12 @@ "LlavaOnevisionModel": "LlavaModel", "FuyuModel": "LlavaModel", "MllamaModel": "LlavaModel", + "VoxtralModel": "Qwen2AudioModel", + "VoxtralRealtimeModel": "Qwen2AudioModel", + "AudioFlamingo3Model": "Qwen2AudioModel", + "GlmAsrModel": "Qwen2AudioModel", + "MusicFlamingoModel": "Qwen2AudioModel", + "GraniteSpeechPlusModel": "GraniteSpeechModel", "MaskFormerDetrDecoder": "DetrModel", "Qwen2_5_VLForConditionalGeneration": "Qwen2VLForConditionalGeneration", # ViT-style vision models (old HuggingFace checkpoint format → new modular format) @@ -420,6 +434,38 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], + "qwen2_audio": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "Qwen2AudioModel": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"), + ], + "granite_speech": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^encoder", target_patterns="model.encoder"), + WeightRenaming(source_patterns=r"^projector", target_patterns="model.projector"), + ], + "GraniteSpeechModel": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"), + ], + "vibevoice_asr": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming( + source_patterns=r"^acoustic_tokenizer_encoder", target_patterns="model.acoustic_tokenizer_encoder" + ), + WeightRenaming( + source_patterns=r"^semantic_tokenizer_encoder", target_patterns="model.semantic_tokenizer_encoder" + ), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "VibeVoiceAsrModel": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"), + ], "llava_next": [ WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"), diff --git a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py index 096e263d856d..cd81ef805205 100644 --- a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py @@ -100,6 +100,7 @@ class AudioFlamingo3Config(PreTrainedConfig): audio_token_id: int = 151669 projector_hidden_act: str = "gelu" projector_bias: bool = True + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index ca77ef4a6cb1..26972fcf18b2 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -21,6 +21,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass import torch from torch import nn @@ -31,13 +32,13 @@ from ...masking_utils import create_bidirectional_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig @@ -254,6 +255,43 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True + + +@dataclass +class AudioFlamingo3ModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for AudioFlamingo3 causal language model (or autoregressive) outputs. + """ +) +class AudioFlamingo3CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None @auto_docstring( @@ -403,38 +441,23 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _supports_attention_backend = True +class AudioFlamingo3Model(AudioFlamingo3PreTrainedModel): _tp_plan = None _sp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -447,11 +470,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -504,77 +523,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | AudioFlamingo3ModelOutputWithPast: r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/audio-flamingo-3-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversations = [ - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> {"type": "text", "text": "Transcribe the input speech."}, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav", - >>> }, - >>> ], - >>> } - >>> ], - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?", - >>> }, - >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"}, - >>> ], - >>> } - >>> ], - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversations, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device) - - >>> outputs = model.generate(**inputs, max_new_tokens=500) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."] - ```""" - + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output @@ -584,17 +543,103 @@ def forward( ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return AudioFlamingo3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + """ +) +class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = AudioFlamingo3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | AudioFlamingo3CausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return AudioFlamingo3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -611,4 +656,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"] +__all__ = [ + "AudioFlamingo3ForConditionalGeneration", + "AudioFlamingo3PreTrainedModel", + "AudioFlamingo3Encoder", + "AudioFlamingo3Model", +] diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index a170d8ee4124..86be1ffa376f 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...masking_utils import create_bidirectional_mask -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults @@ -28,7 +30,12 @@ Qwen2AudioEncoder, Qwen2AudioPreTrainedModel, ) -from ..voxtral.modeling_voxtral import VoxtralForConditionalGeneration, VoxtralMultiModalProjector +from ..voxtral.modeling_voxtral import ( + VoxtralForConditionalGeneration, + VoxtralModel, + VoxtralModelOutputWithPast, + VoxtralMultiModalProjector, +) from ..whisper.modeling_whisper import WhisperAttention, WhisperEncoderLayer from .configuration_audioflamingo3 import AudioFlamingo3Config @@ -45,9 +52,40 @@ class AudioFlamingo3EncoderLayer(WhisperEncoderLayer): class AudioFlamingo3PreTrainedModel(Qwen2AudioPreTrainedModel): + _supports_attention_backend = True + + +@dataclass +class AudioFlamingo3ModelOutputWithPast(VoxtralModelOutputWithPast): pass +@dataclass +@auto_docstring( + custom_intro=""" + Base class for AudioFlamingo3 causal language model (or autoregressive) outputs. + """ +) +class AudioFlamingo3CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The audio model from AudioFlamingo3 without any head or projection on top. @@ -138,11 +176,11 @@ def __init__(self, config: AudioFlamingo3Config): @auto_docstring( custom_intro=""" - The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): - _supports_attention_backend = True +class AudioFlamingo3Model(VoxtralModel): _tp_plan = None _sp_plan = None _pp_plan = None @@ -163,11 +201,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -196,77 +230,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ): r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/audio-flamingo-3-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversations = [ - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> {"type": "text", "text": "Transcribe the input speech."}, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav", - >>> }, - >>> ], - >>> } - >>> ], - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?", - >>> }, - >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"}, - >>> ], - >>> } - >>> ], - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversations, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device) - - >>> outputs = model.generate(**inputs, max_new_tokens=500) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."] - ```""" - + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output @@ -276,17 +250,99 @@ def forward( ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return AudioFlamingo3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + """ +) +class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + self.model = AudioFlamingo3Model(config) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | AudioFlamingo3CausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return AudioFlamingo3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -303,4 +359,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"] +__all__ = [ + "AudioFlamingo3ForConditionalGeneration", + "AudioFlamingo3PreTrainedModel", + "AudioFlamingo3Encoder", + "AudioFlamingo3Model", +] diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e9de6747076f..cf421c30e7c6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -52,7 +52,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("aria", "AriaModel"), ("aria_text", "AriaTextModel"), ("audio-spectrogram-transformer", "ASTModel"), - ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"), + ("audioflamingo3", "AudioFlamingo3Model"), ("audioflamingo3_encoder", "AudioFlamingo3Encoder"), ("autoformer", "AutoformerModel"), ("aya_vision", "AyaVisionModel"), @@ -201,7 +201,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("glm_ocr", "GlmOcrModel"), ("glm_ocr_text", "GlmOcrTextModel"), ("glm_ocr_vision", "GlmOcrVisionModel"), - ("glmasr", "GlmAsrForConditionalGeneration"), + ("glmasr", "GlmAsrModel"), ("glmasr_encoder", "GlmAsrEncoder"), ("glpn", "GLPNModel"), ("got_ocr2", "GotOcr2Model"), @@ -215,7 +215,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gptj", "GPTJModel"), ("granite", "GraniteModel"), ("granite4_vision", "Granite4VisionModel"), - ("granite_speech", "GraniteSpeechForConditionalGeneration"), + ("granite_speech", "GraniteSpeechModel"), + ("granite_speech_plus", "GraniteSpeechPlusModel"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), ("granitemoeshared", "GraniteMoeSharedModel"), @@ -320,7 +321,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mpt", "MptModel"), ("mra", "MraModel"), ("mt5", "MT5Model"), - ("musicflamingo", "MusicFlamingoForConditionalGeneration"), + ("musicflamingo", "MusicFlamingoModel"), ("musicgen", "MusicgenModel"), ("musicgen_melody", "MusicgenMelodyModel"), ("mvp", "MvpModel"), @@ -384,6 +385,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen2", "Qwen2Model"), ("qwen2_5_vl", "Qwen2_5_VLModel"), ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"), + ("qwen2_audio", "Qwen2AudioModel"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("qwen2_vl", "Qwen2VLModel"), @@ -479,7 +481,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerModel"), ("vibevoice_acoustic_tokenizer_decoder", "VibeVoiceAcousticTokenizerDecoderModel"), ("vibevoice_acoustic_tokenizer_encoder", "VibeVoiceAcousticTokenizerEncoderModel"), - ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"), + ("vibevoice_asr", "VibeVoiceAsrModel"), ("video_llama_3", "VideoLlama3Model"), ("video_llama_3_vision", "VideoLlama3VisionModel"), ("video_llava", "VideoLlavaModel"), @@ -495,9 +497,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("vits", "VitsModel"), ("vivit", "VivitModel"), ("vjepa2", "VJEPA2Model"), - ("voxtral", "VoxtralForConditionalGeneration"), + ("voxtral", "VoxtralModel"), ("voxtral_encoder", "VoxtralEncoder"), - ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"), + ("voxtral_realtime", "VoxtralRealtimeModel"), ("voxtral_realtime_encoder", "VoxtralRealtimeEncoder"), ("voxtral_realtime_text", "VoxtralRealtimeTextModel"), ("wav2vec2", "Wav2Vec2Model"), diff --git a/src/transformers/models/glmasr/configuration_glmasr.py b/src/transformers/models/glmasr/configuration_glmasr.py index c3d320bb1db4..c89379ead3e7 100644 --- a/src/transformers/models/glmasr/configuration_glmasr.py +++ b/src/transformers/models/glmasr/configuration_glmasr.py @@ -101,6 +101,7 @@ class GlmAsrConfig(PreTrainedConfig): text_config: dict | PreTrainedConfig | None = None audio_token_id: int = 59260 projector_hidden_act: str = "gelu" + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index df70cfb3884d..29b82adfc52c 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -19,6 +19,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from typing import Optional from ...activations import ACT2FN @@ -26,14 +27,19 @@ from ...generation import GenerationMixin from ...integrations import use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, + ModelOutput, +) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available, torch_compilable_check from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_glmasr import GlmAsrConfig, GlmAsrEncoderConfig @@ -284,6 +290,7 @@ class GlmAsrPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True # TODO: @eustlb, this is what WhisperEncoder should look like @@ -349,40 +356,34 @@ def forward(self, audio_features): return hidden_states +@dataclass +class GlmAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _supports_attention_backend = True +class GlmAsrModel(GlmAsrPreTrainedModel): _tp_plan = None _sp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = GlmAsrMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector." @@ -395,11 +396,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -445,6 +442,99 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GlmAsrModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + return GlmAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for GlmAsr causal language model (or autoregressive) outputs. + """ +) +class GlmAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = GlmAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -489,30 +579,36 @@ def forward( >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) >>> print(decoded_outputs) ```""" - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: CausalLMOutputWithPast = self.language_model( - inputs_embeds=inputs_embeds, + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, + inputs_embeds=inputs_embeds, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return GlmAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -529,4 +625,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrPreTrainedModel"] +__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrModel", "GlmAsrPreTrainedModel"] diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index f46c64224b15..22820df66947 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -29,6 +29,7 @@ from ...utils.output_capturing import capture_outputs from ..audioflamingo3.modeling_audioflamingo3 import ( AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Model, AudioFlamingo3MultiModalProjector, AudioFlamingo3PreTrainedModel, ) @@ -342,9 +343,7 @@ def __init__(self, config: GlmAsrConfig): The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): - _supports_attention_backend = True - +class GlmAsrModel(AudioFlamingo3Model): @can_return_tuple @auto_docstring( custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector." @@ -373,6 +372,18 @@ def get_audio_features( return audio_outputs + +@auto_docstring( + custom_intro=""" + The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + self.model = GlmAsrModel(config) + self.post_init() + def forward( self, input_ids: torch.LongTensor | None = None, @@ -428,4 +439,10 @@ def forward( ) -__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrProcessor", "GlmAsrPreTrainedModel"] +__all__ = [ + "GlmAsrEncoder", + "GlmAsrForConditionalGeneration", + "GlmAsrModel", + "GlmAsrProcessor", + "GlmAsrPreTrainedModel", +] diff --git a/src/transformers/models/granite_speech/configuration_granite_speech.py b/src/transformers/models/granite_speech/configuration_granite_speech.py index e5532b3bf880..bf5eb93235d1 100644 --- a/src/transformers/models/granite_speech/configuration_granite_speech.py +++ b/src/transformers/models/granite_speech/configuration_granite_speech.py @@ -126,6 +126,7 @@ class GraniteSpeechConfig(PreTrainedConfig): has_lora_adapter: bool = True downsample_rate: int = 5 window_size: int = 15 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.text_config, dict): diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 898e51a306b4..a9db60f3427c 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -22,7 +22,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -35,16 +35,31 @@ ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_granite_speech import GraniteSpeechConfig, GraniteSpeechEncoderConfig logger = logging.get_logger(__name__) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Granite Speech outputs, with hidden states and attentions. + """ +) +class GraniteSpeechModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" - Base class for LlavaNext causal language model (or autoregressive) outputs. + Base class for Granite Speech causal language model (or autoregressive) outputs. """ ) @dataclass @@ -59,6 +74,8 @@ class GraniteSpeechCausalLMOutputWithPast(ModelOutput): Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. """ loss: torch.FloatTensor | None = None @@ -66,6 +83,7 @@ class GraniteSpeechCausalLMOutputWithPast(ModelOutput): past_key_values: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None ### Projector @@ -262,9 +280,11 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> class GraniteSpeechPreTrainedModel(PreTrainedModel): config: GraniteSpeechConfig input_modalities = ("audio", "text") + base_model_prefix = "model" _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this _supports_sdpa = True + _supports_attention_backend = True @torch.no_grad() def _init_weights(self, module: nn.Module): @@ -323,45 +343,18 @@ def forward( @auto_docstring( custom_intro=""" - The Granite Speech model, which consists of an audio encoder, projector, and language model. + The Granite Speech model, which consists of an audio encoder, projector, and language model, + without a language modeling head. """ ) -class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): - _supports_attention_backend = True - +class GraniteSpeechModel(GraniteSpeechPreTrainedModel): def __init__(self, config: GraniteSpeechConfig): super().__init__(config) - # NOTE: It doesn't matter when we initialize from config, but we should be careful - # to make sure this does not pick up the adapter_config if in the future we use - # from_pretrained or something similar, since that should be set by the composite - # model; don't need to consider it twice - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.encoder = GraniteSpeechCTCEncoder(config.encoder_config) self.projector = GraniteSpeechEncoderProjector(config) - - if config.has_lora_adapter and not is_peft_available(): - logger.warning( - "Config indicates that a lora adapter should be present, but " - "peft is not installed; this will cause the model to perform " - "incorrectly when audio inputs are provided. Please install " - "peft and reload the model!" - ) - + self.language_model = AutoModel.from_config(config.text_config) self.post_init() - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - @can_return_tuple @auto_docstring def get_audio_features( @@ -373,6 +366,61 @@ def get_audio_features( return audio_outputs + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_audio_mask = special_audio_mask.all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + + n_audio_tokens = special_audio_mask.sum() + n_audio_features = audio_features.shape[0] + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + return special_audio_mask + + def get_merged_audio_embeddings( + self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Adds the audio token to the model's LLM vocabulary so that we can pass it + through the tokenizer; it's assumed that the embeddings corresponding to the + <|audio|> token will be clobbered with speech features. + + Args: + input_ids (`torch.Tensor`): + Input IDs containing one or more audio tokens. + audio_features (`torch.Tensor`): + Audio features to be masked into the language embeddings to form multimodal embeddings. + input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) + Mask to be applied to audio features prior to scattering into the language embeddings. + """ + is_audio_index = input_ids == self.config.audio_token_id + llm_input_ids = torch.where(is_audio_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] + + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + if input_features_mask is not None: + audio_features = audio_features[input_features_mask] + + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + return inputs_embeds + + @can_return_tuple @auto_docstring def forward( self, @@ -383,29 +431,13 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **lm_kwargs, - ) -> tuple[torch.Tensor] | GraniteSpeechCausalLMOutputWithPast: + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GraniteSpeechModelOutputWithPast: r""" input_features_mask (`torch.Tensor`, *optional*): Mask to be applied to audio features prior to scattering into the language embeddings. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ - # TODO (@alex-jw-brooks) add an example to this docstring once models are released - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -423,6 +455,7 @@ def forward( llm_input_ids[is_audio_idx] = 0 inputs_embeds = self.get_input_embeddings()(llm_input_ids) + audio_embeds = None if input_features is not None: if input_features.dtype != self.dtype: input_features = input_features.to(self.dtype) @@ -442,13 +475,84 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - logits_to_keep=logits_to_keep, - **lm_kwargs, + **kwargs, + ) + + return GraniteSpeechModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Granite Speech model, which consists of an audio encoder, projector, and language model. + """ +) +class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechConfig): + super().__init__(config) + self.model = GraniteSpeechModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + if config.has_lora_adapter and not is_peft_available(): + logger.warning( + "Config indicates that a lora adapter should be present, but " + "peft is not installed; this will cause the model to perform " + "incorrectly when audio inputs are provided. Please install " + "peft and reload the model!" + ) + + self.post_init() + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs, + ) -> tuple | GraniteSpeechCausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor`, *optional*): + Mask to be applied to audio features prior to scattering into the language embeddings. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + # TODO (@alex-jw-brooks) add an example to this docstring once models are released + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, ) - logits = outputs[0] + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -468,16 +572,13 @@ def forward( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return GraniteSpeechCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation( @@ -493,7 +594,7 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -510,60 +611,6 @@ def prepare_inputs_for_generation( model_inputs["input_features"] = input_features return model_inputs - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - - def get_merged_audio_embeddings( - self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Adds the audio token to the model's LLM vocabulary so that we can pass it - through the tokenizer; it's assumed that the embeddings corresponding to the - <|audio|> token will be clobbered with speech features. - - Args: - input_ids (`torch.Tensor`): - Input IDs containing one or more audio tokens. - audio_features (`torch.Tensor`): - Audio features to be masked into the language embeddings to form multimodal embeddings. - input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) - Mask to be applied to audio features prior to scattering into the language embeddings. - """ - is_audio_index = input_ids == self.config.audio_token_id - llm_input_ids = torch.where(is_audio_index, 0, input_ids) - inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] - - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - if input_features_mask is not None: - audio_features = audio_features[input_features_mask] - - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) - return inputs_embeds - def generate(self, *args, **kwargs) -> torch.LongTensor: # This model is expected to have a lora adapter, which is only # enabled when considering audio inputs. As such, we override generate @@ -597,5 +644,6 @@ def _get_adapter_name(self): __all__ = [ "GraniteSpeechCTCEncoder", "GraniteSpeechForConditionalGeneration", + "GraniteSpeechModel", "GraniteSpeechPreTrainedModel", ] diff --git a/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py index 1eec538091a4..f7d47a976e79 100644 --- a/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py @@ -137,6 +137,7 @@ class GraniteSpeechPlusConfig(PreTrainedConfig): has_lora_adapter: bool = True downsample_rate: int = 5 window_size: int = 15 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.text_config, dict): diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py index b625bc298ee1..c936e7fcf0c6 100644 --- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py @@ -28,7 +28,7 @@ from ... import initialization as init from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -41,7 +41,7 @@ ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_granite_speech_plus import GraniteSpeechPlusConfig, GraniteSpeechPlusEncoderConfig @@ -87,9 +87,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GraniteSpeechPlusPreTrainedModel(PreTrainedModel): config: GraniteSpeechPlusConfig input_modalities = ("audio", "text") + base_model_prefix = "model" _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this _supports_sdpa = True + _supports_attention_backend = True @torch.no_grad() def _init_weights(self, module: nn.Module): @@ -105,6 +107,167 @@ def _init_weights(self, module: nn.Module): init.copy_(module.attention_dists, attention_dists) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Granite Speech outputs, with hidden states and attentions. + """ +) +class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Granite Speech model, which consists of an audio encoder, projector, and language model, + without a language modeling head. + """ +) +class GraniteSpeechPlusModel(GraniteSpeechPlusPreTrainedModel): + def __init__(self, config: GraniteSpeechPlusConfig): + super().__init__(config) + self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechPlusEncoderProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + @can_return_tuple + @auto_docstring + def get_audio_features( + self, input_features: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + ) -> tuple | BaseModelOutputWithPooling: + audio_outputs = self.encoder(input_features, return_dict=True, **kwargs) + projected_embeds = self.projector(audio_outputs.last_hidden_state) + audio_outputs.pooler_output = projected_embeds + + return audio_outputs + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_audio_mask = special_audio_mask.all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + + n_audio_tokens = special_audio_mask.sum() + n_audio_features = audio_features.shape[0] + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + return special_audio_mask + + def get_merged_audio_embeddings( + self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Adds the audio token to the model's LLM vocabulary so that we can pass it + through the tokenizer; it's assumed that the embeddings corresponding to the + <|audio|> token will be clobbered with speech features. + + Args: + input_ids (`torch.Tensor`): + Input IDs containing one or more audio tokens. + audio_features (`torch.Tensor`): + Audio features to be masked into the language embeddings to form multimodal embeddings. + input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) + Mask to be applied to audio features prior to scattering into the language embeddings. + """ + is_audio_index = input_ids == self.config.audio_token_id + llm_input_ids = torch.where(is_audio_index, 0, input_ids) + inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] + + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + if input_features_mask is not None: + audio_features = audio_features[input_features_mask] + + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + return inputs_embeds + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GraniteSpeechPlusModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor`, *optional*): + Mask to be applied to audio features prior to scattering into the language embeddings. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if input_features is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_features and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + # Get the base embeddings; set all audio tokens to 0 index + # to avoid out of vocabulary issues with the LLM embedding. + # Audio features will be masked into is_audio_idx indices later. + is_audio_idx = input_ids == self.config.audio_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[is_audio_idx] = 0 + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + audio_embeds = None + if input_features is not None: + if input_features.dtype != self.dtype: + input_features = input_features.to(self.dtype) + # Get the audio features from the encoder / projector + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # Merge the audio features into the LLM embeddings + inputs_embeds = self.get_merged_audio_embeddings( + input_ids=input_ids, + audio_features=audio_embeds, + input_features_mask=input_features_mask, + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return GraniteSpeechPlusModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + ### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git class GraniteSpeechPlusConformerFeedForward(nn.Module): """Feedforward module for conformer encoder blocks.""" @@ -317,7 +480,7 @@ def forward( @auto_docstring( custom_intro=""" - Base class for LlavaNext causal language model (or autoregressive) outputs. + Base class for Granite Speech causal language model (or autoregressive) outputs. """ ) @dataclass @@ -332,6 +495,8 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. """ loss: torch.FloatTensor | None = None @@ -339,6 +504,7 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): past_key_values: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None @auto_docstring( @@ -348,18 +514,12 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput): """ ) class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechPlusPreTrainedModel, GenerationMixin): - _supports_attention_backend = True + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: GraniteSpeechPlusConfig): super().__init__(config) - # NOTE: It doesn't matter when we initialize from config, but we should be careful - # to make sure this does not pick up the adapter_config if in the future we use - # from_pretrained or something similar, since that should be set by the composite - # model; don't need to consider it twice - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config) - self.projector = GraniteSpeechPlusEncoderProjector(config) + self.model = GraniteSpeechPlusModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) if config.has_lora_adapter and not is_peft_available(): logger.warning( @@ -371,29 +531,10 @@ def __init__(self, config: GraniteSpeechPlusConfig): self.post_init() - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) @can_return_tuple - @auto_docstring - def get_audio_features( - self, input_features: torch.Tensor, **kwargs: Unpack[TransformersKwargs] - ) -> tuple | BaseModelOutputWithPooling: - audio_outputs = self.encoder(input_features, return_dict=True, **kwargs) - projected_embeds = self.projector(audio_outputs.last_hidden_state) - audio_outputs.pooler_output = projected_embeds - - return audio_outputs - @auto_docstring def forward( self, @@ -406,12 +547,9 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, logits_to_keep: int | torch.Tensor = 0, - **lm_kwargs, - ) -> tuple[torch.Tensor] | GraniteSpeechPlusCausalLMOutputWithPast: + **kwargs, + ) -> tuple | GraniteSpeechPlusCausalLMOutputWithPast: r""" input_features_mask (`torch.Tensor`, *optional*): Mask to be applied to audio features prior to scattering into the language embeddings. @@ -421,55 +559,21 @@ def forward( (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ # TODO (@alex-jw-brooks) add an example to this docstring once models are released - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if input_features is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_features and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - # Get the base embeddings; set all audio tokens to 0 index - # to avoid out of vocabulary issues with the LLM embedding. - # Audio features will be masked into is_audio_idx indices later. - is_audio_idx = input_ids == self.config.audio_token_id - llm_input_ids = input_ids.clone() - llm_input_ids[is_audio_idx] = 0 - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if input_features is not None: - if input_features.dtype != self.dtype: - input_features = input_features.to(self.dtype) - # Get the audio features from the encoder / projector - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # Merge the audio features into the LLM embeddings - inputs_embeds = self.get_merged_audio_embeddings( - input_ids=input_ids, - audio_features=audio_embeds, - input_features_mask=input_features_mask, - ) - - outputs = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - logits_to_keep=logits_to_keep, - **lm_kwargs, + **kwargs, ) - logits = outputs[0] + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -489,16 +593,13 @@ def forward( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return GraniteSpeechPlusCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation( @@ -514,7 +615,7 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward audio inputs to the model - model_inputs = self.language_model.prepare_inputs_for_generation( + model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -531,60 +632,6 @@ def prepare_inputs_for_generation( model_inputs["input_features"] = input_features return model_inputs - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - - def get_merged_audio_embeddings( - self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None - ) -> torch.Tensor: - """ - Adds the audio token to the model's LLM vocabulary so that we can pass it - through the tokenizer; it's assumed that the embeddings corresponding to the - <|audio|> token will be clobbered with speech features. - - Args: - input_ids (`torch.Tensor`): - Input IDs containing one or more audio tokens. - audio_features (`torch.Tensor`): - Audio features to be masked into the language embeddings to form multimodal embeddings. - input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) - Mask to be applied to audio features prior to scattering into the language embeddings. - """ - is_audio_index = input_ids == self.config.audio_token_id - llm_input_ids = torch.where(is_audio_index, 0, input_ids) - inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] - - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - if input_features_mask is not None: - audio_features = audio_features[input_features_mask] - - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) - return inputs_embeds - def generate(self, *args, **kwargs) -> torch.LongTensor: # This model is expected to have a lora adapter, which is only # enabled when considering audio inputs. As such, we override generate @@ -616,6 +663,7 @@ def _get_adapter_name(self): __all__ = [ + "GraniteSpeechPlusModel", "GraniteSpeechPlusCTCEncoder", "GraniteSpeechPlusForConditionalGeneration", "GraniteSpeechPlusPreTrainedModel", diff --git a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py index aa4f966246d2..f5865379fba1 100644 --- a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py +++ b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py @@ -27,6 +27,7 @@ from ..granite_speech.modeling_granite_speech import ( GraniteSpeechCTCEncoder, GraniteSpeechForConditionalGeneration, + GraniteSpeechModel, GraniteSpeechPreTrainedModel, ) @@ -122,6 +123,9 @@ def __post_init__(self, **kwargs): class GraniteSpeechPlusPreTrainedModel(GraniteSpeechPreTrainedModel): ... +class GraniteSpeechPlusModel(GraniteSpeechModel): ... + + class GraniteSpeechPlusCTCEncoder(GraniteSpeechCTCEncoder): @merge_with_config_defaults @capture_outputs @@ -166,6 +170,7 @@ class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechForConditionalGener __all__ = [ "GraniteSpeechPlusConfig", "GraniteSpeechPlusEncoderConfig", + "GraniteSpeechPlusModel", "GraniteSpeechPlusCTCEncoder", "GraniteSpeechPlusForConditionalGeneration", "GraniteSpeechPlusPreTrainedModel", diff --git a/src/transformers/models/musicflamingo/configuration_musicflamingo.py b/src/transformers/models/musicflamingo/configuration_musicflamingo.py index 825039a0eb4e..5a20c1286d32 100644 --- a/src/transformers/models/musicflamingo/configuration_musicflamingo.py +++ b/src/transformers/models/musicflamingo/configuration_musicflamingo.py @@ -66,6 +66,7 @@ class MusicFlamingoConfig(PreTrainedConfig): audio_token_id: int = 151669 projector_hidden_act: str = "gelu" projector_bias: bool = True + tie_word_embeddings: bool = False audio_bos_token_id: int = 151670 audio_eos_token_id: int = 151671 diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index 45e3d9e707d8..f59f3422887d 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -20,6 +20,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from math import pi from typing import Optional @@ -29,12 +30,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_musicflamingo import MusicFlamingoConfig @@ -141,6 +142,7 @@ class MusicFlamingoPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True @torch.no_grad() def _init_weights(self, module): @@ -150,6 +152,16 @@ def _init_weights(self, module): init.copy_(module.position_angles, buffer_value) +@dataclass +class MusicFlamingoModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + class MusicFlamingoMultiModalProjector(nn.Module): """ Audio adaptor (small MLP) that projects MusicFlamingoEncoder features @@ -195,39 +207,24 @@ def apply_rotary_time_emb(hidden_states, cos, sin): @auto_docstring( custom_intro=""" - The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. + The MusicFlamingo model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _supports_attention_backend = True +class MusicFlamingoModel(MusicFlamingoPreTrainedModel): _tp_plan = None _sp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config: MusicFlamingoConfig): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = MusicFlamingoMultiModalProjector(config) self.pos_emb = MusicFlamingoRotaryEmbedding(config) - - # Initialize weights and apply final processing self.post_init() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -299,65 +296,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | MusicFlamingoModelOutputWithPast: r""" - input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/music-flamingo-2601-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversation = [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", - >>> }, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", - >>> }, - >>> ], - >>> } - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversation, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device, model.dtype) - - >>> outputs = model.generate(**inputs, max_new_tokens=100) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["This track is an uplifting Eurodance-style Trance-Pop anthem..."] - ```""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features( input_features, input_features_mask, input_ids=input_ids, return_dict=True @@ -369,31 +318,22 @@ def forward( ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs - def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): - input_features = kwargs.pop("input_features", None) - input_features_mask = kwargs.pop("input_features_mask", None) - - model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) - - if is_first_iteration or not model_inputs.get("use_cache", False): - if input_features is not None: - model_inputs["input_features"] = input_features - if input_features_mask is not None: - model_inputs["input_features_mask"] = input_features_mask - - return model_inputs + return MusicFlamingoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) def _build_audio_timestamps( self, @@ -437,4 +377,125 @@ def _build_audio_timestamps( return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets -__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoPreTrainedModel"] +@dataclass +@auto_docstring( + custom_intro=""" + Base class for MusicFlamingo causal language model (or autoregressive) outputs. + """ +) +class MusicFlamingoCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. + """ +) +class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: MusicFlamingoConfig): + super().__init__(config) + self.model = MusicFlamingoModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, input_ids, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, input_ids, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MusicFlamingoCausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return MusicFlamingoCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) + + def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): + input_features = kwargs.pop("input_features", None) + input_features_mask = kwargs.pop("input_features_mask", None) + + model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + if is_first_iteration or not model_inputs.get("use_cache", False): + if input_features is not None: + model_inputs["input_features"] = input_features + if input_features_mask is not None: + model_inputs["input_features_mask"] = input_features_mask + + return model_inputs + + +__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoModel", "MusicFlamingoPreTrainedModel"] diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py index 8ca8fcb2462a..325845ad4da9 100644 --- a/src/transformers/models/musicflamingo/modular_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from math import pi from huggingface_hub.dataclasses import strict @@ -21,13 +22,15 @@ from ... import initialization as init from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check from ..audioflamingo3.configuration_audioflamingo3 import AudioFlamingo3Config from ..audioflamingo3.modeling_audioflamingo3 import ( AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Model, + AudioFlamingo3ModelOutputWithPast, AudioFlamingo3PreTrainedModel, ) from ..audioflamingo3.processing_audioflamingo3 import AudioFlamingo3Processor @@ -75,6 +78,7 @@ class MusicFlamingoConfig(AudioFlamingo3Config): audio_eos_token_id: int = 151671 audio_frame_step: float = 0.01 rope_parameters: dict | None = None + tie_word_embeddings: bool = False def __post_init__(self, **kwargs): if self.rope_parameters is None: @@ -231,12 +235,12 @@ def _init_weights(self, module): init.copy_(module.position_angles, buffer_value) -@auto_docstring( - custom_intro=""" - The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. - """ -) -class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): +@dataclass +class MusicFlamingoModelOutputWithPast(AudioFlamingo3ModelOutputWithPast): + pass + + +class MusicFlamingoModel(AudioFlamingo3Model): def __init__(self, config: MusicFlamingoConfig): super().__init__(config) self.pos_emb = MusicFlamingoRotaryEmbedding(config) @@ -329,65 +333,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ): r""" - input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/music-flamingo-2601-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversation = [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", - >>> }, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", - >>> }, - >>> ], - >>> } - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversation, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device, model.dtype) - - >>> outputs = model.generate(**inputs, max_new_tokens=100) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["This track is an uplifting Eurodance-style Trance-Pop anthem..."] - ```""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features( input_features, input_features_mask, input_ids=input_ids, return_dict=True @@ -399,22 +355,43 @@ def forward( ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return MusicFlamingoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model. + """ +) +class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): + def __init__(self, config: MusicFlamingoConfig): + super().__init__(config) + self.model = MusicFlamingoModel(config) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, input_ids, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, input_ids, **kwargs) __all__ = [ "MusicFlamingoConfig", "MusicFlamingoProcessor", "MusicFlamingoForConditionalGeneration", + "MusicFlamingoModel", "MusicFlamingoPreTrainedModel", ] diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py index 10563490345d..d7a4fb4ed457 100644 --- a/src/transformers/models/parakeet/processing_parakeet.py +++ b/src/transformers/models/parakeet/processing_parakeet.py @@ -44,7 +44,7 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False): @auto_docstring class ParakeetProcessor(ProcessorMixin): def __init__(self, feature_extractor, tokenizer, blank_token=""): - """ + r""" blank_token (`str`, *optional*, defaults to `""`): Blank token for TDT decoding. """ diff --git a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py index 6aec9eace900..749f24123bb4 100644 --- a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py @@ -98,6 +98,7 @@ class Qwen2AudioConfig(PreTrainedConfig): audio_config: dict | PreTrainedConfig | None = None text_config: dict | PreTrainedConfig | None = None audio_token_index: int = 151646 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 170187291ff5..841c41993e81 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -25,19 +25,44 @@ from ...generation import GenerationMixin from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging, torch_compilable_check from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig logger = logging.get_logger(__name__) +@auto_docstring( + custom_intro=""" + Base class for Qwen2Audio outputs, with hidden states and attentions. + """ +) +@dataclass +class Qwen2AudioModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask, potentially updated by the audio merging logic so that audio tokens are unmasked. + labels (`torch.LongTensor`, *optional*): + Labels, potentially re-aligned by the legacy audio merging logic. Returned so the language-modeling + head can compute the loss against the expanded sequence. + """ + + attention_mask: torch.FloatTensor | None = None + labels: torch.LongTensor | None = None + + +@dataclass @auto_docstring( custom_intro=""" Base class for Qwen2Audio causal language model (or autoregressive) outputs. @@ -394,17 +419,17 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The QWEN2AUDIO model which consists of a audio backbone and a language model. + The Qwen2Audio model which consists of an audio backbone and a language model, without a language modeling head. """ ) -class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): +class Qwen2AudioModel(Qwen2AudioPreTrainedModel): def __init__(self, config: Qwen2AudioConfig): super().__init__(config) self.audio_tower = AutoModel.from_config(config.audio_config) # Usually a `Qwen2AudioEncoder` instance self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = ( self.config.text_config.pad_token_id if self.config.text_config.pad_token_id is not None else -1 @@ -422,18 +447,6 @@ def padding_side(self, padding_side: str): raise ValueError(f"{padding_side} is not `left` or `right`.") self._padding_side = padding_side - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def _merge_input_ids_with_audio_features( self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels ): @@ -645,7 +658,7 @@ def forward( labels: torch.LongTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple | Qwen2AudioCausalLMOutputWithPast: + ) -> tuple | Qwen2AudioModelOutputWithPast: r""" feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: @@ -653,32 +666,9 @@ def forward( - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from io import BytesIO - >>> from urllib.request import urlopen - >>> import librosa - >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration - - >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") - - >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" - >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" - >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) - - >>> inputs = processor(text=prompt, audio=audio, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Generate the caption in English: Glass is breaking." - ```""" + Labels kept in the signature for the legacy merge path that may re-align them with audio tokens. + The loss is not computed here; `Qwen2AudioForConditionalGeneration` is responsible for that. + """ target_device = self.audio_tower.device @@ -761,7 +751,103 @@ def forward( **kwargs, ) - logits = outputs.logits + return Qwen2AudioModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + attention_mask=attention_mask, + labels=labels, + ) + + +@auto_docstring( + custom_intro=""" + The QWEN2AUDIO model which consists of an audio backbone and a language model. + """ +) +class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: Qwen2AudioConfig): + super().__init__(config) + self.model = Qwen2AudioModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + @property + def padding_side(self): + return self.model.padding_side + + @padding_side.setter + def padding_side(self, padding_side: str): + self.model.padding_side = padding_side + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + feature_attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen2AudioCausalLMOutputWithPast: + r""" + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration + + >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") + + >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" + >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) + + >>> inputs = processor(text=prompt, audio=audio, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Generate the caption in English: Glass is breaking." + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + feature_attention_mask=feature_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + attention_mask = outputs.attention_mask + labels = outputs.labels if outputs.labels is not None else labels loss = None if labels is not None: @@ -803,4 +889,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"] +__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder", "Qwen2AudioModel"] diff --git a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py index a673a5845871..11856b20b851 100644 --- a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py @@ -75,6 +75,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig): audio_bos_token_id: int = 151646 audio_eos_token_id: int = 151647 acoustic_tokenizer_chunk_size: int = 1440000 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.acoustic_tokenizer_encoder_config, dict): diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index 4e60c72ccd40..b7d2902aa6c1 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -17,6 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn @@ -25,17 +27,11 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - torch_compilable_check, -) -from ..auto import AutoModel, AutoModelForCausalLM +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ..auto import AutoModel from .configuration_vibevoice_asr import VibeVoiceAsrConfig @@ -247,6 +243,7 @@ class VibeVoiceAsrPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True def _init_weights(self, module): super()._init_weights(module) @@ -255,40 +252,69 @@ def _init_weights(self, module): init.constant_(module.ffn_gamma, self.config.layer_scale_init_value) +@dataclass @auto_docstring( custom_intro=""" - The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + Base class for VibeVoice ASR outputs, with hidden states and attentions. """ ) -class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - _supports_attention_backend = True - _tp_plan = None - _sp_plan = None - _pp_plan = None +class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VibeVoice ASR causal language model outputs. + """ +) +class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores. + past_key_values (`Cache`, *optional*): + Cache instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model), + without a language modeling head. + """ +) +class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config) - - # Initialize weights and apply final processing + self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) self.post_init() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() - def get_decoder(self): - return self.language_model.get_decoder() + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) @can_return_tuple @auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.") @@ -298,17 +324,15 @@ def get_audio_features( padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: + ): r""" input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`): - Input audio tensor. Audio should be sampled at 24kHz. + Input audio tensor sampled at 24kHz. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`, - but can be modified to fit the available memory. + Size of audio chunks to process at once through the tokenizers. """ - if acoustic_tokenizer_chunk_size is None: acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size else: @@ -353,7 +377,6 @@ def get_audio_features( combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents) if padding_mask is not None: - # Adjust padding mask according to tokenizer compression num_audio_tokens = torch.ceil( padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length ).to(torch.int64) @@ -364,29 +387,73 @@ def get_audio_features( return BaseModelOutputWithPooling(last_hidden_state=acoustic_latents, pooler_output=combined_features) - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + input_values: torch.FloatTensor | None = None, + padding_mask: torch.BoolTensor | None = None, + acoustic_tokenizer_chunk_size: int | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VibeVoiceAsrModelOutputWithPast: + r""" + padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing operations on padding feature indices. + acoustic_tokenizer_chunk_size (`int`, *optional*): + Size of audio chunks processed by the acoustic and semantic tokenizers. """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_values is not None and input_ids is not None: + audio_embeds = self.get_audio_features( + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + ).pooler_output + + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + **kwargs, ) - return special_audio_mask + + return VibeVoiceAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + """ +) +class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: VibeVoiceAsrConfig): + super().__init__(config) + self.model = VibeVoiceAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) @can_return_tuple @auto_docstring @@ -399,14 +466,15 @@ def forward( input_values: torch.FloatTensor | None = None, padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast: r""" padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to - `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory. + Size of audio chunks processed by the acoustic and semantic tokenizers. Example: @@ -416,33 +484,35 @@ def forward( >>> model_id = "microsoft/VibeVoice-ASR-HF" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto") - - >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> outputs = model.generate(**inputs) - - >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) - >>> print(decoded_outputs) ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + **kwargs, + ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_values is not None and input_ids is not None: - audio_embeds = self.get_audio_features( - input_values=input_values, - padding_mask=padding_mask, - acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, - ).pooler_output + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) - # Replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) - return self.language_model( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs + return VibeVoiceAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs): @@ -463,4 +533,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg return model_inputs -__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrPreTrainedModel"] +__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrModel", "VibeVoiceAsrPreTrainedModel"] diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index 52652de9adde..c664823c039a 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -11,16 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from huggingface_hub.dataclasses import strict from torch import nn from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + ModelOutput, +) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ..audioflamingo3.modeling_audioflamingo3 import AudioFlamingo3ForConditionalGeneration from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel from ..qwen2.modeling_qwen2 import Qwen2RMSNorm from ..vibevoice_acoustic_tokenizer.modeling_vibevoice_acoustic_tokenizer import ( @@ -82,6 +88,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig): audio_bos_token_id: int = 151646 audio_eos_token_id: int = 151647 acoustic_tokenizer_chunk_size: int = 1440000 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.acoustic_tokenizer_encoder_config, dict): @@ -159,21 +166,72 @@ class VibeVoiceAsrPreTrainedModel(VibeVoiceAcousticTokenizerPreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True + _supports_attention_backend = True +@dataclass @auto_docstring( custom_intro=""" - The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + Base class for VibeVoice ASR outputs, with hidden states and attentions. """ ) -class VibeVoiceAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): - _supports_attention_backend = True +class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VibeVoice ASR causal language model outputs. + """ +) +class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores. + past_key_values (`Cache`, *optional*): + Cache instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model), + without a language modeling head. + """ +) +class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config) - del self.audio_tower + self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) @can_return_tuple @auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.") @@ -186,14 +244,12 @@ def get_audio_features( ): r""" input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`): - Input audio tensor. Audio should be sampled at 24kHz. + Input audio tensor sampled at 24kHz. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`, - but can be modified to fit the available memory. + Size of audio chunks to process at once through the tokenizers. """ - if acoustic_tokenizer_chunk_size is None: acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size else: @@ -238,7 +294,6 @@ def get_audio_features( combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents) if padding_mask is not None: - # Adjust padding mask according to tokenizer compression num_audio_tokens = torch.ceil( padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length ).to(torch.int64) @@ -261,34 +316,17 @@ def forward( padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VibeVoiceAsrModelOutputWithPast: r""" padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to - `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory. - - Example: - - ```python - >>> from transformers import VibeVoiceAsrForConditionalGeneration, AutoProcessor - - >>> model_id = "microsoft/VibeVoice-ASR-HF" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto") - - >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> outputs = model.generate(**inputs) - - >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) - >>> print(decoded_outputs) - ```""" - + Size of audio chunks processed by the acoustic and semantic tokenizers. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_values is not None and input_ids is not None: audio_embeds = self.get_audio_features( input_values=input_values, @@ -296,14 +334,102 @@ def forward( acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, ).pooler_output - # Replace text-audio token placeholders with audio embeddings audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) inputs_embeds = inputs_embeds.masked_scatter( audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) - return self.language_model( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + + return VibeVoiceAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + """ +) +class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: VibeVoiceAsrConfig): + super().__init__(config) + self.model = VibeVoiceAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + input_values: torch.FloatTensor | None = None, + padding_mask: torch.BoolTensor | None = None, + acoustic_tokenizer_chunk_size: int | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast: + r""" + padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing operations on padding feature indices. + acoustic_tokenizer_chunk_size (`int`, *optional*): + Size of audio chunks processed by the acoustic and semantic tokenizers. + + Example: + + ```python + >>> from transformers import VibeVoiceAsrForConditionalGeneration, AutoProcessor + + >>> model_id = "microsoft/VibeVoice-ASR-HF" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return VibeVoiceAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs): @@ -327,5 +453,6 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg __all__ = [ "VibeVoiceAsrConfig", "VibeVoiceAsrForConditionalGeneration", + "VibeVoiceAsrModel", "VibeVoiceAsrPreTrainedModel", ] diff --git a/src/transformers/models/voxtral/configuration_voxtral.py b/src/transformers/models/voxtral/configuration_voxtral.py index 2ecbedfc1a9e..b476d80dd976 100644 --- a/src/transformers/models/voxtral/configuration_voxtral.py +++ b/src/transformers/models/voxtral/configuration_voxtral.py @@ -110,6 +110,7 @@ class VoxtralConfig(PreTrainedConfig): text_config: dict | PreTrainedConfig | None = None audio_token_id: int | None = None projector_hidden_act: str = "gelu" + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index f6aa02fd6c00..9fedfdd1345a 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -21,6 +21,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass import torch from torch import nn @@ -35,7 +36,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig @@ -359,36 +360,35 @@ def forward(self, audio_features): return hidden_states +@dataclass @auto_docstring( custom_intro=""" - The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + Base class for Voxtral outputs, with hidden states and attentions. """ ) -class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = ["embed_positions"] +class VoxtralModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model, + without a language modeling head. + """ +) +class VoxtralModel(VoxtralPreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = VoxtralMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -436,6 +436,68 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs: BaseModelOutputWithPast = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return VoxtralModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -450,7 +512,7 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | CausalLMOutputWithPast: r""" Example: @@ -484,29 +546,34 @@ def forward( >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: BaseModelOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) def prepare_inputs_for_generation(self, *args, **kwargs): # Overwritten -- we should not pass input_features when we are in cached decoding stage @@ -523,4 +590,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"] +__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"] diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index c1089db09944..01abc9052d9f 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -13,6 +13,8 @@ # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn @@ -28,7 +30,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from ..qwen2_audio.modeling_qwen2_audio import ( Qwen2AudioAttention, Qwen2AudioEncoder, @@ -128,36 +130,35 @@ def forward(self, audio_features): return hidden_states +@dataclass @auto_docstring( custom_intro=""" - The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + Base class for Voxtral outputs, with hidden states and attentions. """ ) -class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = ["embed_positions"] +class VoxtralModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model, + without a language modeling head. + """ +) +class VoxtralModel(VoxtralPreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = VoxtralMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -205,6 +206,68 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs: BaseModelOutputWithPast = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return VoxtralModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -219,7 +282,7 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | CausalLMOutputWithPast: r""" Example: @@ -253,29 +316,34 @@ def forward( >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: BaseModelOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) def prepare_inputs_for_generation(self, *args, **kwargs): # Overwritten -- we should not pass input_features when we are in cached decoding stage @@ -292,4 +360,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"] +__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"] diff --git a/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py index 568c6f8748b9..77130e54db78 100644 --- a/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py @@ -176,6 +176,7 @@ class VoxtralRealtimeConfig(PreTrainedConfig): audio_length_per_tok: int = 8 default_num_delay_tokens: int = 6 downsample_factor: int = 4 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index d5d2a92f18f7..aebd752d2aa8 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -39,17 +39,9 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - logging, - torch_compilable_check, -) +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel from .configuration_voxtral_realtime import ( VoxtralRealtimeConfig, VoxtralRealtimeEncoderConfig, @@ -125,6 +117,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast): padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None +@dataclass +class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast): + r""" + Args: + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder + that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states before they are added to the text embeddings. + """ + + encoder_past_key_values: Cache | None = None + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @dataclass class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast): r""" @@ -487,6 +497,7 @@ class VoxtralRealtimePreTrainedModel(PreTrainedModel): _supports_attention_backend = True # TODO: @eustlb, this should be enabled soon _can_compile_fullgraph = False + _keep_in_fp32_modules_strict = None @torch.no_grad() def _init_weights(self, module): @@ -827,82 +838,6 @@ def forward( ) -@auto_docstring -class VoxtralRealtimeTextForCausalLM(VoxtralRealtimeTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_allgather"} - _sp_plan = {"lm_head": "colwise_loss_parallel"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - _fsdp_plan = {"lm_head": "keep_full_weight"} - - def __init__(self, config): - super().__init__(config) - self.model = VoxtralRealtimeTextModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM - - >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - class VoxtralRealtimeTimeEmbedding(nn.Module): """Sinusoidal Embedding for encoding time""" @@ -937,34 +872,26 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The VoxtralRealtime model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector, + a Mistral-based language model and a time embedding, without a language modeling head. """ ) -class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - +class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config) + self.audio_tower = VoxtralRealtimeEncoder(config.audio_config) + self.language_model = VoxtralRealtimeTextModel(config.text_config) self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config) self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size) # Initialize weights and apply final processing self.post_init() - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) @can_return_tuple @auto_docstring( @@ -981,11 +908,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): @@ -1010,30 +933,6 @@ def get_audio_features( return audio_outputs - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - @can_return_tuple @auto_docstring def forward( @@ -1047,43 +946,20 @@ def forward( padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, num_delay_tokens: int | torch.Tensor = None, **kwargs: Unpack[TransformersKwargs], - ) -> VoxtralRealtimeCausalLMOutputWithPast: + ) -> tuple | VoxtralRealtimeModelOutputWithPast: r""" encoder_past_key_values (`Cache`, *optional*): - Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. num_delay_tokens (`int` or `torch.Tensor`, *optional*): - Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. - - Example: - - ```python - >>> import torch - >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor - >>> from datasets import load_dataset - - >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" - - >>> processor = AutoProcessor.from_pretrained(repo_id) - >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> audio = ds[0]["audio"]["array"] - - >>> inputs = processor(audio, return_tensors="pt") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - - >>> outputs = model.generate(**inputs) - >>> processor.batch_decode(outputs, skip_special_tokens=True) - ```""" + Number of delay tokens used when preparing inputs. + """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1093,6 +969,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_outputs = None + audio_embeds = None if input_features is not None or encoder_inputs_embeds is not None: audio_outputs = self.get_audio_features( input_features=input_features, @@ -1102,7 +980,8 @@ def forward( use_cache=use_cache, return_dict=True, ) - inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device) + audio_embeds = audio_outputs.pooler_output + inputs_embeds += audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens @@ -1121,25 +1000,124 @@ def forward( t_cond = self.time_embedding(time_tensor) t_cond = t_cond[None, ...] # broadcastable to batch size - outputs: CausalLMOutputWithPast = self.language_model( + outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, t_cond=t_cond, **kwargs, ) + + return VoxtralRealtimeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + encoder_past_key_values=audio_outputs.past_key_values + if (audio_outputs is not None and use_cache) + else None, + padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None, + audio_hidden_states=audio_embeds, + ) + + +class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralRealtimeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + encoder_past_key_values: Cache | None = None, + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + encoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + num_delay_tokens: int | torch.Tensor = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast: + r""" + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. + num_delay_tokens (`int` or `torch.Tensor`, *optional*): + Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. + + Example: + + ```python + >>> import torch + >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset + + >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" + + >>> processor = AutoProcessor.from_pretrained(repo_id) + >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> audio = ds[0]["audio"]["array"] + + >>> inputs = processor(audio, return_tensors="pt") + >>> inputs = inputs.to(model.device, dtype=model.dtype) + + >>> outputs = model.generate(**inputs) + >>> processor.batch_decode(outputs, skip_special_tokens=True) + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + inputs_embeds=inputs_embeds, + encoder_inputs_embeds=encoder_inputs_embeds, + use_cache=use_cache, + num_delay_tokens=num_delay_tokens, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + return VoxtralRealtimeCausalLMOutputWithPast( - loss=outputs.loss, - logits=outputs.logits, + loss=loss, + logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - encoder_past_key_values=audio_outputs.past_key_values if use_cache else None, - padding_cache=audio_outputs.padding_cache if use_cache else None, + encoder_past_key_values=outputs.encoder_past_key_values, + padding_cache=outputs.padding_cache, ) def prepare_inputs_for_generation( @@ -1171,7 +1149,7 @@ def _prepare_model_inputs( input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): - model_kwargs["encoder_inputs_embeds"] = self.audio_tower.embedder(model_kwargs.pop("input_features")) + model_kwargs["encoder_inputs_embeds"] = self.model.audio_tower.embedder(model_kwargs.pop("input_features")) elif isinstance(input_features, GeneratorType): input_features_generator = model_kwargs.pop("input_features") @@ -1335,4 +1313,9 @@ def _prepare_generated_length( return generation_config -__all__ = ["VoxtralRealtimeForConditionalGeneration", "VoxtralRealtimeEncoder", "VoxtralRealtimePreTrainedModel"] +__all__ = [ + "VoxtralRealtimeForConditionalGeneration", + "VoxtralRealtimeEncoder", + "VoxtralRealtimePreTrainedModel", + "VoxtralRealtimeModel", +] diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index edad37679927..81bf2259d726 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -31,13 +31,11 @@ from ...models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, - MistralForCausalLM, MistralMLP, MistralModel, MistralRMSNorm, ) from ...models.voxtral.modeling_voxtral import ( - VoxtralForConditionalGeneration, VoxtralMultiModalProjector, VoxtralPreTrainedModel, ) @@ -116,6 +114,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast): padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None +@dataclass +class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast): + r""" + Args: + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder + that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states before they are added to the text embeddings. + """ + + encoder_past_key_values: Cache | None = None + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @dataclass class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast): r""" @@ -255,6 +271,7 @@ def forward( class VoxtralRealtimePreTrainedModel(VoxtralPreTrainedModel, PreTrainedModel): # TODO: @eustlb, this should be enabled soon _can_compile_fullgraph = False + _keep_in_fp32_modules_strict = None @torch.no_grad() def _init_weights(self, module): @@ -436,66 +453,6 @@ def __init__(self, config): self.rotary_emb = VoxtralRealtimeRotaryEmbedding(config=config) -class VoxtralRealtimeTextForCausalLM(MistralForCausalLM): - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM - - >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - class VoxtralRealtimeTimeEmbedding(nn.Module): """Sinusoidal Embedding for encoding time""" @@ -520,14 +477,29 @@ def __init__(self, config): ) -class VoxtralRealtimeForConditionalGeneration(VoxtralForConditionalGeneration, GenerationMixin): - _keep_in_fp32_modules_strict = None - +@auto_docstring( + custom_intro=""" + The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector, + a Mistral-based language model and a time embedding, without a language modeling head. + """ +) +class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel): def __init__(self, config): super().__init__(config) - self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config) + self.audio_tower = VoxtralRealtimeEncoder(config.audio_config) + self.language_model = VoxtralRealtimeTextModel(config.text_config) + self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config) self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -543,11 +515,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): @@ -585,43 +553,20 @@ def forward( padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, num_delay_tokens: int | torch.Tensor = None, **kwargs: Unpack[TransformersKwargs], - ) -> VoxtralRealtimeCausalLMOutputWithPast: + ) -> tuple | VoxtralRealtimeModelOutputWithPast: r""" encoder_past_key_values (`Cache`, *optional*): - Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. num_delay_tokens (`int` or `torch.Tensor`, *optional*): - Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. - - Example: - - ```python - >>> import torch - >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor - >>> from datasets import load_dataset - - >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" - - >>> processor = AutoProcessor.from_pretrained(repo_id) - >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> audio = ds[0]["audio"]["array"] - - >>> inputs = processor(audio, return_tensors="pt") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - - >>> outputs = model.generate(**inputs) - >>> processor.batch_decode(outputs, skip_special_tokens=True) - ```""" + Number of delay tokens used when preparing inputs. + """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -631,6 +576,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_outputs = None + audio_embeds = None if input_features is not None or encoder_inputs_embeds is not None: audio_outputs = self.get_audio_features( input_features=input_features, @@ -640,7 +587,8 @@ def forward( use_cache=use_cache, return_dict=True, ) - inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device) + audio_embeds = audio_outputs.pooler_output + inputs_embeds += audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens @@ -659,25 +607,124 @@ def forward( t_cond = self.time_embedding(time_tensor) t_cond = t_cond[None, ...] # broadcastable to batch size - outputs: CausalLMOutputWithPast = self.language_model( + outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, t_cond=t_cond, **kwargs, ) + + return VoxtralRealtimeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + encoder_past_key_values=audio_outputs.past_key_values + if (audio_outputs is not None and use_cache) + else None, + padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None, + audio_hidden_states=audio_embeds, + ) + + +class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralRealtimeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + encoder_past_key_values: Cache | None = None, + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + encoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + num_delay_tokens: int | torch.Tensor = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast: + r""" + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. + num_delay_tokens (`int` or `torch.Tensor`, *optional*): + Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. + + Example: + + ```python + >>> import torch + >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset + + >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" + + >>> processor = AutoProcessor.from_pretrained(repo_id) + >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> audio = ds[0]["audio"]["array"] + + >>> inputs = processor(audio, return_tensors="pt") + >>> inputs = inputs.to(model.device, dtype=model.dtype) + + >>> outputs = model.generate(**inputs) + >>> processor.batch_decode(outputs, skip_special_tokens=True) + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + inputs_embeds=inputs_embeds, + encoder_inputs_embeds=encoder_inputs_embeds, + use_cache=use_cache, + num_delay_tokens=num_delay_tokens, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + return VoxtralRealtimeCausalLMOutputWithPast( - loss=outputs.loss, - logits=outputs.logits, + loss=loss, + logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - encoder_past_key_values=audio_outputs.past_key_values if use_cache else None, - padding_cache=audio_outputs.padding_cache if use_cache else None, + encoder_past_key_values=outputs.encoder_past_key_values, + padding_cache=outputs.padding_cache, ) def prepare_inputs_for_generation( @@ -705,11 +752,11 @@ def _prepare_model_inputs( bos_token_id: torch.Tensor | None = None, model_kwargs: dict[str, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, str | None, dict[str, torch.Tensor]]: - inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs(inputs, bos_token_id, model_kwargs) + inputs, input_name, model_kwargs = super()._prepare_model_inputs(inputs, bos_token_id, model_kwargs) input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): - model_kwargs["encoder_inputs_embeds"] = self.audio_tower.embedder(model_kwargs.pop("input_features")) + model_kwargs["encoder_inputs_embeds"] = self.model.audio_tower.embedder(model_kwargs.pop("input_features")) elif isinstance(input_features, GeneratorType): input_features_generator = model_kwargs.pop("input_features") @@ -725,7 +772,7 @@ def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, if getattr(self, "_stream_exhausted", False): self._stream_exhausted = False return False - return GenerationMixin._has_unfinished_sequences(this_peer_finished, synced_gpus, device) + return super()._has_unfinished_sequences(this_peer_finished, synced_gpus, device) def _update_model_kwargs_for_generation( self, @@ -734,7 +781,7 @@ def _update_model_kwargs_for_generation( is_encoder_decoder: bool = False, num_new_tokens: int = 1, ): - model_kwargs = GenerationMixin._update_model_kwargs_for_generation( + model_kwargs = super()._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder, num_new_tokens ) @@ -761,7 +808,7 @@ def _prepare_cache_for_generation( batch_size: int, max_cache_length: int, ): - GenerationMixin._prepare_cache_for_generation( + super()._prepare_cache_for_generation( generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) @@ -815,7 +862,7 @@ def _prepare_generation_config( generation_config is not None and generation_config.max_new_tokens is not None ) - generation_config, model_kwargs = GenerationMixin._prepare_generation_config(generation_config, **kwargs) + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): @@ -854,7 +901,7 @@ def _prepare_generated_length( if getattr(generation_config, "_voxtral_set_max_length", False): has_default_max_length = False - generation_config = GenerationMixin._prepare_generated_length( + generation_config = super()._prepare_generated_length( generation_config, has_default_max_length, has_default_min_length, @@ -877,4 +924,5 @@ def _prepare_generated_length( "VoxtralRealtimeForConditionalGeneration", "VoxtralRealtimeEncoder", "VoxtralRealtimePreTrainedModel", + "VoxtralRealtimeModel", ] diff --git a/tests/alm_tester.py b/tests/alm_tester.py index af0d63e17fd5..4c05751f564b 100644 --- a/tests/alm_tester.py +++ b/tests/alm_tester.py @@ -14,7 +14,6 @@ import copy import random -import unittest from inspect import signature from unittest.mock import patch @@ -161,11 +160,6 @@ class ALMModelTest(MultiModalModelTest): - `pipeline_model_mapping`: Override if not using default from model_tester """ - # TODO: @eustlb, remove this once #45534 is merged - @unittest.skip("Audio-LMs have no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - def test_sdpa_can_dispatch_on_flash(self): # `test_sdpa_can_dispatch_on_flash` already pops the attention mask, but we cannot simply pop the # audio mask here since it will raise an error in `get_audio_features` (cf. `test_mismatching_num_audio_tokens`). diff --git a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py index 7df6bcd254fc..57fbe90a352e 100644 --- a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py +++ b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py @@ -22,6 +22,7 @@ AudioFlamingo3Config, AudioFlamingo3EncoderConfig, AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Model, AutoProcessor, Qwen2Config, is_torch_available, @@ -42,6 +43,7 @@ class AudioFlamingo3ModelTester(ALMModelTester): config_class = AudioFlamingo3Config + base_model_class = AudioFlamingo3Model conditional_generation_class = AudioFlamingo3ForConditionalGeneration text_config_class = Qwen2Config audio_config_class = AudioFlamingo3EncoderConfig diff --git a/tests/models/glmasr/test_modeling_glmasr.py b/tests/models/glmasr/test_modeling_glmasr.py index b19e91a61209..1df06e819636 100644 --- a/tests/models/glmasr/test_modeling_glmasr.py +++ b/tests/models/glmasr/test_modeling_glmasr.py @@ -19,6 +19,7 @@ AutoProcessor, GlmAsrConfig, GlmAsrForConditionalGeneration, + GlmAsrModel, LlamaConfig, is_torch_available, ) @@ -39,6 +40,7 @@ class GlmAsrModelTester(ALMModelTester): config_class = GlmAsrConfig + base_model_class = GlmAsrModel conditional_generation_class = GlmAsrForConditionalGeneration text_config_class = LlamaConfig audio_config_class = GlmAsrEncoderConfig @@ -48,6 +50,12 @@ def __init__(self, parent, **kwargs): kwargs.setdefault("head_dim", 8) super().__init__(parent, **kwargs) + def create_audio_mask(self): + # Deterministic full-length mask: the base default randomizes lengths in [1, feat_seq_length], + # and short samples collapse to 0 audio tokens after conv2 (s=2) + merge_factor=4, breaking + # test_mismatching_num_audio_tokens (a 0-contribution sample makes "duplicate audio" a no-op). + return torch.ones([self.batch_size, self.feat_seq_length], dtype=torch.bool).to(torch_device) + def get_audio_embeds_mask(self, audio_mask): # conv1 (s=1) preserves length; conv2 (s=2, k=3, p=1) halves; merge_factor=4 post-projector. audio_lengths = audio_mask.sum(-1) diff --git a/tests/models/granite_speech/test_modeling_granite_speech.py b/tests/models/granite_speech/test_modeling_granite_speech.py index e4ecebbcb0ee..613790dbaa5a 100644 --- a/tests/models/granite_speech/test_modeling_granite_speech.py +++ b/tests/models/granite_speech/test_modeling_granite_speech.py @@ -23,6 +23,7 @@ GraniteSpeechConfig, GraniteSpeechEncoderConfig, GraniteSpeechForConditionalGeneration, + GraniteSpeechModel, ) from transformers.testing_utils import ( cleanup, @@ -49,6 +50,7 @@ class GraniteSpeechModelTester(ALMModelTester): config_class = GraniteSpeechConfig + base_model_class = GraniteSpeechModel conditional_generation_class = GraniteSpeechForConditionalGeneration text_config_class = GraniteConfig audio_config_class = GraniteSpeechEncoderConfig diff --git a/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py b/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py index 21f1d997efb4..04372995ffc3 100644 --- a/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py +++ b/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py @@ -20,6 +20,7 @@ GraniteSpeechPlusConfig, GraniteSpeechPlusEncoderConfig, GraniteSpeechPlusForConditionalGeneration, + GraniteSpeechPlusModel, ) from transformers.testing_utils import cleanup, require_torch, slow, torch_device from transformers.utils import is_datasets_available, is_torch_available @@ -43,6 +44,7 @@ class GraniteSpeechPlusForConditionalGenerationModelTester(GraniteSpeechModelTes """ config_class = GraniteSpeechPlusConfig + base_model_class = GraniteSpeechPlusModel conditional_generation_class = GraniteSpeechPlusForConditionalGeneration audio_config_class = GraniteSpeechPlusEncoderConfig diff --git a/tests/models/musicflamingo/test_modeling_musicflamingo.py b/tests/models/musicflamingo/test_modeling_musicflamingo.py index 7c6174e86834..111ec75bea4b 100644 --- a/tests/models/musicflamingo/test_modeling_musicflamingo.py +++ b/tests/models/musicflamingo/test_modeling_musicflamingo.py @@ -24,6 +24,7 @@ AutoProcessor, MusicFlamingoConfig, MusicFlamingoForConditionalGeneration, + MusicFlamingoModel, Qwen2Config, is_torch_available, ) @@ -51,6 +52,7 @@ class MusicFlamingoModelTester(ALMModelTester): """ config_class = MusicFlamingoConfig + base_model_class = MusicFlamingoModel conditional_generation_class = MusicFlamingoForConditionalGeneration text_config_class = Qwen2Config audio_config_class = AudioFlamingo3EncoderConfig @@ -96,7 +98,7 @@ class MusicFlamingoForConditionalGenerationModelTest(ALMModelTest, unittest.Test def test_rotary_window_axis_resets_per_audio(self): config = self.model_tester.get_config() - pos_emb = MusicFlamingoForConditionalGeneration(config).pos_emb.to(torch_device) + pos_emb = MusicFlamingoForConditionalGeneration(config).model.pos_emb.to(torch_device) timestamps = torch.tensor( [ @@ -126,7 +128,7 @@ def test_build_audio_timestamps_reconstructs_windows_from_input_ids(self): input_ids[0, :45] = config.audio_token_id input_ids[1, :30] = config.audio_token_id - _, post_lengths = model.audio_tower._get_feat_extract_output_lengths( + _, post_lengths = model.model.audio_tower._get_feat_extract_output_lengths( input_features_mask.sum(-1).to(torch.long) ) max_post_length = int(post_lengths.max().item()) @@ -144,7 +146,7 @@ def test_build_audio_timestamps_reconstructs_windows_from_input_ids(self): ] ) - inferred = model._build_audio_timestamps(input_ids, post_lengths, max_post_length) + inferred = model.model._build_audio_timestamps(input_ids, post_lengths, max_post_length) torch.testing.assert_close(inferred, audio_timestamps) @unittest.skip( diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index d8a8bd7c88f9..f6e7662ac5cc 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -24,6 +24,7 @@ Qwen2AudioConfig, Qwen2AudioEncoderConfig, Qwen2AudioForConditionalGeneration, + Qwen2AudioModel, Qwen2Config, is_torch_available, ) @@ -43,6 +44,7 @@ class Qwen2AudioModelTester(ALMModelTester): config_class = Qwen2AudioConfig + base_model_class = Qwen2AudioModel conditional_generation_class = Qwen2AudioForConditionalGeneration text_config_class = Qwen2Config audio_config_class = Qwen2AudioEncoderConfig diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py index fc8bb11568ea..dc01897a07ff 100644 --- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py +++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py @@ -22,6 +22,7 @@ from transformers import ( VibeVoiceAsrConfig, VibeVoiceAsrForConditionalGeneration, + VibeVoiceAsrModel, is_datasets_available, is_torch_available, ) @@ -133,11 +134,14 @@ def prepare_config_and_inputs_for_common(self): @require_torch class VibeVoiceAsrForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (VibeVoiceAsrForConditionalGeneration,) if is_torch_available() else () + all_model_classes = (VibeVoiceAsrModel, VibeVoiceAsrForConditionalGeneration) if is_torch_available() else () pipeline_model_mapping = ( {"audio-text-to-text": VibeVoiceAsrForConditionalGeneration} if is_torch_available() else {} ) _is_composite = True + # Acoustic/semantic tokenizers run under torch.no_grad() in get_audio_features, + # so their params never receive grads — the mixin's force-unfreeze can't change that. + test_all_params_have_gradient = False def setUp(self): self.model_tester = VibeVoiceAsrModelTester(self) @@ -185,6 +189,14 @@ def test_model_outputs_equivalence(self): def test_left_padding_compatibility(self): pass + @unittest.skip(reason="VibeVoiceAsr has slight randomness due to VAE sampling.") + def test_forward_with_logits_to_keep(self): + pass + + @unittest.skip(reason="VibeVoiceAsr has slight randomness due to VAE sampling.") + def test_generate_methods_with_logits_to_keep(self): + pass + def test_sdpa_can_dispatch_composite_models(self): # VibeVoiceAsr is audio+text composite; but audio components do not use attention for model_class in self.all_model_classes: @@ -197,16 +209,17 @@ def test_sdpa_can_dispatch_composite_models(self): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + language_model_sdpa = model_sdpa.base_model.language_model + text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager" self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(language_model_sdpa.config._attn_implementation == text_attn) # Eager model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.base_model.language_model.config._attn_implementation == "eager") for _, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ diff --git a/tests/models/voxtral/test_modeling_voxtral.py b/tests/models/voxtral/test_modeling_voxtral.py index 4f0c604ce05f..deeed284bc5d 100644 --- a/tests/models/voxtral/test_modeling_voxtral.py +++ b/tests/models/voxtral/test_modeling_voxtral.py @@ -21,6 +21,7 @@ VoxtralConfig, VoxtralEncoderConfig, VoxtralForConditionalGeneration, + VoxtralModel, is_torch_available, ) from transformers.testing_utils import ( @@ -40,6 +41,7 @@ class VoxtralModelTester(ALMModelTester): config_class = VoxtralConfig + base_model_class = VoxtralModel conditional_generation_class = VoxtralForConditionalGeneration text_config_class = LlamaConfig audio_config_class = VoxtralEncoderConfig diff --git a/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py b/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py index 150d7a894104..420b794cd8ee 100644 --- a/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py +++ b/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py @@ -20,6 +20,7 @@ AutoProcessor, VoxtralRealtimeConfig, VoxtralRealtimeForConditionalGeneration, + VoxtralRealtimeModel, is_datasets_available, is_torch_available, ) @@ -48,6 +49,7 @@ class VoxtralRealtimeModelTester(ALMModelTester): config_class = VoxtralRealtimeConfig + base_model_class = VoxtralRealtimeModel conditional_generation_class = VoxtralRealtimeForConditionalGeneration text_config_class = VoxtralRealtimeTextConfig audio_config_class = VoxtralRealtimeEncoderConfig