Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def forward(self, audio_features):
)
class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
_keep_in_fp32_modules_strict = None
_supports_attention_backend = True
_tp_plan = None
_pp_plan = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(self, config: AudioFlamingo3Config):
"""
)
class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration):
_supports_attention_backend = True
_tp_plan = None
_pp_plan = None
_keep_in_fp32_modules_strict = None
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("gpt_oss", "GptOssModel"),
("gptj", "GPTJModel"),
("granite", "GraniteModel"),
("granite_speech", "GraniteSpeechForConditionalGeneration"),
("granitemoe", "GraniteMoeModel"),
("granitemoehybrid", "GraniteMoeHybridModel"),
("granitemoeshared", "GraniteMoeSharedModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glmasr/modeling_glmasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def forward(self, audio_features):
)
class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin):
_keep_in_fp32_modules_strict = None
_supports_attention_backend = True
_tp_plan = None
_pp_plan = None

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/glmasr/modular_glmasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def __init__(self, config: GlmAsrConfig):
"""
)
class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
_supports_attention_backend = True

@can_return_tuple
@auto_docstring(
custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def forward(
"""
)
class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, GenerationMixin):
_supports_attention_backend = True

def __init__(self, config: GraniteSpeechConfig):
super().__init__(config)
# NOTE: It doesn't matter when we initialize from config, but we should be careful
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def apply_rotary_time_emb(hidden_states, cos, sin):
)
class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin):
_keep_in_fp32_modules_strict = None
_supports_attention_backend = True
_tp_plan = None
_pp_plan = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __call__(
max_length: int | None = None,
return_attention_mask: bool | None = True,
return_tensors: str | None = "pt",
**kwargs,
) -> BatchFeature:
"""
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def _init_weights(self, module):
)
class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin):
_keep_in_fp32_modules_strict = None
_supports_attention_backend = True
_tp_plan = None
_pp_plan = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class VibeVoiceAsrPreTrainedModel(VibeVoiceAcousticTokenizerPreTrainedModel):
"""
)
class VibeVoiceAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
_supports_attention_backend = True

def __init__(self, config: VibeVoiceAsrConfig):
super().__init__(config)
self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config)
Expand Down
6 changes: 6 additions & 0 deletions tests/models/granite_speech/test_modeling_granite_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,12 @@ def setUp(self):
has_text_modality=False,
)

@unittest.skip(
reason="This test does not apply to GraniteSpeech since inputs_embeds corresponding to audio tokens are replaced when input features are provided."
)
def test_inputs_embeds_matches_input_ids(self):
pass

def test_inputs_embeds(self):
# overwrite inputs_embeds tests because we need to delete "input features" for the audio model
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down
Loading