Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
3562c7f
audio tester
tarekziade Apr 13, 2026
0817bdb
tweak check repo for audio tester
tarekziade Apr 13, 2026
356c922
audio -> ALM
eustlb Apr 13, 2026
9663a8e
ALMTester: no audio/text defaults; better input prep
eustlb Apr 13, 2026
73c4548
Merge branch 'main' into tarekziade-audio-test
eustlb Apr 19, 2026
a599b1d
udpate test_sdpa_can_dispatch_composite_models to hanlde ALMs
eustlb Apr 19, 2026
a7d54dc
propagate to other model classes
eustlb Apr 20, 2026
a302c3e
cleaner
eustlb Apr 20, 2026
8fcba58
updates
eustlb Apr 20, 2026
66acc9e
audio_mask_key + updates
eustlb Apr 20, 2026
63ca77e
typo
eustlb Apr 20, 2026
7588135
simplify granite speech
eustlb Apr 20, 2026
41fed1c
nits
eustlb Apr 20, 2026
e5971c7
some more cleaning
eustlb Apr 20, 2026
59703dd
add test_mismatching_num_audio_tokens
eustlb Apr 21, 2026
6a67f32
add get_placeholder_mask
eustlb Apr 21, 2026
b59f958
specific to musicflamingo
eustlb Apr 21, 2026
bb986b6
granite speech fix
eustlb Apr 21, 2026
670c68c
let's factorise alm/vlm testers
eustlb Apr 22, 2026
c953443
make fix-repo
eustlb Apr 22, 2026
8740409
unskip test_sdpa_can_dispatch_on_flash on qwen2_audio
eustlb Apr 22, 2026
dde65f6
should not be skipped
eustlb Apr 22, 2026
19b37c5
make fix-repo
eustlb Apr 22, 2026
b47621a
test_mismatching_num_audio_tokens should be skipped for voxtral_realtime
eustlb Apr 22, 2026
b9d30be
nit
eustlb Apr 27, 2026
8d2e4b7
_special_token_ids as property and skipped in prepare_config_and_inpu…
eustlb Apr 27, 2026
cbd526f
MoE params in common class
eustlb Apr 27, 2026
12dfcd0
add _TEXT_MODEL_TESTER_DEFAULTS to avoid divergence
eustlb Apr 27, 2026
95b1f20
nit
eustlb Apr 27, 2026
c2aa666
clearer inits
eustlb Apr 27, 2026
5e36c9f
_prepare_modality_inputs return dict
eustlb Apr 27, 2026
ca5ff0b
Merge branch 'main' into tarekziade-audio-test
tarekziade May 4, 2026
184227c
format
tarekziade May 4, 2026
d77fbb9
split line for readability
tarekziade May 4, 2026
902dbba
ran python utils/check_modular_conversion.py --fix_and_overwrite
tarekziade May 4, 2026
dcdead1
testing auto cancel
tarekziade May 5, 2026
628343d
testing auto cancel - part 2
tarekziade May 5, 2026
4c35768
Merge branch 'main' into tarekziade-audio-test
tarekziade May 6, 2026
3f5f4d5
Merge branch 'main' into tarekziade-audio-test
eustlb May 11, 2026
c1a4772
remove comment
eustlb May 11, 2026
9322315
udpate granite speech plus tests
eustlb May 11, 2026
95da798
fix test
eustlb May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/transformers/models/audioflamingo3/modeling_audioflamingo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..auto import AutoModel, AutoModelForCausalLM
Expand Down Expand Up @@ -474,6 +474,30 @@ def get_audio_features(

return audio_output

def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_audio_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_audio_mask = special_audio_mask.all(-1)
else:
special_audio_mask = input_ids == self.config.audio_token_id

n_audio_tokens = special_audio_mask.sum()
n_audio_features = audio_features.shape[0]
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
torch_compilable_check(
inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
)
return special_audio_mask

@can_return_tuple
@auto_docstring
def forward(
Expand Down Expand Up @@ -560,10 +584,10 @@ def forward(
audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output

# replace text-audio token placeholders with audio embeddings
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
inputs_embeds = inputs_embeds.masked_scatter(
audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
special_audio_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))

outputs: CausalLMOutputWithPast = self.language_model(
inputs_embeds=inputs_embeds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,10 @@ def forward(
audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output

# replace text-audio token placeholders with audio embeddings
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
inputs_embeds = inputs_embeds.masked_scatter(
audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
special_audio_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))

outputs: CausalLMOutputWithPast = self.language_model(
inputs_embeds=inputs_embeds,
Expand Down
32 changes: 28 additions & 4 deletions src/transformers/models/glmasr/modeling_glmasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, is_torch_available
from ...utils import TransformersKwargs, auto_docstring, is_torch_available, torch_compilable_check
from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..auto import AutoModel, AutoModelForCausalLM
Expand Down Expand Up @@ -426,6 +426,30 @@ def get_audio_features(

return audio_outputs

def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_audio_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_audio_mask = special_audio_mask.all(-1)
else:
special_audio_mask = input_ids == self.config.audio_token_id

n_audio_tokens = special_audio_mask.sum()
n_audio_features = audio_features.shape[0]
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
torch_compilable_check(
inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
)
return special_audio_mask

@can_return_tuple
@auto_docstring
def forward(
Expand Down Expand Up @@ -478,10 +502,10 @@ def forward(
audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output

# replace text-audio token placeholders with audio embeddings
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
inputs_embeds = inputs_embeds.masked_scatter(
audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
special_audio_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))

outputs: CausalLMOutputWithPast = self.language_model(
inputs_embeds=inputs_embeds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,31 @@ class GraniteSpeechEncoderConfig(PreTrainedConfig):
```"""

model_type = "granite_speech_encoder"
attribute_map = {
"hidden_size": "hidden_dim",
"num_hidden_layers": "num_layers",
"num_attention_heads": "num_heads",
"num_mel_bins": "input_dim",
}

input_dim: int = 160
num_layers: int = 10
hidden_dim: int = 1024
feedforward_mult: int = 4
num_heads: int = 8
dim_head: int = 128
dim_head: int | None = None
output_dim: int = 42
context_size: int = 200
max_pos_emb: int = 512
dropout: float | int = 0.1
conv_kernel_size: int = 15
conv_expansion_factor: int = 2

def __post_init__(self, **kwargs):
super().__post_init__(**kwargs)
if self.dim_head is None:
self.dim_head = self.hidden_dim // self.num_heads


@auto_docstring(checkpoint="ibm-granite/granite-speech-3.3-2b")
@strict
Expand Down
36 changes: 27 additions & 9 deletions src/transformers/models/granite_speech/modeling_granite_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,30 @@ def prepare_inputs_for_generation(
model_inputs["input_features"] = input_features
return model_inputs

def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_audio_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_audio_mask = special_audio_mask.all(-1)
else:
special_audio_mask = input_ids == self.config.audio_token_id

n_audio_tokens = special_audio_mask.sum()
n_audio_features = audio_features.shape[0]
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
torch_compilable_check(
inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
)
return special_audio_mask

def get_merged_audio_embeddings(
self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None
) -> torch.Tensor:
Expand All @@ -536,20 +560,14 @@ def get_merged_audio_embeddings(
llm_input_ids = torch.where(is_audio_index, 0, input_ids)
inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]

# Mask the audio features into the text embeddings
special_audio_mask = is_audio_index.unsqueeze(-1)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
if input_features_mask is not None:
torch_compilable_check(
not torch.all(is_audio_index.int().sum(dim=1) != input_features_mask.int().sum(dim=1)),
"Number of audio tokens does not match number of audio features",
)
audio_features = audio_features[input_features_mask]

inputs_embeds = inputs_embeds.masked_scatter(
special_audio_mask,
audio_features,
special_audio_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
return inputs_embeds

def generate(self, *args, **kwargs) -> torch.LongTensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,19 @@ class GraniteSpeechPlusEncoderConfig(PreTrainedConfig):
```"""

model_type = "granite_speech_plus_encoder"
attribute_map = {
"hidden_size": "hidden_dim",
"num_hidden_layers": "num_layers",
"num_attention_heads": "num_heads",
"num_mel_bins": "input_dim",
}

input_dim: int = 160
num_layers: int = 10
hidden_dim: int = 1024
feedforward_mult: int = 4
num_heads: int = 8
dim_head: int = 128
dim_head: int | None = None
output_dim: int = 42
context_size: int = 200
max_pos_emb: int = 512
Expand All @@ -78,6 +84,11 @@ class GraniteSpeechPlusEncoderConfig(PreTrainedConfig):

cat_hidden_layers: list[int] | None = None

def __post_init__(self, **kwargs):
super().__post_init__(**kwargs)
if self.dim_head is None:
self.dim_head = self.hidden_dim // self.num_heads


@auto_docstring(checkpoint="ibm-granite/granite-speech-4.1-2b-plus")
@strict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,30 @@ def prepare_inputs_for_generation(
model_inputs["input_features"] = input_features
return model_inputs

def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_audio_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_audio_mask = special_audio_mask.all(-1)
else:
special_audio_mask = input_ids == self.config.audio_token_id

n_audio_tokens = special_audio_mask.sum()
n_audio_features = audio_features.shape[0]
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
torch_compilable_check(
inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
)
return special_audio_mask

def get_merged_audio_embeddings(
self, input_ids: torch.Tensor, audio_features: torch.Tensor, input_features_mask: torch.Tensor | None = None
) -> torch.Tensor:
Expand All @@ -557,20 +581,14 @@ def get_merged_audio_embeddings(
llm_input_ids = torch.where(is_audio_index, 0, input_ids)
inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]

# Mask the audio features into the text embeddings
special_audio_mask = is_audio_index.unsqueeze(-1)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
if input_features_mask is not None:
torch_compilable_check(
not torch.all(is_audio_index.int().sum(dim=1) != input_features_mask.int().sum(dim=1)),
"Number of audio tokens does not match number of audio features",
)
audio_features = audio_features[input_features_mask]

inputs_embeds = inputs_embeds.masked_scatter(
special_audio_mask,
audio_features,
special_audio_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
return inputs_embeds

def generate(self, *args, **kwargs) -> torch.LongTensor:
Expand Down
39 changes: 35 additions & 4 deletions src/transformers/models/musicflamingo/modeling_musicflamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_musicflamingo import MusicFlamingoConfig

Expand Down Expand Up @@ -269,6 +269,30 @@ def get_audio_features(

return audio_output

def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_audio_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_audio_mask = special_audio_mask.all(-1)
else:
special_audio_mask = input_ids == self.config.audio_token_id

n_audio_tokens = special_audio_mask.sum()
n_audio_features = audio_features.shape[0]
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
torch_compilable_check(
inputs_embeds[special_audio_mask].numel() == audio_features.numel(),
f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
)
return special_audio_mask

@can_return_tuple
@auto_docstring
def forward(
Expand Down Expand Up @@ -345,10 +369,10 @@ def forward(
).pooler_output

# replace text-audio token placeholders with audio embeddings
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
inputs_embeds = inputs_embeds.masked_scatter(
audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
special_audio_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))

outputs: CausalLMOutputWithPast = self.language_model(
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -388,6 +412,13 @@ def _build_audio_timestamps(
_, ends = torch.where(diff == -1)
sample_lengths = (ends - starts).to(torch.long)

n_audio_tokens = audio_token_mask.sum()
n_audio_features = post_lengths.sum()
torch_compilable_check(
n_audio_tokens == n_audio_features,
f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}",
)

# Account for 4x downsampling in audio encoder (conv2 and avg pooling)
audio_embed_frame_step = self.config.audio_frame_step * 4
frame_offsets = (
Expand Down
Loading
Loading