diff --git a/docs/source/en/model_doc/audioflamingo3.md b/docs/source/en/model_doc/audioflamingo3.md
index 0249480b3317..95cf1d9caa84 100644
--- a/docs/source/en/model_doc/audioflamingo3.md
+++ b/docs/source/en/model_doc/audioflamingo3.md
@@ -403,6 +403,11 @@ are forwarded, so you can tweak padding or tensor formats just like when calling
[[autodoc]] AudioFlamingo3Encoder
- forward
+## AudioFlamingo3Model
+
+[[autodoc]] AudioFlamingo3Model
+ - forward
+
## AudioFlamingo3ForConditionalGeneration
[[autodoc]] AudioFlamingo3ForConditionalGeneration
diff --git a/docs/source/en/model_doc/glmasr.md b/docs/source/en/model_doc/glmasr.md
index a9acd132d8eb..ad7dda6757a6 100644
--- a/docs/source/en/model_doc/glmasr.md
+++ b/docs/source/en/model_doc/glmasr.md
@@ -231,6 +231,11 @@ assert decoded_outputs == EXPECTED_OUTPUT
[[autodoc]] GlmAsrEncoder
- forward
+## GlmAsrModel
+
+[[autodoc]] GlmAsrModel
+ - forward
+
## GlmAsrForConditionalGeneration
[[autodoc]] GlmAsrForConditionalGeneration
diff --git a/docs/source/en/model_doc/granite_speech.md b/docs/source/en/model_doc/granite_speech.md
index 5115292a84d8..aedfb5eea61a 100644
--- a/docs/source/en/model_doc/granite_speech.md
+++ b/docs/source/en/model_doc/granite_speech.md
@@ -163,6 +163,11 @@ for i, transcription in enumerate(transcriptions):
[[autodoc]] GraniteSpeechFeatureExtractor
+## GraniteSpeechModel
+
+[[autodoc]] GraniteSpeechModel
+ - forward
+
## GraniteSpeechForConditionalGeneration
[[autodoc]] GraniteSpeechForConditionalGeneration
diff --git a/docs/source/en/model_doc/granite_speech_plus.md b/docs/source/en/model_doc/granite_speech_plus.md
index 22cf57935c12..810c0330fda6 100644
--- a/docs/source/en/model_doc/granite_speech_plus.md
+++ b/docs/source/en/model_doc/granite_speech_plus.md
@@ -143,6 +143,11 @@ for k in range(NUM_SEGMENTS):
[[autodoc]] GraniteSpeechPlusEncoderConfig
+## GraniteSpeechPlusModel
+
+[[autodoc]] GraniteSpeechPlusModel
+ - forward
+
## GraniteSpeechPlusForConditionalGeneration
[[autodoc]] GraniteSpeechPlusForConditionalGeneration
diff --git a/docs/source/en/model_doc/hyperclovax.md b/docs/source/en/model_doc/hyperclovax.md
index 2725fbbc090f..0a26d0bde2b3 100644
--- a/docs/source/en/model_doc/hyperclovax.md
+++ b/docs/source/en/model_doc/hyperclovax.md
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-*This model was released on 2025-07-21 and added to Hugging Face Transformers on 2026-05-06.*
+*This model was released on 2025-07-21 and added to Hugging Face Transformers on 2026-05-08.*
diff --git a/docs/source/en/model_doc/musicflamingo.md b/docs/source/en/model_doc/musicflamingo.md
index 52cab097b3b0..bdb3fe3437c6 100644
--- a/docs/source/en/model_doc/musicflamingo.md
+++ b/docs/source/en/model_doc/musicflamingo.md
@@ -287,6 +287,11 @@ loss.backward()
[[autodoc]] MusicFlamingoProcessor
+## MusicFlamingoModel
+
+[[autodoc]] MusicFlamingoModel
+ - forward
+
## MusicFlamingoForConditionalGeneration
[[autodoc]] MusicFlamingoForConditionalGeneration
diff --git a/docs/source/en/model_doc/pe_audio.md b/docs/source/en/model_doc/pe_audio.md
index 21e0f8938b9e..6b2210f26d4e 100644
--- a/docs/source/en/model_doc/pe_audio.md
+++ b/docs/source/en/model_doc/pe_audio.md
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.*
+*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.*
# PE Audio
diff --git a/docs/source/en/model_doc/pe_audio_video.md b/docs/source/en/model_doc/pe_audio_video.md
index b091383413d1..cd2377d35781 100644
--- a/docs/source/en/model_doc/pe_audio_video.md
+++ b/docs/source/en/model_doc/pe_audio_video.md
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.*
+*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.*
# PE Audio Video
diff --git a/docs/source/en/model_doc/pe_video.md b/docs/source/en/model_doc/pe_video.md
index 872273f78597..48cb691792f7 100644
--- a/docs/source/en/model_doc/pe_video.md
+++ b/docs/source/en/model_doc/pe_video.md
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-16.*
+*This model was released on 2025-04-17 and added to Hugging Face Transformers on 2025-12-16.*
# PE Video
diff --git a/docs/source/en/model_doc/qwen2_audio.md b/docs/source/en/model_doc/qwen2_audio.md
index 2cfd00b8e808..785b4bf5ca06 100644
--- a/docs/source/en/model_doc/qwen2_audio.md
+++ b/docs/source/en/model_doc/qwen2_audio.md
@@ -251,6 +251,11 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
[[autodoc]] Qwen2AudioEncoder
- forward
+## Qwen2AudioModel
+
+[[autodoc]] Qwen2AudioModel
+ - forward
+
## Qwen2AudioForConditionalGeneration
[[autodoc]] Qwen2AudioForConditionalGeneration
diff --git a/docs/source/en/model_doc/vibevoice_asr.md b/docs/source/en/model_doc/vibevoice_asr.md
index f28485e6cb9e..8c29de227240 100644
--- a/docs/source/en/model_doc/vibevoice_asr.md
+++ b/docs/source/en/model_doc/vibevoice_asr.md
@@ -452,6 +452,11 @@ print(transcription)
- apply_transcription_request
- decode
+## VibeVoiceAsrModel
+
+[[autodoc]] VibeVoiceAsrModel
+ - forward
+
## VibeVoiceAsrForConditionalGeneration
[[autodoc]] VibeVoiceAsrForConditionalGeneration
diff --git a/docs/source/en/model_doc/voxtral.md b/docs/source/en/model_doc/voxtral.md
index 0f520f547584..adb2c3706ffb 100644
--- a/docs/source/en/model_doc/voxtral.md
+++ b/docs/source/en/model_doc/voxtral.md
@@ -352,6 +352,11 @@ This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb)
[[autodoc]] VoxtralEncoder
- forward
+## VoxtralModel
+
+[[autodoc]] VoxtralModel
+ - forward
+
## VoxtralForConditionalGeneration
[[autodoc]] VoxtralForConditionalGeneration
diff --git a/docs/source/en/model_doc/voxtral_realtime.md b/docs/source/en/model_doc/voxtral_realtime.md
index 49b274ea6659..97fe5fecff55 100644
--- a/docs/source/en/model_doc/voxtral_realtime.md
+++ b/docs/source/en/model_doc/voxtral_realtime.md
@@ -182,6 +182,11 @@ This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb)
[[autodoc]] VoxtralRealtimeEncoder
- forward
+## VoxtralRealtimeModel
+
+[[autodoc]] VoxtralRealtimeModel
+ - forward
+
## VoxtralRealtimeForConditionalGeneration
[[autodoc]] VoxtralRealtimeForConditionalGeneration
diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py
index 858b8d8d0127..4b4a464caf05 100755
--- a/src/transformers/conversion_mapping.py
+++ b/src/transformers/conversion_mapping.py
@@ -87,6 +87,14 @@
"vipllava": "llava",
"mistral3": "llava",
"pp_chart2table": "llava",
+ "voxtral": "qwen2_audio",
+ "voxtral_realtime": "qwen2_audio",
+ "audioflamingo3": "qwen2_audio",
+ "glmasr": "qwen2_audio",
+ "musicflamingo": "qwen2_audio",
+ "granite_speech_plus": "granite_speech",
+ "gemma3n_text": "qwen3_5_text",
+ "qwen3_5_moe_text": "qwen3_5_text",
"llava_next_video": "llava_next",
"llava_onevision": "llava_next",
# class-based mappings
@@ -103,6 +111,12 @@
"LlavaOnevisionModel": "LlavaModel",
"FuyuModel": "LlavaModel",
"MllamaModel": "LlavaModel",
+ "VoxtralModel": "Qwen2AudioModel",
+ "VoxtralRealtimeModel": "Qwen2AudioModel",
+ "AudioFlamingo3Model": "Qwen2AudioModel",
+ "GlmAsrModel": "Qwen2AudioModel",
+ "MusicFlamingoModel": "Qwen2AudioModel",
+ "GraniteSpeechPlusModel": "GraniteSpeechModel",
"MaskFormerDetrDecoder": "DetrModel",
"Qwen2_5_VLForConditionalGeneration": "Qwen2VLForConditionalGeneration",
# ViT-style vision models (old HuggingFace checkpoint format → new modular format)
@@ -420,6 +434,38 @@ def _build_checkpoint_conversion_mapping():
WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"),
WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"),
],
+ "qwen2_audio": [
+ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
+ WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
+ WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"),
+ WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"),
+ ],
+ "Qwen2AudioModel": [
+ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"),
+ ],
+ "granite_speech": [
+ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
+ WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
+ WeightRenaming(source_patterns=r"^encoder", target_patterns="model.encoder"),
+ WeightRenaming(source_patterns=r"^projector", target_patterns="model.projector"),
+ ],
+ "GraniteSpeechModel": [
+ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"),
+ ],
+ "vibevoice_asr": [
+ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
+ WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
+ WeightRenaming(
+ source_patterns=r"^acoustic_tokenizer_encoder", target_patterns="model.acoustic_tokenizer_encoder"
+ ),
+ WeightRenaming(
+ source_patterns=r"^semantic_tokenizer_encoder", target_patterns="model.semantic_tokenizer_encoder"
+ ),
+ WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"),
+ ],
+ "VibeVoiceAsrModel": [
+ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="language_model"),
+ ],
"llava_next": [
WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"),
diff --git a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py
index 096e263d856d..cd81ef805205 100644
--- a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py
+++ b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py
@@ -100,6 +100,7 @@ class AudioFlamingo3Config(PreTrainedConfig):
audio_token_id: int = 151669
projector_hidden_act: str = "gelu"
projector_bias: bool = True
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.audio_config, dict):
diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py
index ca77ef4a6cb1..26972fcf18b2 100644
--- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py
+++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py
@@ -21,6 +21,7 @@
import math
from collections.abc import Callable
+from dataclasses import dataclass
import torch
from torch import nn
@@ -31,13 +32,13 @@
from ...masking_utils import create_bidirectional_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
-from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig
@@ -254,6 +255,43 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
+ _supports_attention_backend = True
+
+
+@dataclass
+class AudioFlamingo3ModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for AudioFlamingo3 causal language model (or autoregressive) outputs.
+ """
+)
+class AudioFlamingo3CausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Hidden states of the audio encoder after projection.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
@auto_docstring(
@@ -403,38 +441,23 @@ def forward(self, audio_features):
@auto_docstring(
custom_intro="""
- The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
+ The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model),
+ without a language modeling head.
"""
)
-class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = None
- _supports_attention_backend = True
+class AudioFlamingo3Model(AudioFlamingo3PreTrainedModel):
_tp_plan = None
_sp_plan = None
_pp_plan = None
+ _keep_in_fp32_modules_strict = None
def __init__(self, config):
super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
self.audio_tower = AutoModel.from_config(config.audio_config)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
-
- # Initialize weights and apply final processing
self.post_init()
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
@can_return_tuple
@auto_docstring(
custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
@@ -447,11 +470,7 @@ def get_audio_features(
) -> tuple | BaseModelOutputWithPooling:
r"""
input_features (`torch.FloatTensor`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
- `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
+ Float values of mel features extracted from the raw speech waveform.
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padded feature indices.
"""
@@ -504,77 +523,17 @@ def forward(
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
+ ) -> tuple | AudioFlamingo3ModelOutputWithPast:
r"""
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
- Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Example:
-
- ```python
- >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
-
- >>> model_id = "nvidia/audio-flamingo-3-hf"
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
-
- >>> conversations = [
- >>> [
- >>> {
- >>> "role": "user",
- >>> "content": [
- >>> {"type": "text", "text": "Transcribe the input speech."},
- >>> {
- >>> "type": "audio",
- >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
- >>> },
- >>> ],
- >>> }
- >>> ],
- >>> [
- >>> {
- >>> "role": "user",
- >>> "content": [
- >>> {
- >>> "type": "text",
- >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
- >>> },
- >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
- >>> ],
- >>> }
- >>> ],
- >>> ]
-
- >>> inputs = processor.apply_chat_template(
- >>> conversations,
- >>> tokenize=True,
- >>> add_generation_prompt=True,
- >>> return_dict=True,
- >>> ).to(model.device)
-
- >>> outputs = model.generate(**inputs, max_new_tokens=500)
-
- >>> decoded_outputs = processor.batch_decode(
- >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
- >>> )
- >>> print(decoded_outputs)
- ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."]
- ```"""
-
+ Mask to avoid performing attention on padding feature indices.
+ """
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
+ audio_embeds = None
if input_features is not None and input_ids is not None:
audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output
@@ -584,17 +543,103 @@ def forward(
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
- labels=labels,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
**kwargs,
)
- return outputs
+
+ return AudioFlamingo3ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
+ """
+)
+class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
+ _keep_in_fp32_modules_strict = ["embed_positions"]
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = AudioFlamingo3Model(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, input_features, input_features_mask, **kwargs):
+ return self.model.get_audio_features(input_features, input_features_mask, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ input_features_mask: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | AudioFlamingo3CausalLMOutputWithPast:
+ r"""
+ input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
+ Mask to avoid performing attention on padding feature indices.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss.
+
+ Example:
+
+ ```python
+ >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
+
+ >>> model_id = "nvidia/audio-flamingo-3-hf"
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ input_features_mask=input_features_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return AudioFlamingo3CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=outputs.audio_hidden_states,
+ )
def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
input_features = kwargs.pop("input_features", None)
@@ -611,4 +656,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False,
return model_inputs
-__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"]
+__all__ = [
+ "AudioFlamingo3ForConditionalGeneration",
+ "AudioFlamingo3PreTrainedModel",
+ "AudioFlamingo3Encoder",
+ "AudioFlamingo3Model",
+]
diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py
index a170d8ee4124..86be1ffa376f 100644
--- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py
+++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py
@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from dataclasses import dataclass
+
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...masking_utils import create_bidirectional_mask
-from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
+from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import merge_with_config_defaults
@@ -28,7 +30,12 @@
Qwen2AudioEncoder,
Qwen2AudioPreTrainedModel,
)
-from ..voxtral.modeling_voxtral import VoxtralForConditionalGeneration, VoxtralMultiModalProjector
+from ..voxtral.modeling_voxtral import (
+ VoxtralForConditionalGeneration,
+ VoxtralModel,
+ VoxtralModelOutputWithPast,
+ VoxtralMultiModalProjector,
+)
from ..whisper.modeling_whisper import WhisperAttention, WhisperEncoderLayer
from .configuration_audioflamingo3 import AudioFlamingo3Config
@@ -45,9 +52,40 @@ class AudioFlamingo3EncoderLayer(WhisperEncoderLayer):
class AudioFlamingo3PreTrainedModel(Qwen2AudioPreTrainedModel):
+ _supports_attention_backend = True
+
+
+@dataclass
+class AudioFlamingo3ModelOutputWithPast(VoxtralModelOutputWithPast):
pass
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for AudioFlamingo3 causal language model (or autoregressive) outputs.
+ """
+)
+class AudioFlamingo3CausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Hidden states of the audio encoder after projection.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
@auto_docstring(
custom_intro="""
The audio model from AudioFlamingo3 without any head or projection on top.
@@ -138,11 +176,11 @@ def __init__(self, config: AudioFlamingo3Config):
@auto_docstring(
custom_intro="""
- The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
+ The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model),
+ without a language modeling head.
"""
)
-class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration):
- _supports_attention_backend = True
+class AudioFlamingo3Model(VoxtralModel):
_tp_plan = None
_sp_plan = None
_pp_plan = None
@@ -163,11 +201,7 @@ def get_audio_features(
) -> tuple | BaseModelOutputWithPooling:
r"""
input_features (`torch.FloatTensor`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
- `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
+ Float values of mel features extracted from the raw speech waveform.
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padded feature indices.
"""
@@ -196,77 +230,17 @@ def forward(
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
+ ):
r"""
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
- Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Example:
-
- ```python
- >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
-
- >>> model_id = "nvidia/audio-flamingo-3-hf"
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
-
- >>> conversations = [
- >>> [
- >>> {
- >>> "role": "user",
- >>> "content": [
- >>> {"type": "text", "text": "Transcribe the input speech."},
- >>> {
- >>> "type": "audio",
- >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
- >>> },
- >>> ],
- >>> }
- >>> ],
- >>> [
- >>> {
- >>> "role": "user",
- >>> "content": [
- >>> {
- >>> "type": "text",
- >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
- >>> },
- >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
- >>> ],
- >>> }
- >>> ],
- >>> ]
-
- >>> inputs = processor.apply_chat_template(
- >>> conversations,
- >>> tokenize=True,
- >>> add_generation_prompt=True,
- >>> return_dict=True,
- >>> ).to(model.device)
-
- >>> outputs = model.generate(**inputs, max_new_tokens=500)
-
- >>> decoded_outputs = processor.batch_decode(
- >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
- >>> )
- >>> print(decoded_outputs)
- ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."]
- ```"""
-
+ Mask to avoid performing attention on padding feature indices.
+ """
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
+ audio_embeds = None
if input_features is not None and input_ids is not None:
audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output
@@ -276,17 +250,99 @@ def forward(
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
- labels=labels,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
**kwargs,
)
- return outputs
+
+ return AudioFlamingo3ModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
+ """
+)
+class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = AudioFlamingo3Model(config)
+ self.post_init()
+
+ def get_audio_features(self, input_features, input_features_mask, **kwargs):
+ return self.model.get_audio_features(input_features, input_features_mask, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ input_features_mask: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | AudioFlamingo3CausalLMOutputWithPast:
+ r"""
+ input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
+ Mask to avoid performing attention on padding feature indices.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss.
+
+ Example:
+
+ ```python
+ >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
+
+ >>> model_id = "nvidia/audio-flamingo-3-hf"
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ input_features_mask=input_features_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return AudioFlamingo3CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=outputs.audio_hidden_states,
+ )
def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
input_features = kwargs.pop("input_features", None)
@@ -303,4 +359,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False,
return model_inputs
-__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"]
+__all__ = [
+ "AudioFlamingo3ForConditionalGeneration",
+ "AudioFlamingo3PreTrainedModel",
+ "AudioFlamingo3Encoder",
+ "AudioFlamingo3Model",
+]
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index e9de6747076f..cf421c30e7c6 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -52,7 +52,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("aria", "AriaModel"),
("aria_text", "AriaTextModel"),
("audio-spectrogram-transformer", "ASTModel"),
- ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
+ ("audioflamingo3", "AudioFlamingo3Model"),
("audioflamingo3_encoder", "AudioFlamingo3Encoder"),
("autoformer", "AutoformerModel"),
("aya_vision", "AyaVisionModel"),
@@ -201,7 +201,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("glm_ocr", "GlmOcrModel"),
("glm_ocr_text", "GlmOcrTextModel"),
("glm_ocr_vision", "GlmOcrVisionModel"),
- ("glmasr", "GlmAsrForConditionalGeneration"),
+ ("glmasr", "GlmAsrModel"),
("glmasr_encoder", "GlmAsrEncoder"),
("glpn", "GLPNModel"),
("got_ocr2", "GotOcr2Model"),
@@ -215,7 +215,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("gptj", "GPTJModel"),
("granite", "GraniteModel"),
("granite4_vision", "Granite4VisionModel"),
- ("granite_speech", "GraniteSpeechForConditionalGeneration"),
+ ("granite_speech", "GraniteSpeechModel"),
+ ("granite_speech_plus", "GraniteSpeechPlusModel"),
("granitemoe", "GraniteMoeModel"),
("granitemoehybrid", "GraniteMoeHybridModel"),
("granitemoeshared", "GraniteMoeSharedModel"),
@@ -320,7 +321,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("mpt", "MptModel"),
("mra", "MraModel"),
("mt5", "MT5Model"),
- ("musicflamingo", "MusicFlamingoForConditionalGeneration"),
+ ("musicflamingo", "MusicFlamingoModel"),
("musicgen", "MusicgenModel"),
("musicgen_melody", "MusicgenMelodyModel"),
("mvp", "MvpModel"),
@@ -384,6 +385,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("qwen2", "Qwen2Model"),
("qwen2_5_vl", "Qwen2_5_VLModel"),
("qwen2_5_vl_text", "Qwen2_5_VLTextModel"),
+ ("qwen2_audio", "Qwen2AudioModel"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoeModel"),
("qwen2_vl", "Qwen2VLModel"),
@@ -479,7 +481,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerModel"),
("vibevoice_acoustic_tokenizer_decoder", "VibeVoiceAcousticTokenizerDecoderModel"),
("vibevoice_acoustic_tokenizer_encoder", "VibeVoiceAcousticTokenizerEncoderModel"),
- ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"),
+ ("vibevoice_asr", "VibeVoiceAsrModel"),
("video_llama_3", "VideoLlama3Model"),
("video_llama_3_vision", "VideoLlama3VisionModel"),
("video_llava", "VideoLlavaModel"),
@@ -495,9 +497,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("vits", "VitsModel"),
("vivit", "VivitModel"),
("vjepa2", "VJEPA2Model"),
- ("voxtral", "VoxtralForConditionalGeneration"),
+ ("voxtral", "VoxtralModel"),
("voxtral_encoder", "VoxtralEncoder"),
- ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"),
+ ("voxtral_realtime", "VoxtralRealtimeModel"),
("voxtral_realtime_encoder", "VoxtralRealtimeEncoder"),
("voxtral_realtime_text", "VoxtralRealtimeTextModel"),
("wav2vec2", "Wav2Vec2Model"),
diff --git a/src/transformers/models/glmasr/configuration_glmasr.py b/src/transformers/models/glmasr/configuration_glmasr.py
index c3d320bb1db4..c89379ead3e7 100644
--- a/src/transformers/models/glmasr/configuration_glmasr.py
+++ b/src/transformers/models/glmasr/configuration_glmasr.py
@@ -101,6 +101,7 @@ class GlmAsrConfig(PreTrainedConfig):
text_config: dict | PreTrainedConfig | None = None
audio_token_id: int = 59260
projector_hidden_act: str = "gelu"
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.audio_config, dict):
diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py
index df70cfb3884d..29b82adfc52c 100644
--- a/src/transformers/models/glmasr/modeling_glmasr.py
+++ b/src/transformers/models/glmasr/modeling_glmasr.py
@@ -19,6 +19,7 @@
# limitations under the License.
from collections.abc import Callable
+from dataclasses import dataclass
from typing import Optional
from ...activations import ACT2FN
@@ -26,14 +27,19 @@
from ...generation import GenerationMixin
from ...integrations import use_kernelized_func
from ...modeling_layers import GradientCheckpointingLayer
-from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ CausalLMOutputWithPast,
+ ModelOutput,
+)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, is_torch_available, torch_compilable_check
from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_glmasr import GlmAsrConfig, GlmAsrEncoderConfig
@@ -284,6 +290,7 @@ class GlmAsrPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
+ _supports_attention_backend = True
# TODO: @eustlb, this is what WhisperEncoder should look like
@@ -349,40 +356,34 @@ def forward(self, audio_features):
return hidden_states
+@dataclass
+class GlmAsrModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
@auto_docstring(
custom_intro="""
The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model.
"""
)
-class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = None
- _supports_attention_backend = True
+class GlmAsrModel(GlmAsrPreTrainedModel):
_tp_plan = None
_sp_plan = None
_pp_plan = None
+ _keep_in_fp32_modules_strict = None
def __init__(self, config):
super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
self.audio_tower = AutoModel.from_config(config.audio_config)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
self.multi_modal_projector = GlmAsrMultiModalProjector(config)
-
- # Initialize weights and apply final processing
self.post_init()
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
@can_return_tuple
@auto_docstring(
custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector."
@@ -395,11 +396,7 @@ def get_audio_features(
) -> tuple | BaseModelOutputWithPooling:
r"""
input_features (`torch.FloatTensor`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
- `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
+ Float values of mel features extracted from the raw speech waveform.
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padded feature indices.
"""
@@ -445,6 +442,99 @@ def get_placeholder_mask(
)
return special_audio_mask
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ input_features_mask: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | GlmAsrModelOutputWithPast:
+ r"""
+ input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
+ Mask to avoid performing attention on padding feature indices.
+ """
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ audio_embeds = None
+ if input_features is not None and input_ids is not None:
+ audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output
+
+ # replace text-audio token placeholders with audio embeddings
+ special_audio_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
+
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ return GlmAsrModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for GlmAsr causal language model (or autoregressive) outputs.
+ """
+)
+class GlmAsrCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Hidden states of the audio encoder after projection.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model.
+ """
+)
+class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin):
+ _keep_in_fp32_modules_strict = ["embed_positions"]
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GlmAsrModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, input_features, input_features_mask, **kwargs):
+ return self.model.get_audio_features(input_features, input_features_mask, **kwargs)
+
@can_return_tuple
@auto_docstring
def forward(
@@ -489,30 +579,36 @@ def forward(
>>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
>>> print(decoded_outputs)
```"""
-
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if input_features is not None and input_ids is not None:
- audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output
-
- # replace text-audio token placeholders with audio embeddings
- special_audio_mask = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
-
- outputs: CausalLMOutputWithPast = self.language_model(
- inputs_embeds=inputs_embeds,
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ input_features_mask=input_features_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
- labels=labels,
+ inputs_embeds=inputs_embeds,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
**kwargs,
)
- return outputs
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return GlmAsrCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=outputs.audio_hidden_states,
+ )
def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
input_features = kwargs.pop("input_features", None)
@@ -529,4 +625,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False,
return model_inputs
-__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrPreTrainedModel"]
+__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrModel", "GlmAsrPreTrainedModel"]
diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py
index f46c64224b15..22820df66947 100644
--- a/src/transformers/models/glmasr/modular_glmasr.py
+++ b/src/transformers/models/glmasr/modular_glmasr.py
@@ -29,6 +29,7 @@
from ...utils.output_capturing import capture_outputs
from ..audioflamingo3.modeling_audioflamingo3 import (
AudioFlamingo3ForConditionalGeneration,
+ AudioFlamingo3Model,
AudioFlamingo3MultiModalProjector,
AudioFlamingo3PreTrainedModel,
)
@@ -342,9 +343,7 @@ def __init__(self, config: GlmAsrConfig):
The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model.
"""
)
-class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
- _supports_attention_backend = True
-
+class GlmAsrModel(AudioFlamingo3Model):
@can_return_tuple
@auto_docstring(
custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector."
@@ -373,6 +372,18 @@ def get_audio_features(
return audio_outputs
+
+@auto_docstring(
+ custom_intro="""
+ The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model.
+ """
+)
+class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GlmAsrModel(config)
+ self.post_init()
+
def forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -428,4 +439,10 @@ def forward(
)
-__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrProcessor", "GlmAsrPreTrainedModel"]
+__all__ = [
+ "GlmAsrEncoder",
+ "GlmAsrForConditionalGeneration",
+ "GlmAsrModel",
+ "GlmAsrProcessor",
+ "GlmAsrPreTrainedModel",
+]
diff --git a/src/transformers/models/granite_speech/configuration_granite_speech.py b/src/transformers/models/granite_speech/configuration_granite_speech.py
index e5532b3bf880..bf5eb93235d1 100644
--- a/src/transformers/models/granite_speech/configuration_granite_speech.py
+++ b/src/transformers/models/granite_speech/configuration_granite_speech.py
@@ -126,6 +126,7 @@ class GraniteSpeechConfig(PreTrainedConfig):
has_lora_adapter: bool = True
downsample_rate: int = 5
window_size: int = 15
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.text_config, dict):
diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py
index 898e51a306b4..a9db60f3427c 100644
--- a/src/transformers/models/granite_speech/modeling_granite_speech.py
+++ b/src/transformers/models/granite_speech/modeling_granite_speech.py
@@ -22,7 +22,7 @@
from ... import initialization as init
from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
@@ -35,16 +35,31 @@
)
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_granite_speech import GraniteSpeechConfig, GraniteSpeechEncoderConfig
logger = logging.get_logger(__name__)
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Granite Speech outputs, with hidden states and attentions.
+ """
+)
+class GraniteSpeechModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
@auto_docstring(
custom_intro="""
- Base class for LlavaNext causal language model (or autoregressive) outputs.
+ Base class for Granite Speech causal language model (or autoregressive) outputs.
"""
)
@dataclass
@@ -59,6 +74,8 @@ class GraniteSpeechCausalLMOutputWithPast(ModelOutput):
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
"""
loss: torch.FloatTensor | None = None
@@ -66,6 +83,7 @@ class GraniteSpeechCausalLMOutputWithPast(ModelOutput):
past_key_values: Cache | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
### Projector
@@ -262,9 +280,11 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) ->
class GraniteSpeechPreTrainedModel(PreTrainedModel):
config: GraniteSpeechConfig
input_modalities = ("audio", "text")
+ base_model_prefix = "model"
_supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this
_supports_sdpa = True
+ _supports_attention_backend = True
@torch.no_grad()
def _init_weights(self, module: nn.Module):
@@ -323,45 +343,18 @@ def forward(
@auto_docstring(
custom_intro="""
- The Granite Speech model, which consists of an audio encoder, projector, and language model.
+ The Granite Speech model, which consists of an audio encoder, projector, and language model,
+ without a language modeling head.
"""
)
-class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin):
- _supports_attention_backend = True
-
+class GraniteSpeechModel(GraniteSpeechPreTrainedModel):
def __init__(self, config: GraniteSpeechConfig):
super().__init__(config)
- # NOTE: It doesn't matter when we initialize from config, but we should be careful
- # to make sure this does not pick up the adapter_config if in the future we use
- # from_pretrained or something similar, since that should be set by the composite
- # model; don't need to consider it twice
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
self.encoder = GraniteSpeechCTCEncoder(config.encoder_config)
self.projector = GraniteSpeechEncoderProjector(config)
-
- if config.has_lora_adapter and not is_peft_available():
- logger.warning(
- "Config indicates that a lora adapter should be present, but "
- "peft is not installed; this will cause the model to perform "
- "incorrectly when audio inputs are provided. Please install "
- "peft and reload the model!"
- )
-
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
@can_return_tuple
@auto_docstring
def get_audio_features(
@@ -373,6 +366,61 @@ def get_audio_features(
return audio_outputs
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_audio_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_audio_mask = special_audio_mask.all(-1)
+ else:
+ special_audio_mask = input_ids == self.config.audio_token_id
+
+ n_audio_tokens = special_audio_mask.sum()
+ n_audio_features = audio_features.shape[0]
+ special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ torch_compilable_check(
+ inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
+ f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
+ )
+ return special_audio_mask
+
+ def get_merged_audio_embeddings(
+ self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ """
+ Adds the audio token to the model's LLM vocabulary so that we can pass it
+ through the tokenizer; it's assumed that the embeddings corresponding to the
+ <|audio|> token will be clobbered with speech features.
+
+ Args:
+ input_ids (`torch.Tensor`):
+ Input IDs containing one or more audio tokens.
+ audio_features (`torch.Tensor`):
+ Audio features to be masked into the language embeddings to form multimodal embeddings.
+ input_features_mask (`torch.Tensor`, *optional*, defaults to `None`)
+ Mask to be applied to audio features prior to scattering into the language embeddings.
+ """
+ is_audio_index = input_ids == self.config.audio_token_id
+ llm_input_ids = torch.where(is_audio_index, 0, input_ids)
+ inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]
+
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ if input_features_mask is not None:
+ audio_features = audio_features[input_features_mask]
+
+ special_audio_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
+ return inputs_embeds
+
+ @can_return_tuple
@auto_docstring
def forward(
self,
@@ -383,29 +431,13 @@ def forward(
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **lm_kwargs,
- ) -> tuple[torch.Tensor] | GraniteSpeechCausalLMOutputWithPast:
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | GraniteSpeechModelOutputWithPast:
r"""
input_features_mask (`torch.Tensor`, *optional*):
Mask to be applied to audio features prior to scattering into the language embeddings.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
- # TODO (@alex-jw-brooks) add an example to this docstring once models are released
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
-
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -423,6 +455,7 @@ def forward(
llm_input_ids[is_audio_idx] = 0
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+ audio_embeds = None
if input_features is not None:
if input_features.dtype != self.dtype:
input_features = input_features.to(self.dtype)
@@ -442,13 +475,84 @@ def forward(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- logits_to_keep=logits_to_keep,
- **lm_kwargs,
+ **kwargs,
+ )
+
+ return GraniteSpeechModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Granite Speech model, which consists of an audio encoder, projector, and language model.
+ """
+)
+class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config: GraniteSpeechConfig):
+ super().__init__(config)
+ self.model = GraniteSpeechModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+ if config.has_lora_adapter and not is_peft_available():
+ logger.warning(
+ "Config indicates that a lora adapter should be present, but "
+ "peft is not installed; this will cause the model to perform "
+ "incorrectly when audio inputs are provided. Please install "
+ "peft and reload the model!"
+ )
+
+ self.post_init()
+
+ def get_audio_features(self, *args, **kwargs):
+ return self.model.get_audio_features(*args, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ input_features_mask: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs,
+ ) -> tuple | GraniteSpeechCausalLMOutputWithPast:
+ r"""
+ input_features_mask (`torch.Tensor`, *optional*):
+ Mask to be applied to audio features prior to scattering into the language embeddings.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ # TODO (@alex-jw-brooks) add an example to this docstring once models are released
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ input_features_mask=input_features_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
)
- logits = outputs[0]
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@@ -468,16 +572,13 @@ def forward(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
return GraniteSpeechCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
+ audio_hidden_states=outputs.audio_hidden_states,
)
def prepare_inputs_for_generation(
@@ -493,7 +594,7 @@ def prepare_inputs_for_generation(
):
# Overwritten -- in specific circumstances we don't want to forward audio inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -510,60 +611,6 @@ def prepare_inputs_for_generation(
model_inputs["input_features"] = input_features
return model_inputs
- def get_placeholder_mask(
- self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
- ):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
- equal to the length of multimodal features. If the lengths are different, an error is raised.
- """
- if input_ids is None:
- special_audio_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_audio_mask = special_audio_mask.all(-1)
- else:
- special_audio_mask = input_ids == self.config.audio_token_id
-
- n_audio_tokens = special_audio_mask.sum()
- n_audio_features = audio_features.shape[0]
- special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
- f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
- )
- return special_audio_mask
-
- def get_merged_audio_embeddings(
- self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None
- ) -> torch.Tensor:
- """
- Adds the audio token to the model's LLM vocabulary so that we can pass it
- through the tokenizer; it's assumed that the embeddings corresponding to the
- <|audio|> token will be clobbered with speech features.
-
- Args:
- input_ids (`torch.Tensor`):
- Input IDs containing one or more audio tokens.
- audio_features (`torch.Tensor`):
- Audio features to be masked into the language embeddings to form multimodal embeddings.
- input_features_mask (`torch.Tensor`, *optional*, defaults to `None`)
- Mask to be applied to audio features prior to scattering into the language embeddings.
- """
- is_audio_index = input_ids == self.config.audio_token_id
- llm_input_ids = torch.where(is_audio_index, 0, input_ids)
- inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]
-
- audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
- if input_features_mask is not None:
- audio_features = audio_features[input_features_mask]
-
- special_audio_mask = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
- return inputs_embeds
-
def generate(self, *args, **kwargs) -> torch.LongTensor:
# This model is expected to have a lora adapter, which is only
# enabled when considering audio inputs. As such, we override generate
@@ -597,5 +644,6 @@ def _get_adapter_name(self):
__all__ = [
"GraniteSpeechCTCEncoder",
"GraniteSpeechForConditionalGeneration",
+ "GraniteSpeechModel",
"GraniteSpeechPreTrainedModel",
]
diff --git a/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py
index 1eec538091a4..f7d47a976e79 100644
--- a/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py
+++ b/src/transformers/models/granite_speech_plus/configuration_granite_speech_plus.py
@@ -137,6 +137,7 @@ class GraniteSpeechPlusConfig(PreTrainedConfig):
has_lora_adapter: bool = True
downsample_rate: int = 5
window_size: int = 15
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.text_config, dict):
diff --git a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py
index b625bc298ee1..c936e7fcf0c6 100644
--- a/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py
+++ b/src/transformers/models/granite_speech_plus/modeling_granite_speech_plus.py
@@ -28,7 +28,7 @@
from ... import initialization as init
from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
@@ -41,7 +41,7 @@
)
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_granite_speech_plus import GraniteSpeechPlusConfig, GraniteSpeechPlusEncoderConfig
@@ -87,9 +87,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class GraniteSpeechPlusPreTrainedModel(PreTrainedModel):
config: GraniteSpeechPlusConfig
input_modalities = ("audio", "text")
+ base_model_prefix = "model"
_supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this
_supports_sdpa = True
+ _supports_attention_backend = True
@torch.no_grad()
def _init_weights(self, module: nn.Module):
@@ -105,6 +107,167 @@ def _init_weights(self, module: nn.Module):
init.copy_(module.attention_dists, attention_dists)
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Granite Speech outputs, with hidden states and attentions.
+ """
+)
+class GraniteSpeechPlusModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The Granite Speech model, which consists of an audio encoder, projector, and language model,
+ without a language modeling head.
+ """
+)
+class GraniteSpeechPlusModel(GraniteSpeechPlusPreTrainedModel):
+ def __init__(self, config: GraniteSpeechPlusConfig):
+ super().__init__(config)
+ self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config)
+ self.projector = GraniteSpeechPlusEncoderProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def get_audio_features(
+ self, input_features: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
+ ) -> tuple | BaseModelOutputWithPooling:
+ audio_outputs = self.encoder(input_features, return_dict=True, **kwargs)
+ projected_embeds = self.projector(audio_outputs.last_hidden_state)
+ audio_outputs.pooler_output = projected_embeds
+
+ return audio_outputs
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_audio_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_audio_mask = special_audio_mask.all(-1)
+ else:
+ special_audio_mask = input_ids == self.config.audio_token_id
+
+ n_audio_tokens = special_audio_mask.sum()
+ n_audio_features = audio_features.shape[0]
+ special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ torch_compilable_check(
+ inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
+ f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
+ )
+ return special_audio_mask
+
+ def get_merged_audio_embeddings(
+ self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ """
+ Adds the audio token to the model's LLM vocabulary so that we can pass it
+ through the tokenizer; it's assumed that the embeddings corresponding to the
+ <|audio|> token will be clobbered with speech features.
+
+ Args:
+ input_ids (`torch.Tensor`):
+ Input IDs containing one or more audio tokens.
+ audio_features (`torch.Tensor`):
+ Audio features to be masked into the language embeddings to form multimodal embeddings.
+ input_features_mask (`torch.Tensor`, *optional*, defaults to `None`)
+ Mask to be applied to audio features prior to scattering into the language embeddings.
+ """
+ is_audio_index = input_ids == self.config.audio_token_id
+ llm_input_ids = torch.where(is_audio_index, 0, input_ids)
+ inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]
+
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ if input_features_mask is not None:
+ audio_features = audio_features[input_features_mask]
+
+ special_audio_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
+ return inputs_embeds
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ input_features_mask: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | GraniteSpeechPlusModelOutputWithPast:
+ r"""
+ input_features_mask (`torch.Tensor`, *optional*):
+ Mask to be applied to audio features prior to scattering into the language embeddings.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if input_features is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_features and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ # Get the base embeddings; set all audio tokens to 0 index
+ # to avoid out of vocabulary issues with the LLM embedding.
+ # Audio features will be masked into is_audio_idx indices later.
+ is_audio_idx = input_ids == self.config.audio_token_id
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[is_audio_idx] = 0
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+
+ audio_embeds = None
+ if input_features is not None:
+ if input_features.dtype != self.dtype:
+ input_features = input_features.to(self.dtype)
+ # Get the audio features from the encoder / projector
+ audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
+
+ # Merge the audio features into the LLM embeddings
+ inputs_embeds = self.get_merged_audio_embeddings(
+ input_ids=input_ids,
+ audio_features=audio_embeds,
+ input_features_mask=input_features_mask,
+ )
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ return GraniteSpeechPlusModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
class GraniteSpeechPlusConformerFeedForward(nn.Module):
"""Feedforward module for conformer encoder blocks."""
@@ -317,7 +480,7 @@ def forward(
@auto_docstring(
custom_intro="""
- Base class for LlavaNext causal language model (or autoregressive) outputs.
+ Base class for Granite Speech causal language model (or autoregressive) outputs.
"""
)
@dataclass
@@ -332,6 +495,8 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput):
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
"""
loss: torch.FloatTensor | None = None
@@ -339,6 +504,7 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput):
past_key_values: Cache | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
@auto_docstring(
@@ -348,18 +514,12 @@ class GraniteSpeechPlusCausalLMOutputWithPast(ModelOutput):
"""
)
class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechPlusPreTrainedModel, GenerationMixin):
- _supports_attention_backend = True
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
def __init__(self, config: GraniteSpeechPlusConfig):
super().__init__(config)
- # NOTE: It doesn't matter when we initialize from config, but we should be careful
- # to make sure this does not pick up the adapter_config if in the future we use
- # from_pretrained or something similar, since that should be set by the composite
- # model; don't need to consider it twice
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
-
- self.encoder = GraniteSpeechPlusCTCEncoder(config.encoder_config)
- self.projector = GraniteSpeechPlusEncoderProjector(config)
+ self.model = GraniteSpeechPlusModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
if config.has_lora_adapter and not is_peft_available():
logger.warning(
@@ -371,29 +531,10 @@ def __init__(self, config: GraniteSpeechPlusConfig):
self.post_init()
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
+ def get_audio_features(self, *args, **kwargs):
+ return self.model.get_audio_features(*args, **kwargs)
@can_return_tuple
- @auto_docstring
- def get_audio_features(
- self, input_features: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
- ) -> tuple | BaseModelOutputWithPooling:
- audio_outputs = self.encoder(input_features, return_dict=True, **kwargs)
- projected_embeds = self.projector(audio_outputs.last_hidden_state)
- audio_outputs.pooler_output = projected_embeds
-
- return audio_outputs
-
@auto_docstring
def forward(
self,
@@ -406,12 +547,9 @@ def forward(
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
- **lm_kwargs,
- ) -> tuple[torch.Tensor] | GraniteSpeechPlusCausalLMOutputWithPast:
+ **kwargs,
+ ) -> tuple | GraniteSpeechPlusCausalLMOutputWithPast:
r"""
input_features_mask (`torch.Tensor`, *optional*):
Mask to be applied to audio features prior to scattering into the language embeddings.
@@ -421,55 +559,21 @@ def forward(
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
# TODO (@alex-jw-brooks) add an example to this docstring once models are released
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if input_features is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both input_features and inputs_embeds at the same time, and must specify either one"
- )
-
- if inputs_embeds is None:
- # Get the base embeddings; set all audio tokens to 0 index
- # to avoid out of vocabulary issues with the LLM embedding.
- # Audio features will be masked into is_audio_idx indices later.
- is_audio_idx = input_ids == self.config.audio_token_id
- llm_input_ids = input_ids.clone()
- llm_input_ids[is_audio_idx] = 0
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
-
- if input_features is not None:
- if input_features.dtype != self.dtype:
- input_features = input_features.to(self.dtype)
- # Get the audio features from the encoder / projector
- audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
-
- # Merge the audio features into the LLM embeddings
- inputs_embeds = self.get_merged_audio_embeddings(
- input_ids=input_ids,
- audio_features=audio_embeds,
- input_features_mask=input_features_mask,
- )
-
- outputs = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ input_features_mask=input_features_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- logits_to_keep=logits_to_keep,
- **lm_kwargs,
+ **kwargs,
)
- logits = outputs[0]
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@@ -489,16 +593,13 @@ def forward(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
return GraniteSpeechPlusCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
+ audio_hidden_states=outputs.audio_hidden_states,
)
def prepare_inputs_for_generation(
@@ -514,7 +615,7 @@ def prepare_inputs_for_generation(
):
# Overwritten -- in specific circumstances we don't want to forward audio inputs to the model
- model_inputs = self.language_model.prepare_inputs_for_generation(
+ model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -531,60 +632,6 @@ def prepare_inputs_for_generation(
model_inputs["input_features"] = input_features
return model_inputs
- def get_placeholder_mask(
- self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
- ):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
- equal to the length of multimodal features. If the lengths are different, an error is raised.
- """
- if input_ids is None:
- special_audio_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_audio_mask = special_audio_mask.all(-1)
- else:
- special_audio_mask = input_ids == self.config.audio_token_id
-
- n_audio_tokens = special_audio_mask.sum()
- n_audio_features = audio_features.shape[0]
- special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
- f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
- )
- return special_audio_mask
-
- def get_merged_audio_embeddings(
- self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None
- ) -> torch.Tensor:
- """
- Adds the audio token to the model's LLM vocabulary so that we can pass it
- through the tokenizer; it's assumed that the embeddings corresponding to the
- <|audio|> token will be clobbered with speech features.
-
- Args:
- input_ids (`torch.Tensor`):
- Input IDs containing one or more audio tokens.
- audio_features (`torch.Tensor`):
- Audio features to be masked into the language embeddings to form multimodal embeddings.
- input_features_mask (`torch.Tensor`, *optional*, defaults to `None`)
- Mask to be applied to audio features prior to scattering into the language embeddings.
- """
- is_audio_index = input_ids == self.config.audio_token_id
- llm_input_ids = torch.where(is_audio_index, 0, input_ids)
- inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]
-
- audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
- if input_features_mask is not None:
- audio_features = audio_features[input_features_mask]
-
- special_audio_mask = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
- return inputs_embeds
-
def generate(self, *args, **kwargs) -> torch.LongTensor:
# This model is expected to have a lora adapter, which is only
# enabled when considering audio inputs. As such, we override generate
@@ -616,6 +663,7 @@ def _get_adapter_name(self):
__all__ = [
+ "GraniteSpeechPlusModel",
"GraniteSpeechPlusCTCEncoder",
"GraniteSpeechPlusForConditionalGeneration",
"GraniteSpeechPlusPreTrainedModel",
diff --git a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py
index aa4f966246d2..f5865379fba1 100644
--- a/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py
+++ b/src/transformers/models/granite_speech_plus/modular_granite_speech_plus.py
@@ -27,6 +27,7 @@
from ..granite_speech.modeling_granite_speech import (
GraniteSpeechCTCEncoder,
GraniteSpeechForConditionalGeneration,
+ GraniteSpeechModel,
GraniteSpeechPreTrainedModel,
)
@@ -122,6 +123,9 @@ def __post_init__(self, **kwargs):
class GraniteSpeechPlusPreTrainedModel(GraniteSpeechPreTrainedModel): ...
+class GraniteSpeechPlusModel(GraniteSpeechModel): ...
+
+
class GraniteSpeechPlusCTCEncoder(GraniteSpeechCTCEncoder):
@merge_with_config_defaults
@capture_outputs
@@ -166,6 +170,7 @@ class GraniteSpeechPlusForConditionalGeneration(GraniteSpeechForConditionalGener
__all__ = [
"GraniteSpeechPlusConfig",
"GraniteSpeechPlusEncoderConfig",
+ "GraniteSpeechPlusModel",
"GraniteSpeechPlusCTCEncoder",
"GraniteSpeechPlusForConditionalGeneration",
"GraniteSpeechPlusPreTrainedModel",
diff --git a/src/transformers/models/musicflamingo/configuration_musicflamingo.py b/src/transformers/models/musicflamingo/configuration_musicflamingo.py
index 825039a0eb4e..5a20c1286d32 100644
--- a/src/transformers/models/musicflamingo/configuration_musicflamingo.py
+++ b/src/transformers/models/musicflamingo/configuration_musicflamingo.py
@@ -66,6 +66,7 @@ class MusicFlamingoConfig(PreTrainedConfig):
audio_token_id: int = 151669
projector_hidden_act: str = "gelu"
projector_bias: bool = True
+ tie_word_embeddings: bool = False
audio_bos_token_id: int = 151670
audio_eos_token_id: int = 151671
diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py
index 45e3d9e707d8..f59f3422887d 100644
--- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py
+++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py
@@ -20,6 +20,7 @@
# limitations under the License.
from collections.abc import Callable
+from dataclasses import dataclass
from math import pi
from typing import Optional
@@ -29,12 +30,12 @@
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import GenerationMixin
-from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_musicflamingo import MusicFlamingoConfig
@@ -141,6 +142,7 @@ class MusicFlamingoPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
+ _supports_attention_backend = True
@torch.no_grad()
def _init_weights(self, module):
@@ -150,6 +152,16 @@ def _init_weights(self, module):
init.copy_(module.position_angles, buffer_value)
+@dataclass
+class MusicFlamingoModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
class MusicFlamingoMultiModalProjector(nn.Module):
"""
Audio adaptor (small MLP) that projects MusicFlamingoEncoder features
@@ -195,39 +207,24 @@ def apply_rotary_time_emb(hidden_states, cos, sin):
@auto_docstring(
custom_intro="""
- The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model.
+ The MusicFlamingo model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model),
+ without a language modeling head.
"""
)
-class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = None
- _supports_attention_backend = True
+class MusicFlamingoModel(MusicFlamingoPreTrainedModel):
_tp_plan = None
_sp_plan = None
_pp_plan = None
+ _keep_in_fp32_modules_strict = None
def __init__(self, config: MusicFlamingoConfig):
super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
self.audio_tower = AutoModel.from_config(config.audio_config)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
self.multi_modal_projector = MusicFlamingoMultiModalProjector(config)
self.pos_emb = MusicFlamingoRotaryEmbedding(config)
-
- # Initialize weights and apply final processing
self.post_init()
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
@can_return_tuple
@auto_docstring(
custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
@@ -299,65 +296,17 @@ def forward(
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
+ ) -> tuple | MusicFlamingoModelOutputWithPast:
r"""
- input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
- Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Example:
-
- ```python
- >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor
-
- >>> model_id = "nvidia/music-flamingo-2601-hf"
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto")
-
- >>> conversation = [
- >>> {
- >>> "role": "user",
- >>> "content": [
- >>> {
- >>> "type": "text",
- >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.",
- >>> },
- >>> {
- >>> "type": "audio",
- >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3",
- >>> },
- >>> ],
- >>> }
- >>> ]
-
- >>> inputs = processor.apply_chat_template(
- >>> conversation,
- >>> tokenize=True,
- >>> add_generation_prompt=True,
- >>> return_dict=True,
- >>> ).to(model.device, model.dtype)
-
- >>> outputs = model.generate(**inputs, max_new_tokens=100)
-
- >>> decoded_outputs = processor.batch_decode(
- >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
- >>> )
- >>> print(decoded_outputs)
- ["This track is an uplifting Eurodance-style Trance-Pop anthem..."]
- ```"""
+ input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
+ Mask to avoid performing attention on padding feature indices.
+ """
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
+ audio_embeds = None
if input_features is not None and input_ids is not None:
audio_embeds = self.get_audio_features(
input_features, input_features_mask, input_ids=input_ids, return_dict=True
@@ -369,31 +318,22 @@ def forward(
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
- labels=labels,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
**kwargs,
)
- return outputs
- def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
- input_features = kwargs.pop("input_features", None)
- input_features_mask = kwargs.pop("input_features_mask", None)
-
- model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
-
- if is_first_iteration or not model_inputs.get("use_cache", False):
- if input_features is not None:
- model_inputs["input_features"] = input_features
- if input_features_mask is not None:
- model_inputs["input_features_mask"] = input_features_mask
-
- return model_inputs
+ return MusicFlamingoModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
def _build_audio_timestamps(
self,
@@ -437,4 +377,125 @@ def _build_audio_timestamps(
return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets
-__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoPreTrainedModel"]
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for MusicFlamingo causal language model (or autoregressive) outputs.
+ """
+)
+class MusicFlamingoCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head.
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Hidden states of the audio encoder after projection.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model.
+ """
+)
+class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin):
+ _keep_in_fp32_modules_strict = ["embed_positions"]
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config: MusicFlamingoConfig):
+ super().__init__(config)
+ self.model = MusicFlamingoModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, input_features, input_features_mask, input_ids, **kwargs):
+ return self.model.get_audio_features(input_features, input_features_mask, input_ids, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ input_features_mask: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | MusicFlamingoCausalLMOutputWithPast:
+ r"""
+ input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
+ Mask to avoid performing attention on padding feature indices.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss.
+
+ Example:
+
+ ```python
+ >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor
+
+ >>> model_id = "nvidia/audio-flamingo-3-hf"
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto")
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ input_features_mask=input_features_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return MusicFlamingoCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=outputs.audio_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
+ input_features = kwargs.pop("input_features", None)
+ input_features_mask = kwargs.pop("input_features_mask", None)
+
+ model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
+
+ if is_first_iteration or not model_inputs.get("use_cache", False):
+ if input_features is not None:
+ model_inputs["input_features"] = input_features
+ if input_features_mask is not None:
+ model_inputs["input_features_mask"] = input_features_mask
+
+ return model_inputs
+
+
+__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoModel", "MusicFlamingoPreTrainedModel"]
diff --git a/src/transformers/models/musicflamingo/modular_musicflamingo.py b/src/transformers/models/musicflamingo/modular_musicflamingo.py
index 8ca8fcb2462a..325845ad4da9 100644
--- a/src/transformers/models/musicflamingo/modular_musicflamingo.py
+++ b/src/transformers/models/musicflamingo/modular_musicflamingo.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from dataclasses import dataclass
from math import pi
from huggingface_hub.dataclasses import strict
@@ -21,13 +22,15 @@
from ... import initialization as init
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
-from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
+from ...modeling_outputs import BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check
from ..audioflamingo3.configuration_audioflamingo3 import AudioFlamingo3Config
from ..audioflamingo3.modeling_audioflamingo3 import (
AudioFlamingo3ForConditionalGeneration,
+ AudioFlamingo3Model,
+ AudioFlamingo3ModelOutputWithPast,
AudioFlamingo3PreTrainedModel,
)
from ..audioflamingo3.processing_audioflamingo3 import AudioFlamingo3Processor
@@ -75,6 +78,7 @@ class MusicFlamingoConfig(AudioFlamingo3Config):
audio_eos_token_id: int = 151671
audio_frame_step: float = 0.01
rope_parameters: dict | None = None
+ tie_word_embeddings: bool = False
def __post_init__(self, **kwargs):
if self.rope_parameters is None:
@@ -231,12 +235,12 @@ def _init_weights(self, module):
init.copy_(module.position_angles, buffer_value)
-@auto_docstring(
- custom_intro="""
- The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model.
- """
-)
-class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
+@dataclass
+class MusicFlamingoModelOutputWithPast(AudioFlamingo3ModelOutputWithPast):
+ pass
+
+
+class MusicFlamingoModel(AudioFlamingo3Model):
def __init__(self, config: MusicFlamingoConfig):
super().__init__(config)
self.pos_emb = MusicFlamingoRotaryEmbedding(config)
@@ -329,65 +333,17 @@ def forward(
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
+ ):
r"""
- input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
- Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Example:
-
- ```python
- >>> from transformers import MusicFlamingoForConditionalGeneration, AutoProcessor
-
- >>> model_id = "nvidia/music-flamingo-2601-hf"
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = MusicFlamingoForConditionalGeneration.from_pretrained(model_id, device_map="auto")
-
- >>> conversation = [
- >>> {
- >>> "role": "user",
- >>> "content": [
- >>> {
- >>> "type": "text",
- >>> "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.",
- >>> },
- >>> {
- >>> "type": "audio",
- >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3",
- >>> },
- >>> ],
- >>> }
- >>> ]
-
- >>> inputs = processor.apply_chat_template(
- >>> conversation,
- >>> tokenize=True,
- >>> add_generation_prompt=True,
- >>> return_dict=True,
- >>> ).to(model.device, model.dtype)
-
- >>> outputs = model.generate(**inputs, max_new_tokens=100)
-
- >>> decoded_outputs = processor.batch_decode(
- >>> outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
- >>> )
- >>> print(decoded_outputs)
- ["This track is an uplifting Eurodance-style Trance-Pop anthem..."]
- ```"""
+ input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
+ Mask to avoid performing attention on padding feature indices.
+ """
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
+ audio_embeds = None
if input_features is not None and input_ids is not None:
audio_embeds = self.get_audio_features(
input_features, input_features_mask, input_ids=input_ids, return_dict=True
@@ -399,22 +355,43 @@ def forward(
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
- labels=labels,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
**kwargs,
)
- return outputs
+
+ return MusicFlamingoModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The MusicFlamingo model which consists of a fine-tuned Whisper encoder, rotary time embedding, a multi-modal projector, and a Qwen2 language model.
+ """
+)
+class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
+ def __init__(self, config: MusicFlamingoConfig):
+ super().__init__(config)
+ self.model = MusicFlamingoModel(config)
+ self.post_init()
+
+ def get_audio_features(self, input_features, input_features_mask, input_ids, **kwargs):
+ return self.model.get_audio_features(input_features, input_features_mask, input_ids, **kwargs)
__all__ = [
"MusicFlamingoConfig",
"MusicFlamingoProcessor",
"MusicFlamingoForConditionalGeneration",
+ "MusicFlamingoModel",
"MusicFlamingoPreTrainedModel",
]
diff --git a/src/transformers/models/parakeet/processing_parakeet.py b/src/transformers/models/parakeet/processing_parakeet.py
index 10563490345d..d7a4fb4ed457 100644
--- a/src/transformers/models/parakeet/processing_parakeet.py
+++ b/src/transformers/models/parakeet/processing_parakeet.py
@@ -44,7 +44,7 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False):
@auto_docstring
class ParakeetProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer, blank_token=""):
- """
+ r"""
blank_token (`str`, *optional*, defaults to `""`):
Blank token for TDT decoding.
"""
diff --git a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py
index 6aec9eace900..749f24123bb4 100644
--- a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py
+++ b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py
@@ -98,6 +98,7 @@ class Qwen2AudioConfig(PreTrainedConfig):
audio_config: dict | PreTrainedConfig | None = None
text_config: dict | PreTrainedConfig | None = None
audio_token_index: int = 151646
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.audio_config, dict):
diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
index 170187291ff5..841c41993e81 100644
--- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
+++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
@@ -25,19 +25,44 @@
from ...generation import GenerationMixin
from ...masking_utils import create_bidirectional_mask
from ...modeling_layers import GradientCheckpointingLayer
-from ...modeling_outputs import BaseModelOutput, ModelOutput
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging, torch_compilable_check
from ...utils.generic import can_return_tuple, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
logger = logging.get_logger(__name__)
+@auto_docstring(
+ custom_intro="""
+ Base class for Qwen2Audio outputs, with hidden states and attentions.
+ """
+)
+@dataclass
+class Qwen2AudioModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Attention mask, potentially updated by the audio merging logic so that audio tokens are unmasked.
+ labels (`torch.LongTensor`, *optional*):
+ Labels, potentially re-aligned by the legacy audio merging logic. Returned so the language-modeling
+ head can compute the loss against the expanded sequence.
+ """
+
+ attention_mask: torch.FloatTensor | None = None
+ labels: torch.LongTensor | None = None
+
+
+@dataclass
@auto_docstring(
custom_intro="""
Base class for Qwen2Audio causal language model (or autoregressive) outputs.
@@ -394,17 +419,17 @@ def forward(self, audio_features):
@auto_docstring(
custom_intro="""
- The QWEN2AUDIO model which consists of a audio backbone and a language model.
+ The Qwen2Audio model which consists of an audio backbone and a language model, without a language modeling head.
"""
)
-class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin):
+class Qwen2AudioModel(Qwen2AudioPreTrainedModel):
def __init__(self, config: Qwen2AudioConfig):
super().__init__(config)
self.audio_tower = AutoModel.from_config(config.audio_config) # Usually a `Qwen2AudioEncoder` instance
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = (
self.config.text_config.pad_token_id if self.config.text_config.pad_token_id is not None else -1
@@ -422,18 +447,6 @@ def padding_side(self, padding_side: str):
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
def _merge_input_ids_with_audio_features(
self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels
):
@@ -645,7 +658,7 @@ def forward(
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Qwen2AudioCausalLMOutputWithPast:
+ ) -> tuple | Qwen2AudioModelOutputWithPast:
r"""
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
@@ -653,32 +666,9 @@ def forward(
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Example:
-
- ```python
- >>> from io import BytesIO
- >>> from urllib.request import urlopen
- >>> import librosa
- >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
-
- >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B")
- >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B")
-
- >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
- >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
- >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate)
-
- >>> inputs = processor(text=prompt, audio=audio, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(**inputs, max_length=30)
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Generate the caption in English: Glass is breaking."
- ```"""
+ Labels kept in the signature for the legacy merge path that may re-align them with audio tokens.
+ The loss is not computed here; `Qwen2AudioForConditionalGeneration` is responsible for that.
+ """
target_device = self.audio_tower.device
@@ -761,7 +751,103 @@ def forward(
**kwargs,
)
- logits = outputs.logits
+ return Qwen2AudioModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ attention_mask=attention_mask,
+ labels=labels,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The QWEN2AUDIO model which consists of an audio backbone and a language model.
+ """
+)
+class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config: Qwen2AudioConfig):
+ super().__init__(config)
+ self.model = Qwen2AudioModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ @property
+ def padding_side(self):
+ return self.model.padding_side
+
+ @padding_side.setter
+ def padding_side(self, padding_side: str):
+ self.model.padding_side = padding_side
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ feature_attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Qwen2AudioCausalLMOutputWithPast:
+ r"""
+ feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from io import BytesIO
+ >>> from urllib.request import urlopen
+ >>> import librosa
+ >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
+
+ >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B")
+
+ >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
+ >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
+ >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate)
+
+ >>> inputs = processor(text=prompt, audio=audio, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Generate the caption in English: Glass is breaking."
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ attention_mask=attention_mask,
+ feature_attention_mask=feature_attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ logits = self.lm_head(hidden_states)
+ attention_mask = outputs.attention_mask
+ labels = outputs.labels if outputs.labels is not None else labels
loss = None
if labels is not None:
@@ -803,4 +889,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
return model_inputs
-__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]
+__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder", "Qwen2AudioModel"]
diff --git a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py
index a673a5845871..11856b20b851 100644
--- a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py
+++ b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py
@@ -75,6 +75,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig):
audio_bos_token_id: int = 151646
audio_eos_token_id: int = 151647
acoustic_tokenizer_chunk_size: int = 1440000
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.acoustic_tokenizer_encoder_config, dict):
diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py
index 4e60c72ccd40..b7d2902aa6c1 100644
--- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py
+++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py
@@ -17,6 +17,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from dataclasses import dataclass
+
import torch
from torch import nn
@@ -25,17 +27,11 @@
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
-from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
-from ...utils import (
- TransformersKwargs,
- auto_docstring,
- can_return_tuple,
- is_torchdynamo_compiling,
- torch_compilable_check,
-)
-from ..auto import AutoModel, AutoModelForCausalLM
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
+from ..auto import AutoModel
from .configuration_vibevoice_asr import VibeVoiceAsrConfig
@@ -247,6 +243,7 @@ class VibeVoiceAsrPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
+ _supports_attention_backend = True
def _init_weights(self, module):
super()._init_weights(module)
@@ -255,40 +252,69 @@ def _init_weights(self, module):
init.constant_(module.ffn_gamma, self.config.layer_scale_init_value)
+@dataclass
@auto_docstring(
custom_intro="""
- The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model.
+ Base class for VibeVoice ASR outputs, with hidden states and attentions.
"""
)
-class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = None
- _supports_attention_backend = True
- _tp_plan = None
- _sp_plan = None
- _pp_plan = None
+class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ audio_hidden_states: torch.FloatTensor | None = None
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for VibeVoice ASR causal language model outputs.
+ """
+)
+class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores.
+ past_key_values (`Cache`, *optional*):
+ Cache instance.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model),
+ without a language modeling head.
+ """
+)
+class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel):
def __init__(self, config: VibeVoiceAsrConfig):
super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
- self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config)
self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config)
self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config)
-
- # Initialize weights and apply final processing
+ self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
- def get_decoder(self):
- return self.language_model.get_decoder()
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
@can_return_tuple
@auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.")
@@ -298,17 +324,15 @@ def get_audio_features(
padding_mask: torch.BoolTensor | None = None,
acoustic_tokenizer_chunk_size: int | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
+ ):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`):
- Input audio tensor. Audio should be sampled at 24kHz.
+ Input audio tensor sampled at 24kHz.
padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing operations on padding feature indices.
acoustic_tokenizer_chunk_size (`int`, *optional*):
- Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`,
- but can be modified to fit the available memory.
+ Size of audio chunks to process at once through the tokenizers.
"""
-
if acoustic_tokenizer_chunk_size is None:
acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size
else:
@@ -353,7 +377,6 @@ def get_audio_features(
combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents)
if padding_mask is not None:
- # Adjust padding mask according to tokenizer compression
num_audio_tokens = torch.ceil(
padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length
).to(torch.int64)
@@ -364,29 +387,73 @@ def get_audio_features(
return BaseModelOutputWithPooling(last_hidden_state=acoustic_latents, pooler_output=combined_features)
- def get_placeholder_mask(
- self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
- ):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
- equal to the length of multimodal features. If the lengths are different, an error is raised.
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ input_values: torch.FloatTensor | None = None,
+ padding_mask: torch.BoolTensor | None = None,
+ acoustic_tokenizer_chunk_size: int | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | VibeVoiceAsrModelOutputWithPast:
+ r"""
+ padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing operations on padding feature indices.
+ acoustic_tokenizer_chunk_size (`int`, *optional*):
+ Size of audio chunks processed by the acoustic and semantic tokenizers.
"""
- if input_ids is None:
- special_audio_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ audio_embeds = None
+ if input_values is not None and input_ids is not None:
+ audio_embeds = self.get_audio_features(
+ input_values=input_values,
+ padding_mask=padding_mask,
+ acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size,
+ ).pooler_output
+
+ audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
+ inputs_embeds = inputs_embeds.masked_scatter(
+ audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
)
- special_audio_mask = special_audio_mask.all(-1)
- else:
- special_audio_mask = input_ids == self.config.audio_token_id
-
- n_audio_tokens = special_audio_mask.sum()
- n_audio_features = audio_features.shape[0]
- special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
- f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
+
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ **kwargs,
)
- return special_audio_mask
+
+ return VibeVoiceAsrModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model.
+ """
+)
+class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config: VibeVoiceAsrConfig):
+ super().__init__(config)
+ self.model = VibeVoiceAsrModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, *args, **kwargs):
+ return self.model.get_audio_features(*args, **kwargs)
@can_return_tuple
@auto_docstring
@@ -399,14 +466,15 @@ def forward(
input_values: torch.FloatTensor | None = None,
padding_mask: torch.BoolTensor | None = None,
acoustic_tokenizer_chunk_size: int | None = None,
+ labels: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
+ ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast:
r"""
padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing operations on padding feature indices.
acoustic_tokenizer_chunk_size (`int`, *optional*):
- Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to
- `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory.
+ Size of audio chunks processed by the acoustic and semantic tokenizers.
Example:
@@ -416,33 +484,35 @@ def forward(
>>> model_id = "microsoft/VibeVoice-ASR-HF"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto")
-
- >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
- >>> inputs = inputs.to(model.device, dtype=model.dtype)
- >>> outputs = model.generate(**inputs)
-
- >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
- >>> print(decoded_outputs)
```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ input_values=input_values,
+ padding_mask=padding_mask,
+ acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size,
+ **kwargs,
+ )
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if input_values is not None and input_ids is not None:
- audio_embeds = self.get_audio_features(
- input_values=input_values,
- padding_mask=padding_mask,
- acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size,
- ).pooler_output
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
- # Replace text-audio token placeholders with audio embeddings
- audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
- inputs_embeds = inputs_embeds.masked_scatter(
- audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
- return self.language_model(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs
+ return VibeVoiceAsrCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=outputs.audio_hidden_states,
)
def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs):
@@ -463,4 +533,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg
return model_inputs
-__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrPreTrainedModel"]
+__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrModel", "VibeVoiceAsrPreTrainedModel"]
diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py
index 52652de9adde..c664823c039a 100644
--- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py
+++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py
@@ -11,16 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from dataclasses import dataclass
+
import torch
from huggingface_hub.dataclasses import strict
from torch import nn
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
-from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
+from ...generation import GenerationMixin
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ ModelOutput,
+)
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
-from ..audioflamingo3.modeling_audioflamingo3 import AudioFlamingo3ForConditionalGeneration
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
from ..vibevoice_acoustic_tokenizer.modeling_vibevoice_acoustic_tokenizer import (
@@ -82,6 +88,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig):
audio_bos_token_id: int = 151646
audio_eos_token_id: int = 151647
acoustic_tokenizer_chunk_size: int = 1440000
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.acoustic_tokenizer_encoder_config, dict):
@@ -159,21 +166,72 @@ class VibeVoiceAsrPreTrainedModel(VibeVoiceAcousticTokenizerPreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
+ _supports_attention_backend = True
+@dataclass
@auto_docstring(
custom_intro="""
- The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model.
+ Base class for VibeVoice ASR outputs, with hidden states and attentions.
"""
)
-class VibeVoiceAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
- _supports_attention_backend = True
+class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for VibeVoice ASR causal language model outputs.
+ """
+)
+class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores.
+ past_key_values (`Cache`, *optional*):
+ Cache instance.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
+
+@auto_docstring(
+ custom_intro="""
+ The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model),
+ without a language modeling head.
+ """
+)
+class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel):
def __init__(self, config: VibeVoiceAsrConfig):
super().__init__(config)
self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config)
self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config)
- del self.audio_tower
+ self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
@can_return_tuple
@auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.")
@@ -186,14 +244,12 @@ def get_audio_features(
):
r"""
input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`):
- Input audio tensor. Audio should be sampled at 24kHz.
+ Input audio tensor sampled at 24kHz.
padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing operations on padding feature indices.
acoustic_tokenizer_chunk_size (`int`, *optional*):
- Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`,
- but can be modified to fit the available memory.
+ Size of audio chunks to process at once through the tokenizers.
"""
-
if acoustic_tokenizer_chunk_size is None:
acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size
else:
@@ -238,7 +294,6 @@ def get_audio_features(
combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents)
if padding_mask is not None:
- # Adjust padding mask according to tokenizer compression
num_audio_tokens = torch.ceil(
padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length
).to(torch.int64)
@@ -261,34 +316,17 @@ def forward(
padding_mask: torch.BoolTensor | None = None,
acoustic_tokenizer_chunk_size: int | None = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
+ ) -> tuple | VibeVoiceAsrModelOutputWithPast:
r"""
padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing operations on padding feature indices.
acoustic_tokenizer_chunk_size (`int`, *optional*):
- Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to
- `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory.
-
- Example:
-
- ```python
- >>> from transformers import VibeVoiceAsrForConditionalGeneration, AutoProcessor
-
- >>> model_id = "microsoft/VibeVoice-ASR-HF"
- >>> processor = AutoProcessor.from_pretrained(model_id)
- >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto")
-
- >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
- >>> inputs = inputs.to(model.device, dtype=model.dtype)
- >>> outputs = model.generate(**inputs)
-
- >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
- >>> print(decoded_outputs)
- ```"""
-
+ Size of audio chunks processed by the acoustic and semantic tokenizers.
+ """
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
+ audio_embeds = None
if input_values is not None and input_ids is not None:
audio_embeds = self.get_audio_features(
input_values=input_values,
@@ -296,14 +334,102 @@ def forward(
acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size,
).pooler_output
- # Replace text-audio token placeholders with audio embeddings
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
inputs_embeds = inputs_embeds.masked_scatter(
audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
)
- return self.language_model(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ **kwargs,
+ )
+
+ return VibeVoiceAsrModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model.
+ """
+)
+class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config: VibeVoiceAsrConfig):
+ super().__init__(config)
+ self.model = VibeVoiceAsrModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, *args, **kwargs):
+ return self.model.get_audio_features(*args, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ input_values: torch.FloatTensor | None = None,
+ padding_mask: torch.BoolTensor | None = None,
+ acoustic_tokenizer_chunk_size: int | None = None,
+ labels: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast:
+ r"""
+ padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing operations on padding feature indices.
+ acoustic_tokenizer_chunk_size (`int`, *optional*):
+ Size of audio chunks processed by the acoustic and semantic tokenizers.
+
+ Example:
+
+ ```python
+ >>> from transformers import VibeVoiceAsrForConditionalGeneration, AutoProcessor
+
+ >>> model_id = "microsoft/VibeVoice-ASR-HF"
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+ >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto")
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ input_values=input_values,
+ padding_mask=padding_mask,
+ acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return VibeVoiceAsrCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=outputs.audio_hidden_states,
)
def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs):
@@ -327,5 +453,6 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg
__all__ = [
"VibeVoiceAsrConfig",
"VibeVoiceAsrForConditionalGeneration",
+ "VibeVoiceAsrModel",
"VibeVoiceAsrPreTrainedModel",
]
diff --git a/src/transformers/models/voxtral/configuration_voxtral.py b/src/transformers/models/voxtral/configuration_voxtral.py
index 2ecbedfc1a9e..b476d80dd976 100644
--- a/src/transformers/models/voxtral/configuration_voxtral.py
+++ b/src/transformers/models/voxtral/configuration_voxtral.py
@@ -110,6 +110,7 @@ class VoxtralConfig(PreTrainedConfig):
text_config: dict | PreTrainedConfig | None = None
audio_token_id: int | None = None
projector_hidden_act: str = "gelu"
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.audio_config, dict):
diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py
index f6aa02fd6c00..9fedfdd1345a 100644
--- a/src/transformers/models/voxtral/modeling_voxtral.py
+++ b/src/transformers/models/voxtral/modeling_voxtral.py
@@ -21,6 +21,7 @@
import math
from collections.abc import Callable
+from dataclasses import dataclass
import torch
from torch import nn
@@ -35,7 +36,7 @@
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from .configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig
@@ -359,36 +360,35 @@ def forward(self, audio_features):
return hidden_states
+@dataclass
@auto_docstring(
custom_intro="""
- The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
+ Base class for Voxtral outputs, with hidden states and attentions.
"""
)
-class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = ["embed_positions"]
+class VoxtralModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model,
+ without a language modeling head.
+ """
+)
+class VoxtralModel(VoxtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
self.audio_tower = AutoModel.from_config(config.audio_config)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
self.multi_modal_projector = VoxtralMultiModalProjector(config)
-
- # Initialize weights and apply final processing
self.post_init()
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
@can_return_tuple
@auto_docstring(
custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
@@ -436,6 +436,68 @@ def get_placeholder_mask(
)
return special_audio_mask
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | VoxtralModelOutputWithPast:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ audio_embeds = None
+ if input_features is not None and input_ids is not None:
+ audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
+
+ # replace text-audio token placeholders with audio embeddings
+ special_audio_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
+
+ outputs: BaseModelOutputWithPast = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ return VoxtralModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model.
+ """
+)
+class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
+ _keep_in_fp32_modules_strict = ["embed_positions"]
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = VoxtralModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, *args, **kwargs):
+ return self.model.get_audio_features(*args, **kwargs)
+
@can_return_tuple
@auto_docstring
def forward(
@@ -450,7 +512,7 @@ def forward(
use_cache: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
+ ) -> tuple | CausalLMOutputWithPast:
r"""
Example:
@@ -484,29 +546,34 @@ def forward(
>>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."]
```"""
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if input_features is not None and input_ids is not None:
- audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
-
- # replace text-audio token placeholders with audio embeddings
- special_audio_mask = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
-
- outputs: BaseModelOutputWithPast = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
- labels=labels,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
**kwargs,
)
- return outputs
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
def prepare_inputs_for_generation(self, *args, **kwargs):
# Overwritten -- we should not pass input_features when we are in cached decoding stage
@@ -523,4 +590,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
return model_inputs
-__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"]
+__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"]
diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py
index c1089db09944..01abc9052d9f 100644
--- a/src/transformers/models/voxtral/modular_voxtral.py
+++ b/src/transformers/models/voxtral/modular_voxtral.py
@@ -13,6 +13,8 @@
# limitations under the License.
+from dataclasses import dataclass
+
import torch
from torch import nn
@@ -28,7 +30,7 @@
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from ..auto import AutoModel, AutoModelForCausalLM
+from ..auto import AutoModel
from ..qwen2_audio.modeling_qwen2_audio import (
Qwen2AudioAttention,
Qwen2AudioEncoder,
@@ -128,36 +130,35 @@ def forward(self, audio_features):
return hidden_states
+@dataclass
@auto_docstring(
custom_intro="""
- The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
+ Base class for Voxtral outputs, with hidden states and attentions.
"""
)
-class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = ["embed_positions"]
+class VoxtralModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states.
+ """
+
+ audio_hidden_states: torch.FloatTensor | None = None
+
+@auto_docstring(
+ custom_intro="""
+ The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model,
+ without a language modeling head.
+ """
+)
+class VoxtralModel(VoxtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
self.audio_tower = AutoModel.from_config(config.audio_config)
- self.language_model = AutoModelForCausalLM.from_config(config.text_config)
+ self.language_model = AutoModel.from_config(config.text_config)
self.multi_modal_projector = VoxtralMultiModalProjector(config)
-
- # Initialize weights and apply final processing
self.post_init()
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
@can_return_tuple
@auto_docstring(
custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
@@ -205,6 +206,68 @@ def get_placeholder_mask(
)
return special_audio_mask
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | VoxtralModelOutputWithPast:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ audio_embeds = None
+ if input_features is not None and input_ids is not None:
+ audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
+
+ # replace text-audio token placeholders with audio embeddings
+ special_audio_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
+
+ outputs: BaseModelOutputWithPast = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ return VoxtralModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model.
+ """
+)
+class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
+ _keep_in_fp32_modules_strict = ["embed_positions"]
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = VoxtralModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, *args, **kwargs):
+ return self.model.get_audio_features(*args, **kwargs)
+
@can_return_tuple
@auto_docstring
def forward(
@@ -219,7 +282,7 @@ def forward(
use_cache: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
+ ) -> tuple | CausalLMOutputWithPast:
r"""
Example:
@@ -253,29 +316,34 @@ def forward(
>>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."]
```"""
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- if input_features is not None and input_ids is not None:
- audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output
-
- # replace text-audio token placeholders with audio embeddings
- special_audio_mask = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))
-
- outputs: BaseModelOutputWithPast = self.language_model(
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
- labels=labels,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
**kwargs,
)
- return outputs
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
def prepare_inputs_for_generation(self, *args, **kwargs):
# Overwritten -- we should not pass input_features when we are in cached decoding stage
@@ -292,4 +360,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
return model_inputs
-__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"]
+__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"]
diff --git a/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py
index 568c6f8748b9..77130e54db78 100644
--- a/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py
+++ b/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py
@@ -176,6 +176,7 @@ class VoxtralRealtimeConfig(PreTrainedConfig):
audio_length_per_tok: int = 8
default_num_delay_tokens: int = 6
downsample_factor: int = 4
+ tie_word_embeddings: bool = True
def __post_init__(self, **kwargs):
if isinstance(self.audio_config, dict):
diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py
index d5d2a92f18f7..aebd752d2aa8 100644
--- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py
+++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py
@@ -39,17 +39,9 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
-from ...utils import (
- TransformersKwargs,
- auto_docstring,
- can_return_tuple,
- is_torchdynamo_compiling,
- logging,
- torch_compilable_check,
-)
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
-from ..auto import AutoModel
from .configuration_voxtral_realtime import (
VoxtralRealtimeConfig,
VoxtralRealtimeEncoderConfig,
@@ -125,6 +117,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast):
padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None
+@dataclass
+class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ Args:
+ encoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder
+ that can be used to speed up sequential decoding.
+ padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*):
+ Cache for padding in convolutional layers to maintain state across streaming chunks.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states before they are added to the text embeddings.
+ """
+
+ encoder_past_key_values: Cache | None = None
+ padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
@dataclass
class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast):
r"""
@@ -487,6 +497,7 @@ class VoxtralRealtimePreTrainedModel(PreTrainedModel):
_supports_attention_backend = True
# TODO: @eustlb, this should be enabled soon
_can_compile_fullgraph = False
+ _keep_in_fp32_modules_strict = None
@torch.no_grad()
def _init_weights(self, module):
@@ -827,82 +838,6 @@ def forward(
)
-@auto_docstring
-class VoxtralRealtimeTextForCausalLM(VoxtralRealtimeTextPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- _tp_plan = {"lm_head": "colwise_allgather"}
- _sp_plan = {"lm_head": "colwise_loss_parallel"}
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- _fsdp_plan = {"lm_head": "keep_full_weight"}
-
- def __init__(self, config):
- super().__init__(config)
- self.model = VoxtralRealtimeTextModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM
-
- >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602")
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
-
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
class VoxtralRealtimeTimeEmbedding(nn.Module):
"""Sinusoidal Embedding for encoding time"""
@@ -937,34 +872,26 @@ def forward(self, audio_features):
@auto_docstring(
custom_intro="""
- The VoxtralRealtime model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
+ The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector,
+ a Mistral-based language model and a time embedding, without a language modeling head.
"""
)
-class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin):
- _keep_in_fp32_modules_strict = None
-
+class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel):
def __init__(self, config):
super().__init__(config)
- self.vocab_size = config.text_config.vocab_size
- self.audio_tower = AutoModel.from_config(config.audio_config)
- self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config)
+ self.audio_tower = VoxtralRealtimeEncoder(config.audio_config)
+ self.language_model = VoxtralRealtimeTextModel(config.text_config)
self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config)
self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size)
# Initialize weights and apply final processing
self.post_init()
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
@can_return_tuple
@auto_docstring(
@@ -981,11 +908,7 @@ def get_audio_features(
) -> tuple | BaseModelOutputWithPooling:
r"""
input_features (`torch.FloatTensor`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
- `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`]
+ Float values of mel features extracted from the raw speech waveform.
padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*):
Cache for padding in convolutional layers to maintain state across streaming chunks.
encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
@@ -1010,30 +933,6 @@ def get_audio_features(
return audio_outputs
- def get_placeholder_mask(
- self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
- ):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
- equal to the length of multimodal features. If the lengths are different, an error is raised.
- """
- if input_ids is None:
- special_audio_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_audio_mask = special_audio_mask.all(-1)
- else:
- special_audio_mask = input_ids == self.config.audio_token_id
-
- n_audio_tokens = special_audio_mask.sum()
- n_audio_features = audio_features.shape[0]
- special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
- f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
- )
- return special_audio_mask
-
@can_return_tuple
@auto_docstring
def forward(
@@ -1047,43 +946,20 @@ def forward(
padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
encoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
num_delay_tokens: int | torch.Tensor = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> VoxtralRealtimeCausalLMOutputWithPast:
+ ) -> tuple | VoxtralRealtimeModelOutputWithPast:
r"""
encoder_past_key_values (`Cache`, *optional*):
- Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding.
+ Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder.
padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*):
Cache for padding in convolutional layers to maintain state across streaming chunks.
encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder.
num_delay_tokens (`int` or `torch.Tensor`, *optional*):
- Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details.
-
- Example:
-
- ```python
- >>> import torch
- >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor
- >>> from datasets import load_dataset
-
- >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602"
-
- >>> processor = AutoProcessor.from_pretrained(repo_id)
- >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto")
-
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> audio = ds[0]["audio"]["array"]
-
- >>> inputs = processor(audio, return_tensors="pt")
- >>> inputs = inputs.to(model.device, dtype=model.dtype)
-
- >>> outputs = model.generate(**inputs)
- >>> processor.batch_decode(outputs, skip_special_tokens=True)
- ```"""
+ Number of delay tokens used when preparing inputs.
+ """
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -1093,6 +969,8 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
+ audio_outputs = None
+ audio_embeds = None
if input_features is not None or encoder_inputs_embeds is not None:
audio_outputs = self.get_audio_features(
input_features=input_features,
@@ -1102,7 +980,8 @@ def forward(
use_cache=use_cache,
return_dict=True,
)
- inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device)
+ audio_embeds = audio_outputs.pooler_output
+ inputs_embeds += audio_embeds.to(inputs_embeds.device)
if num_delay_tokens is None:
num_delay_tokens = self.config.default_num_delay_tokens
@@ -1121,25 +1000,124 @@ def forward(
t_cond = self.time_embedding(time_tensor)
t_cond = t_cond[None, ...] # broadcastable to batch size
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs: BaseModelOutputWithPast = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
- labels=labels,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
t_cond=t_cond,
**kwargs,
)
+
+ return VoxtralRealtimeModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ encoder_past_key_values=audio_outputs.past_key_values
+ if (audio_outputs is not None and use_cache)
+ else None,
+ padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = VoxtralRealtimeModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, *args, **kwargs):
+ return self.model.get_audio_features(*args, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ encoder_past_key_values: Cache | None = None,
+ padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ encoder_inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ num_delay_tokens: int | torch.Tensor = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast:
+ r"""
+ encoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding.
+ padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*):
+ Cache for padding in convolutional layers to maintain state across streaming chunks.
+ encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
+ Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder.
+ num_delay_tokens (`int` or `torch.Tensor`, *optional*):
+ Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor
+ >>> from datasets import load_dataset
+
+ >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602"
+
+ >>> processor = AutoProcessor.from_pretrained(repo_id)
+ >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> audio = ds[0]["audio"]["array"]
+
+ >>> inputs = processor(audio, return_tensors="pt")
+ >>> inputs = inputs.to(model.device, dtype=model.dtype)
+
+ >>> outputs = model.generate(**inputs)
+ >>> processor.batch_decode(outputs, skip_special_tokens=True)
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ encoder_past_key_values=encoder_past_key_values,
+ padding_cache=padding_cache,
+ inputs_embeds=inputs_embeds,
+ encoder_inputs_embeds=encoder_inputs_embeds,
+ use_cache=use_cache,
+ num_delay_tokens=num_delay_tokens,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
return VoxtralRealtimeCausalLMOutputWithPast(
- loss=outputs.loss,
- logits=outputs.logits,
+ loss=loss,
+ logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- encoder_past_key_values=audio_outputs.past_key_values if use_cache else None,
- padding_cache=audio_outputs.padding_cache if use_cache else None,
+ encoder_past_key_values=outputs.encoder_past_key_values,
+ padding_cache=outputs.padding_cache,
)
def prepare_inputs_for_generation(
@@ -1171,7 +1149,7 @@ def _prepare_model_inputs(
input_features = model_kwargs.get("input_features")
if input_features is not None and not isinstance(input_features, GeneratorType):
- model_kwargs["encoder_inputs_embeds"] = self.audio_tower.embedder(model_kwargs.pop("input_features"))
+ model_kwargs["encoder_inputs_embeds"] = self.model.audio_tower.embedder(model_kwargs.pop("input_features"))
elif isinstance(input_features, GeneratorType):
input_features_generator = model_kwargs.pop("input_features")
@@ -1335,4 +1313,9 @@ def _prepare_generated_length(
return generation_config
-__all__ = ["VoxtralRealtimeForConditionalGeneration", "VoxtralRealtimeEncoder", "VoxtralRealtimePreTrainedModel"]
+__all__ = [
+ "VoxtralRealtimeForConditionalGeneration",
+ "VoxtralRealtimeEncoder",
+ "VoxtralRealtimePreTrainedModel",
+ "VoxtralRealtimeModel",
+]
diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py
index edad37679927..81bf2259d726 100644
--- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py
+++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py
@@ -31,13 +31,11 @@
from ...models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
- MistralForCausalLM,
MistralMLP,
MistralModel,
MistralRMSNorm,
)
from ...models.voxtral.modeling_voxtral import (
- VoxtralForConditionalGeneration,
VoxtralMultiModalProjector,
VoxtralPreTrainedModel,
)
@@ -116,6 +114,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast):
padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None
+@dataclass
+class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ Args:
+ encoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder
+ that can be used to speed up sequential decoding.
+ padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*):
+ Cache for padding in convolutional layers to maintain state across streaming chunks.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ Projected audio hidden states before they are added to the text embeddings.
+ """
+
+ encoder_past_key_values: Cache | None = None
+ padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None
+ audio_hidden_states: torch.FloatTensor | None = None
+
+
@dataclass
class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast):
r"""
@@ -255,6 +271,7 @@ def forward(
class VoxtralRealtimePreTrainedModel(VoxtralPreTrainedModel, PreTrainedModel):
# TODO: @eustlb, this should be enabled soon
_can_compile_fullgraph = False
+ _keep_in_fp32_modules_strict = None
@torch.no_grad()
def _init_weights(self, module):
@@ -436,66 +453,6 @@ def __init__(self, config):
self.rotary_emb = VoxtralRealtimeRotaryEmbedding(config=config)
-class VoxtralRealtimeTextForCausalLM(MistralForCausalLM):
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM
-
- >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602")
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
-
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
class VoxtralRealtimeTimeEmbedding(nn.Module):
"""Sinusoidal Embedding for encoding time"""
@@ -520,14 +477,29 @@ def __init__(self, config):
)
-class VoxtralRealtimeForConditionalGeneration(VoxtralForConditionalGeneration, GenerationMixin):
- _keep_in_fp32_modules_strict = None
-
+@auto_docstring(
+ custom_intro="""
+ The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector,
+ a Mistral-based language model and a time embedding, without a language modeling head.
+ """
+)
+class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel):
def __init__(self, config):
super().__init__(config)
- self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config)
+ self.audio_tower = VoxtralRealtimeEncoder(config.audio_config)
+ self.language_model = VoxtralRealtimeTextModel(config.text_config)
+ self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config)
self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
@can_return_tuple
@auto_docstring(
custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
@@ -543,11 +515,7 @@ def get_audio_features(
) -> tuple | BaseModelOutputWithPooling:
r"""
input_features (`torch.FloatTensor`):
- Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
- obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
- `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
- `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
- and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`]
+ Float values of mel features extracted from the raw speech waveform.
padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*):
Cache for padding in convolutional layers to maintain state across streaming chunks.
encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
@@ -585,43 +553,20 @@ def forward(
padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
encoder_inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
num_delay_tokens: int | torch.Tensor = None,
**kwargs: Unpack[TransformersKwargs],
- ) -> VoxtralRealtimeCausalLMOutputWithPast:
+ ) -> tuple | VoxtralRealtimeModelOutputWithPast:
r"""
encoder_past_key_values (`Cache`, *optional*):
- Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding.
+ Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder.
padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*):
Cache for padding in convolutional layers to maintain state across streaming chunks.
encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder.
num_delay_tokens (`int` or `torch.Tensor`, *optional*):
- Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details.
-
- Example:
-
- ```python
- >>> import torch
- >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor
- >>> from datasets import load_dataset
-
- >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602"
-
- >>> processor = AutoProcessor.from_pretrained(repo_id)
- >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto")
-
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> audio = ds[0]["audio"]["array"]
-
- >>> inputs = processor(audio, return_tensors="pt")
- >>> inputs = inputs.to(model.device, dtype=model.dtype)
-
- >>> outputs = model.generate(**inputs)
- >>> processor.batch_decode(outputs, skip_special_tokens=True)
- ```"""
+ Number of delay tokens used when preparing inputs.
+ """
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@@ -631,6 +576,8 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
+ audio_outputs = None
+ audio_embeds = None
if input_features is not None or encoder_inputs_embeds is not None:
audio_outputs = self.get_audio_features(
input_features=input_features,
@@ -640,7 +587,8 @@ def forward(
use_cache=use_cache,
return_dict=True,
)
- inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device)
+ audio_embeds = audio_outputs.pooler_output
+ inputs_embeds += audio_embeds.to(inputs_embeds.device)
if num_delay_tokens is None:
num_delay_tokens = self.config.default_num_delay_tokens
@@ -659,25 +607,124 @@ def forward(
t_cond = self.time_embedding(time_tensor)
t_cond = t_cond[None, ...] # broadcastable to batch size
- outputs: CausalLMOutputWithPast = self.language_model(
+ outputs: BaseModelOutputWithPast = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
- labels=labels,
use_cache=use_cache,
- logits_to_keep=logits_to_keep,
t_cond=t_cond,
**kwargs,
)
+
+ return VoxtralRealtimeModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ encoder_past_key_values=audio_outputs.past_key_values
+ if (audio_outputs is not None and use_cache)
+ else None,
+ padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None,
+ audio_hidden_states=audio_embeds,
+ )
+
+
+class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = VoxtralRealtimeModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_audio_features(self, *args, **kwargs):
+ return self.model.get_audio_features(*args, **kwargs)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ input_features: torch.FloatTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ encoder_past_key_values: Cache | None = None,
+ padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ encoder_inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ num_delay_tokens: int | torch.Tensor = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast:
+ r"""
+ encoder_past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding.
+ padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*):
+ Cache for padding in convolutional layers to maintain state across streaming chunks.
+ encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
+ Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder.
+ num_delay_tokens (`int` or `torch.Tensor`, *optional*):
+ Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor
+ >>> from datasets import load_dataset
+
+ >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602"
+
+ >>> processor = AutoProcessor.from_pretrained(repo_id)
+ >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> audio = ds[0]["audio"]["array"]
+
+ >>> inputs = processor(audio, return_tensors="pt")
+ >>> inputs = inputs.to(model.device, dtype=model.dtype)
+
+ >>> outputs = model.generate(**inputs)
+ >>> processor.batch_decode(outputs, skip_special_tokens=True)
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ input_features=input_features,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ encoder_past_key_values=encoder_past_key_values,
+ padding_cache=padding_cache,
+ inputs_embeds=inputs_embeds,
+ encoder_inputs_embeds=encoder_inputs_embeds,
+ use_cache=use_cache,
+ num_delay_tokens=num_delay_tokens,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
return VoxtralRealtimeCausalLMOutputWithPast(
- loss=outputs.loss,
- logits=outputs.logits,
+ loss=loss,
+ logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
- encoder_past_key_values=audio_outputs.past_key_values if use_cache else None,
- padding_cache=audio_outputs.padding_cache if use_cache else None,
+ encoder_past_key_values=outputs.encoder_past_key_values,
+ padding_cache=outputs.padding_cache,
)
def prepare_inputs_for_generation(
@@ -705,11 +752,11 @@ def _prepare_model_inputs(
bos_token_id: torch.Tensor | None = None,
model_kwargs: dict[str, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, str | None, dict[str, torch.Tensor]]:
- inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
+ inputs, input_name, model_kwargs = super()._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
input_features = model_kwargs.get("input_features")
if input_features is not None and not isinstance(input_features, GeneratorType):
- model_kwargs["encoder_inputs_embeds"] = self.audio_tower.embedder(model_kwargs.pop("input_features"))
+ model_kwargs["encoder_inputs_embeds"] = self.model.audio_tower.embedder(model_kwargs.pop("input_features"))
elif isinstance(input_features, GeneratorType):
input_features_generator = model_kwargs.pop("input_features")
@@ -725,7 +772,7 @@ def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool,
if getattr(self, "_stream_exhausted", False):
self._stream_exhausted = False
return False
- return GenerationMixin._has_unfinished_sequences(this_peer_finished, synced_gpus, device)
+ return super()._has_unfinished_sequences(this_peer_finished, synced_gpus, device)
def _update_model_kwargs_for_generation(
self,
@@ -734,7 +781,7 @@ def _update_model_kwargs_for_generation(
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
):
- model_kwargs = GenerationMixin._update_model_kwargs_for_generation(
+ model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder, num_new_tokens
)
@@ -761,7 +808,7 @@ def _prepare_cache_for_generation(
batch_size: int,
max_cache_length: int,
):
- GenerationMixin._prepare_cache_for_generation(
+ super()._prepare_cache_for_generation(
generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
)
@@ -815,7 +862,7 @@ def _prepare_generation_config(
generation_config is not None and generation_config.max_new_tokens is not None
)
- generation_config, model_kwargs = GenerationMixin._prepare_generation_config(generation_config, **kwargs)
+ generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
input_features = model_kwargs.get("input_features")
if input_features is not None and not isinstance(input_features, GeneratorType):
@@ -854,7 +901,7 @@ def _prepare_generated_length(
if getattr(generation_config, "_voxtral_set_max_length", False):
has_default_max_length = False
- generation_config = GenerationMixin._prepare_generated_length(
+ generation_config = super()._prepare_generated_length(
generation_config,
has_default_max_length,
has_default_min_length,
@@ -877,4 +924,5 @@ def _prepare_generated_length(
"VoxtralRealtimeForConditionalGeneration",
"VoxtralRealtimeEncoder",
"VoxtralRealtimePreTrainedModel",
+ "VoxtralRealtimeModel",
]
diff --git a/tests/alm_tester.py b/tests/alm_tester.py
index af0d63e17fd5..4c05751f564b 100644
--- a/tests/alm_tester.py
+++ b/tests/alm_tester.py
@@ -14,7 +14,6 @@
import copy
import random
-import unittest
from inspect import signature
from unittest.mock import patch
@@ -161,11 +160,6 @@ class ALMModelTest(MultiModalModelTest):
- `pipeline_model_mapping`: Override if not using default from model_tester
"""
- # TODO: @eustlb, remove this once #45534 is merged
- @unittest.skip("Audio-LMs have no separate base model without a head.")
- def test_model_base_model_prefix(self):
- pass
-
def test_sdpa_can_dispatch_on_flash(self):
# `test_sdpa_can_dispatch_on_flash` already pops the attention mask, but we cannot simply pop the
# audio mask here since it will raise an error in `get_audio_features` (cf. `test_mismatching_num_audio_tokens`).
diff --git a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py
index 7df6bcd254fc..57fbe90a352e 100644
--- a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py
+++ b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py
@@ -22,6 +22,7 @@
AudioFlamingo3Config,
AudioFlamingo3EncoderConfig,
AudioFlamingo3ForConditionalGeneration,
+ AudioFlamingo3Model,
AutoProcessor,
Qwen2Config,
is_torch_available,
@@ -42,6 +43,7 @@
class AudioFlamingo3ModelTester(ALMModelTester):
config_class = AudioFlamingo3Config
+ base_model_class = AudioFlamingo3Model
conditional_generation_class = AudioFlamingo3ForConditionalGeneration
text_config_class = Qwen2Config
audio_config_class = AudioFlamingo3EncoderConfig
diff --git a/tests/models/glmasr/test_modeling_glmasr.py b/tests/models/glmasr/test_modeling_glmasr.py
index b19e91a61209..1df06e819636 100644
--- a/tests/models/glmasr/test_modeling_glmasr.py
+++ b/tests/models/glmasr/test_modeling_glmasr.py
@@ -19,6 +19,7 @@
AutoProcessor,
GlmAsrConfig,
GlmAsrForConditionalGeneration,
+ GlmAsrModel,
LlamaConfig,
is_torch_available,
)
@@ -39,6 +40,7 @@
class GlmAsrModelTester(ALMModelTester):
config_class = GlmAsrConfig
+ base_model_class = GlmAsrModel
conditional_generation_class = GlmAsrForConditionalGeneration
text_config_class = LlamaConfig
audio_config_class = GlmAsrEncoderConfig
@@ -48,6 +50,12 @@ def __init__(self, parent, **kwargs):
kwargs.setdefault("head_dim", 8)
super().__init__(parent, **kwargs)
+ def create_audio_mask(self):
+ # Deterministic full-length mask: the base default randomizes lengths in [1, feat_seq_length],
+ # and short samples collapse to 0 audio tokens after conv2 (s=2) + merge_factor=4, breaking
+ # test_mismatching_num_audio_tokens (a 0-contribution sample makes "duplicate audio" a no-op).
+ return torch.ones([self.batch_size, self.feat_seq_length], dtype=torch.bool).to(torch_device)
+
def get_audio_embeds_mask(self, audio_mask):
# conv1 (s=1) preserves length; conv2 (s=2, k=3, p=1) halves; merge_factor=4 post-projector.
audio_lengths = audio_mask.sum(-1)
diff --git a/tests/models/granite_speech/test_modeling_granite_speech.py b/tests/models/granite_speech/test_modeling_granite_speech.py
index e4ecebbcb0ee..613790dbaa5a 100644
--- a/tests/models/granite_speech/test_modeling_granite_speech.py
+++ b/tests/models/granite_speech/test_modeling_granite_speech.py
@@ -23,6 +23,7 @@
GraniteSpeechConfig,
GraniteSpeechEncoderConfig,
GraniteSpeechForConditionalGeneration,
+ GraniteSpeechModel,
)
from transformers.testing_utils import (
cleanup,
@@ -49,6 +50,7 @@
class GraniteSpeechModelTester(ALMModelTester):
config_class = GraniteSpeechConfig
+ base_model_class = GraniteSpeechModel
conditional_generation_class = GraniteSpeechForConditionalGeneration
text_config_class = GraniteConfig
audio_config_class = GraniteSpeechEncoderConfig
diff --git a/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py b/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py
index 21f1d997efb4..04372995ffc3 100644
--- a/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py
+++ b/tests/models/granite_speech_plus/test_modeling_granite_speech_plus.py
@@ -20,6 +20,7 @@
GraniteSpeechPlusConfig,
GraniteSpeechPlusEncoderConfig,
GraniteSpeechPlusForConditionalGeneration,
+ GraniteSpeechPlusModel,
)
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from transformers.utils import is_datasets_available, is_torch_available
@@ -43,6 +44,7 @@ class GraniteSpeechPlusForConditionalGenerationModelTester(GraniteSpeechModelTes
"""
config_class = GraniteSpeechPlusConfig
+ base_model_class = GraniteSpeechPlusModel
conditional_generation_class = GraniteSpeechPlusForConditionalGeneration
audio_config_class = GraniteSpeechPlusEncoderConfig
diff --git a/tests/models/musicflamingo/test_modeling_musicflamingo.py b/tests/models/musicflamingo/test_modeling_musicflamingo.py
index 7c6174e86834..111ec75bea4b 100644
--- a/tests/models/musicflamingo/test_modeling_musicflamingo.py
+++ b/tests/models/musicflamingo/test_modeling_musicflamingo.py
@@ -24,6 +24,7 @@
AutoProcessor,
MusicFlamingoConfig,
MusicFlamingoForConditionalGeneration,
+ MusicFlamingoModel,
Qwen2Config,
is_torch_available,
)
@@ -51,6 +52,7 @@ class MusicFlamingoModelTester(ALMModelTester):
"""
config_class = MusicFlamingoConfig
+ base_model_class = MusicFlamingoModel
conditional_generation_class = MusicFlamingoForConditionalGeneration
text_config_class = Qwen2Config
audio_config_class = AudioFlamingo3EncoderConfig
@@ -96,7 +98,7 @@ class MusicFlamingoForConditionalGenerationModelTest(ALMModelTest, unittest.Test
def test_rotary_window_axis_resets_per_audio(self):
config = self.model_tester.get_config()
- pos_emb = MusicFlamingoForConditionalGeneration(config).pos_emb.to(torch_device)
+ pos_emb = MusicFlamingoForConditionalGeneration(config).model.pos_emb.to(torch_device)
timestamps = torch.tensor(
[
@@ -126,7 +128,7 @@ def test_build_audio_timestamps_reconstructs_windows_from_input_ids(self):
input_ids[0, :45] = config.audio_token_id
input_ids[1, :30] = config.audio_token_id
- _, post_lengths = model.audio_tower._get_feat_extract_output_lengths(
+ _, post_lengths = model.model.audio_tower._get_feat_extract_output_lengths(
input_features_mask.sum(-1).to(torch.long)
)
max_post_length = int(post_lengths.max().item())
@@ -144,7 +146,7 @@ def test_build_audio_timestamps_reconstructs_windows_from_input_ids(self):
]
)
- inferred = model._build_audio_timestamps(input_ids, post_lengths, max_post_length)
+ inferred = model.model._build_audio_timestamps(input_ids, post_lengths, max_post_length)
torch.testing.assert_close(inferred, audio_timestamps)
@unittest.skip(
diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py
index d8a8bd7c88f9..f6e7662ac5cc 100644
--- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py
+++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py
@@ -24,6 +24,7 @@
Qwen2AudioConfig,
Qwen2AudioEncoderConfig,
Qwen2AudioForConditionalGeneration,
+ Qwen2AudioModel,
Qwen2Config,
is_torch_available,
)
@@ -43,6 +44,7 @@
class Qwen2AudioModelTester(ALMModelTester):
config_class = Qwen2AudioConfig
+ base_model_class = Qwen2AudioModel
conditional_generation_class = Qwen2AudioForConditionalGeneration
text_config_class = Qwen2Config
audio_config_class = Qwen2AudioEncoderConfig
diff --git a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py
index fc8bb11568ea..dc01897a07ff 100644
--- a/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py
+++ b/tests/models/vibevoice_asr/test_modeling_vibevoice_asr.py
@@ -22,6 +22,7 @@
from transformers import (
VibeVoiceAsrConfig,
VibeVoiceAsrForConditionalGeneration,
+ VibeVoiceAsrModel,
is_datasets_available,
is_torch_available,
)
@@ -133,11 +134,14 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class VibeVoiceAsrForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
- all_model_classes = (VibeVoiceAsrForConditionalGeneration,) if is_torch_available() else ()
+ all_model_classes = (VibeVoiceAsrModel, VibeVoiceAsrForConditionalGeneration) if is_torch_available() else ()
pipeline_model_mapping = (
{"audio-text-to-text": VibeVoiceAsrForConditionalGeneration} if is_torch_available() else {}
)
_is_composite = True
+ # Acoustic/semantic tokenizers run under torch.no_grad() in get_audio_features,
+ # so their params never receive grads — the mixin's force-unfreeze can't change that.
+ test_all_params_have_gradient = False
def setUp(self):
self.model_tester = VibeVoiceAsrModelTester(self)
@@ -185,6 +189,14 @@ def test_model_outputs_equivalence(self):
def test_left_padding_compatibility(self):
pass
+ @unittest.skip(reason="VibeVoiceAsr has slight randomness due to VAE sampling.")
+ def test_forward_with_logits_to_keep(self):
+ pass
+
+ @unittest.skip(reason="VibeVoiceAsr has slight randomness due to VAE sampling.")
+ def test_generate_methods_with_logits_to_keep(self):
+ pass
+
def test_sdpa_can_dispatch_composite_models(self):
# VibeVoiceAsr is audio+text composite; but audio components do not use attention
for model_class in self.all_model_classes:
@@ -197,16 +209,17 @@ def test_sdpa_can_dispatch_composite_models(self):
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
- text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
+ language_model_sdpa = model_sdpa.base_model.language_model
+ text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager"
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
- self.assertTrue(model.language_model.config._attn_implementation == text_attn)
+ self.assertTrue(language_model_sdpa.config._attn_implementation == text_attn)
# Eager
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
- self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
+ self.assertTrue(model_eager.base_model.language_model.config._attn_implementation == "eager")
for _, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
diff --git a/tests/models/voxtral/test_modeling_voxtral.py b/tests/models/voxtral/test_modeling_voxtral.py
index 4f0c604ce05f..deeed284bc5d 100644
--- a/tests/models/voxtral/test_modeling_voxtral.py
+++ b/tests/models/voxtral/test_modeling_voxtral.py
@@ -21,6 +21,7 @@
VoxtralConfig,
VoxtralEncoderConfig,
VoxtralForConditionalGeneration,
+ VoxtralModel,
is_torch_available,
)
from transformers.testing_utils import (
@@ -40,6 +41,7 @@
class VoxtralModelTester(ALMModelTester):
config_class = VoxtralConfig
+ base_model_class = VoxtralModel
conditional_generation_class = VoxtralForConditionalGeneration
text_config_class = LlamaConfig
audio_config_class = VoxtralEncoderConfig
diff --git a/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py b/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py
index 150d7a894104..420b794cd8ee 100644
--- a/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py
+++ b/tests/models/voxtral_realtime/test_modeling_voxtral_realtime.py
@@ -20,6 +20,7 @@
AutoProcessor,
VoxtralRealtimeConfig,
VoxtralRealtimeForConditionalGeneration,
+ VoxtralRealtimeModel,
is_datasets_available,
is_torch_available,
)
@@ -48,6 +49,7 @@
class VoxtralRealtimeModelTester(ALMModelTester):
config_class = VoxtralRealtimeConfig
+ base_model_class = VoxtralRealtimeModel
conditional_generation_class = VoxtralRealtimeForConditionalGeneration
text_config_class = VoxtralRealtimeTextConfig
audio_config_class = VoxtralRealtimeEncoderConfig