From 7c816f903971facbda74c6906a23c7e5b1b24a46 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 12:38:43 +0000 Subject: [PATCH 01/39] support granite speech nar model --- src/transformers/models/__init__.py | 1 + src/transformers/models/auto/auto_mappings.py | 5 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/granite_speech_nar/__init__.py | 29 + .../configuration_granite_speech_nar.py | 215 ++++++ .../feature_extraction_granite_speech_nar.py | 105 +++ .../modeling_granite_speech_nar.py | 624 ++++++++++++++++++ .../processing_granite_speech_nar.py | 45 ++ tests/models/granite_speech_nar/__init__.py | 0 .../test_modeling_granite_speech_nar.py | 366 ++++++++++ 12 files changed, 1394 insertions(+) create mode 100644 src/transformers/models/granite_speech_nar/__init__.py create mode 100644 src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py create mode 100644 src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py create mode 100644 src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py create mode 100644 src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py create mode 100644 tests/models/granite_speech_nar/__init__.py create mode 100644 tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 406c5f7be0fc..e27f6198aa53 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -182,6 +182,7 @@ from .granite import * from .granite4_vision import * from .granite_speech import * + from .granite_speech_nar import * from .granite_speech_plus import * from .granitemoe import * from .granitemoehybrid import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 95a5e647e963..d30da1872a09 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -242,6 +242,9 @@ ("granite4_vision_text", "Granite4VisionTextConfig"), ("granite_speech", "GraniteSpeechConfig"), ("granite_speech_encoder", "GraniteSpeechEncoderConfig"), + ("granite_speech_nar", "GraniteSpeechNarConfig"), + ("granite_speech_nar_encoder", "GraniteSpeechNarEncoderConfig"), + ("granite_speech_nar_projector", "GraniteSpeechNarProjectorConfig"), ("granite_speech_plus", "GraniteSpeechPlusConfig"), ("granite_speech_plus_encoder", "GraniteSpeechPlusEncoderConfig"), ("granitemoe", "GraniteMoeConfig"), @@ -725,6 +728,8 @@ ("glmasr_encoder", "glmasr"), ("granite4_vision_text", "granite4_vision"), ("granite_speech_encoder", "granite_speech"), + ("granite_speech_nar_encoder", "granite_speech_nar"), + ("granite_speech_nar_projector", "granite_speech_nar"), ("granite_speech_plus_encoder", "granite_speech_plus"), ("grounding-dino", "grounding_dino"), ("groupvit_text_model", "groupvit"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 1b6ad6c44844..58e9f3e4f727 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -49,6 +49,7 @@ ("gemma4", "Gemma4AudioFeatureExtractor"), ("glmasr", "WhisperFeatureExtractor"), ("granite_speech", "GraniteSpeechFeatureExtractor"), + ("granite_speech_nar", "GraniteSpeechNarFeatureExtractor"), ("granite_speech_plus", "GraniteSpeechFeatureExtractor"), ("higgs_audio_v2_tokenizer", "DacFeatureExtractor"), ("hubert", "Wav2Vec2FeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e43a3d90a200..58d6997147e7 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -215,6 +215,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("granite", "GraniteModel"), ("granite4_vision", "Granite4VisionModel"), ("granite_speech", "GraniteSpeechForConditionalGeneration"), + ("granite_speech_nar", "GraniteSpeechNarForASR"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), ("granitemoeshared", "GraniteMoeSharedModel"), @@ -1657,6 +1658,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): [ # Model for Connectionist temporal classification (CTC) mapping ("data2vec-audio", "Data2VecAudioForCTC"), + ("granite_speech_nar", "GraniteSpeechNarForASR"), ("hubert", "HubertForCTC"), ("lasr_ctc", "LasrForCTC"), ("parakeet_ctc", "ParakeetForCTC"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 30c6fc520c49..cc59c4acbe50 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -92,6 +92,7 @@ ("got_ocr2", "GotOcr2Processor"), ("granite4_vision", "Granite4VisionProcessor"), ("granite_speech", "GraniteSpeechProcessor"), + ("granite_speech_nar", "GraniteSpeechNarProcessor"), ("granite_speech_plus", "GraniteSpeechProcessor"), ("grounding-dino", "GroundingDinoProcessor"), ("groupvit", "CLIPProcessor"), diff --git a/src/transformers/models/granite_speech_nar/__init__.py b/src/transformers/models/granite_speech_nar/__init__.py new file mode 100644 index 000000000000..419097100ba6 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_granite_speech_nar import * + from .feature_extraction_granite_speech_nar import * + from .modeling_granite_speech_nar import * + from .processing_granite_speech_nar import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py new file mode 100644 index 000000000000..3cf99a3be500 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py @@ -0,0 +1,215 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Config classes for Granite Speech NAR (Non-Autoregressive ASR).""" + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PretrainedConfig +from ...utils import auto_docstring +from ..auto import CONFIG_MAPPING + + +@auto_docstring +@strict +class GraniteSpeechNarEncoderConfig(PretrainedConfig): + r""" + Configuration for the conformer encoder component of GraniteSpeechNar. + + feedforward_mult (`int`, *optional*, defaults to 4): + Multiplier for the feedforward layers; intermediate dim = `hidden_dim * feedforward_mult`. + output_dim (`int`, *optional*, defaults to 348): + Output dimension of the mid-layer CTC prediction head. + context_size (`int`, *optional*, defaults to 200): + Context size for block-wise conformer attention. + max_pos_emb (`int`, *optional*, defaults to 512): + Maximum relative positional embedding index (Shaw's relative positional encoding). + pred_dropout (`float`, *optional*, defaults to 0.25): + Dropout applied to encoder hidden states before prediction heads. + conv_expansion_factor (`int`, *optional*, defaults to 2): + Expansion factor for conformer convolution module. + self_conditioning_layer (`int`, *optional*): + Layer index at which self-conditioning (mid-layer CTC feedback) is applied. + Defaults to `num_layers // 2`. + bpe_output_dim (`int`, *optional*): + Vocabulary size for the BPE CTC head (shifted by +1 for blank). If None, BPE head is disabled. + bpe_pooling_window (`int`, *optional*, defaults to 4): + Window size for posterior-weighted pooling before the BPE CTC head. + + Example: + + ```python + >>> from transformers import GraniteSpeechNarEncoderConfig + + >>> configuration = GraniteSpeechNarEncoderConfig() + >>> print(configuration.hidden_dim) + 1024 + ```""" + + model_type = "granite_speech_nar_encoder" + attribute_map = { + "hidden_size": "hidden_dim", + "num_hidden_layers": "num_layers", + "num_attention_heads": "num_heads", + "num_mel_bins": "input_dim", + } + + input_dim: int = 160 + num_layers: int = 16 + hidden_dim: int = 1024 + feedforward_mult: int = 4 + num_heads: int = 8 + dim_head: int | None = None + output_dim: int = 348 + context_size: int = 200 + max_pos_emb: int = 512 + dropout: float = 0.1 + pred_dropout: float = 0.25 + conv_kernel_size: int = 15 + conv_expansion_factor: int = 2 + self_conditioning_layer: int | None = None + bpe_output_dim: int | None = None + bpe_pooling_window: int = 4 + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + super().__post_init__(**kwargs) + if self.dim_head is None: + self.dim_head = self.hidden_dim // self.num_heads + if self.self_conditioning_layer is None: + self.self_conditioning_layer = self.num_layers // 2 + + +@auto_docstring +@strict +class GraniteSpeechNarProjectorConfig(PretrainedConfig): + r""" + Configuration for the QFormer-based audio projector in GraniteSpeechNar. + + encoder_dim (`int`, *optional*, defaults to 1024): + Hidden dimension of the encoder (per layer). + llm_dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the language model. + downsample_rate (`int`, *optional*, defaults to 5): + Temporal downsampling rate within each window block. + num_encoder_layers (`int`, *optional*, defaults to 4): + Number of encoder layers concatenated as projector input. + block_size (`int`, *optional*, defaults to 15): + Window size for blocked cross-attention in the projector. + layernorm_eps (`float`, *optional*, defaults to 1e-6): + Epsilon for layer normalization. + attn_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in attention projections. + + Example: + + ```python + >>> from transformers import GraniteSpeechNarProjectorConfig + + >>> configuration = GraniteSpeechNarProjectorConfig() + >>> print(configuration.hidden_size) + 2048 + ```""" + + model_type = "granite_speech_nar_projector" + + encoder_dim: int = 1024 + llm_dim: int = 2048 + downsample_rate: int = 5 + num_encoder_layers: int = 4 + hidden_size: int = 2048 + num_heads: int = 32 + num_layers: int = 2 + dropout_prob: float = 0.1 + block_size: int = 15 + mlp_ratio: int = 2 + layernorm_eps: float = 1e-6 + attn_bias: bool = True + mlp_bias: bool = True + + +@auto_docstring +@strict +class GraniteSpeechNarConfig(PretrainedConfig): + r""" + Configuration for the GraniteSpeechNar non-autoregressive ASR model. + + This model uses a conformer encoder with BPE CTC head, a QFormer-based projector, + and a bidirectional Granite LLM backbone for single-pass speech recognition. + + projector_config (`GraniteSpeechNarProjectorConfig` or `dict`, *optional*): + Configuration for the QFormer-based audio projector. + encoder_layer_indices (`list[int]`, *optional*, defaults to `[4, 8, 12, -1]`): + Indices of encoder layers whose hidden states are concatenated as projector input. + scale_projected_embeddings (`bool`, *optional*, defaults to `True`): + Whether to divide projected audio embeddings by the LLM's embedding multiplier. + blank_token_id (`int`, *optional*): + Token ID used as the CTC blank symbol. Defaults to `text_config.eos_token_id`. + min_edit_sequence_length (`int`, *optional*, defaults to 8): + Minimum length of the edit sequence (CTC tokens + insertion slots) fed to the LLM. + ce_loss_lambda (`float`, *optional*, defaults to 0.0): + Weight for auxiliary cross-entropy loss on the LLM output. + encoder_ctc_loss_lambda (`float`, *optional*, defaults to 0.0): + Weight for auxiliary encoder BPE CTC loss. + + Example: + + ```python + >>> from transformers import GraniteSpeechNarConfig, GraniteSpeechNarForASR + + >>> configuration = GraniteSpeechNarConfig() + >>> model = GraniteSpeechNarForASR(configuration) + >>> print(configuration.model_type) + 'granite_speech_nar' + ```""" + + model_type = "granite_speech_nar" + sub_configs = { + "encoder_config": GraniteSpeechNarEncoderConfig, + "projector_config": GraniteSpeechNarProjectorConfig, + "text_config": "AutoConfig", + } + + encoder_config: dict | PretrainedConfig | None = None + projector_config: dict | PretrainedConfig | None = None + text_config: dict | PretrainedConfig | None = None + encoder_layer_indices: list[int] | None = None + scale_projected_embeddings: bool = True + blank_token_id: int | None = None + min_edit_sequence_length: int = 8 + ce_loss_lambda: float = 0.0 + encoder_ctc_loss_lambda: float = 0.0 + + def __post_init__(self, **kwargs): + if isinstance(self.text_config, dict): + self.text_config["model_type"] = self.text_config.get("model_type", "granite") + self.text_config = CONFIG_MAPPING[self.text_config["model_type"]](**self.text_config) + elif self.text_config is None: + self.text_config = CONFIG_MAPPING["granite"]() + + if not isinstance(self.encoder_config, GraniteSpeechNarEncoderConfig): + self.encoder_config = GraniteSpeechNarEncoderConfig(**(self.encoder_config or {})) + + if not isinstance(self.projector_config, GraniteSpeechNarProjectorConfig): + self.projector_config = GraniteSpeechNarProjectorConfig(**(self.projector_config or {})) + + if self.encoder_layer_indices is None: + self.encoder_layer_indices = [4, 8, 12, -1] + + if self.blank_token_id is None: + self.blank_token_id = self.text_config.eos_token_id + + super().__post_init__(**kwargs) + + +__all__ = ["GraniteSpeechNarEncoderConfig", "GraniteSpeechNarProjectorConfig", "GraniteSpeechNarConfig"] diff --git a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py new file mode 100644 index 000000000000..84e95a6059f9 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py @@ -0,0 +1,105 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extraction for Granite Speech NAR.""" + +import torch +import torchaudio + +from ...feature_extraction_utils import FeatureExtractionMixin + + +class GraniteSpeechNarFeatureExtractor(FeatureExtractionMixin): + """Extracts log-mel spectrogram features for GraniteSpeechNar. + + Produces stacked pairs of 80-band mel frames, yielding 160-dim features + at half the original frame rate. + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + sampling_rate: int = 16000, + n_fft: int = 512, + win_length: int = 400, + hop_length: int = 160, + n_mels: int = 80, + **kwargs, + ): + super().__init__(**kwargs) + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.n_mels = n_mels + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sampling_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mels, + ) + + @torch.no_grad() + def _extract_features(self, raw_audio: torch.Tensor) -> torch.Tensor: + mel_transform = self.mel_transform.to(raw_audio.device) + B, T = raw_audio.shape + l = 2 * (T // (2 * self.hop_length)) + mel = mel_transform(raw_audio.float())[..., :l] + logmel = mel.transpose(-1, -2).clamp_min_(1e-10).log10_() + mx = logmel.amax(dim=(-2, -1), keepdim=True) + logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) + return logmel.reshape(B, -1, 2 * self.n_mels) + + def __call__( + self, + audios: torch.Tensor | list[torch.Tensor], + device: str | torch.device | None = None, + ) -> dict: + if isinstance(audios, torch.Tensor): + if audios.ndim == 1: + audios = [audios] + elif audios.ndim == 2: + audios = [audios[i] for i in range(audios.shape[0])] + else: + raise ValueError(f"Expected 1-D or 2-D tensor, got {audios.ndim}-D") + + raw_lengths = [a.shape[-1] for a in audios] + encoder_frame_counts = [l // (2 * self.hop_length) for l in raw_lengths] + + raw_audio = torch.nn.utils.rnn.pad_sequence( + [a.squeeze(0) if a.ndim > 1 else a for a in audios], + batch_first=True, + padding_value=0.0, + ) + if device is not None: + raw_audio = raw_audio.to(device) + + input_features = self._extract_features(raw_audio) + + max_enc_frames = input_features.shape[1] + x_sizes = torch.tensor(encoder_frame_counts, dtype=torch.long) + attention_mask = torch.arange(max_enc_frames).unsqueeze(0) < x_sizes.unsqueeze(1) + + if device is not None: + input_features = input_features.to(device) + attention_mask = attention_mask.to(device) + + return { + "input_features": input_features, + "attention_mask": attention_mask, + } + + +__all__ = ["GraniteSpeechNarFeatureExtractor"] diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py new file mode 100644 index 000000000000..859e243582d3 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -0,0 +1,624 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model classes for Granite Speech NAR (Non-Autoregressive ASR).""" + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +from ...masking_utils import ( + create_bidirectional_mask, + find_packed_sequence_indices, + packed_sequence_mask_function, +) +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, logging +from ..granite.modeling_granite import GraniteForCausalLM, GraniteModel +from ..granite_speech.modeling_granite_speech import GraniteSpeechConformerBlock +from .configuration_granite_speech_nar import ( + GraniteSpeechNarConfig, + GraniteSpeechNarEncoderConfig, + GraniteSpeechNarProjectorConfig, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class GraniteSpeechNarEncoderOutput(ModelOutput): + """Output of the GraniteSpeechNar encoder.""" + + loss: torch.Tensor | None = None + logits: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + all_hidden_states: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +class GraniteSpeechNarOutput(ModelOutput): + """Output of the GraniteSpeechNarForASR model. + + Attributes: + loss: Combined CTC + auxiliary losses (only when labels provided). + preds: List of predicted token ID tensors per sample (after CTC collapse, inference only). + logits: List of per-sample logit tensors from the LLM head. + encoder_logits: Flat BPE CTC logits from the encoder. + encoder_preds: List of CTC-collapsed encoder predictions per sample. + """ + + loss: torch.Tensor | None = None + preds: list[torch.Tensor] | None = None + logits: list[torch.Tensor] | None = None + encoder_logits: torch.Tensor | None = None + encoder_preds: list[torch.Tensor] | None = None + + +class GraniteSpeechNarConformerBlock(GraniteSpeechConformerBlock): + pass + + +def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, window_size: int = 4) -> torch.Tensor: + B, T, D = hidden.shape + pad_len = (window_size - T % window_size) % window_size + if pad_len > 0: + hidden = F.pad(hidden, (0, 0, 0, pad_len)) + importance = F.pad(importance, (0, pad_len)) + num_windows = hidden.shape[1] // window_size + hidden = hidden.view(B, num_windows, window_size, D) + importance = importance.view(B, num_windows, window_size) + weights = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8) + pooled = (hidden * weights.unsqueeze(-1)).sum(dim=2) + return pooled + + +class GraniteSpeechNarQFormerCrossAttention(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.num_heads = config.num_heads + self.head_dim = config.hidden_size // config.num_heads + self.hidden_size = config.hidden_size + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, query_len, _ = hidden_states.shape + encoder_len = encoder_hidden_states.shape[1] + + query_states = ( + self.q_proj(hidden_states).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + key_states = ( + self.k_proj(encoder_hidden_states) + .view(batch_size, encoder_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(encoder_hidden_states) + .view(batch_size, encoder_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=False) + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.hidden_size) + return self.o_proj(attn_output) + + +class GraniteSpeechNarQFormerMLP(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + mlp_hidden_size = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(config.hidden_size, mlp_hidden_size, bias=config.mlp_bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(mlp_hidden_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.fc2(self.act(self.fc1(hidden_states))) + + +class GraniteSpeechNarQFormerLayer(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.cross_attention = GraniteSpeechNarQFormerCrossAttention(config) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.mlp = GraniteSpeechNarQFormerMLP(config) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states + self.cross_attention(self.attn_norm(hidden_states), encoder_hidden_states) + hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states)) + return hidden_states + + +class GraniteSpeechNarQFormer(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.layers = nn.ModuleList([GraniteSpeechNarQFormerLayer(config) for _ in range(config.num_layers)]) + + def forward(self, query_embeds: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = query_embeds + for layer in self.layers: + hidden_states = layer(hidden_states, encoder_hidden_states) + return hidden_states + + +class GraniteSpeechNarProjector(nn.Module): + """Windowed QFormer projector that maps multi-layer encoder features to LLM embedding space.""" + + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.config = config + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(config.encoder_dim, eps=config.layernorm_eps) for _ in range(config.num_encoder_layers)] + ) + self.layer_projector = nn.Linear(config.encoder_dim * config.num_encoder_layers, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_prob) + self.projector_act = nn.GELU() + self.qformer = GraniteSpeechNarQFormer(config) + + query_length = config.block_size // config.downsample_rate + embed_std = config.hidden_size**-0.5 + self.query = nn.Parameter(torch.randn(1, query_length, config.hidden_size) * embed_std) + self.window_positions = nn.Parameter(torch.randn(1, config.block_size, config.hidden_size) * embed_std) + self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.out_linear = nn.Linear(config.hidden_size, config.llm_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.size() + + x = x.view(batch_size, seq_len, self.config.num_encoder_layers, self.config.encoder_dim) + normalized_layers = [] + for i, layer_norm in enumerate(self.layer_norms): + normalized_layers.append(layer_norm(x[:, :, i])) + x = torch.cat(normalized_layers, dim=-1) + + x = self.projector_act(self.layer_projector(x)) + + block_size = self.config.block_size + nblocks = seq_len // block_size + rest = seq_len % block_size + if rest > 0: + x = F.pad(x, (0, 0, 0, block_size - rest), "constant", 0) + nblocks += 1 + + x = x.view(batch_size * nblocks, block_size, self.config.hidden_size) + query_length = self.query.shape[1] + mean_pool = x.view( + batch_size * nblocks, query_length, self.config.downsample_rate, self.config.hidden_size + ).mean(dim=-2) + + hidden_states = self.qformer( + query_embeds=self.dropout(self.query + mean_pool), + encoder_hidden_states=self.dropout(x + self.window_positions), + ) + + hidden_states = hidden_states.view(batch_size, nblocks * query_length, -1) + hidden_states = self.dropout(self.out_norm(hidden_states)) + return self.out_linear(hidden_states) + + +class GraniteSpeechNarBidirectionalGraniteModel(GraniteModel): + """GraniteModel with bidirectional (non-causal) attention. + + Replaces create_causal_mask() with create_bidirectional_mask() so all + attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. + """ + + def __init__(self, config): + super().__init__(config) + for layer in self.layers: + layer.self_attn.is_causal = False + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + **kwargs, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + if position_ids is None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + packed_seq_mask = find_packed_sequence_indices(position_ids) + and_mask_fn = packed_sequence_mask_function(packed_seq_mask) if packed_seq_mask is not None else None + bidirectional_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + and_mask_function=and_mask_fn, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=bidirectional_mask, + position_ids=position_ids, + use_cache=False, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +@auto_docstring +class GraniteSpeechNarPreTrainedModel(PreTrainedModel): + config_class = GraniteSpeechNarConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["GraniteSpeechNarConformerBlock", "GraniteDecoderLayer"] + input_modalities = ("audio",) + + +class GraniteSpeechNarCTCEncoder(GraniteSpeechNarPreTrainedModel): + """Conformer encoder with BPE CTC head and multi-layer output.""" + + config_class = GraniteSpeechNarEncoderConfig + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__(config) + self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) + self.layers = nn.ModuleList([GraniteSpeechNarConformerBlock(config) for _ in range(config.num_layers)]) + self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True) + self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True) + self.out_bpe = None + if config.bpe_output_dim is not None: + self.out_bpe = nn.Linear(config.hidden_dim, config.bpe_output_dim, bias=True) + self.dropout = nn.Dropout(config.pred_dropout) + self.post_init() + + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + ) -> GraniteSpeechNarEncoderOutput: + if attention_mask is None: + attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) + + hidden_states = self.input_linear(input_features.to(self.dtype)) + all_hidden_states = (hidden_states,) if output_hidden_states else None + blank_probs = None + + context_size = self.config.context_size + seq = torch.arange(context_size, device=hidden_states.device) + relpos_dist = seq.view(-1, 1) - seq.view(1, -1) + attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + self.config.max_pos_emb + + for idx, layer in enumerate(self.layers, start=1): + hidden_states = layer(hidden_states, attention_dists=attention_dists) + + if idx == self.config.self_conditioning_layer: + mid_logits = self.out(self.dropout(hidden_states)) + mid_probs = torch.softmax(mid_logits.float(), dim=-1) + blank_probs = mid_probs[:, :, 0] + hidden_states = hidden_states + self.out_mid(mid_probs.to(hidden_states.dtype)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.dropout(hidden_states) + + logits = None + pool_window = self.config.bpe_pooling_window + if self.out_bpe is not None and blank_probs is not None: + importance = 1.0 - blank_probs + pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( + hidden_states.dtype + ) + T = attention_mask.shape[1] + pad_len = (pool_window - T % pool_window) % pool_window + pooled_mask = F.pad(attention_mask, (0, pad_len), value=False)[:, ::pool_window] + logits = self.out_bpe(pooled[pooled_mask]) + + loss = None + if labels is not None and logits is not None: + encoder_lengths = attention_mask.sum(dim=1) + bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() + T_max = max(bpe_lengths) + B = len(bpe_lengths) + logits_padded = logits.new_zeros(B, T_max, logits.shape[-1]) + offset = 0 + for i, length in enumerate(bpe_lengths): + logits_padded[i, :length] = logits[offset : offset + length] + offset += length + + log_probs = torch.log_softmax(logits_padded.float(), dim=-1) + bpe_x_sizes = torch.tensor(bpe_lengths, device=logits.device) + loss = ( + F.ctc_loss( + log_probs.transpose(0, 1), + labels + 1, + bpe_x_sizes, + label_lengths, + blank=0, + reduction="sum", + zero_infinity=True, + ) + / bpe_x_sizes.sum() + ) + + return GraniteSpeechNarEncoderOutput( + loss=loss, + logits=logits, + last_hidden_state=hidden_states, + all_hidden_states=all_hidden_states, + ) + + +class GraniteSpeechNarLanguageModel(GraniteForCausalLM): + """GraniteForCausalLM with a bidirectional (non-causal) backbone.""" + + def __init__(self, config): + super().__init__(config) + self.model = GraniteSpeechNarBidirectionalGraniteModel(config) + + +@auto_docstring( + custom_intro=""" + The GraniteSpeechNar model for non-autoregressive automatic speech recognition. + Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, + and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. + """ +) +class GraniteSpeechNarForASR(GraniteSpeechNarPreTrainedModel): + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + + self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechNarProjector(config.projector_config) + + text_config = config.text_config + if hasattr(config, "_attn_implementation"): + text_config._attn_implementation = config._attn_implementation + self.language_model = GraniteSpeechNarLanguageModel(text_config) + + self.post_init() + + def _ctc_collapse_decode( + self, + bpe_logits_flat: torch.Tensor, + bpe_lengths: list[int], + ) -> list[torch.Tensor]: + """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank -> shift.""" + preds_flat = bpe_logits_flat.argmax(dim=-1) + per_sample = preds_flat.split(bpe_lengths) + return [(collapsed := torch.unique_consecutive(seq))[collapsed != 0] - 1 for seq in per_sample] + + def _add_insertion_slots(self, token_ids: torch.Tensor) -> torch.Tensor: + """Insert blank tokens between each CTC token as editing slots for the LLM.""" + blank_id = self.config.blank_token_id + n = token_ids.numel() + total_len = max(2 * n + 1, self.config.min_edit_sequence_length) + idx = torch.arange(n, device=token_ids.device) + out_idx = 2 * idx + 1 + out = torch.full((total_len,), fill_value=blank_id, dtype=token_ids.dtype, device=token_ids.device) + out[out_idx] = token_ids + return out + + def _build_flat_inputs( + self, + ctc_token_ids: list[torch.Tensor], + audio_embeds: torch.Tensor, + audio_lengths: list[int], + ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """Build flat (pad-free) LLM input: [audio_0, text_0, audio_1, text_1, ...]""" + embed_tokens = self.language_model.model.embed_tokens + + embeds_list = [] + position_ids_list = [] + text_lengths = [] + + for i, audio_len in enumerate(audio_lengths): + audio_emb = audio_embeds[i, :audio_len] + text_ids_with_slots = self._add_insertion_slots(ctc_token_ids[i]) + text_emb = embed_tokens(text_ids_with_slots) + sample_embeds = torch.cat([audio_emb, text_emb], dim=0) + embeds_list.append(sample_embeds) + position_ids_list.append(torch.arange(sample_embeds.shape[0], device=audio_embeds.device)) + text_lengths.append(text_ids_with_slots.shape[0]) + + flat_embeds = torch.cat(embeds_list, dim=0).unsqueeze(0) + flat_position_ids = torch.cat(position_ids_list, dim=0).unsqueeze(0) + return flat_embeds, flat_position_ids, text_lengths + + def forward( + self, + *, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + output_encoder_logits: bool = False, + ) -> GraniteSpeechNarOutput: + r""" + Args: + input_features (`torch.Tensor` of shape `(batch_size, seq_len, input_dim)`): + Mel spectrogram features. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Encoder attention mask (1 for valid frames, 0 for padding). + labels (`torch.Tensor` of shape `(batch_size, max_label_len)`, *optional*): + Ground truth LLM token IDs for training. + label_lengths (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Number of valid tokens per sample in `labels`. + output_encoder_logits (`bool`, *optional*, defaults to `False`): + Whether to return encoder BPE logits. When False, the large logits + tensor is freed early to reduce peak memory. + + Returns: + [`GraniteSpeechNarOutput`] + """ + encoder_labels = labels if (labels is not None and self.config.encoder_ctc_loss_lambda > 0.0) else None + enc_out = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + labels=encoder_labels, + label_lengths=label_lengths if encoder_labels is not None else None, + ) + + if attention_mask is None: + attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) + + encoder_lengths = attention_mask.sum(dim=1) + + pool_window = self.encoder.config.bpe_pooling_window + bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() + ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) + + multilayer_features = torch.cat( + [enc_out.all_hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 + ) + + encoder_loss = enc_out.loss + encoder_logits = enc_out.logits if output_encoder_logits else None + del enc_out + + audio_embeds = self.projector(multilayer_features) + del multilayer_features + if self.config.scale_projected_embeddings: + embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) + audio_embeds = audio_embeds / embedding_multiplier + audio_embeds = audio_embeds.to(self.language_model.model.embed_tokens.weight.dtype) + + audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() + + flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( + ctc_token_ids, audio_embeds, audio_lengths + ) + + llm_out = self.language_model.model( + inputs_embeds=flat_embeds, + position_ids=flat_position_ids, + ) + llm_hidden = llm_out.last_hidden_state.squeeze(0) + + segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] + text_hidden = torch.cat(list(llm_hidden.split(segment_lengths)[1::2])) + + logits = self.language_model.lm_head(text_hidden) + logits = logits / self.language_model.config.logits_scaling + logits_per_sample = list(logits.split(text_lengths)) + + loss = None + if labels is not None: + log_probs = torch.log_softmax(logits.float(), dim=-1) + + T_max = max(text_lengths) + B = len(text_lengths) + V = log_probs.shape[-1] + log_probs_padded = log_probs.new_zeros(B, T_max, V) + offset = 0 + for i, tl in enumerate(text_lengths): + log_probs_padded[i, :tl] = log_probs[offset : offset + tl] + offset += tl + + input_lengths = torch.tensor(text_lengths, device=logits.device) + + loss = ( + F.ctc_loss( + log_probs_padded.transpose(0, 1), + labels, + input_lengths, + label_lengths, + blank=self.config.blank_token_id, + reduction="sum", + zero_infinity=True, + ) + / input_lengths.sum() + ) + + if self.config.ce_loss_lambda > 0.0: + ce_targets_list = [self._add_insertion_slots(ids) for ids in ctc_token_ids] + ce_targets = torch.cat(ce_targets_list) + ce_loss = F.cross_entropy( + logits, + ce_targets.long(), + reduction="mean", + ignore_index=-100, + ) + loss = loss + self.config.ce_loss_lambda * ce_loss + + if encoder_loss is not None: + loss = loss + self.config.encoder_ctc_loss_lambda * encoder_loss + + return GraniteSpeechNarOutput( + loss=loss, + logits=logits_per_sample, + encoder_logits=encoder_logits, + encoder_preds=ctc_token_ids, + ) + + @torch.inference_mode() + def transcribe( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_encoder_logits: bool = False, + ) -> GraniteSpeechNarOutput: + """Single-pass non-autoregressive inference: forward + CTC collapse on LLM output. + + Returns token ID tensors in `preds`. Use `GraniteSpeechNarProcessor.batch_decode()` + to convert to strings. + """ + output = self.forward( + input_features=input_features, + attention_mask=attention_mask, + output_encoder_logits=output_encoder_logits, + ) + + blank_id = self.config.blank_token_id + preds = [] + for sample_logits in output.logits: + pred = torch.unique_consecutive(sample_logits.argmax(-1)) + pred = pred[pred != blank_id] + preds.append(pred) + + return GraniteSpeechNarOutput( + preds=preds, + logits=output.logits, + encoder_logits=output.encoder_logits, + encoder_preds=output.encoder_preds, + ) + + +__all__ = [ + "GraniteSpeechNarCTCEncoder", + "GraniteSpeechNarForASR", + "GraniteSpeechNarPreTrainedModel", +] diff --git a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py new file mode 100644 index 000000000000..38d7abe3719d --- /dev/null +++ b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py @@ -0,0 +1,45 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor for Granite Speech NAR.""" + +import torch + +from ...processing_utils import ProcessorMixin +from .feature_extraction_granite_speech_nar import GraniteSpeechNarFeatureExtractor + + +class GraniteSpeechNarProcessor(ProcessorMixin): + """Processor combining audio feature extraction and tokenizer for GraniteSpeechNar.""" + + feature_extractor_class = "GraniteSpeechNarFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, feature_extractor: GraniteSpeechNarFeatureExtractor, tokenizer=None, **kwargs): + super().__init__(feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs) + + def __call__( + self, + audios: torch.Tensor | list[torch.Tensor], + device: str | torch.device | None = None, + **kwargs, + ) -> dict: + return self.feature_extractor(audios, device=device) + + def batch_decode(self, token_ids_list: list[torch.Tensor], **kwargs) -> list[str]: + if self.tokenizer is None: + raise ValueError("Tokenizer not set. Pass tokenizer to GraniteSpeechNarProcessor.") + return [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in token_ids_list] + + +__all__ = ["GraniteSpeechNarProcessor"] diff --git a/tests/models/granite_speech_nar/__init__.py b/tests/models/granite_speech_nar/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py new file mode 100644 index 000000000000..740129a39f7d --- /dev/null +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -0,0 +1,366 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for GraniteSpeechNar model.""" + +import math + +import torch + +from transformers import ( + AutoConfig, + GraniteConfig, + GraniteSpeechNarConfig, +) +from transformers.models.granite_speech_nar.configuration_granite_speech_nar import ( + GraniteSpeechNarEncoderConfig, + GraniteSpeechNarProjectorConfig, +) +from transformers.models.granite_speech_nar.modeling_granite_speech_nar import ( + GraniteSpeechNarCTCEncoder, + GraniteSpeechNarForASR, + GraniteSpeechNarOutput, + GraniteSpeechNarProjector, +) + + +def _make_small_config(): + encoder_config = GraniteSpeechNarEncoderConfig( + num_layers=4, + hidden_dim=64, + num_heads=4, + dim_head=16, + input_dim=160, + output_dim=10, + context_size=50, + self_conditioning_layer=2, + bpe_output_dim=52, + bpe_pooling_window=4, + ) + projector_config = GraniteSpeechNarProjectorConfig( + encoder_dim=64, + llm_dim=128, + downsample_rate=5, + num_encoder_layers=4, + hidden_size=128, + num_heads=4, + num_layers=1, + block_size=15, + ) + text_config = GraniteConfig( + vocab_size=51, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=256, + max_position_embeddings=512, + tie_word_embeddings=True, + embedding_multiplier=1.0, + attention_multiplier=1.0, + residual_multiplier=1.0, + logits_scaling=1.0, + ) + return GraniteSpeechNarConfig( + encoder_config=encoder_config, + projector_config=projector_config, + text_config=text_config.to_dict(), + encoder_layer_indices=[1, 2, 3, -1], + scale_projected_embeddings=False, + ) + + +# === Configuration tests === + + +class TestConfiguration: + def test_encoder_config_defaults(self): + config = GraniteSpeechNarEncoderConfig() + assert config.model_type == "granite_speech_nar_encoder" + assert config.input_dim == 160 + assert config.num_layers == 16 + assert config.hidden_dim == 1024 + assert config.self_conditioning_layer == 8 + assert config.bpe_output_dim is None + + def test_projector_config_defaults(self): + config = GraniteSpeechNarProjectorConfig() + assert config.model_type == "granite_speech_nar_projector" + assert config.encoder_dim == 1024 + assert config.llm_dim == 2048 + assert config.downsample_rate == 5 + + def test_config_defaults(self): + config = GraniteSpeechNarConfig() + assert config.model_type == "granite_speech_nar" + assert config.encoder_layer_indices == [4, 8, 12, -1] + assert config.scale_projected_embeddings is True + + def test_config_serialization_roundtrip(self): + config = _make_small_config() + d = config.to_dict() + restored = GraniteSpeechNarConfig(**d) + assert restored.encoder_config.num_layers == 4 + assert restored.encoder_config.bpe_output_dim == 52 + assert restored.projector_config.num_layers == 1 + assert restored.encoder_layer_indices == [1, 2, 3, -1] + + def test_auto_config_resolution(self): + config = AutoConfig.for_model("granite_speech_nar") + assert isinstance(config, GraniteSpeechNarConfig) + + +# === Encoder tests === + + +class TestEncoder: + def test_output_shapes(self): + config = GraniteSpeechNarEncoderConfig( + num_layers=4, + hidden_dim=64, + num_heads=4, + dim_head=16, + input_dim=160, + output_dim=348, + context_size=50, + self_conditioning_layer=2, + bpe_output_dim=100, + bpe_pooling_window=4, + ) + encoder = GraniteSpeechNarCTCEncoder(config).eval() + + B, T = 2, 100 + features = torch.randn(B, T, 160) + mask = torch.ones(B, T, dtype=torch.bool) + mask[1, 80:] = False + + out = encoder(features, mask, output_hidden_states=True) + + assert out.logits is not None + assert out.logits.shape[1] == 100 + assert out.all_hidden_states is not None + assert len(out.all_hidden_states) == 5 # input + 4 layers + + def test_no_bpe_head(self): + config = GraniteSpeechNarEncoderConfig( + num_layers=2, + hidden_dim=64, + num_heads=4, + dim_head=16, + input_dim=160, + output_dim=348, + context_size=50, + self_conditioning_layer=1, + bpe_output_dim=None, + ) + encoder = GraniteSpeechNarCTCEncoder(config).eval() + + features = torch.randn(1, 50, 160) + out = encoder(features, output_hidden_states=False) + + assert out.logits is None + assert out.all_hidden_states is None + + +# === Projector tests === + + +class TestProjector: + def test_output_shape(self): + config = GraniteSpeechNarProjectorConfig( + encoder_dim=64, + llm_dim=128, + downsample_rate=5, + num_encoder_layers=2, + hidden_size=128, + num_heads=4, + num_layers=1, + block_size=15, + ) + projector = GraniteSpeechNarProjector(config) + + B, T = 2, 60 + x = torch.randn(B, T, 2 * 64) + out = projector(x) + expected_len = math.ceil(T / config.block_size) * (config.block_size // config.downsample_rate) + assert out.shape == (B, expected_len, 128) + + def test_handles_non_divisible_length(self): + config = GraniteSpeechNarProjectorConfig( + encoder_dim=64, + llm_dim=128, + downsample_rate=5, + num_encoder_layers=1, + hidden_size=64, + num_heads=4, + num_layers=1, + block_size=15, + ) + projector = GraniteSpeechNarProjector(config) + + x = torch.randn(1, 37, 64) + out = projector(x) + assert out.shape == (1, 9, 128) + + +# === Full model tests === + + +class TestGraniteSpeechNarForASR: + def test_forward(self): + config = _make_small_config() + model = GraniteSpeechNarForASR(config).eval() + + B, T = 2, 100 + features = torch.randn(B, T, 160) + mask = torch.ones(B, T, dtype=torch.bool) + mask[1, 80:] = False + + with torch.no_grad(): + output = model(input_features=features, attention_mask=mask) + + assert isinstance(output, GraniteSpeechNarOutput) + assert output.logits is not None + assert isinstance(output.logits, list) + assert len(output.logits) == B + for logits in output.logits: + assert logits.ndim == 2 + assert logits.shape[1] == 51 + + def test_transcribe(self): + config = _make_small_config() + model = GraniteSpeechNarForASR(config).eval() + + features = torch.randn(1, 60, 160) + output = model.transcribe(input_features=features) + + assert output.preds is not None + assert len(output.preds) == 1 + assert isinstance(output.preds[0], torch.Tensor) + + def test_loss(self): + config = _make_small_config() + model = GraniteSpeechNarForASR(config).train() + + B, T = 2, 100 + features = torch.randn(B, T, 160) + mask = torch.ones(B, T, dtype=torch.bool) + mask[1, 80:] = False + labels = torch.randint(0, 51, (B, 5)) + label_lengths = torch.tensor([5, 3]) + + output = model( + input_features=features, + attention_mask=mask, + labels=labels, + label_lengths=label_lengths, + ) + + assert output.loss is not None + assert output.loss.ndim == 0 + assert output.loss.requires_grad + output.loss.backward() + + def test_loss_with_ce(self): + config = _make_small_config() + config.ce_loss_lambda = 0.5 + model = GraniteSpeechNarForASR(config).train() + + features = torch.randn(1, 60, 160) + labels = torch.randint(0, 51, (1, 4)) + label_lengths = torch.tensor([4]) + + output = model( + input_features=features, + labels=labels, + label_lengths=label_lengths, + ) + + assert output.loss is not None + assert output.loss.requires_grad + output.loss.backward() + + def test_loss_with_encoder_ctc(self): + config = _make_small_config() + config.encoder_ctc_loss_lambda = 0.3 + model = GraniteSpeechNarForASR(config).train() + + features = torch.randn(1, 60, 160) + labels = torch.randint(0, 51, (1, 4)) + label_lengths = torch.tensor([4]) + + output = model( + input_features=features, + labels=labels, + label_lengths=label_lengths, + ) + + assert output.loss is not None + assert output.loss.requires_grad + output.loss.backward() + + def test_no_loss_without_labels(self): + config = _make_small_config() + model = GraniteSpeechNarForASR(config).eval() + + features = torch.randn(1, 60, 160) + with torch.no_grad(): + output = model(input_features=features) + + assert output.loss is None + + def test_output_encoder_logits_flag(self): + config = _make_small_config() + model = GraniteSpeechNarForASR(config).eval() + + features = torch.randn(1, 60, 160) + with torch.no_grad(): + out_no = model(input_features=features, output_encoder_logits=False) + out_yes = model(input_features=features, output_encoder_logits=True) + + assert out_no.encoder_logits is None + assert out_yes.encoder_logits is not None + assert out_no.encoder_preds is not None # always returned + + def test_automodel_resolves(self): + config = AutoConfig.for_model("granite_speech_nar") + assert isinstance(config, GraniteSpeechNarConfig) + assert config.model_type == "granite_speech_nar" + + +# === Bidirectional attention test === + + +class TestBidirectionalAttention: + def test_last_token_affects_first(self): + """Changing the last token must affect the first (bidirectional).""" + config = _make_small_config() + model = GraniteSpeechNarForASR(config).eval() + granite_model = model.language_model.model + + embeds_a = torch.randn(1, 10, 128) + embeds_b = embeds_a.clone() + embeds_b[0, -1, :] = torch.randn(128) + + with torch.no_grad(): + out_a = granite_model(inputs_embeds=embeds_a).last_hidden_state + out_b = granite_model(inputs_embeds=embeds_b).last_hidden_state + + diff_first = (out_a[0, 0] - out_b[0, 0]).abs().max().item() + assert diff_first > 1e-5, f"First token unchanged (diff={diff_first}). Attention appears causal." + + def test_is_causal_false_on_layers(self): + config = _make_small_config() + model = GraniteSpeechNarForASR(config) + for i, layer in enumerate(model.language_model.model.layers): + assert layer.self_attn.is_causal is False, f"Layer {i} is_causal is not False" From bb5b6482a87d8990e894a333ed1732146670f6ba Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 14:28:41 +0000 Subject: [PATCH 02/39] attempt to use modular - inherit from the Granite base LLM - changing the attention pattern to bidirectional. Inherit the conformer encoder from GraniteSpeech --- .../modeling_granite_speech_nar.py | 654 ++++++++++++++++-- .../modular_granite_speech_nar.py | 644 +++++++++++++++++ 2 files changed, 1246 insertions(+), 52 deletions(-) create mode 100644 src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index 859e243582d3..f53f428a03b4 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granite_speech_nar.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2026 IBM and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,24 +17,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Model classes for Granite Speech NAR (Non-Autoregressive ASR).""" +import math +from collections.abc import Callable from dataclasses import dataclass +from typing import Optional import torch import torch.nn.functional as F from torch import nn -from ...masking_utils import ( - create_bidirectional_mask, - find_packed_sequence_indices, - packed_sequence_mask_function, -) -from ...modeling_outputs import BaseModelOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...utils import ModelOutput, auto_docstring, logging -from ..granite.modeling_granite import GraniteForCausalLM, GraniteModel -from ..granite_speech.modeling_granite_speech import GraniteSpeechConformerBlock +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...masking_utils import create_bidirectional_mask, find_packed_sequence_indices, packed_sequence_mask_function +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs from .configuration_granite_speech_nar import ( GraniteSpeechNarConfig, GraniteSpeechNarEncoderConfig, @@ -36,9 +47,6 @@ ) -logger = logging.get_logger(__name__) - - @dataclass class GraniteSpeechNarEncoderOutput(ModelOutput): """Output of the GraniteSpeechNar encoder.""" @@ -68,22 +76,159 @@ class GraniteSpeechNarOutput(ModelOutput): encoder_preds: list[torch.Tensor] | None = None -class GraniteSpeechNarConformerBlock(GraniteSpeechConformerBlock): - pass +### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git +class GraniteSpeechNarConformerFeedForward(nn.Module): + """Feedforward module for conformer encoder blocks.""" + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__() + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.up_proj = nn.Linear(config.hidden_dim, config.hidden_dim * config.feedforward_mult) + self.silu = nn.SiLU() + self.dropout = nn.Dropout(config.dropout) + self.down_proj = nn.Linear(config.hidden_dim * config.feedforward_mult, config.hidden_dim) -def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, window_size: int = 4) -> torch.Tensor: - B, T, D = hidden.shape - pad_len = (window_size - T % window_size) % window_size - if pad_len > 0: - hidden = F.pad(hidden, (0, 0, 0, pad_len)) - importance = F.pad(importance, (0, pad_len)) - num_windows = hidden.shape[1] // window_size - hidden = hidden.view(B, num_windows, window_size, D) - importance = importance.view(B, num_windows, window_size) - weights = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8) - pooled = (hidden * weights.unsqueeze(-1)).sum(dim=2) - return pooled + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + hidden_states = self.up_proj(hidden_states) + hidden_states = self.dropout(self.silu(hidden_states)) + hidden_states = self.down_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GraniteSpeechNarConformerAttention(nn.Module): + """Attention for conformer blocks using Shaw's relative positional embeddings. + See the following [paper](https://huggingface.co/papers/1803.02155) for more details. + """ + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__() + + inner_dim = config.dim_head * config.num_heads + self.max_pos_emb = config.max_pos_emb + self.context_size = config.context_size + self.num_heads = config.num_heads + self.dim_head = config.dim_head + self.scale = self.dim_head**-0.5 + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, config.hidden_dim) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head) + self.dropout = nn.Dropout(config.dropout) + + if self.context_size <= 0 or self.context_size > self.max_pos_emb: + raise ValueError("Context size is either less than 0 or exceeds the max_pos_emb") + + def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + bsz, num_features, _ = hidden_states.shape + + num_blocks = math.ceil(num_features / self.context_size) + remainder = num_features % self.context_size + if remainder > 0: + # right padding to reach block size + hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, self.context_size - remainder)) + + query_states = self.to_q(hidden_states) + key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) + + query_states = query_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + key_states = key_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + value_states = value_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + + # shaw's relative positional embedding + rel_pos_emb = self.rel_pos_emb(attention_dists) + # alternative computation of `pos_attn` - for readability + # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) + # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale + # einsum implementation of pos_attn - gives x30 speedup over the alternative + # TODO (@avihu111) find a fast alternative to einsum + pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale + + if remainder > 0: + # masked attention in the extended block + mask = torch.ones(self.context_size, self.context_size, dtype=bool, device=hidden_states.device) + mask[:remainder, :remainder] = 0 + mask_value = -torch.finfo(pos_attn.dtype).max + pos_attn[:, -1, :].masked_fill_(mask, mask_value) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + out = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=pos_attn, scale=self.scale + ) + out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) + out = self.to_out(out[:, :num_features, :]) + return self.dropout(out) + + +class GraniteSpeechNarConformerDepthWiseConv1d(nn.Module): + """Wrapper for padded 1D pointwise convolution.""" + + def __init__(self, chan_in: int, chan_out: int, kernel_size: int): + super().__init__() + # Padding for the 1D conv is symmetric or close (i.e., offset by one). + pad = kernel_size // 2 + pad_offset = (kernel_size + 1) % 2 + self.padding = (pad, pad - pad_offset) + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.padding) + return self.conv(hidden_states) + + +class GraniteSpeechNarConformerConvModule(nn.Module): + """Conformer conv module consisting of several 1D/depthwise 1D convolutional layers.""" + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__() + inner_dim = config.hidden_dim * config.conv_expansion_factor + + self.norm = nn.LayerNorm(config.hidden_dim) + self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) + self.glu = nn.GLU(dim=1) + self.depth_conv = GraniteSpeechNarConformerDepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=config.conv_kernel_size, + ) + self.silu = nn.SiLU() + self.batch_norm = nn.BatchNorm1d(inner_dim) + self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) + hidden_states = self.glu(hidden_states) + hidden_states = self.depth_conv(hidden_states) + hidden_states = self.silu(self.batch_norm(hidden_states)) + hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GraniteSpeechNarConformerBlock(nn.Module): + """Conformer block, consisting largely of linear layers, attention, and convolutional layers.""" + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__() + self.ff1 = GraniteSpeechNarConformerFeedForward(config) + self.attn = GraniteSpeechNarConformerAttention(config) + self.conv = GraniteSpeechNarConformerConvModule(config) + self.ff2 = GraniteSpeechNarConformerFeedForward(config) + self.post_norm = nn.LayerNorm(config.hidden_dim) + + def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states + hidden_states = self.attn(hidden_states, attention_dists=attention_dists) + hidden_states + hidden_states = self.conv(hidden_states) + hidden_states + hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states + hidden_states = self.post_norm(hidden_states) + return hidden_states class GraniteSpeechNarQFormerCrossAttention(nn.Module): @@ -213,18 +358,358 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_linear(hidden_states) -class GraniteSpeechNarBidirectionalGraniteModel(GraniteModel): +@auto_docstring +class GraniteSpeechNarPreTrainedModel(PreTrainedModel): + config_class = GraniteSpeechNarConfig + base_model_prefix = "encoder" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["GraniteSpeechNarConformerBlock", "GraniteDecoderLayer"] + input_modalities = ("audio",) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@use_kernelized_func(apply_rotary_pos_emb) +class GraniteSpeechNarBidirectionalAttention(nn.Module): + """GraniteAttention with is_causal=False for bidirectional attention.""" + + is_causal = False + + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.attention_multiplier + self.attention_dropout = config.attention_dropout + self.is_causal = False + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class GraniteSpeechNarRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + GraniteSpeechNarRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GraniteSpeechNarMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class GraniteSpeechNarBidirectionalDecoderLayer(GradientCheckpointingLayer): + """GraniteDecoderLayer using bidirectional attention.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GraniteSpeechNarBidirectionalAttention(config=config, layer_idx=layer_idx) + + self.mlp = GraniteSpeechNarMLP(config) + self.input_layernorm = GraniteSpeechNarRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GraniteSpeechNarRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +class GraniteSpeechNarRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GraniteSpeechNarConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: GraniteSpeechNarConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class GraniteSpeechNarBidirectionalGraniteModel(GraniteSpeechNarPreTrainedModel): """GraniteModel with bidirectional (non-causal) attention. - Replaces create_causal_mask() with create_bidirectional_mask() so all + Uses GraniteSpeechNarBidirectionalDecoderLayer which sets is_causal=False, + and replaces create_causal_mask() with create_bidirectional_mask() so all attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. """ def __init__(self, config): super().__init__(config) - for layer in self.layers: - layer.self_attn.is_causal = False + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + GraniteSpeechNarBidirectionalDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = GraniteSpeechNarRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GraniteSpeechNarRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.embedding_multiplier = config.embedding_multiplier + + # Initialize weights and apply final processing + self.post_init() + @merge_with_config_defaults + @capture_outputs + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -256,12 +741,12 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + kwargs["use_cache"] = False for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, attention_mask=bidirectional_mask, position_ids=position_ids, - use_cache=False, position_embeddings=position_embeddings, **kwargs, ) @@ -271,16 +756,18 @@ def forward( return BaseModelOutputWithPast(last_hidden_state=hidden_states) -@auto_docstring -class GraniteSpeechNarPreTrainedModel(PreTrainedModel): - config_class = GraniteSpeechNarConfig - base_model_prefix = "" - supports_gradient_checkpointing = True - _supports_flash_attn = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _no_split_modules = ["GraniteSpeechNarConformerBlock", "GraniteDecoderLayer"] - input_modalities = ("audio",) +def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, window_size: int = 4) -> torch.Tensor: + B, T, D = hidden.shape + pad_len = (window_size - T % window_size) % window_size + if pad_len > 0: + hidden = F.pad(hidden, (0, 0, 0, pad_len)) + importance = F.pad(importance, (0, pad_len)) + num_windows = hidden.shape[1] // window_size + hidden = hidden.view(B, num_windows, window_size, D) + importance = importance.view(B, num_windows, window_size) + weights = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8) + pooled = (hidden * weights.unsqueeze(-1)).sum(dim=2) + return pooled class GraniteSpeechNarCTCEncoder(GraniteSpeechNarPreTrainedModel): @@ -381,12 +868,80 @@ def forward( ) -class GraniteSpeechNarLanguageModel(GraniteForCausalLM): +@auto_docstring +class GraniteSpeechNarLanguageModel(GraniteSpeechNarPreTrainedModel, GenerationMixin): """GraniteForCausalLM with a bidirectional (non-causal) backbone.""" + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + def __init__(self, config): super().__init__(config) self.model = GraniteSpeechNarBidirectionalGraniteModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, GraniteSpeechNarLanguageModel + + >>> model = GraniteSpeechNarLanguageModel.from_pretrained("meta-granite_speech_nar/GraniteSpeechNar-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite_speech_nar/GraniteSpeechNar-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = logits / self.config.logits_scaling # main diff with Llama + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @auto_docstring( @@ -406,7 +961,7 @@ def __init__(self, config: GraniteSpeechNarConfig): text_config = config.text_config if hasattr(config, "_attn_implementation"): text_config._attn_implementation = config._attn_implementation - self.language_model = GraniteSpeechNarLanguageModel(text_config) + self.language_model = GraniteSpeechNarLanguageModel._from_config(text_config) self.post_init() @@ -564,8 +1119,7 @@ def forward( ) if self.config.ce_loss_lambda > 0.0: - ce_targets_list = [self._add_insertion_slots(ids) for ids in ctc_token_ids] - ce_targets = torch.cat(ce_targets_list) + ce_targets = torch.cat([self._add_insertion_slots(ids) for ids in ctc_token_ids]) ce_loss = F.cross_entropy( logits, ce_targets.long(), @@ -617,8 +1171,4 @@ def transcribe( ) -__all__ = [ - "GraniteSpeechNarCTCEncoder", - "GraniteSpeechNarForASR", - "GraniteSpeechNarPreTrainedModel", -] +__all__ = ["GraniteSpeechNarCTCEncoder", "GraniteSpeechNarForASR", "GraniteSpeechNarPreTrainedModel"] diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py new file mode 100644 index 000000000000..8c6bffe0e8d9 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -0,0 +1,644 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GraniteSpeechNar: Non-autoregressive ASR with conformer encoder, QFormer projector, +and bidirectional Granite LLM backbone.""" + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +from ...masking_utils import ( + create_bidirectional_mask, + find_packed_sequence_indices, + packed_sequence_mask_function, +) +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, logging +from ..granite.modeling_granite import GraniteAttention, GraniteDecoderLayer, GraniteForCausalLM, GraniteModel +from ..granite_speech.modeling_granite_speech import GraniteSpeechConformerBlock +from .configuration_granite_speech_nar import ( + GraniteSpeechNarConfig, + GraniteSpeechNarEncoderConfig, + GraniteSpeechNarProjectorConfig, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class GraniteSpeechNarEncoderOutput(ModelOutput): + """Output of the GraniteSpeechNar encoder.""" + + loss: torch.Tensor | None = None + logits: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + all_hidden_states: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +class GraniteSpeechNarOutput(ModelOutput): + """Output of the GraniteSpeechNarForASR model. + + Attributes: + loss: Combined CTC + auxiliary losses (only when labels provided). + preds: List of predicted token ID tensors per sample (after CTC collapse, inference only). + logits: List of per-sample logit tensors from the LLM head. + encoder_logits: Flat BPE CTC logits from the encoder. + encoder_preds: List of CTC-collapsed encoder predictions per sample. + """ + + loss: torch.Tensor | None = None + preds: list[torch.Tensor] | None = None + logits: list[torch.Tensor] | None = None + encoder_logits: torch.Tensor | None = None + encoder_preds: list[torch.Tensor] | None = None + + +class GraniteSpeechNarConformerBlock(GraniteSpeechConformerBlock): + pass + + +def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, window_size: int = 4) -> torch.Tensor: + B, T, D = hidden.shape + pad_len = (window_size - T % window_size) % window_size + if pad_len > 0: + hidden = F.pad(hidden, (0, 0, 0, pad_len)) + importance = F.pad(importance, (0, pad_len)) + num_windows = hidden.shape[1] // window_size + hidden = hidden.view(B, num_windows, window_size, D) + importance = importance.view(B, num_windows, window_size) + weights = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8) + pooled = (hidden * weights.unsqueeze(-1)).sum(dim=2) + return pooled + + +class GraniteSpeechNarQFormerCrossAttention(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.num_heads = config.num_heads + self.head_dim = config.hidden_size // config.num_heads + self.hidden_size = config.hidden_size + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, query_len, _ = hidden_states.shape + encoder_len = encoder_hidden_states.shape[1] + + query_states = ( + self.q_proj(hidden_states).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + key_states = ( + self.k_proj(encoder_hidden_states) + .view(batch_size, encoder_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(encoder_hidden_states) + .view(batch_size, encoder_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=False) + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.hidden_size) + return self.o_proj(attn_output) + + +class GraniteSpeechNarQFormerMLP(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + mlp_hidden_size = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(config.hidden_size, mlp_hidden_size, bias=config.mlp_bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(mlp_hidden_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.fc2(self.act(self.fc1(hidden_states))) + + +class GraniteSpeechNarQFormerLayer(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.cross_attention = GraniteSpeechNarQFormerCrossAttention(config) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.mlp = GraniteSpeechNarQFormerMLP(config) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states + self.cross_attention(self.attn_norm(hidden_states), encoder_hidden_states) + hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states)) + return hidden_states + + +class GraniteSpeechNarQFormer(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.layers = nn.ModuleList([GraniteSpeechNarQFormerLayer(config) for _ in range(config.num_layers)]) + + def forward(self, query_embeds: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = query_embeds + for layer in self.layers: + hidden_states = layer(hidden_states, encoder_hidden_states) + return hidden_states + + +class GraniteSpeechNarProjector(nn.Module): + """Windowed QFormer projector that maps multi-layer encoder features to LLM embedding space.""" + + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.config = config + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(config.encoder_dim, eps=config.layernorm_eps) for _ in range(config.num_encoder_layers)] + ) + self.layer_projector = nn.Linear(config.encoder_dim * config.num_encoder_layers, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_prob) + self.projector_act = nn.GELU() + self.qformer = GraniteSpeechNarQFormer(config) + + query_length = config.block_size // config.downsample_rate + embed_std = config.hidden_size**-0.5 + self.query = nn.Parameter(torch.randn(1, query_length, config.hidden_size) * embed_std) + self.window_positions = nn.Parameter(torch.randn(1, config.block_size, config.hidden_size) * embed_std) + self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.out_linear = nn.Linear(config.hidden_size, config.llm_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.size() + + x = x.view(batch_size, seq_len, self.config.num_encoder_layers, self.config.encoder_dim) + normalized_layers = [] + for i, layer_norm in enumerate(self.layer_norms): + normalized_layers.append(layer_norm(x[:, :, i])) + x = torch.cat(normalized_layers, dim=-1) + + x = self.projector_act(self.layer_projector(x)) + + block_size = self.config.block_size + nblocks = seq_len // block_size + rest = seq_len % block_size + if rest > 0: + x = F.pad(x, (0, 0, 0, block_size - rest), "constant", 0) + nblocks += 1 + + x = x.view(batch_size * nblocks, block_size, self.config.hidden_size) + query_length = self.query.shape[1] + mean_pool = x.view( + batch_size * nblocks, query_length, self.config.downsample_rate, self.config.hidden_size + ).mean(dim=-2) + + hidden_states = self.qformer( + query_embeds=self.dropout(self.query + mean_pool), + encoder_hidden_states=self.dropout(x + self.window_positions), + ) + + hidden_states = hidden_states.view(batch_size, nblocks * query_length, -1) + hidden_states = self.dropout(self.out_norm(hidden_states)) + return self.out_linear(hidden_states) + + +@auto_docstring +class GraniteSpeechNarPreTrainedModel(PreTrainedModel): + config_class = GraniteSpeechNarConfig + base_model_prefix = "encoder" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["GraniteSpeechNarConformerBlock", "GraniteDecoderLayer"] + input_modalities = ("audio",) + + +class GraniteSpeechNarBidirectionalAttention(GraniteAttention): + """GraniteAttention with is_causal=False for bidirectional attention.""" + + is_causal = False + + def __init__(self, config, layer_idx=None): + super().__init__(config, layer_idx=layer_idx) + self.is_causal = False + + +class GraniteSpeechNarBidirectionalDecoderLayer(GraniteDecoderLayer): + """GraniteDecoderLayer using bidirectional attention.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = GraniteSpeechNarBidirectionalAttention(config=config, layer_idx=layer_idx) + + +class GraniteSpeechNarBidirectionalGraniteModel(GraniteModel): + """GraniteModel with bidirectional (non-causal) attention. + + Uses GraniteSpeechNarBidirectionalDecoderLayer which sets is_causal=False, + and replaces create_causal_mask() with create_bidirectional_mask() so all + attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. + """ + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList( + [GraniteSpeechNarBidirectionalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + **kwargs, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + if position_ids is None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + packed_seq_mask = find_packed_sequence_indices(position_ids) + and_mask_fn = packed_sequence_mask_function(packed_seq_mask) if packed_seq_mask is not None else None + bidirectional_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + and_mask_function=and_mask_fn, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + kwargs["use_cache"] = False + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=bidirectional_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +class GraniteSpeechNarCTCEncoder(GraniteSpeechNarPreTrainedModel): + """Conformer encoder with BPE CTC head and multi-layer output.""" + + config_class = GraniteSpeechNarEncoderConfig + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__(config) + self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) + self.layers = nn.ModuleList([GraniteSpeechNarConformerBlock(config) for _ in range(config.num_layers)]) + self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True) + self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True) + self.out_bpe = None + if config.bpe_output_dim is not None: + self.out_bpe = nn.Linear(config.hidden_dim, config.bpe_output_dim, bias=True) + self.dropout = nn.Dropout(config.pred_dropout) + self.post_init() + + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + ) -> GraniteSpeechNarEncoderOutput: + if attention_mask is None: + attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) + + hidden_states = self.input_linear(input_features.to(self.dtype)) + all_hidden_states = (hidden_states,) if output_hidden_states else None + blank_probs = None + + context_size = self.config.context_size + seq = torch.arange(context_size, device=hidden_states.device) + relpos_dist = seq.view(-1, 1) - seq.view(1, -1) + attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + self.config.max_pos_emb + + for idx, layer in enumerate(self.layers, start=1): + hidden_states = layer(hidden_states, attention_dists=attention_dists) + + if idx == self.config.self_conditioning_layer: + mid_logits = self.out(self.dropout(hidden_states)) + mid_probs = torch.softmax(mid_logits.float(), dim=-1) + blank_probs = mid_probs[:, :, 0] + hidden_states = hidden_states + self.out_mid(mid_probs.to(hidden_states.dtype)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.dropout(hidden_states) + + logits = None + pool_window = self.config.bpe_pooling_window + if self.out_bpe is not None and blank_probs is not None: + importance = 1.0 - blank_probs + pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( + hidden_states.dtype + ) + T = attention_mask.shape[1] + pad_len = (pool_window - T % pool_window) % pool_window + pooled_mask = F.pad(attention_mask, (0, pad_len), value=False)[:, ::pool_window] + logits = self.out_bpe(pooled[pooled_mask]) + + loss = None + if labels is not None and logits is not None: + encoder_lengths = attention_mask.sum(dim=1) + bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() + T_max = max(bpe_lengths) + B = len(bpe_lengths) + logits_padded = logits.new_zeros(B, T_max, logits.shape[-1]) + offset = 0 + for i, length in enumerate(bpe_lengths): + logits_padded[i, :length] = logits[offset : offset + length] + offset += length + + log_probs = torch.log_softmax(logits_padded.float(), dim=-1) + bpe_x_sizes = torch.tensor(bpe_lengths, device=logits.device) + loss = ( + F.ctc_loss( + log_probs.transpose(0, 1), + labels + 1, + bpe_x_sizes, + label_lengths, + blank=0, + reduction="sum", + zero_infinity=True, + ) + / bpe_x_sizes.sum() + ) + + return GraniteSpeechNarEncoderOutput( + loss=loss, + logits=logits, + last_hidden_state=hidden_states, + all_hidden_states=all_hidden_states, + ) + + +class GraniteSpeechNarLanguageModel(GraniteForCausalLM): + """GraniteForCausalLM with a bidirectional (non-causal) backbone.""" + + def __init__(self, config): + super().__init__(config) + self.model = GraniteSpeechNarBidirectionalGraniteModel(config) + + +@auto_docstring( + custom_intro=""" + The GraniteSpeechNar model for non-autoregressive automatic speech recognition. + Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, + and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. + """ +) +class GraniteSpeechNarForASR(GraniteSpeechNarPreTrainedModel): + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + + self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechNarProjector(config.projector_config) + + text_config = config.text_config + if hasattr(config, "_attn_implementation"): + text_config._attn_implementation = config._attn_implementation + self.language_model = GraniteSpeechNarLanguageModel._from_config(text_config) + + self.post_init() + + def _ctc_collapse_decode( + self, + bpe_logits_flat: torch.Tensor, + bpe_lengths: list[int], + ) -> list[torch.Tensor]: + """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank -> shift.""" + preds_flat = bpe_logits_flat.argmax(dim=-1) + per_sample = preds_flat.split(bpe_lengths) + return [(collapsed := torch.unique_consecutive(seq))[collapsed != 0] - 1 for seq in per_sample] + + def _add_insertion_slots(self, token_ids: torch.Tensor) -> torch.Tensor: + """Insert blank tokens between each CTC token as editing slots for the LLM.""" + blank_id = self.config.blank_token_id + n = token_ids.numel() + total_len = max(2 * n + 1, self.config.min_edit_sequence_length) + idx = torch.arange(n, device=token_ids.device) + out_idx = 2 * idx + 1 + out = torch.full((total_len,), fill_value=blank_id, dtype=token_ids.dtype, device=token_ids.device) + out[out_idx] = token_ids + return out + + def _build_flat_inputs( + self, + ctc_token_ids: list[torch.Tensor], + audio_embeds: torch.Tensor, + audio_lengths: list[int], + ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """Build flat (pad-free) LLM input: [audio_0, text_0, audio_1, text_1, ...]""" + embed_tokens = self.language_model.model.embed_tokens + + embeds_list = [] + position_ids_list = [] + text_lengths = [] + + for i, audio_len in enumerate(audio_lengths): + audio_emb = audio_embeds[i, :audio_len] + text_ids_with_slots = self._add_insertion_slots(ctc_token_ids[i]) + text_emb = embed_tokens(text_ids_with_slots) + sample_embeds = torch.cat([audio_emb, text_emb], dim=0) + embeds_list.append(sample_embeds) + position_ids_list.append(torch.arange(sample_embeds.shape[0], device=audio_embeds.device)) + text_lengths.append(text_ids_with_slots.shape[0]) + + flat_embeds = torch.cat(embeds_list, dim=0).unsqueeze(0) + flat_position_ids = torch.cat(position_ids_list, dim=0).unsqueeze(0) + return flat_embeds, flat_position_ids, text_lengths + + def forward( + self, + *, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + output_encoder_logits: bool = False, + ) -> GraniteSpeechNarOutput: + r""" + Args: + input_features (`torch.Tensor` of shape `(batch_size, seq_len, input_dim)`): + Mel spectrogram features. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Encoder attention mask (1 for valid frames, 0 for padding). + labels (`torch.Tensor` of shape `(batch_size, max_label_len)`, *optional*): + Ground truth LLM token IDs for training. + label_lengths (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Number of valid tokens per sample in `labels`. + output_encoder_logits (`bool`, *optional*, defaults to `False`): + Whether to return encoder BPE logits. When False, the large logits + tensor is freed early to reduce peak memory. + + Returns: + [`GraniteSpeechNarOutput`] + """ + encoder_labels = labels if (labels is not None and self.config.encoder_ctc_loss_lambda > 0.0) else None + enc_out = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + labels=encoder_labels, + label_lengths=label_lengths if encoder_labels is not None else None, + ) + + if attention_mask is None: + attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) + + encoder_lengths = attention_mask.sum(dim=1) + + pool_window = self.encoder.config.bpe_pooling_window + bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() + ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) + + multilayer_features = torch.cat( + [enc_out.all_hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 + ) + + encoder_loss = enc_out.loss + encoder_logits = enc_out.logits if output_encoder_logits else None + del enc_out + + audio_embeds = self.projector(multilayer_features) + del multilayer_features + if self.config.scale_projected_embeddings: + embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) + audio_embeds = audio_embeds / embedding_multiplier + audio_embeds = audio_embeds.to(self.language_model.model.embed_tokens.weight.dtype) + + audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() + + flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( + ctc_token_ids, audio_embeds, audio_lengths + ) + + llm_out = self.language_model.model( + inputs_embeds=flat_embeds, + position_ids=flat_position_ids, + ) + llm_hidden = llm_out.last_hidden_state.squeeze(0) + + segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] + text_hidden = torch.cat(list(llm_hidden.split(segment_lengths)[1::2])) + + logits = self.language_model.lm_head(text_hidden) + logits = logits / self.language_model.config.logits_scaling + logits_per_sample = list(logits.split(text_lengths)) + + loss = None + if labels is not None: + log_probs = torch.log_softmax(logits.float(), dim=-1) + + T_max = max(text_lengths) + B = len(text_lengths) + V = log_probs.shape[-1] + log_probs_padded = log_probs.new_zeros(B, T_max, V) + offset = 0 + for i, tl in enumerate(text_lengths): + log_probs_padded[i, :tl] = log_probs[offset : offset + tl] + offset += tl + + input_lengths = torch.tensor(text_lengths, device=logits.device) + + loss = ( + F.ctc_loss( + log_probs_padded.transpose(0, 1), + labels, + input_lengths, + label_lengths, + blank=self.config.blank_token_id, + reduction="sum", + zero_infinity=True, + ) + / input_lengths.sum() + ) + + if self.config.ce_loss_lambda > 0.0: + ce_targets = torch.cat([self._add_insertion_slots(ids) for ids in ctc_token_ids]) + ce_loss = F.cross_entropy( + logits, + ce_targets.long(), + reduction="mean", + ignore_index=-100, + ) + loss = loss + self.config.ce_loss_lambda * ce_loss + + if encoder_loss is not None: + loss = loss + self.config.encoder_ctc_loss_lambda * encoder_loss + + return GraniteSpeechNarOutput( + loss=loss, + logits=logits_per_sample, + encoder_logits=encoder_logits, + encoder_preds=ctc_token_ids, + ) + + @torch.inference_mode() + def transcribe( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_encoder_logits: bool = False, + ) -> GraniteSpeechNarOutput: + """Single-pass non-autoregressive inference: forward + CTC collapse on LLM output. + + Returns token ID tensors in `preds`. Use `GraniteSpeechNarProcessor.batch_decode()` + to convert to strings. + """ + output = self.forward( + input_features=input_features, + attention_mask=attention_mask, + output_encoder_logits=output_encoder_logits, + ) + + blank_id = self.config.blank_token_id + preds = [] + for sample_logits in output.logits: + pred = torch.unique_consecutive(sample_logits.argmax(-1)) + pred = pred[pred != blank_id] + preds.append(pred) + + return GraniteSpeechNarOutput( + preds=preds, + logits=output.logits, + encoder_logits=output.encoder_logits, + encoder_preds=output.encoder_preds, + ) + + +__all__ = [ + "GraniteSpeechNarCTCEncoder", + "GraniteSpeechNarForASR", + "GraniteSpeechNarPreTrainedModel", +] From a46f94d28d2198ad3b9a58bdb608d5e905fdcb3a Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 14:42:09 +0000 Subject: [PATCH 03/39] minor --- .../configuration_granite_speech_nar.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py index 3cf99a3be500..2f2da61191b8 100644 --- a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py @@ -15,14 +15,14 @@ from huggingface_hub.dataclasses import strict -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PreTrainedConfig from ...utils import auto_docstring from ..auto import CONFIG_MAPPING @auto_docstring @strict -class GraniteSpeechNarEncoderConfig(PretrainedConfig): +class GraniteSpeechNarEncoderConfig(PreTrainedConfig): r""" Configuration for the conformer encoder component of GraniteSpeechNar. @@ -92,7 +92,7 @@ def __post_init__(self, **kwargs): @auto_docstring @strict -class GraniteSpeechNarProjectorConfig(PretrainedConfig): +class GraniteSpeechNarProjectorConfig(PreTrainedConfig): r""" Configuration for the QFormer-based audio projector in GraniteSpeechNar. @@ -140,7 +140,7 @@ class GraniteSpeechNarProjectorConfig(PretrainedConfig): @auto_docstring @strict -class GraniteSpeechNarConfig(PretrainedConfig): +class GraniteSpeechNarConfig(PreTrainedConfig): r""" Configuration for the GraniteSpeechNar non-autoregressive ASR model. @@ -149,6 +149,8 @@ class GraniteSpeechNarConfig(PretrainedConfig): projector_config (`GraniteSpeechNarProjectorConfig` or `dict`, *optional*): Configuration for the QFormer-based audio projector. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the LLM's input and output word embeddings should be tied. encoder_layer_indices (`list[int]`, *optional*, defaults to `[4, 8, 12, -1]`): Indices of encoder layers whose hidden states are concatenated as projector input. scale_projected_embeddings (`bool`, *optional*, defaults to `True`): @@ -180,9 +182,10 @@ class GraniteSpeechNarConfig(PretrainedConfig): "text_config": "AutoConfig", } - encoder_config: dict | PretrainedConfig | None = None - projector_config: dict | PretrainedConfig | None = None - text_config: dict | PretrainedConfig | None = None + encoder_config: dict | PreTrainedConfig | None = None + projector_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + tie_word_embeddings: bool = True encoder_layer_indices: list[int] | None = None scale_projected_embeddings: bool = True blank_token_id: int | None = None From 925bfdb804f927d6ea68e43bcc1a7971f3427ada Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 16:33:02 +0000 Subject: [PATCH 04/39] minor --- .../granite_speech_nar/modeling_granite_speech_nar.py | 8 +++++++- .../granite_speech_nar/modular_granite_speech_nar.py | 7 ++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index f53f428a03b4..ff1e3ce54c2e 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -1171,4 +1171,10 @@ def transcribe( ) -__all__ = ["GraniteSpeechNarCTCEncoder", "GraniteSpeechNarForASR", "GraniteSpeechNarPreTrainedModel"] +__all__ = [ + "GraniteSpeechNarBidirectionalGraniteModel", + "GraniteSpeechNarCTCEncoder", + "GraniteSpeechNarForASR", + "GraniteSpeechNarLanguageModel", + "GraniteSpeechNarPreTrainedModel", +] diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 8c6bffe0e8d9..1a14cd1972de 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -255,7 +255,10 @@ class GraniteSpeechNarBidirectionalGraniteModel(GraniteModel): def __init__(self, config): super().__init__(config) self.layers = nn.ModuleList( - [GraniteSpeechNarBidirectionalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [ + GraniteSpeechNarBidirectionalDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] ) def forward( @@ -638,7 +641,9 @@ def transcribe( __all__ = [ + "GraniteSpeechNarBidirectionalGraniteModel", "GraniteSpeechNarCTCEncoder", "GraniteSpeechNarForASR", + "GraniteSpeechNarLanguageModel", "GraniteSpeechNarPreTrainedModel", ] From daf4f9313e9ba622ae7ed8fca2e86d01211e24b6 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 16:43:31 +0000 Subject: [PATCH 05/39] minor --- utils/check_repo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/check_repo.py b/utils/check_repo.py index 5a7484409e31..6adf34d98526 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -308,6 +308,7 @@ "models/sam3_video/test_modeling_sam3_video.py", "models/edgetam_video/test_modeling_edgetam_video.py", "models/gemma4_assistant/test_modeling_gemma4_assistant.py", + "models/granite_speech_nar/test_modeling_granite_speech_nar.py", ] # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and From 65729c2f2ba0b8f60b69973df2953de8e01a55e3 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 18:04:33 +0000 Subject: [PATCH 06/39] add docs, rename components, skip submodule tests --- docs/source/en/_toctree.yml | 2 + .../source/en/model_doc/granite_speech_nar.md | 71 +++++++++++++++++++ .../modeling_granite_speech_nar.py | 38 +++++----- .../modular_granite_speech_nar.py | 33 +++++---- utils/check_repo.py | 2 + 5 files changed, 115 insertions(+), 31 deletions(-) create mode 100644 docs/source/en/model_doc/granite_speech_nar.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 80ef11db201c..f973b971dc86 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1071,6 +1071,8 @@ title: GLM-ASR - local: model_doc/granite_speech title: GraniteSpeech + - local: model_doc/granite_speech_nar + title: GraniteSpeechNar - local: model_doc/granite_speech_plus title: GraniteSpeechPlus - local: model_doc/higgs_audio_v2 diff --git a/docs/source/en/model_doc/granite_speech_nar.md b/docs/source/en/model_doc/granite_speech_nar.md new file mode 100644 index 000000000000..a65cb0688579 --- /dev/null +++ b/docs/source/en/model_doc/granite_speech_nar.md @@ -0,0 +1,71 @@ + + +# GraniteSpeechNar + +## Overview + +GraniteSpeechNar is a non-autoregressive (NAR) speech recognition model based on [NLE: Non-autoregressive LLM-based ASR by Transcript Editing](https://huggingface.co/papers/2603.08397). It formulates ASR as conditional transcript editing, achieving fully parallel prediction with significant speedups over autoregressive baselines. + +The model consists of: + +1. **Conformer Encoder**: A conformer encoder trained with CTC on BPE targets, using block-attention and self-conditioned CTC from the middle layer. + +2. **QFormer Projector**: A windowed query-transformer that maps multi-layer encoder features to the LLM embedding space with temporal downsampling. + +3. **Bidirectional Granite LLM**: A Granite language model with bidirectional (non-causal) attention that refines CTC predictions in a single forward pass. + +The model performs inference in a single pass: the encoder produces initial CTC predictions, which are interleaved with blank insertion slots (exploiting the identity mapping bias of Transformers) and fed alongside projected audio embeddings to the bidirectional LLM for refinement via a latent alignment objective. + +This model was contributed by [Avihu Dekel](https://huggingface.co/Avihu). + +## GraniteSpeechNarConfig + +[[autodoc]] GraniteSpeechNarConfig + +## GraniteSpeechNarEncoderConfig + +[[autodoc]] GraniteSpeechNarEncoderConfig + +## GraniteSpeechNarProjectorConfig + +[[autodoc]] GraniteSpeechNarProjectorConfig + +## GraniteSpeechNarProcessor + +[[autodoc]] GraniteSpeechNarProcessor + - __call__ + - batch_decode + +## GraniteSpeechNarFeatureExtractor + +[[autodoc]] GraniteSpeechNarFeatureExtractor + +## GraniteSpeechNarModel + +[[autodoc]] GraniteSpeechNarModel + - forward + +## GraniteSpeechNarLM + +[[autodoc]] GraniteSpeechNarLM + - forward + +## GraniteSpeechNarForASR + +[[autodoc]] GraniteSpeechNarForASR + - forward + - transcribe diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index ff1e3ce54c2e..5b0509213237 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -441,7 +441,7 @@ def eager_attention_forward( @use_kernelized_func(apply_rotary_pos_emb) -class GraniteSpeechNarBidirectionalAttention(nn.Module): +class GraniteSpeechNarAttention(nn.Module): """GraniteAttention with is_causal=False for bidirectional attention.""" is_causal = False @@ -547,13 +547,13 @@ def forward(self, x): return down_proj -class GraniteSpeechNarBidirectionalDecoderLayer(GradientCheckpointingLayer): +class GraniteSpeechNarDecoderLayer(GradientCheckpointingLayer): """GraniteDecoderLayer using bidirectional attention.""" def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GraniteSpeechNarBidirectionalAttention(config=config, layer_idx=layer_idx) + self.self_attn = GraniteSpeechNarAttention(config=config, layer_idx=layer_idx) self.mlp = GraniteSpeechNarMLP(config) self.input_layernorm = GraniteSpeechNarRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -679,10 +679,10 @@ def forward(self, x, position_ids): @auto_docstring -class GraniteSpeechNarBidirectionalGraniteModel(GraniteSpeechNarPreTrainedModel): +class GraniteSpeechNarModel(GraniteSpeechNarPreTrainedModel): """GraniteModel with bidirectional (non-causal) attention. - Uses GraniteSpeechNarBidirectionalDecoderLayer which sets is_causal=False, + Uses GraniteSpeechNarDecoderLayer which sets is_causal=False, and replaces create_causal_mask() with create_bidirectional_mask() so all attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. """ @@ -694,10 +694,7 @@ def __init__(self, config): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( - [ - GraniteSpeechNarBidirectionalDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] + [GraniteSpeechNarDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GraniteSpeechNarRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GraniteSpeechNarRotaryEmbedding(config=config) @@ -794,6 +791,7 @@ def forward( output_hidden_states: bool | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, + **kwargs, ) -> GraniteSpeechNarEncoderOutput: if attention_mask is None: attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) @@ -868,8 +866,13 @@ def forward( ) -@auto_docstring -class GraniteSpeechNarLanguageModel(GraniteSpeechNarPreTrainedModel, GenerationMixin): +@auto_docstring( + custom_intro=""" + The bidirectional language model component of GraniteSpeechNar, used internally + to refine CTC predictions in a single non-autoregressive pass. + """ +) +class GraniteSpeechNarLM(GraniteSpeechNarPreTrainedModel, GenerationMixin): """GraniteForCausalLM with a bidirectional (non-causal) backbone.""" _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} @@ -878,7 +881,7 @@ class GraniteSpeechNarLanguageModel(GraniteSpeechNarPreTrainedModel, GenerationM def __init__(self, config): super().__init__(config) - self.model = GraniteSpeechNarBidirectionalGraniteModel(config) + self.model = GraniteSpeechNarModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -903,9 +906,9 @@ def forward( Example: ```python - >>> from transformers import AutoTokenizer, GraniteSpeechNarLanguageModel + >>> from transformers import AutoTokenizer, GraniteSpeechNarLM - >>> model = GraniteSpeechNarLanguageModel.from_pretrained("meta-granite_speech_nar/GraniteSpeechNar-2-7b-hf") + >>> model = GraniteSpeechNarLM.from_pretrained("meta-granite_speech_nar/GraniteSpeechNar-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite_speech_nar/GraniteSpeechNar-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" @@ -961,7 +964,7 @@ def __init__(self, config: GraniteSpeechNarConfig): text_config = config.text_config if hasattr(config, "_attn_implementation"): text_config._attn_implementation = config._attn_implementation - self.language_model = GraniteSpeechNarLanguageModel._from_config(text_config) + self.language_model = GraniteSpeechNarLM._from_config(text_config) self.post_init() @@ -1020,6 +1023,7 @@ def forward( labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, output_encoder_logits: bool = False, + **kwargs, ) -> GraniteSpeechNarOutput: r""" Args: @@ -1172,9 +1176,9 @@ def transcribe( __all__ = [ - "GraniteSpeechNarBidirectionalGraniteModel", + "GraniteSpeechNarModel", "GraniteSpeechNarCTCEncoder", "GraniteSpeechNarForASR", - "GraniteSpeechNarLanguageModel", + "GraniteSpeechNarLM", "GraniteSpeechNarPreTrainedModel", ] diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 1a14cd1972de..e5c37a5e0e60 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -226,7 +226,7 @@ class GraniteSpeechNarPreTrainedModel(PreTrainedModel): input_modalities = ("audio",) -class GraniteSpeechNarBidirectionalAttention(GraniteAttention): +class GraniteSpeechNarAttention(GraniteAttention): """GraniteAttention with is_causal=False for bidirectional attention.""" is_causal = False @@ -236,18 +236,18 @@ def __init__(self, config, layer_idx=None): self.is_causal = False -class GraniteSpeechNarBidirectionalDecoderLayer(GraniteDecoderLayer): +class GraniteSpeechNarDecoderLayer(GraniteDecoderLayer): """GraniteDecoderLayer using bidirectional attention.""" def __init__(self, config, layer_idx: int): super().__init__(config, layer_idx) - self.self_attn = GraniteSpeechNarBidirectionalAttention(config=config, layer_idx=layer_idx) + self.self_attn = GraniteSpeechNarAttention(config=config, layer_idx=layer_idx) -class GraniteSpeechNarBidirectionalGraniteModel(GraniteModel): +class GraniteSpeechNarModel(GraniteModel): """GraniteModel with bidirectional (non-causal) attention. - Uses GraniteSpeechNarBidirectionalDecoderLayer which sets is_causal=False, + Uses GraniteSpeechNarDecoderLayer which sets is_causal=False, and replaces create_causal_mask() with create_bidirectional_mask() so all attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. """ @@ -255,10 +255,7 @@ class GraniteSpeechNarBidirectionalGraniteModel(GraniteModel): def __init__(self, config): super().__init__(config) self.layers = nn.ModuleList( - [ - GraniteSpeechNarBidirectionalDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] + [GraniteSpeechNarDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) def forward( @@ -331,6 +328,7 @@ def forward( output_hidden_states: bool | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, + **kwargs, ) -> GraniteSpeechNarEncoderOutput: if attention_mask is None: attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) @@ -405,12 +403,18 @@ def forward( ) -class GraniteSpeechNarLanguageModel(GraniteForCausalLM): +@auto_docstring( + custom_intro=""" + The bidirectional language model component of GraniteSpeechNar, used internally + to refine CTC predictions in a single non-autoregressive pass. + """ +) +class GraniteSpeechNarLM(GraniteForCausalLM): """GraniteForCausalLM with a bidirectional (non-causal) backbone.""" def __init__(self, config): super().__init__(config) - self.model = GraniteSpeechNarBidirectionalGraniteModel(config) + self.model = GraniteSpeechNarModel(config) @auto_docstring( @@ -430,7 +434,7 @@ def __init__(self, config: GraniteSpeechNarConfig): text_config = config.text_config if hasattr(config, "_attn_implementation"): text_config._attn_implementation = config._attn_implementation - self.language_model = GraniteSpeechNarLanguageModel._from_config(text_config) + self.language_model = GraniteSpeechNarLM._from_config(text_config) self.post_init() @@ -489,6 +493,7 @@ def forward( labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, output_encoder_logits: bool = False, + **kwargs, ) -> GraniteSpeechNarOutput: r""" Args: @@ -641,9 +646,9 @@ def transcribe( __all__ = [ - "GraniteSpeechNarBidirectionalGraniteModel", + "GraniteSpeechNarModel", "GraniteSpeechNarCTCEncoder", "GraniteSpeechNarForASR", - "GraniteSpeechNarLanguageModel", + "GraniteSpeechNarLM", "GraniteSpeechNarPreTrainedModel", ] diff --git a/utils/check_repo.py b/utils/check_repo.py index 6adf34d98526..0e33364dbfcd 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -514,6 +514,8 @@ "Ernie4_5_VL_MoeTextModel", # BC Alias "UVDocBridge", # Building part of a bigger model, tested implicitly through UVDocModel "Granite4VisionTextModel", # Building part of bigger (tested) model. + "GraniteSpeechNarModel", # Building part of bigger (tested) model. + "GraniteSpeechNarLM", # Building part of bigger (tested) model. ] From fa81168ea1f5cfbe0016f54f3df1d254c5db9edc Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 18:20:27 +0000 Subject: [PATCH 07/39] fix check_config_docstrings_have_checkpoints --- .../granite_speech_nar/configuration_granite_speech_nar.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py index 2f2da61191b8..3dc3da7facc1 100644 --- a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py @@ -20,7 +20,7 @@ from ..auto import CONFIG_MAPPING -@auto_docstring +@auto_docstring(checkpoint="ibm-granite/granite-speech-4.1-2b-nar") @strict class GraniteSpeechNarEncoderConfig(PreTrainedConfig): r""" @@ -90,7 +90,7 @@ def __post_init__(self, **kwargs): self.self_conditioning_layer = self.num_layers // 2 -@auto_docstring +@auto_docstring(checkpoint="ibm-granite/granite-speech-4.1-2b-nar") @strict class GraniteSpeechNarProjectorConfig(PreTrainedConfig): r""" @@ -138,7 +138,7 @@ class GraniteSpeechNarProjectorConfig(PreTrainedConfig): mlp_bias: bool = True -@auto_docstring +@auto_docstring(checkpoint="ibm-granite/granite-speech-4.1-2b-nar") @strict class GraniteSpeechNarConfig(PreTrainedConfig): r""" From f72b333519fd7d7361b46504e18fe72f36f202f4 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 18:29:07 +0000 Subject: [PATCH 08/39] add dates --- docs/source/en/model_doc/granite_speech_nar.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/model_doc/granite_speech_nar.md b/docs/source/en/model_doc/granite_speech_nar.md index a65cb0688579..d2d43c3d4095 100644 --- a/docs/source/en/model_doc/granite_speech_nar.md +++ b/docs/source/en/model_doc/granite_speech_nar.md @@ -13,6 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> +*This model was released on 2026-03-09 and added to Hugging Face Transformers on 2026-05-18.* # GraniteSpeechNar From 151c7090c355454897c4238b5e615c9e6a609206 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Mon, 18 May 2026 19:20:02 +0000 Subject: [PATCH 09/39] save processor imports --- .../feature_extraction_granite_speech_nar.py | 13 ++++++++++--- .../processing_granite_speech_nar.py | 7 +++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py index 84e95a6059f9..1abe099bf154 100644 --- a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py @@ -13,10 +13,16 @@ # limitations under the License. """Feature extraction for Granite Speech NAR.""" -import torch -import torchaudio - from ...feature_extraction_utils import FeatureExtractionMixin +from ...utils import is_torch_available, is_torchaudio_available +from ...utils.import_utils import requires_backends + + +if is_torch_available(): + import torch + +if is_torchaudio_available(): + import torchaudio class GraniteSpeechNarFeatureExtractor(FeatureExtractionMixin): @@ -37,6 +43,7 @@ def __init__( n_mels: int = 80, **kwargs, ): + requires_backends(self, ["torch", "torchaudio"]) super().__init__(**kwargs) self.sampling_rate = sampling_rate self.n_fft = n_fft diff --git a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py index 38d7abe3719d..d3fae65d0299 100644 --- a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py @@ -13,12 +13,15 @@ # limitations under the License. """Processor for Granite Speech NAR.""" -import torch - from ...processing_utils import ProcessorMixin +from ...utils import is_torch_available from .feature_extraction_granite_speech_nar import GraniteSpeechNarFeatureExtractor +if is_torch_available(): + import torch + + class GraniteSpeechNarProcessor(ProcessorMixin): """Processor combining audio feature extraction and tokenizer for GraniteSpeechNar.""" From d5d851c36d63c1ba5b2d839ab23ecc0462603172 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Tue, 19 May 2026 05:06:25 +0000 Subject: [PATCH 10/39] avoid a crash without torch available --- .../granite_speech_nar/feature_extraction_granite_speech_nar.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py index 1abe099bf154..58a5d804322b 100644 --- a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py @@ -58,7 +58,6 @@ def __init__( n_mels=n_mels, ) - @torch.no_grad() def _extract_features(self, raw_audio: torch.Tensor) -> torch.Tensor: mel_transform = self.mel_transform.to(raw_audio.device) B, T = raw_audio.shape From 0c29229b57b5765fa4f92e592286f467390256b2 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Tue, 19 May 2026 05:40:05 +0000 Subject: [PATCH 11/39] fix unguarded torch usage in typing --- .../feature_extraction_granite_speech_nar.py | 24 ++++++++++--------- .../processing_granite_speech_nar.py | 7 +++--- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py index 58a5d804322b..740a4f00d116 100644 --- a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py @@ -14,6 +14,7 @@ """Feature extraction for Granite Speech NAR.""" from ...feature_extraction_utils import FeatureExtractionMixin +from ...tokenization_utils_base import AudioInput from ...utils import is_torch_available, is_torchaudio_available from ...utils.import_utils import requires_backends @@ -58,20 +59,21 @@ def __init__( n_mels=n_mels, ) - def _extract_features(self, raw_audio: torch.Tensor) -> torch.Tensor: - mel_transform = self.mel_transform.to(raw_audio.device) - B, T = raw_audio.shape - l = 2 * (T // (2 * self.hop_length)) - mel = mel_transform(raw_audio.float())[..., :l] - logmel = mel.transpose(-1, -2).clamp_min_(1e-10).log10_() - mx = logmel.amax(dim=(-2, -1), keepdim=True) - logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) - return logmel.reshape(B, -1, 2 * self.n_mels) + def _extract_features(self, raw_audio: "torch.Tensor") -> "torch.Tensor": + with torch.no_grad(): + mel_transform = self.mel_transform.to(raw_audio.device) + B, T = raw_audio.shape + l = 2 * (T // (2 * self.hop_length)) + mel = mel_transform(raw_audio.float())[..., :l] + logmel = mel.transpose(-1, -2).clamp_min_(1e-10).log10_() + mx = logmel.amax(dim=(-2, -1), keepdim=True) + logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) + return logmel.reshape(B, -1, 2 * self.n_mels) def __call__( self, - audios: torch.Tensor | list[torch.Tensor], - device: str | torch.device | None = None, + audios: AudioInput, + device: str | "torch.device" | None = None, ) -> dict: if isinstance(audios, torch.Tensor): if audios.ndim == 1: diff --git a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py index d3fae65d0299..190462a23321 100644 --- a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py @@ -14,6 +14,7 @@ """Processor for Granite Speech NAR.""" from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import AudioInput from ...utils import is_torch_available from .feature_extraction_granite_speech_nar import GraniteSpeechNarFeatureExtractor @@ -33,13 +34,13 @@ def __init__(self, feature_extractor: GraniteSpeechNarFeatureExtractor, tokenize def __call__( self, - audios: torch.Tensor | list[torch.Tensor], - device: str | torch.device | None = None, + audios: AudioInput, + device: str | "torch.device" | None = None, **kwargs, ) -> dict: return self.feature_extractor(audios, device=device) - def batch_decode(self, token_ids_list: list[torch.Tensor], **kwargs) -> list[str]: + def batch_decode(self, token_ids_list: list["torch.Tensor"], **kwargs) -> list[str]: if self.tokenizer is None: raise ValueError("Tokenizer not set. Pass tokenizer to GraniteSpeechNarProcessor.") return [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in token_ids_list] From e592181120643bd7b8b48ba81a1594682283ab99 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Tue, 19 May 2026 05:46:13 +0000 Subject: [PATCH 12/39] minor --- .../granite_speech_nar/feature_extraction_granite_speech_nar.py | 2 +- .../models/granite_speech_nar/processing_granite_speech_nar.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py index 740a4f00d116..0f4aae05ba10 100644 --- a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py @@ -73,7 +73,7 @@ def _extract_features(self, raw_audio: "torch.Tensor") -> "torch.Tensor": def __call__( self, audios: AudioInput, - device: str | "torch.device" | None = None, + device: str | None = None, ) -> dict: if isinstance(audios, torch.Tensor): if audios.ndim == 1: diff --git a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py index 190462a23321..b8cfc95f5d3a 100644 --- a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py @@ -35,7 +35,7 @@ def __init__(self, feature_extractor: GraniteSpeechNarFeatureExtractor, tokenize def __call__( self, audios: AudioInput, - device: str | "torch.device" | None = None, + device: str | None = None, **kwargs, ) -> dict: return self.feature_extractor(audios, device=device) From 480801cce359a5ea16152e3b8596e7df0abc69a8 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Tue, 19 May 2026 06:38:25 +0000 Subject: [PATCH 13/39] clean up variable names, avoid spliting the language_model/lm_head forwards --- .../modeling_granite_speech_nar.py | 114 ++++++++---------- .../modular_granite_speech_nar.py | 114 ++++++++---------- 2 files changed, 106 insertions(+), 122 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index 5b0509213237..4a9241ccf621 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -324,33 +324,35 @@ def __init__(self, config: GraniteSpeechNarProjectorConfig): self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) self.out_linear = nn.Linear(config.hidden_size, config.llm_dim) - def forward(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.size() + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = hidden_states.size() - x = x.view(batch_size, seq_len, self.config.num_encoder_layers, self.config.encoder_dim) + hidden_states = hidden_states.view( + batch_size, seq_len, self.config.num_encoder_layers, self.config.encoder_dim + ) normalized_layers = [] for i, layer_norm in enumerate(self.layer_norms): - normalized_layers.append(layer_norm(x[:, :, i])) - x = torch.cat(normalized_layers, dim=-1) + normalized_layers.append(layer_norm(hidden_states[:, :, i])) + hidden_states = torch.cat(normalized_layers, dim=-1) - x = self.projector_act(self.layer_projector(x)) + hidden_states = self.projector_act(self.layer_projector(hidden_states)) block_size = self.config.block_size nblocks = seq_len // block_size rest = seq_len % block_size if rest > 0: - x = F.pad(x, (0, 0, 0, block_size - rest), "constant", 0) + hidden_states = F.pad(hidden_states, (0, 0, 0, block_size - rest), "constant", 0) nblocks += 1 - x = x.view(batch_size * nblocks, block_size, self.config.hidden_size) + hidden_states = hidden_states.view(batch_size * nblocks, block_size, self.config.hidden_size) query_length = self.query.shape[1] - mean_pool = x.view( + mean_pool = hidden_states.view( batch_size * nblocks, query_length, self.config.downsample_rate, self.config.hidden_size ).mean(dim=-2) hidden_states = self.qformer( query_embeds=self.dropout(self.query + mean_pool), - encoder_hidden_states=self.dropout(x + self.window_positions), + encoder_hidden_states=self.dropout(hidden_states + self.window_positions), ) hidden_states = hidden_states.view(batch_size, nblocks * query_length, -1) @@ -738,6 +740,7 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + # KV cache is not needed in a non-autoregressive model kwargs["use_cache"] = False for decoder_layer in self.layers: hidden_states = decoder_layer( @@ -754,14 +757,14 @@ def forward( def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, window_size: int = 4) -> torch.Tensor: - B, T, D = hidden.shape - pad_len = (window_size - T % window_size) % window_size + batch_size, seq_len, hidden_dim = hidden.shape + pad_len = (window_size - seq_len % window_size) % window_size if pad_len > 0: hidden = F.pad(hidden, (0, 0, 0, pad_len)) importance = F.pad(importance, (0, pad_len)) num_windows = hidden.shape[1] // window_size - hidden = hidden.view(B, num_windows, window_size, D) - importance = importance.view(B, num_windows, window_size) + hidden = hidden.view(batch_size, num_windows, window_size, hidden_dim) + importance = importance.view(batch_size, num_windows, window_size) weights = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8) pooled = (hidden * weights.unsqueeze(-1)).sum(dim=2) return pooled @@ -805,10 +808,10 @@ def forward( relpos_dist = seq.view(-1, 1) - seq.view(1, -1) attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + self.config.max_pos_emb - for idx, layer in enumerate(self.layers, start=1): + for layer_idx, layer in enumerate(self.layers, start=1): hidden_states = layer(hidden_states, attention_dists=attention_dists) - if idx == self.config.self_conditioning_layer: + if layer_idx == self.config.self_conditioning_layer: mid_logits = self.out(self.dropout(hidden_states)) mid_probs = torch.softmax(mid_logits.float(), dim=-1) blank_probs = mid_probs[:, :, 0] @@ -820,43 +823,38 @@ def forward( hidden_states = self.dropout(hidden_states) logits = None - pool_window = self.config.bpe_pooling_window + loss = None if self.out_bpe is not None and blank_probs is not None: + pool_window = self.config.bpe_pooling_window importance = 1.0 - blank_probs pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( hidden_states.dtype ) - T = attention_mask.shape[1] - pad_len = (pool_window - T % pool_window) % pool_window - pooled_mask = F.pad(attention_mask, (0, pad_len), value=False)[:, ::pool_window] - logits = self.out_bpe(pooled[pooled_mask]) - - loss = None - if labels is not None and logits is not None: encoder_lengths = attention_mask.sum(dim=1) - bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() - T_max = max(bpe_lengths) - B = len(bpe_lengths) - logits_padded = logits.new_zeros(B, T_max, logits.shape[-1]) - offset = 0 - for i, length in enumerate(bpe_lengths): - logits_padded[i, :length] = logits[offset : offset + length] - offset += length - - log_probs = torch.log_softmax(logits_padded.float(), dim=-1) - bpe_x_sizes = torch.tensor(bpe_lengths, device=logits.device) - loss = ( - F.ctc_loss( - log_probs.transpose(0, 1), - labels + 1, - bpe_x_sizes, - label_lengths, - blank=0, - reduction="sum", - zero_infinity=True, + lengths = -(encoder_lengths // -pool_window) + lengths_list = lengths.tolist() + logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) + + if labels is not None: + logits_padded = logits.new_zeros(len(lengths_list), max(lengths_list), logits.shape[-1]) + offset = 0 + for i, length in enumerate(lengths_list): + logits_padded[i, :length] = logits[offset : offset + length] + offset += length + + log_probs = torch.log_softmax(logits_padded.float(), dim=-1) + loss = ( + F.ctc_loss( + log_probs.transpose(0, 1), + labels + 1, + lengths, + label_lengths, + blank=0, + reduction="sum", + zero_infinity=True, + ) + / lengths.sum() ) - / bpe_x_sizes.sum() - ) return GraniteSpeechNarEncoderOutput( loss=loss, @@ -1042,7 +1040,7 @@ def forward( Returns: [`GraniteSpeechNarOutput`] """ - encoder_labels = labels if (labels is not None and self.config.encoder_ctc_loss_lambda > 0.0) else None + encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None enc_out = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -1081,33 +1079,27 @@ def forward( ctc_token_ids, audio_embeds, audio_lengths ) - llm_out = self.language_model.model( + llm_out = self.language_model( inputs_embeds=flat_embeds, position_ids=flat_position_ids, ) - llm_hidden = llm_out.last_hidden_state.squeeze(0) + all_logits = llm_out.logits.squeeze(0) segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] - text_hidden = torch.cat(list(llm_hidden.split(segment_lengths)[1::2])) - - logits = self.language_model.lm_head(text_hidden) - logits = logits / self.language_model.config.logits_scaling - logits_per_sample = list(logits.split(text_lengths)) + text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) + logits_per_sample = list(text_logits.split(text_lengths)) loss = None if labels is not None: - log_probs = torch.log_softmax(logits.float(), dim=-1) + log_probs = torch.log_softmax(text_logits.float(), dim=-1) - T_max = max(text_lengths) - B = len(text_lengths) - V = log_probs.shape[-1] - log_probs_padded = log_probs.new_zeros(B, T_max, V) + log_probs_padded = log_probs.new_zeros(len(text_lengths), max(text_lengths), log_probs.shape[-1]) offset = 0 for i, tl in enumerate(text_lengths): log_probs_padded[i, :tl] = log_probs[offset : offset + tl] offset += tl - input_lengths = torch.tensor(text_lengths, device=logits.device) + input_lengths = torch.tensor(text_lengths, device=text_logits.device) loss = ( F.ctc_loss( @@ -1125,7 +1117,7 @@ def forward( if self.config.ce_loss_lambda > 0.0: ce_targets = torch.cat([self._add_insertion_slots(ids) for ids in ctc_token_ids]) ce_loss = F.cross_entropy( - logits, + text_logits, ce_targets.long(), reduction="mean", ignore_index=-100, diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index e5c37a5e0e60..76b4bf1aa95f 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -74,14 +74,14 @@ class GraniteSpeechNarConformerBlock(GraniteSpeechConformerBlock): def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, window_size: int = 4) -> torch.Tensor: - B, T, D = hidden.shape - pad_len = (window_size - T % window_size) % window_size + batch_size, seq_len, hidden_dim = hidden.shape + pad_len = (window_size - seq_len % window_size) % window_size if pad_len > 0: hidden = F.pad(hidden, (0, 0, 0, pad_len)) importance = F.pad(importance, (0, pad_len)) num_windows = hidden.shape[1] // window_size - hidden = hidden.view(B, num_windows, window_size, D) - importance = importance.view(B, num_windows, window_size) + hidden = hidden.view(batch_size, num_windows, window_size, hidden_dim) + importance = importance.view(batch_size, num_windows, window_size) weights = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8) pooled = (hidden * weights.unsqueeze(-1)).sum(dim=2) return pooled @@ -180,33 +180,35 @@ def __init__(self, config: GraniteSpeechNarProjectorConfig): self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) self.out_linear = nn.Linear(config.hidden_size, config.llm_dim) - def forward(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.size() + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = hidden_states.size() - x = x.view(batch_size, seq_len, self.config.num_encoder_layers, self.config.encoder_dim) + hidden_states = hidden_states.view( + batch_size, seq_len, self.config.num_encoder_layers, self.config.encoder_dim + ) normalized_layers = [] for i, layer_norm in enumerate(self.layer_norms): - normalized_layers.append(layer_norm(x[:, :, i])) - x = torch.cat(normalized_layers, dim=-1) + normalized_layers.append(layer_norm(hidden_states[:, :, i])) + hidden_states = torch.cat(normalized_layers, dim=-1) - x = self.projector_act(self.layer_projector(x)) + hidden_states = self.projector_act(self.layer_projector(hidden_states)) block_size = self.config.block_size nblocks = seq_len // block_size rest = seq_len % block_size if rest > 0: - x = F.pad(x, (0, 0, 0, block_size - rest), "constant", 0) + hidden_states = F.pad(hidden_states, (0, 0, 0, block_size - rest), "constant", 0) nblocks += 1 - x = x.view(batch_size * nblocks, block_size, self.config.hidden_size) + hidden_states = hidden_states.view(batch_size * nblocks, block_size, self.config.hidden_size) query_length = self.query.shape[1] - mean_pool = x.view( + mean_pool = hidden_states.view( batch_size * nblocks, query_length, self.config.downsample_rate, self.config.hidden_size ).mean(dim=-2) hidden_states = self.qformer( query_embeds=self.dropout(self.query + mean_pool), - encoder_hidden_states=self.dropout(x + self.window_positions), + encoder_hidden_states=self.dropout(hidden_states + self.window_positions), ) hidden_states = hidden_states.view(batch_size, nblocks * query_length, -1) @@ -289,6 +291,7 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + # KV cache is not needed in a non-autoregressive model kwargs["use_cache"] = False for decoder_layer in self.layers: hidden_states = decoder_layer( @@ -342,10 +345,10 @@ def forward( relpos_dist = seq.view(-1, 1) - seq.view(1, -1) attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + self.config.max_pos_emb - for idx, layer in enumerate(self.layers, start=1): + for layer_idx, layer in enumerate(self.layers, start=1): hidden_states = layer(hidden_states, attention_dists=attention_dists) - if idx == self.config.self_conditioning_layer: + if layer_idx == self.config.self_conditioning_layer: mid_logits = self.out(self.dropout(hidden_states)) mid_probs = torch.softmax(mid_logits.float(), dim=-1) blank_probs = mid_probs[:, :, 0] @@ -357,43 +360,38 @@ def forward( hidden_states = self.dropout(hidden_states) logits = None - pool_window = self.config.bpe_pooling_window + loss = None if self.out_bpe is not None and blank_probs is not None: + pool_window = self.config.bpe_pooling_window importance = 1.0 - blank_probs pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( hidden_states.dtype ) - T = attention_mask.shape[1] - pad_len = (pool_window - T % pool_window) % pool_window - pooled_mask = F.pad(attention_mask, (0, pad_len), value=False)[:, ::pool_window] - logits = self.out_bpe(pooled[pooled_mask]) - - loss = None - if labels is not None and logits is not None: encoder_lengths = attention_mask.sum(dim=1) - bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() - T_max = max(bpe_lengths) - B = len(bpe_lengths) - logits_padded = logits.new_zeros(B, T_max, logits.shape[-1]) - offset = 0 - for i, length in enumerate(bpe_lengths): - logits_padded[i, :length] = logits[offset : offset + length] - offset += length - - log_probs = torch.log_softmax(logits_padded.float(), dim=-1) - bpe_x_sizes = torch.tensor(bpe_lengths, device=logits.device) - loss = ( - F.ctc_loss( - log_probs.transpose(0, 1), - labels + 1, - bpe_x_sizes, - label_lengths, - blank=0, - reduction="sum", - zero_infinity=True, + lengths = -(encoder_lengths // -pool_window) + lengths_list = lengths.tolist() + logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) + + if labels is not None: + logits_padded = logits.new_zeros(len(lengths_list), max(lengths_list), logits.shape[-1]) + offset = 0 + for i, length in enumerate(lengths_list): + logits_padded[i, :length] = logits[offset : offset + length] + offset += length + + log_probs = torch.log_softmax(logits_padded.float(), dim=-1) + loss = ( + F.ctc_loss( + log_probs.transpose(0, 1), + labels + 1, + lengths, + label_lengths, + blank=0, + reduction="sum", + zero_infinity=True, + ) + / lengths.sum() ) - / bpe_x_sizes.sum() - ) return GraniteSpeechNarEncoderOutput( loss=loss, @@ -512,7 +510,7 @@ def forward( Returns: [`GraniteSpeechNarOutput`] """ - encoder_labels = labels if (labels is not None and self.config.encoder_ctc_loss_lambda > 0.0) else None + encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None enc_out = self.encoder( input_features=input_features, attention_mask=attention_mask, @@ -551,33 +549,27 @@ def forward( ctc_token_ids, audio_embeds, audio_lengths ) - llm_out = self.language_model.model( + llm_out = self.language_model( inputs_embeds=flat_embeds, position_ids=flat_position_ids, ) - llm_hidden = llm_out.last_hidden_state.squeeze(0) + all_logits = llm_out.logits.squeeze(0) segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] - text_hidden = torch.cat(list(llm_hidden.split(segment_lengths)[1::2])) - - logits = self.language_model.lm_head(text_hidden) - logits = logits / self.language_model.config.logits_scaling - logits_per_sample = list(logits.split(text_lengths)) + text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) + logits_per_sample = list(text_logits.split(text_lengths)) loss = None if labels is not None: - log_probs = torch.log_softmax(logits.float(), dim=-1) + log_probs = torch.log_softmax(text_logits.float(), dim=-1) - T_max = max(text_lengths) - B = len(text_lengths) - V = log_probs.shape[-1] - log_probs_padded = log_probs.new_zeros(B, T_max, V) + log_probs_padded = log_probs.new_zeros(len(text_lengths), max(text_lengths), log_probs.shape[-1]) offset = 0 for i, tl in enumerate(text_lengths): log_probs_padded[i, :tl] = log_probs[offset : offset + tl] offset += tl - input_lengths = torch.tensor(text_lengths, device=logits.device) + input_lengths = torch.tensor(text_lengths, device=text_logits.device) loss = ( F.ctc_loss( @@ -595,7 +587,7 @@ def forward( if self.config.ce_loss_lambda > 0.0: ce_targets = torch.cat([self._add_insertion_slots(ids) for ids in ctc_token_ids]) ce_loss = F.cross_entropy( - logits, + text_logits, ce_targets.long(), reduction="mean", ignore_index=-100, From 2a95b12a7de7a4c71a082b2edb8a186317468164 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Thu, 21 May 2026 09:53:24 +0000 Subject: [PATCH 14/39] change encoder bpe prediction head to match the editor. (original vocab, same blank token) --- .../configuration_granite_speech_nar.py | 8 +++++++- .../feature_extraction_granite_speech_nar.py | 6 +++--- .../granite_speech_nar/modeling_granite_speech_nar.py | 9 +++++---- .../granite_speech_nar/modular_granite_speech_nar.py | 9 +++++---- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py index 3dc3da7facc1..f6667c40532d 100644 --- a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py @@ -42,9 +42,11 @@ class GraniteSpeechNarEncoderConfig(PreTrainedConfig): Layer index at which self-conditioning (mid-layer CTC feedback) is applied. Defaults to `num_layers // 2`. bpe_output_dim (`int`, *optional*): - Vocabulary size for the BPE CTC head (shifted by +1 for blank). If None, BPE head is disabled. + Vocabulary size for the BPE CTC head (same as LLM vocab, blank reuses eos_token_id). If None, BPE head is disabled. bpe_pooling_window (`int`, *optional*, defaults to 4): Window size for posterior-weighted pooling before the BPE CTC head. + blank_token_id (`int`, *optional*): + Token ID used as the CTC blank symbol. Defaults to the language model's `eos_token_id` if not set. Example: @@ -80,6 +82,7 @@ class GraniteSpeechNarEncoderConfig(PreTrainedConfig): self_conditioning_layer: int | None = None bpe_output_dim: int | None = None bpe_pooling_window: int = 4 + blank_token_id: int | None = None initializer_range: float = 0.02 def __post_init__(self, **kwargs): @@ -212,6 +215,9 @@ def __post_init__(self, **kwargs): if self.blank_token_id is None: self.blank_token_id = self.text_config.eos_token_id + # Propagate blank_token_id to encoder config + self.encoder_config.blank_token_id = self.blank_token_id + super().__post_init__(**kwargs) diff --git a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py index 0f4aae05ba10..328ba2578112 100644 --- a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py @@ -51,7 +51,7 @@ def __init__( self.win_length = win_length self.hop_length = hop_length self.n_mels = n_mels - self.mel_transform = torchaudio.transforms.MelSpectrogram( + self.mel_filters = torchaudio.transforms.MelSpectrogram( sample_rate=sampling_rate, n_fft=n_fft, win_length=win_length, @@ -61,10 +61,10 @@ def __init__( def _extract_features(self, raw_audio: "torch.Tensor") -> "torch.Tensor": with torch.no_grad(): - mel_transform = self.mel_transform.to(raw_audio.device) + mel_filters = self.mel_filters.to(raw_audio.device) B, T = raw_audio.shape l = 2 * (T // (2 * self.hop_length)) - mel = mel_transform(raw_audio.float())[..., :l] + mel = mel_filters(raw_audio.float())[..., :l] logmel = mel.transpose(-1, -2).clamp_min_(1e-10).log10_() mx = logmel.amax(dim=(-2, -1), keepdim=True) logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index 4a9241ccf621..ee566c4dcfd4 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -846,10 +846,10 @@ def forward( loss = ( F.ctc_loss( log_probs.transpose(0, 1), - labels + 1, + labels, lengths, label_lengths, - blank=0, + blank=self.config.blank_token_id, reduction="sum", zero_infinity=True, ) @@ -971,10 +971,11 @@ def _ctc_collapse_decode( bpe_logits_flat: torch.Tensor, bpe_lengths: list[int], ) -> list[torch.Tensor]: - """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank -> shift.""" + """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank.""" + blank_id = self.config.blank_token_id preds_flat = bpe_logits_flat.argmax(dim=-1) per_sample = preds_flat.split(bpe_lengths) - return [(collapsed := torch.unique_consecutive(seq))[collapsed != 0] - 1 for seq in per_sample] + return [(collapsed := torch.unique_consecutive(seq))[collapsed != blank_id] for seq in per_sample] def _add_insertion_slots(self, token_ids: torch.Tensor) -> torch.Tensor: """Insert blank tokens between each CTC token as editing slots for the LLM.""" diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 76b4bf1aa95f..3720464583a3 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -383,10 +383,10 @@ def forward( loss = ( F.ctc_loss( log_probs.transpose(0, 1), - labels + 1, + labels, lengths, label_lengths, - blank=0, + blank=self.config.blank_token_id, reduction="sum", zero_infinity=True, ) @@ -441,10 +441,11 @@ def _ctc_collapse_decode( bpe_logits_flat: torch.Tensor, bpe_lengths: list[int], ) -> list[torch.Tensor]: - """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank -> shift.""" + """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank.""" + blank_id = self.config.blank_token_id preds_flat = bpe_logits_flat.argmax(dim=-1) per_sample = preds_flat.split(bpe_lengths) - return [(collapsed := torch.unique_consecutive(seq))[collapsed != 0] - 1 for seq in per_sample] + return [(collapsed := torch.unique_consecutive(seq))[collapsed != blank_id] for seq in per_sample] def _add_insertion_slots(self, token_ids: torch.Tensor) -> torch.Tensor: """Insert blank tokens between each CTC token as editing slots for the LLM.""" From 16efa3400abf1e00c451892df26e4bc768c7d437 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Tue, 26 May 2026 14:18:48 +0000 Subject: [PATCH 15/39] add integration tests for granite speech nar --- .../test_modeling_granite_speech_nar.py | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 740129a39f7d..28dea3fa6df0 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -14,11 +14,15 @@ """Tests for GraniteSpeechNar model.""" import math +import unittest +import pytest import torch from transformers import ( AutoConfig, + AutoModel, + AutoProcessor, GraniteConfig, GraniteSpeechNarConfig, ) @@ -32,6 +36,11 @@ GraniteSpeechNarOutput, GraniteSpeechNarProjector, ) +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import is_datasets_available + +if is_datasets_available(): + from datasets import load_dataset def _make_small_config(): @@ -364,3 +373,75 @@ def test_is_causal_false_on_layers(self): model = GraniteSpeechNarForASR(config) for i, layer in enumerate(model.language_model.model.layers): assert layer.self_attn.is_causal is False, f"Layer {i} is_causal is not False" + + +# === Integration tests === + + +@require_torch +class GraniteSpeechNarIntegrationTest(unittest.TestCase): + model_path = "ibm-granite/granite-speech-4.1-2b-nar" + _dataset = None + + @classmethod + def _load_dataset(cls): + if cls._dataset is None: + cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + def _load_datasamples(self, num_samples): + self._load_dataset() + samples = self._dataset.sort("id")[:num_samples]["audio"] + return [torch.tensor(x["array"], dtype=torch.float32) for x in samples] + + @slow + def test_single_sample_transcription(self): + model = AutoModel.from_pretrained( + self.model_path, attn_implementation="flash_attention_2", device_map=torch_device, dtype=torch.bfloat16 + ).eval() + processor = AutoProcessor.from_pretrained(self.model_path) + + waveforms = self._load_datasamples(1) + inputs = processor(waveforms, device=torch_device) + output = model.transcribe(**inputs) + transcriptions = processor.batch_decode(output.preds) + + expected = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" + self.assertEqual(transcriptions[0], expected) + + @slow + def test_batch_transcription(self): + model = AutoModel.from_pretrained( + self.model_path, attn_implementation="flash_attention_2", device_map=torch_device, dtype=torch.bfloat16 + ).eval() + processor = AutoProcessor.from_pretrained(self.model_path) + + waveforms = self._load_datasamples(2) + inputs = processor(waveforms, device=torch_device) + output = model.transcribe(**inputs) + transcriptions = processor.batch_decode(output.preds) + + expected = [ + "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", + "nor is mister quilter's manner less interesting than his matter", + ] + self.assertEqual(len(transcriptions), 2) + self.assertEqual(transcriptions, expected) + + @slow + @pytest.mark.skipif(not is_datasets_available(), reason="datasets not installed") + def test_processor_output_shapes(self): + processor = AutoProcessor.from_pretrained(self.model_path) + + waveforms = self._load_datasamples(2) + inputs = processor(waveforms, device="cpu") + + self.assertEqual(inputs["input_features"].ndim, 3) + self.assertEqual(inputs["input_features"].shape[0], 2) + self.assertEqual(inputs["input_features"].shape[2], 160) + + self.assertEqual(inputs["attention_mask"].shape, inputs["input_features"].shape[:2]) + + # Shorter sample should have False values at end + mask_sums = inputs["attention_mask"].sum(dim=1) + self.assertEqual(mask_sums[0].item(), inputs["input_features"].shape[1]) + self.assertLess(mask_sums[1].item(), mask_sums[0].item()) From 4707df0de1c2bc62cd892e02f5e447655aaa19da Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Tue, 26 May 2026 14:21:54 +0000 Subject: [PATCH 16/39] ruff newline --- .../granite_speech_nar/test_modeling_granite_speech_nar.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 28dea3fa6df0..4add1821451a 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -39,6 +39,7 @@ from transformers.testing_utils import require_torch, slow, torch_device from transformers.utils import is_datasets_available + if is_datasets_available(): from datasets import load_dataset From 18abe6e93f74c2c9c6d7827b6444cedd126f49f9 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Tue, 26 May 2026 14:28:14 +0000 Subject: [PATCH 17/39] minor fixes after pulling main --- src/transformers/models/auto/modeling_auto.py | 2 +- .../models/granite_speech_nar/modeling_granite_speech_nar.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index eb00782df12d..0fbcbc279c3f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -217,8 +217,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("granite", "GraniteModel"), ("granite4_vision", "Granite4VisionModel"), ("granite_speech", "GraniteSpeechModel"), - ("granite_speech_plus", "GraniteSpeechPlusModel"), ("granite_speech_nar", "GraniteSpeechNarForASR"), + ("granite_speech_plus", "GraniteSpeechPlusModel"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), ("granitemoeshared", "GraniteMoeSharedModel"), diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index ee566c4dcfd4..6aad05c28822 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -874,8 +874,10 @@ class GraniteSpeechNarLM(GraniteSpeechNarPreTrainedModel, GenerationMixin): """GraniteForCausalLM with a bidirectional (non-causal) backbone.""" _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_gather_output"} + _tp_plan = {"lm_head": "colwise_allgather"} + _sp_plan = {"lm_head": "colwise_loss_parallel"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _fsdp_plan = {"lm_head": "keep_full_weight"} def __init__(self, config): super().__init__(config) From 8da331d8d46d74c96f1d60f820bdb8afe317b370 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Tue, 26 May 2026 14:45:07 +0000 Subject: [PATCH 18/39] minor --- docs/source/en/model_doc/granite_speech_nar.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/granite_speech_nar.md b/docs/source/en/model_doc/granite_speech_nar.md index d2d43c3d4095..007342f3caac 100644 --- a/docs/source/en/model_doc/granite_speech_nar.md +++ b/docs/source/en/model_doc/granite_speech_nar.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2026-03-09 and added to Hugging Face Transformers on 2026-05-18.* +*This model was released on 2026-03-09 and added to Hugging Face Transformers on 2026-05-26.* # GraniteSpeechNar From e0a51e5db8826e872c0f058318b6a3aa3542025a Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 27 May 2026 07:09:46 +0000 Subject: [PATCH 19/39] - renames (GraniteSpeechNarForCTC, generate) - GraniteSpeechNarForCTC has a base multimodal model + lm_head - fix test with older encoder bpe head. - reuse conversion mapping by granite speech --- src/transformers/conversion_mapping.py | 2 + src/transformers/models/auto/modeling_auto.py | 4 +- .../modeling_granite_speech_nar.py | 133 +++++------------- .../modular_granite_speech_nar.py | 67 +++++---- .../test_modeling_granite_speech_nar.py | 38 ++--- 5 files changed, 93 insertions(+), 151 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 4b4a464caf05..7b513b34215e 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -92,6 +92,7 @@ "audioflamingo3": "qwen2_audio", "glmasr": "qwen2_audio", "musicflamingo": "qwen2_audio", + "granite_speech_nar": "granite_speech", "granite_speech_plus": "granite_speech", "gemma3n_text": "qwen3_5_text", "qwen3_5_moe_text": "qwen3_5_text", @@ -116,6 +117,7 @@ "AudioFlamingo3Model": "Qwen2AudioModel", "GlmAsrModel": "Qwen2AudioModel", "MusicFlamingoModel": "Qwen2AudioModel", + "GraniteSpeechNarModel": "GraniteSpeechModel", "GraniteSpeechPlusModel": "GraniteSpeechModel", "MaskFormerDetrDecoder": "DetrModel", "Qwen2_5_VLForConditionalGeneration": "Qwen2VLForConditionalGeneration", diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0fbcbc279c3f..ee47aac3103e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -217,7 +217,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("granite", "GraniteModel"), ("granite4_vision", "Granite4VisionModel"), ("granite_speech", "GraniteSpeechModel"), - ("granite_speech_nar", "GraniteSpeechNarForASR"), + ("granite_speech_nar", "GraniteSpeechNarForCTC"), ("granite_speech_plus", "GraniteSpeechPlusModel"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), @@ -1673,7 +1673,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): [ # Model for Connectionist temporal classification (CTC) mapping ("data2vec-audio", "Data2VecAudioForCTC"), - ("granite_speech_nar", "GraniteSpeechNarForASR"), + ("granite_speech_nar", "GraniteSpeechNarForCTC"), ("hubert", "HubertForCTC"), ("lasr_ctc", "LasrForCTC"), ("parakeet_ctc", "ParakeetForCTC"), diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index 6aad05c28822..d7dd15ae6502 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -29,15 +29,14 @@ from ...activations import ACT2FN from ...cache_utils import Cache -from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_bidirectional_mask, find_packed_sequence_indices, packed_sequence_mask_function from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple +from ...utils import ModelOutput, TransformersKwargs, auto_docstring from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from .configuration_granite_speech_nar import ( @@ -363,7 +362,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @auto_docstring class GraniteSpeechNarPreTrainedModel(PreTrainedModel): config_class = GraniteSpeechNarConfig - base_model_prefix = "encoder" + base_model_prefix = "model" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_flash_attn_2 = True @@ -681,7 +680,7 @@ def forward(self, x, position_ids): @auto_docstring -class GraniteSpeechNarModel(GraniteSpeechNarPreTrainedModel): +class GraniteSpeechNarLanguageModel(GraniteSpeechNarPreTrainedModel): """GraniteModel with bidirectional (non-causal) attention. Uses GraniteSpeechNarDecoderLayer which sets is_causal=False, @@ -866,105 +865,40 @@ def forward( @auto_docstring( custom_intro=""" - The bidirectional language model component of GraniteSpeechNar, used internally - to refine CTC predictions in a single non-autoregressive pass. + The GraniteSpeechNar base model consisting of a conformer encoder, QFormer projector, + and a bidirectional Granite language model backbone. """ ) -class GraniteSpeechNarLM(GraniteSpeechNarPreTrainedModel, GenerationMixin): - """GraniteForCausalLM with a bidirectional (non-causal) backbone.""" - - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_allgather"} - _sp_plan = {"lm_head": "colwise_loss_parallel"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - _fsdp_plan = {"lm_head": "keep_full_weight"} - - def __init__(self, config): +class GraniteSpeechNarModel(GraniteSpeechNarPreTrainedModel): + def __init__(self, config: GraniteSpeechNarConfig): super().__init__(config) - self.model = GraniteSpeechNarModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, GraniteSpeechNarLM - - >>> model = GraniteSpeechNarLM.from_pretrained("meta-granite_speech_nar/GraniteSpeechNar-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite_speech_nar/GraniteSpeechNar-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - hidden_states = outputs.last_hidden_state - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - logits = logits / self.config.logits_scaling # main diff with Llama + self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechNarProjector(config.projector_config) - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + text_config = config.text_config + if hasattr(config, "_attn_implementation"): + text_config._attn_implementation = config._attn_implementation + self.language_model = GraniteSpeechNarLanguageModel._from_config(text_config) - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + self.post_init() @auto_docstring( custom_intro=""" - The GraniteSpeechNar model for non-autoregressive automatic speech recognition. + The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. """ ) -class GraniteSpeechNarForASR(GraniteSpeechNarPreTrainedModel): +class GraniteSpeechNarForCTC(GraniteSpeechNarPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + def __init__(self, config: GraniteSpeechNarConfig): super().__init__(config) - self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) - self.projector = GraniteSpeechNarProjector(config.projector_config) - - text_config = config.text_config - if hasattr(config, "_attn_implementation"): - text_config._attn_implementation = config._attn_implementation - self.language_model = GraniteSpeechNarLM._from_config(text_config) + self.model = GraniteSpeechNarModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() @@ -997,7 +931,7 @@ def _build_flat_inputs( audio_lengths: list[int], ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: """Build flat (pad-free) LLM input: [audio_0, text_0, audio_1, text_1, ...]""" - embed_tokens = self.language_model.model.embed_tokens + embed_tokens = self.model.language_model.embed_tokens embeds_list = [] position_ids_list = [] @@ -1044,7 +978,7 @@ def forward( [`GraniteSpeechNarOutput`] """ encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None - enc_out = self.encoder( + enc_out = self.model.encoder( input_features=input_features, attention_mask=attention_mask, output_hidden_states=True, @@ -1057,7 +991,7 @@ def forward( encoder_lengths = attention_mask.sum(dim=1) - pool_window = self.encoder.config.bpe_pooling_window + pool_window = self.model.encoder.config.bpe_pooling_window bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) @@ -1069,24 +1003,25 @@ def forward( encoder_logits = enc_out.logits if output_encoder_logits else None del enc_out - audio_embeds = self.projector(multilayer_features) + audio_embeds = self.model.projector(multilayer_features) del multilayer_features if self.config.scale_projected_embeddings: embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) audio_embeds = audio_embeds / embedding_multiplier - audio_embeds = audio_embeds.to(self.language_model.model.embed_tokens.weight.dtype) + audio_embeds = audio_embeds.to(self.model.language_model.embed_tokens.weight.dtype) - audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() + audio_lengths = (encoder_lengths // self.model.projector.config.downsample_rate).cpu().tolist() flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( ctc_token_ids, audio_embeds, audio_lengths ) - llm_out = self.language_model( + llm_out = self.model.language_model( inputs_embeds=flat_embeds, position_ids=flat_position_ids, ) - all_logits = llm_out.logits.squeeze(0) + hidden_states = llm_out.last_hidden_state.squeeze(0) + all_logits = self.lm_head(hidden_states) segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) @@ -1138,7 +1073,7 @@ def forward( ) @torch.inference_mode() - def transcribe( + def generate( self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, @@ -1171,9 +1106,9 @@ def transcribe( __all__ = [ - "GraniteSpeechNarModel", "GraniteSpeechNarCTCEncoder", - "GraniteSpeechNarForASR", - "GraniteSpeechNarLM", + "GraniteSpeechNarForCTC", + "GraniteSpeechNarLanguageModel", + "GraniteSpeechNarModel", "GraniteSpeechNarPreTrainedModel", ] diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 3720464583a3..3c2efe0843c5 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -28,7 +28,7 @@ from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging -from ..granite.modeling_granite import GraniteAttention, GraniteDecoderLayer, GraniteForCausalLM, GraniteModel +from ..granite.modeling_granite import GraniteAttention, GraniteDecoderLayer, GraniteModel from ..granite_speech.modeling_granite_speech import GraniteSpeechConformerBlock from .configuration_granite_speech_nar import ( GraniteSpeechNarConfig, @@ -219,7 +219,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @auto_docstring class GraniteSpeechNarPreTrainedModel(PreTrainedModel): config_class = GraniteSpeechNarConfig - base_model_prefix = "encoder" + base_model_prefix = "model" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_flash_attn_2 = True @@ -246,7 +246,7 @@ def __init__(self, config, layer_idx: int): self.self_attn = GraniteSpeechNarAttention(config=config, layer_idx=layer_idx) -class GraniteSpeechNarModel(GraniteModel): +class GraniteSpeechNarLanguageModel(GraniteModel): """GraniteModel with bidirectional (non-causal) attention. Uses GraniteSpeechNarDecoderLayer which sets is_causal=False, @@ -403,36 +403,40 @@ def forward( @auto_docstring( custom_intro=""" - The bidirectional language model component of GraniteSpeechNar, used internally - to refine CTC predictions in a single non-autoregressive pass. + The GraniteSpeechNar base model consisting of a conformer encoder, QFormer projector, + and a bidirectional Granite language model backbone. """ ) -class GraniteSpeechNarLM(GraniteForCausalLM): - """GraniteForCausalLM with a bidirectional (non-causal) backbone.""" - - def __init__(self, config): +class GraniteSpeechNarModel(GraniteSpeechNarPreTrainedModel): + def __init__(self, config: GraniteSpeechNarConfig): super().__init__(config) - self.model = GraniteSpeechNarModel(config) + + self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechNarProjector(config.projector_config) + + text_config = config.text_config + if hasattr(config, "_attn_implementation"): + text_config._attn_implementation = config._attn_implementation + self.language_model = GraniteSpeechNarLanguageModel._from_config(text_config) + + self.post_init() @auto_docstring( custom_intro=""" - The GraniteSpeechNar model for non-autoregressive automatic speech recognition. + The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. """ ) -class GraniteSpeechNarForASR(GraniteSpeechNarPreTrainedModel): +class GraniteSpeechNarForCTC(GraniteSpeechNarPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + def __init__(self, config: GraniteSpeechNarConfig): super().__init__(config) - self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) - self.projector = GraniteSpeechNarProjector(config.projector_config) - - text_config = config.text_config - if hasattr(config, "_attn_implementation"): - text_config._attn_implementation = config._attn_implementation - self.language_model = GraniteSpeechNarLM._from_config(text_config) + self.model = GraniteSpeechNarModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() @@ -465,7 +469,7 @@ def _build_flat_inputs( audio_lengths: list[int], ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: """Build flat (pad-free) LLM input: [audio_0, text_0, audio_1, text_1, ...]""" - embed_tokens = self.language_model.model.embed_tokens + embed_tokens = self.model.language_model.embed_tokens embeds_list = [] position_ids_list = [] @@ -512,7 +516,7 @@ def forward( [`GraniteSpeechNarOutput`] """ encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None - enc_out = self.encoder( + enc_out = self.model.encoder( input_features=input_features, attention_mask=attention_mask, output_hidden_states=True, @@ -525,7 +529,7 @@ def forward( encoder_lengths = attention_mask.sum(dim=1) - pool_window = self.encoder.config.bpe_pooling_window + pool_window = self.model.encoder.config.bpe_pooling_window bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) @@ -537,24 +541,25 @@ def forward( encoder_logits = enc_out.logits if output_encoder_logits else None del enc_out - audio_embeds = self.projector(multilayer_features) + audio_embeds = self.model.projector(multilayer_features) del multilayer_features if self.config.scale_projected_embeddings: embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) audio_embeds = audio_embeds / embedding_multiplier - audio_embeds = audio_embeds.to(self.language_model.model.embed_tokens.weight.dtype) + audio_embeds = audio_embeds.to(self.model.language_model.embed_tokens.weight.dtype) - audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() + audio_lengths = (encoder_lengths // self.model.projector.config.downsample_rate).cpu().tolist() flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( ctc_token_ids, audio_embeds, audio_lengths ) - llm_out = self.language_model( + llm_out = self.model.language_model( inputs_embeds=flat_embeds, position_ids=flat_position_ids, ) - all_logits = llm_out.logits.squeeze(0) + hidden_states = llm_out.last_hidden_state.squeeze(0) + all_logits = self.lm_head(hidden_states) segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) @@ -606,7 +611,7 @@ def forward( ) @torch.inference_mode() - def transcribe( + def generate( self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, @@ -639,9 +644,9 @@ def transcribe( __all__ = [ - "GraniteSpeechNarModel", "GraniteSpeechNarCTCEncoder", - "GraniteSpeechNarForASR", - "GraniteSpeechNarLM", + "GraniteSpeechNarForCTC", + "GraniteSpeechNarLanguageModel", + "GraniteSpeechNarModel", "GraniteSpeechNarPreTrainedModel", ] diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 4add1821451a..28a6b466c9e4 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -32,7 +32,7 @@ ) from transformers.models.granite_speech_nar.modeling_granite_speech_nar import ( GraniteSpeechNarCTCEncoder, - GraniteSpeechNarForASR, + GraniteSpeechNarForCTC, GraniteSpeechNarOutput, GraniteSpeechNarProjector, ) @@ -54,7 +54,7 @@ def _make_small_config(): output_dim=10, context_size=50, self_conditioning_layer=2, - bpe_output_dim=52, + bpe_output_dim=51, bpe_pooling_window=4, ) projector_config = GraniteSpeechNarProjectorConfig( @@ -121,7 +121,7 @@ def test_config_serialization_roundtrip(self): d = config.to_dict() restored = GraniteSpeechNarConfig(**d) assert restored.encoder_config.num_layers == 4 - assert restored.encoder_config.bpe_output_dim == 52 + assert restored.encoder_config.bpe_output_dim == 51 assert restored.projector_config.num_layers == 1 assert restored.encoder_layer_indices == [1, 2, 3, -1] @@ -226,10 +226,10 @@ def test_handles_non_divisible_length(self): # === Full model tests === -class TestGraniteSpeechNarForASR: +class TestGraniteSpeechNarForCTC: def test_forward(self): config = _make_small_config() - model = GraniteSpeechNarForASR(config).eval() + model = GraniteSpeechNarForCTC(config).eval() B, T = 2, 100 features = torch.randn(B, T, 160) @@ -247,12 +247,12 @@ def test_forward(self): assert logits.ndim == 2 assert logits.shape[1] == 51 - def test_transcribe(self): + def test_generate(self): config = _make_small_config() - model = GraniteSpeechNarForASR(config).eval() + model = GraniteSpeechNarForCTC(config).eval() features = torch.randn(1, 60, 160) - output = model.transcribe(input_features=features) + output = model.generate(input_features=features) assert output.preds is not None assert len(output.preds) == 1 @@ -260,7 +260,7 @@ def test_transcribe(self): def test_loss(self): config = _make_small_config() - model = GraniteSpeechNarForASR(config).train() + model = GraniteSpeechNarForCTC(config).train() B, T = 2, 100 features = torch.randn(B, T, 160) @@ -284,7 +284,7 @@ def test_loss(self): def test_loss_with_ce(self): config = _make_small_config() config.ce_loss_lambda = 0.5 - model = GraniteSpeechNarForASR(config).train() + model = GraniteSpeechNarForCTC(config).train() features = torch.randn(1, 60, 160) labels = torch.randint(0, 51, (1, 4)) @@ -303,7 +303,7 @@ def test_loss_with_ce(self): def test_loss_with_encoder_ctc(self): config = _make_small_config() config.encoder_ctc_loss_lambda = 0.3 - model = GraniteSpeechNarForASR(config).train() + model = GraniteSpeechNarForCTC(config).train() features = torch.randn(1, 60, 160) labels = torch.randint(0, 51, (1, 4)) @@ -321,7 +321,7 @@ def test_loss_with_encoder_ctc(self): def test_no_loss_without_labels(self): config = _make_small_config() - model = GraniteSpeechNarForASR(config).eval() + model = GraniteSpeechNarForCTC(config).eval() features = torch.randn(1, 60, 160) with torch.no_grad(): @@ -331,7 +331,7 @@ def test_no_loss_without_labels(self): def test_output_encoder_logits_flag(self): config = _make_small_config() - model = GraniteSpeechNarForASR(config).eval() + model = GraniteSpeechNarForCTC(config).eval() features = torch.randn(1, 60, 160) with torch.no_grad(): @@ -355,8 +355,8 @@ class TestBidirectionalAttention: def test_last_token_affects_first(self): """Changing the last token must affect the first (bidirectional).""" config = _make_small_config() - model = GraniteSpeechNarForASR(config).eval() - granite_model = model.language_model.model + model = GraniteSpeechNarForCTC(config).eval() + granite_model = model.model.language_model embeds_a = torch.randn(1, 10, 128) embeds_b = embeds_a.clone() @@ -371,8 +371,8 @@ def test_last_token_affects_first(self): def test_is_causal_false_on_layers(self): config = _make_small_config() - model = GraniteSpeechNarForASR(config) - for i, layer in enumerate(model.language_model.model.layers): + model = GraniteSpeechNarForCTC(config) + for i, layer in enumerate(model.model.language_model.layers): assert layer.self_attn.is_causal is False, f"Layer {i} is_causal is not False" @@ -403,7 +403,7 @@ def test_single_sample_transcription(self): waveforms = self._load_datasamples(1) inputs = processor(waveforms, device=torch_device) - output = model.transcribe(**inputs) + output = model.generate(**inputs) transcriptions = processor.batch_decode(output.preds) expected = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" @@ -418,7 +418,7 @@ def test_batch_transcription(self): waveforms = self._load_datasamples(2) inputs = processor(waveforms, device=torch_device) - output = model.transcribe(**inputs) + output = model.generate(**inputs) transcriptions = processor.batch_decode(output.preds) expected = [ From 3fc93c21b7716487f56f7c9935e8c65298cfe2f3 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 27 May 2026 07:57:42 +0000 Subject: [PATCH 20/39] shared helper methods for ctc loss and ctc decoding, shared between the encoder/editor. minor docs fix --- .../source/en/model_doc/granite_speech_nar.md | 10 +-- .../modeling_granite_speech_nar.py | 90 +++++++++---------- .../modular_granite_speech_nar.py | 90 +++++++++---------- 3 files changed, 89 insertions(+), 101 deletions(-) diff --git a/docs/source/en/model_doc/granite_speech_nar.md b/docs/source/en/model_doc/granite_speech_nar.md index 007342f3caac..a6cf43a57551 100644 --- a/docs/source/en/model_doc/granite_speech_nar.md +++ b/docs/source/en/model_doc/granite_speech_nar.md @@ -60,13 +60,13 @@ This model was contributed by [Avihu Dekel](https://huggingface.co/Avihu). [[autodoc]] GraniteSpeechNarModel - forward -## GraniteSpeechNarLM +## GraniteSpeechNarLanguageModel -[[autodoc]] GraniteSpeechNarLM +[[autodoc]] GraniteSpeechNarLanguageModel - forward -## GraniteSpeechNarForASR +## GraniteSpeechNarForCTC -[[autodoc]] GraniteSpeechNarForASR +[[autodoc]] GraniteSpeechNarForCTC - forward - - transcribe + - generate diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index d7dd15ae6502..a06afbfcf704 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -769,6 +769,36 @@ def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, win return pooled +def _ctc_loss_from_flat_logits( + flat_logits: torch.Tensor, + lengths: list[int], + labels: torch.Tensor, + label_lengths: torch.Tensor, + blank_id: int, +) -> torch.Tensor: + """Compute mean CTC loss from flat (concatenated) logits with per-sample lengths.""" + log_probs = torch.log_softmax(flat_logits.float(), dim=-1) + log_probs_padded = log_probs.new_zeros(len(lengths), max(lengths), log_probs.shape[-1]) + offset = 0 + for i, length in enumerate(lengths): + log_probs_padded[i, :length] = log_probs[offset : offset + length] + offset += length + + lengths_t = torch.tensor(lengths, device=flat_logits.device) + return ( + F.ctc_loss( + log_probs_padded.transpose(0, 1), + labels, + lengths_t, + label_lengths, + blank=blank_id, + reduction="sum", + zero_infinity=True, + ) + / lengths_t.sum() + ) + + class GraniteSpeechNarCTCEncoder(GraniteSpeechNarPreTrainedModel): """Conformer encoder with BPE CTC head and multi-layer output.""" @@ -835,24 +865,8 @@ def forward( logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) if labels is not None: - logits_padded = logits.new_zeros(len(lengths_list), max(lengths_list), logits.shape[-1]) - offset = 0 - for i, length in enumerate(lengths_list): - logits_padded[i, :length] = logits[offset : offset + length] - offset += length - - log_probs = torch.log_softmax(logits_padded.float(), dim=-1) - loss = ( - F.ctc_loss( - log_probs.transpose(0, 1), - labels, - lengths, - label_lengths, - blank=self.config.blank_token_id, - reduction="sum", - zero_infinity=True, - ) - / lengths.sum() + loss = _ctc_loss_from_flat_logits( + logits, lengths_list, labels, label_lengths, self.config.blank_token_id ) return GraniteSpeechNarEncoderOutput( @@ -884,6 +898,12 @@ def __init__(self, config: GraniteSpeechNarConfig): self.post_init() +def _ctc_greedy_decode(logits: torch.Tensor, blank_id: int) -> torch.Tensor: + """CTC greedy decode a single sequence: argmax -> unique_consecutive -> remove blank.""" + pred = torch.unique_consecutive(logits.argmax(-1)) + return pred[pred != blank_id] + + @auto_docstring( custom_intro=""" The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. @@ -909,9 +929,8 @@ def _ctc_collapse_decode( ) -> list[torch.Tensor]: """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank.""" blank_id = self.config.blank_token_id - preds_flat = bpe_logits_flat.argmax(dim=-1) - per_sample = preds_flat.split(bpe_lengths) - return [(collapsed := torch.unique_consecutive(seq))[collapsed != blank_id] for seq in per_sample] + per_sample = bpe_logits_flat.split(bpe_lengths) + return [_ctc_greedy_decode(seq_logits, blank_id) for seq_logits in per_sample] def _add_insertion_slots(self, token_ids: torch.Tensor) -> torch.Tensor: """Insert blank tokens between each CTC token as editing slots for the LLM.""" @@ -1029,27 +1048,8 @@ def forward( loss = None if labels is not None: - log_probs = torch.log_softmax(text_logits.float(), dim=-1) - - log_probs_padded = log_probs.new_zeros(len(text_lengths), max(text_lengths), log_probs.shape[-1]) - offset = 0 - for i, tl in enumerate(text_lengths): - log_probs_padded[i, :tl] = log_probs[offset : offset + tl] - offset += tl - - input_lengths = torch.tensor(text_lengths, device=text_logits.device) - - loss = ( - F.ctc_loss( - log_probs_padded.transpose(0, 1), - labels, - input_lengths, - label_lengths, - blank=self.config.blank_token_id, - reduction="sum", - zero_infinity=True, - ) - / input_lengths.sum() + loss = _ctc_loss_from_flat_logits( + text_logits, text_lengths, labels, label_lengths, self.config.blank_token_id ) if self.config.ce_loss_lambda > 0.0: @@ -1091,11 +1091,7 @@ def generate( ) blank_id = self.config.blank_token_id - preds = [] - for sample_logits in output.logits: - pred = torch.unique_consecutive(sample_logits.argmax(-1)) - pred = pred[pred != blank_id] - preds.append(pred) + preds = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in output.logits] return GraniteSpeechNarOutput( preds=preds, diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 3c2efe0843c5..8d820bff3188 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -87,6 +87,42 @@ def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, win return pooled +def _ctc_greedy_decode(logits: torch.Tensor, blank_id: int) -> torch.Tensor: + """CTC greedy decode a single sequence: argmax -> unique_consecutive -> remove blank.""" + pred = torch.unique_consecutive(logits.argmax(-1)) + return pred[pred != blank_id] + + +def _ctc_loss_from_flat_logits( + flat_logits: torch.Tensor, + lengths: list[int], + labels: torch.Tensor, + label_lengths: torch.Tensor, + blank_id: int, +) -> torch.Tensor: + """Compute mean CTC loss from flat (concatenated) logits with per-sample lengths.""" + log_probs = torch.log_softmax(flat_logits.float(), dim=-1) + log_probs_padded = log_probs.new_zeros(len(lengths), max(lengths), log_probs.shape[-1]) + offset = 0 + for i, length in enumerate(lengths): + log_probs_padded[i, :length] = log_probs[offset : offset + length] + offset += length + + lengths_t = torch.tensor(lengths, device=flat_logits.device) + return ( + F.ctc_loss( + log_probs_padded.transpose(0, 1), + labels, + lengths_t, + label_lengths, + blank=blank_id, + reduction="sum", + zero_infinity=True, + ) + / lengths_t.sum() + ) + + class GraniteSpeechNarQFormerCrossAttention(nn.Module): def __init__(self, config: GraniteSpeechNarProjectorConfig): super().__init__() @@ -373,25 +409,7 @@ def forward( logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) if labels is not None: - logits_padded = logits.new_zeros(len(lengths_list), max(lengths_list), logits.shape[-1]) - offset = 0 - for i, length in enumerate(lengths_list): - logits_padded[i, :length] = logits[offset : offset + length] - offset += length - - log_probs = torch.log_softmax(logits_padded.float(), dim=-1) - loss = ( - F.ctc_loss( - log_probs.transpose(0, 1), - labels, - lengths, - label_lengths, - blank=self.config.blank_token_id, - reduction="sum", - zero_infinity=True, - ) - / lengths.sum() - ) + loss = _ctc_loss_from_flat_logits(logits, lengths_list, labels, label_lengths, self.config.blank_token_id) return GraniteSpeechNarEncoderOutput( loss=loss, @@ -447,9 +465,8 @@ def _ctc_collapse_decode( ) -> list[torch.Tensor]: """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank.""" blank_id = self.config.blank_token_id - preds_flat = bpe_logits_flat.argmax(dim=-1) - per_sample = preds_flat.split(bpe_lengths) - return [(collapsed := torch.unique_consecutive(seq))[collapsed != blank_id] for seq in per_sample] + per_sample = bpe_logits_flat.split(bpe_lengths) + return [_ctc_greedy_decode(seq_logits, blank_id) for seq_logits in per_sample] def _add_insertion_slots(self, token_ids: torch.Tensor) -> torch.Tensor: """Insert blank tokens between each CTC token as editing slots for the LLM.""" @@ -567,28 +584,7 @@ def forward( loss = None if labels is not None: - log_probs = torch.log_softmax(text_logits.float(), dim=-1) - - log_probs_padded = log_probs.new_zeros(len(text_lengths), max(text_lengths), log_probs.shape[-1]) - offset = 0 - for i, tl in enumerate(text_lengths): - log_probs_padded[i, :tl] = log_probs[offset : offset + tl] - offset += tl - - input_lengths = torch.tensor(text_lengths, device=text_logits.device) - - loss = ( - F.ctc_loss( - log_probs_padded.transpose(0, 1), - labels, - input_lengths, - label_lengths, - blank=self.config.blank_token_id, - reduction="sum", - zero_infinity=True, - ) - / input_lengths.sum() - ) + loss = _ctc_loss_from_flat_logits(text_logits, text_lengths, labels, label_lengths, self.config.blank_token_id) if self.config.ce_loss_lambda > 0.0: ce_targets = torch.cat([self._add_insertion_slots(ids) for ids in ctc_token_ids]) @@ -629,11 +625,7 @@ def generate( ) blank_id = self.config.blank_token_id - preds = [] - for sample_logits in output.logits: - pred = torch.unique_consecutive(sample_logits.argmax(-1)) - pred = pred[pred != blank_id] - preds.append(pred) + preds = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in output.logits] return GraniteSpeechNarOutput( preds=preds, From 97e7c1bb18f3f40c7d649cb9c00dfc099712c1b9 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 27 May 2026 08:21:02 +0000 Subject: [PATCH 21/39] move the encoding code to GraniteSpeechNarModel, keep just the lm_head + losses in the ForCTC. Create a new output type for for the NarModel class. --- .../modeling_granite_speech_nar.py | 171 ++++++++++++------ .../modular_granite_speech_nar.py | 159 ++++++++++------ 2 files changed, 222 insertions(+), 108 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index a06afbfcf704..a23a2e8fc5fa 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -56,9 +56,30 @@ class GraniteSpeechNarEncoderOutput(ModelOutput): all_hidden_states: tuple[torch.FloatTensor, ...] | None = None +@dataclass +class GraniteSpeechNarModelOutput(ModelOutput): + """Output of GraniteSpeechNarModel (backbone without lm_head). + + Attributes: + last_hidden_state: Hidden states from the LLM backbone (flat, batch dim squeezed). + ctc_token_ids: List of CTC-collapsed encoder predictions per sample. + text_lengths: Per-sample text sequence lengths (after insertion slots). + audio_lengths: Per-sample projected audio lengths. + encoder_loss: Encoder BPE CTC loss (when encoder_ctc_loss_lambda > 0 and labels provided). + encoder_logits: Flat BPE CTC logits from the encoder (when output_encoder_logits=True). + """ + + last_hidden_state: torch.FloatTensor | None = None + ctc_token_ids: list[torch.Tensor] | None = None + text_lengths: list[int] | None = None + audio_lengths: list[int] | None = None + encoder_loss: torch.Tensor | None = None + encoder_logits: torch.Tensor | None = None + + @dataclass class GraniteSpeechNarOutput(ModelOutput): - """Output of the GraniteSpeechNarForASR model. + """Output of the GraniteSpeechNarForCTC model. Attributes: loss: Combined CTC + auxiliary losses (only when labels provided). @@ -66,6 +87,8 @@ class GraniteSpeechNarOutput(ModelOutput): logits: List of per-sample logit tensors from the LLM head. encoder_logits: Flat BPE CTC logits from the encoder. encoder_preds: List of CTC-collapsed encoder predictions per sample. + encoder_loss: Encoder BPE CTC loss component (for logging). + ce_loss: Cross-entropy auxiliary loss component (for logging). """ loss: torch.Tensor | None = None @@ -73,6 +96,8 @@ class GraniteSpeechNarOutput(ModelOutput): logits: list[torch.Tensor] | None = None encoder_logits: torch.Tensor | None = None encoder_preds: list[torch.Tensor] | None = None + encoder_loss: torch.Tensor | None = None + ce_loss: torch.Tensor | None = None ### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git @@ -877,6 +902,12 @@ def forward( ) +def _ctc_greedy_decode(logits: torch.Tensor, blank_id: int) -> torch.Tensor: + """CTC greedy decode a single sequence: argmax -> unique_consecutive -> remove blank.""" + pred = torch.unique_consecutive(logits.argmax(-1)) + return pred[pred != blank_id] + + @auto_docstring( custom_intro=""" The GraniteSpeechNar base model consisting of a conformer encoder, QFormer projector, @@ -897,31 +928,6 @@ def __init__(self, config: GraniteSpeechNarConfig): self.post_init() - -def _ctc_greedy_decode(logits: torch.Tensor, blank_id: int) -> torch.Tensor: - """CTC greedy decode a single sequence: argmax -> unique_consecutive -> remove blank.""" - pred = torch.unique_consecutive(logits.argmax(-1)) - return pred[pred != blank_id] - - -@auto_docstring( - custom_intro=""" - The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. - Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, - and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. - """ -) -class GraniteSpeechNarForCTC(GraniteSpeechNarPreTrainedModel): - _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} - - def __init__(self, config: GraniteSpeechNarConfig): - super().__init__(config) - - self.model = GraniteSpeechNarModel(config) - self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) - - self.post_init() - def _ctc_collapse_decode( self, bpe_logits_flat: torch.Tensor, @@ -950,7 +956,7 @@ def _build_flat_inputs( audio_lengths: list[int], ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: """Build flat (pad-free) LLM input: [audio_0, text_0, audio_1, text_1, ...]""" - embed_tokens = self.model.language_model.embed_tokens + embed_tokens = self.language_model.embed_tokens embeds_list = [] position_ids_list = [] @@ -971,33 +977,15 @@ def _build_flat_inputs( def forward( self, - *, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, output_encoder_logits: bool = False, **kwargs, - ) -> GraniteSpeechNarOutput: - r""" - Args: - input_features (`torch.Tensor` of shape `(batch_size, seq_len, input_dim)`): - Mel spectrogram features. - attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): - Encoder attention mask (1 for valid frames, 0 for padding). - labels (`torch.Tensor` of shape `(batch_size, max_label_len)`, *optional*): - Ground truth LLM token IDs for training. - label_lengths (`torch.Tensor` of shape `(batch_size,)`, *optional*): - Number of valid tokens per sample in `labels`. - output_encoder_logits (`bool`, *optional*, defaults to `False`): - Whether to return encoder BPE logits. When False, the large logits - tensor is freed early to reduce peak memory. - - Returns: - [`GraniteSpeechNarOutput`] - """ + ) -> GraniteSpeechNarModelOutput: encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None - enc_out = self.model.encoder( + enc_out = self.encoder( input_features=input_features, attention_mask=attention_mask, output_hidden_states=True, @@ -1010,7 +998,7 @@ def forward( encoder_lengths = attention_mask.sum(dim=1) - pool_window = self.model.encoder.config.bpe_pooling_window + pool_window = self.encoder.config.bpe_pooling_window bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) @@ -1022,38 +1010,105 @@ def forward( encoder_logits = enc_out.logits if output_encoder_logits else None del enc_out - audio_embeds = self.model.projector(multilayer_features) + audio_embeds = self.projector(multilayer_features) del multilayer_features if self.config.scale_projected_embeddings: embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) audio_embeds = audio_embeds / embedding_multiplier - audio_embeds = audio_embeds.to(self.model.language_model.embed_tokens.weight.dtype) + audio_embeds = audio_embeds.to(self.language_model.embed_tokens.weight.dtype) - audio_lengths = (encoder_lengths // self.model.projector.config.downsample_rate).cpu().tolist() + audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( ctc_token_ids, audio_embeds, audio_lengths ) - llm_out = self.model.language_model( + llm_out = self.language_model( inputs_embeds=flat_embeds, position_ids=flat_position_ids, ) - hidden_states = llm_out.last_hidden_state.squeeze(0) - all_logits = self.lm_head(hidden_states) + return GraniteSpeechNarModelOutput( + last_hidden_state=llm_out.last_hidden_state.squeeze(0), + ctc_token_ids=ctc_token_ids, + text_lengths=text_lengths, + audio_lengths=audio_lengths, + encoder_loss=encoder_loss, + encoder_logits=encoder_logits, + ) + + +@auto_docstring( + custom_intro=""" + The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. + Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, + and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. + """ +) +class GraniteSpeechNarForCTC(GraniteSpeechNarPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + + self.model = GraniteSpeechNarModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def forward( + self, + *, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + output_encoder_logits: bool = False, + **kwargs, + ) -> GraniteSpeechNarOutput: + r""" + Args: + input_features (`torch.Tensor` of shape `(batch_size, seq_len, input_dim)`): + Mel spectrogram features. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Encoder attention mask (1 for valid frames, 0 for padding). + labels (`torch.Tensor` of shape `(batch_size, max_label_len)`, *optional*): + Ground truth LLM token IDs for training. + label_lengths (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Number of valid tokens per sample in `labels`. + output_encoder_logits (`bool`, *optional*, defaults to `False`): + Whether to return encoder BPE logits. When False, the large logits + tensor is freed early to reduce peak memory. + + Returns: + [`GraniteSpeechNarOutput`] + """ + model_out = self.model( + input_features=input_features, + attention_mask=attention_mask, + labels=labels, + label_lengths=label_lengths, + output_encoder_logits=output_encoder_logits, + ) + + all_logits = self.lm_head(model_out.last_hidden_state) + + audio_lengths = model_out.audio_lengths + text_lengths = model_out.text_lengths segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) logits_per_sample = list(text_logits.split(text_lengths)) loss = None + encoder_loss = model_out.encoder_loss + ce_loss = None if labels is not None: loss = _ctc_loss_from_flat_logits( text_logits, text_lengths, labels, label_lengths, self.config.blank_token_id ) if self.config.ce_loss_lambda > 0.0: - ce_targets = torch.cat([self._add_insertion_slots(ids) for ids in ctc_token_ids]) + ce_targets = torch.cat([self.model._add_insertion_slots(ids) for ids in model_out.ctc_token_ids]) ce_loss = F.cross_entropy( text_logits, ce_targets.long(), @@ -1068,8 +1123,10 @@ def forward( return GraniteSpeechNarOutput( loss=loss, logits=logits_per_sample, - encoder_logits=encoder_logits, - encoder_preds=ctc_token_ids, + encoder_logits=model_out.encoder_logits, + encoder_preds=model_out.ctc_token_ids, + encoder_loss=encoder_loss, + ce_loss=ce_loss, ) @torch.inference_mode() diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 8d820bff3188..1d1e369553e3 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -50,9 +50,30 @@ class GraniteSpeechNarEncoderOutput(ModelOutput): all_hidden_states: tuple[torch.FloatTensor, ...] | None = None +@dataclass +class GraniteSpeechNarModelOutput(ModelOutput): + """Output of GraniteSpeechNarModel (backbone without lm_head). + + Attributes: + last_hidden_state: Hidden states from the LLM backbone (flat, batch dim squeezed). + ctc_token_ids: List of CTC-collapsed encoder predictions per sample. + text_lengths: Per-sample text sequence lengths (after insertion slots). + audio_lengths: Per-sample projected audio lengths. + encoder_loss: Encoder BPE CTC loss (when encoder_ctc_loss_lambda > 0 and labels provided). + encoder_logits: Flat BPE CTC logits from the encoder (when output_encoder_logits=True). + """ + + last_hidden_state: torch.FloatTensor | None = None + ctc_token_ids: list[torch.Tensor] | None = None + text_lengths: list[int] | None = None + audio_lengths: list[int] | None = None + encoder_loss: torch.Tensor | None = None + encoder_logits: torch.Tensor | None = None + + @dataclass class GraniteSpeechNarOutput(ModelOutput): - """Output of the GraniteSpeechNarForASR model. + """Output of the GraniteSpeechNarForCTC model. Attributes: loss: Combined CTC + auxiliary losses (only when labels provided). @@ -60,6 +81,8 @@ class GraniteSpeechNarOutput(ModelOutput): logits: List of per-sample logit tensors from the LLM head. encoder_logits: Flat BPE CTC logits from the encoder. encoder_preds: List of CTC-collapsed encoder predictions per sample. + encoder_loss: Encoder BPE CTC loss component (for logging). + ce_loss: Cross-entropy auxiliary loss component (for logging). """ loss: torch.Tensor | None = None @@ -67,6 +90,8 @@ class GraniteSpeechNarOutput(ModelOutput): logits: list[torch.Tensor] | None = None encoder_logits: torch.Tensor | None = None encoder_preds: list[torch.Tensor] | None = None + encoder_loss: torch.Tensor | None = None + ce_loss: torch.Tensor | None = None class GraniteSpeechNarConformerBlock(GraniteSpeechConformerBlock): @@ -439,25 +464,6 @@ def __init__(self, config: GraniteSpeechNarConfig): self.post_init() - -@auto_docstring( - custom_intro=""" - The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. - Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, - and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. - """ -) -class GraniteSpeechNarForCTC(GraniteSpeechNarPreTrainedModel): - _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} - - def __init__(self, config: GraniteSpeechNarConfig): - super().__init__(config) - - self.model = GraniteSpeechNarModel(config) - self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) - - self.post_init() - def _ctc_collapse_decode( self, bpe_logits_flat: torch.Tensor, @@ -486,7 +492,7 @@ def _build_flat_inputs( audio_lengths: list[int], ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: """Build flat (pad-free) LLM input: [audio_0, text_0, audio_1, text_1, ...]""" - embed_tokens = self.model.language_model.embed_tokens + embed_tokens = self.language_model.embed_tokens embeds_list = [] position_ids_list = [] @@ -507,33 +513,15 @@ def _build_flat_inputs( def forward( self, - *, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, output_encoder_logits: bool = False, **kwargs, - ) -> GraniteSpeechNarOutput: - r""" - Args: - input_features (`torch.Tensor` of shape `(batch_size, seq_len, input_dim)`): - Mel spectrogram features. - attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): - Encoder attention mask (1 for valid frames, 0 for padding). - labels (`torch.Tensor` of shape `(batch_size, max_label_len)`, *optional*): - Ground truth LLM token IDs for training. - label_lengths (`torch.Tensor` of shape `(batch_size,)`, *optional*): - Number of valid tokens per sample in `labels`. - output_encoder_logits (`bool`, *optional*, defaults to `False`): - Whether to return encoder BPE logits. When False, the large logits - tensor is freed early to reduce peak memory. - - Returns: - [`GraniteSpeechNarOutput`] - """ + ) -> GraniteSpeechNarModelOutput: encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None - enc_out = self.model.encoder( + enc_out = self.encoder( input_features=input_features, attention_mask=attention_mask, output_hidden_states=True, @@ -546,7 +534,7 @@ def forward( encoder_lengths = attention_mask.sum(dim=1) - pool_window = self.model.encoder.config.bpe_pooling_window + pool_window = self.encoder.config.bpe_pooling_window bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) @@ -558,36 +546,103 @@ def forward( encoder_logits = enc_out.logits if output_encoder_logits else None del enc_out - audio_embeds = self.model.projector(multilayer_features) + audio_embeds = self.projector(multilayer_features) del multilayer_features if self.config.scale_projected_embeddings: embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) audio_embeds = audio_embeds / embedding_multiplier - audio_embeds = audio_embeds.to(self.model.language_model.embed_tokens.weight.dtype) + audio_embeds = audio_embeds.to(self.language_model.embed_tokens.weight.dtype) - audio_lengths = (encoder_lengths // self.model.projector.config.downsample_rate).cpu().tolist() + audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( ctc_token_ids, audio_embeds, audio_lengths ) - llm_out = self.model.language_model( + llm_out = self.language_model( inputs_embeds=flat_embeds, position_ids=flat_position_ids, ) - hidden_states = llm_out.last_hidden_state.squeeze(0) - all_logits = self.lm_head(hidden_states) + return GraniteSpeechNarModelOutput( + last_hidden_state=llm_out.last_hidden_state.squeeze(0), + ctc_token_ids=ctc_token_ids, + text_lengths=text_lengths, + audio_lengths=audio_lengths, + encoder_loss=encoder_loss, + encoder_logits=encoder_logits, + ) + + +@auto_docstring( + custom_intro=""" + The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. + Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, + and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. + """ +) +class GraniteSpeechNarForCTC(GraniteSpeechNarPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + + self.model = GraniteSpeechNarModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def forward( + self, + *, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + output_encoder_logits: bool = False, + **kwargs, + ) -> GraniteSpeechNarOutput: + r""" + Args: + input_features (`torch.Tensor` of shape `(batch_size, seq_len, input_dim)`): + Mel spectrogram features. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Encoder attention mask (1 for valid frames, 0 for padding). + labels (`torch.Tensor` of shape `(batch_size, max_label_len)`, *optional*): + Ground truth LLM token IDs for training. + label_lengths (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Number of valid tokens per sample in `labels`. + output_encoder_logits (`bool`, *optional*, defaults to `False`): + Whether to return encoder BPE logits. When False, the large logits + tensor is freed early to reduce peak memory. + + Returns: + [`GraniteSpeechNarOutput`] + """ + model_out = self.model( + input_features=input_features, + attention_mask=attention_mask, + labels=labels, + label_lengths=label_lengths, + output_encoder_logits=output_encoder_logits, + ) + + all_logits = self.lm_head(model_out.last_hidden_state) + + audio_lengths = model_out.audio_lengths + text_lengths = model_out.text_lengths segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) logits_per_sample = list(text_logits.split(text_lengths)) loss = None + encoder_loss = model_out.encoder_loss + ce_loss = None if labels is not None: loss = _ctc_loss_from_flat_logits(text_logits, text_lengths, labels, label_lengths, self.config.blank_token_id) if self.config.ce_loss_lambda > 0.0: - ce_targets = torch.cat([self._add_insertion_slots(ids) for ids in ctc_token_ids]) + ce_targets = torch.cat([self.model._add_insertion_slots(ids) for ids in model_out.ctc_token_ids]) ce_loss = F.cross_entropy( text_logits, ce_targets.long(), @@ -602,8 +657,10 @@ def forward( return GraniteSpeechNarOutput( loss=loss, logits=logits_per_sample, - encoder_logits=encoder_logits, - encoder_preds=ctc_token_ids, + encoder_logits=model_out.encoder_logits, + encoder_preds=model_out.ctc_token_ids, + encoder_loss=encoder_loss, + ce_loss=ce_loss, ) @torch.inference_mode() From 367e0ee0475a321d2c6ded33cf2cc1f723cdbf95 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 27 May 2026 08:21:43 +0000 Subject: [PATCH 22/39] add logits scaling (as used in granite lm_head) --- .../models/granite_speech_nar/modeling_granite_speech_nar.py | 2 ++ .../models/granite_speech_nar/modular_granite_speech_nar.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index a23a2e8fc5fa..c00ec7178498 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -1092,6 +1092,8 @@ def forward( ) all_logits = self.lm_head(model_out.last_hidden_state) + if hasattr(self.config.text_config, "logits_scaling"): + all_logits = all_logits / self.config.text_config.logits_scaling audio_lengths = model_out.audio_lengths text_lengths = model_out.text_lengths diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 1d1e369553e3..6972e466cf6e 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -628,6 +628,8 @@ def forward( ) all_logits = self.lm_head(model_out.last_hidden_state) + if hasattr(self.config.text_config, "logits_scaling"): + all_logits = all_logits / self.config.text_config.logits_scaling audio_lengths = model_out.audio_lengths text_lengths = model_out.text_lengths From 3db6757ca77f1137785478dab1d3f6351a3d7835 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 27 May 2026 09:11:37 +0000 Subject: [PATCH 23/39] add multi-step editing support. (without recomputing audio_embeds) --- .../modeling_granite_speech_nar.py | 110 ++++++++++++------ .../modular_granite_speech_nar.py | 110 ++++++++++++------ .../test_modeling_granite_speech_nar.py | 20 ++++ 3 files changed, 166 insertions(+), 74 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index c00ec7178498..a706a73188b6 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -70,6 +70,7 @@ class GraniteSpeechNarModelOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor | None = None + audio_embeds: torch.FloatTensor | None = None ctc_token_ids: list[torch.Tensor] | None = None text_lengths: list[int] | None = None audio_lengths: list[int] | None = None @@ -94,6 +95,7 @@ class GraniteSpeechNarOutput(ModelOutput): loss: torch.Tensor | None = None preds: list[torch.Tensor] | None = None logits: list[torch.Tensor] | None = None + audio_embeds: torch.FloatTensor | None = None encoder_logits: torch.Tensor | None = None encoder_preds: list[torch.Tensor] | None = None encoder_loss: torch.Tensor | None = None @@ -977,46 +979,53 @@ def _build_flat_inputs( def forward( self, - input_features: torch.Tensor, + input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, output_encoder_logits: bool = False, + audio_embeds: torch.Tensor | None = None, + ctc_token_ids: list[torch.Tensor] | None = None, **kwargs, ) -> GraniteSpeechNarModelOutput: - encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None - enc_out = self.encoder( - input_features=input_features, - attention_mask=attention_mask, - output_hidden_states=True, - labels=encoder_labels, - label_lengths=label_lengths if encoder_labels is not None else None, - ) - if attention_mask is None: attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) - encoder_lengths = attention_mask.sum(dim=1) + encoder_loss = None + encoder_logits = None + + if audio_embeds is None: + encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None + enc_out = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + labels=encoder_labels, + label_lengths=label_lengths if encoder_labels is not None else None, + ) - pool_window = self.encoder.config.bpe_pooling_window - bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() - ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) + encoder_lengths = attention_mask.sum(dim=1) - multilayer_features = torch.cat( - [enc_out.all_hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 - ) + pool_window = self.encoder.config.bpe_pooling_window + bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() + ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) + + multilayer_features = torch.cat( + [enc_out.all_hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 + ) - encoder_loss = enc_out.loss - encoder_logits = enc_out.logits if output_encoder_logits else None - del enc_out + encoder_loss = enc_out.loss + encoder_logits = enc_out.logits if output_encoder_logits else None + del enc_out - audio_embeds = self.projector(multilayer_features) - del multilayer_features - if self.config.scale_projected_embeddings: - embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) - audio_embeds = audio_embeds / embedding_multiplier - audio_embeds = audio_embeds.to(self.language_model.embed_tokens.weight.dtype) + audio_embeds = self.projector(multilayer_features) + del multilayer_features + if self.config.scale_projected_embeddings: + embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) + audio_embeds = audio_embeds / embedding_multiplier + audio_embeds = audio_embeds.to(self.language_model.embed_tokens.weight.dtype) + encoder_lengths = attention_mask.sum(dim=1) audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( @@ -1030,6 +1039,7 @@ def forward( return GraniteSpeechNarModelOutput( last_hidden_state=llm_out.last_hidden_state.squeeze(0), + audio_embeds=audio_embeds, ctc_token_ids=ctc_token_ids, text_lengths=text_lengths, audio_lengths=audio_lengths, @@ -1059,11 +1069,13 @@ def __init__(self, config: GraniteSpeechNarConfig): def forward( self, *, - input_features: torch.Tensor, + input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, output_encoder_logits: bool = False, + audio_embeds: torch.Tensor | None = None, + ctc_token_ids: list[torch.Tensor] | None = None, **kwargs, ) -> GraniteSpeechNarOutput: r""" @@ -1079,6 +1091,12 @@ def forward( output_encoder_logits (`bool`, *optional*, defaults to `False`): Whether to return encoder BPE logits. When False, the large logits tensor is freed early to reduce peak memory. + audio_embeds (`torch.Tensor`, *optional*): + Pre-computed projected audio embeddings. When provided, encoder and + projector are skipped (used for multi-step editing). + ctc_token_ids (`list[torch.Tensor]`, *optional*): + Pre-computed CTC token predictions to use as text input. When provided + with `audio_embeds`, replaces encoder CTC predictions. Returns: [`GraniteSpeechNarOutput`] @@ -1089,6 +1107,8 @@ def forward( labels=labels, label_lengths=label_lengths, output_encoder_logits=output_encoder_logits, + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, ) all_logits = self.lm_head(model_out.last_hidden_state) @@ -1125,6 +1145,7 @@ def forward( return GraniteSpeechNarOutput( loss=loss, logits=logits_per_sample, + audio_embeds=model_out.audio_embeds, encoder_logits=model_out.encoder_logits, encoder_preds=model_out.ctc_token_ids, encoder_loss=encoder_loss, @@ -1137,26 +1158,41 @@ def generate( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, output_encoder_logits: bool = False, + num_editing_steps: int = 1, ) -> GraniteSpeechNarOutput: - """Single-pass non-autoregressive inference: forward + CTC collapse on LLM output. + """Non-autoregressive inference with iterative editing. + + Each editing step collapses the LLM output via CTC and feeds it back + as input for refinement, reusing cached audio embeddings. Returns token ID tensors in `preds`. Use `GraniteSpeechNarProcessor.batch_decode()` to convert to strings. """ - output = self.forward( - input_features=input_features, - attention_mask=attention_mask, - output_encoder_logits=output_encoder_logits, - ) + audio_embeds = None + ctc_token_ids = None + encoder_preds = None + + for step in range(num_editing_steps): + output = self.forward( + input_features=input_features, + attention_mask=attention_mask, + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, + output_encoder_logits=(output_encoder_logits and step == 0), + ) - blank_id = self.config.blank_token_id - preds = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in output.logits] + blank_id = self.config.blank_token_id + ctc_token_ids = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in output.logits] + audio_embeds = output.audio_embeds + + if step == 0: + encoder_preds = output.encoder_preds return GraniteSpeechNarOutput( - preds=preds, + preds=ctc_token_ids, logits=output.logits, encoder_logits=output.encoder_logits, - encoder_preds=output.encoder_preds, + encoder_preds=encoder_preds, ) diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 6972e466cf6e..7218bc4db141 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -64,6 +64,7 @@ class GraniteSpeechNarModelOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor | None = None + audio_embeds: torch.FloatTensor | None = None ctc_token_ids: list[torch.Tensor] | None = None text_lengths: list[int] | None = None audio_lengths: list[int] | None = None @@ -88,6 +89,7 @@ class GraniteSpeechNarOutput(ModelOutput): loss: torch.Tensor | None = None preds: list[torch.Tensor] | None = None logits: list[torch.Tensor] | None = None + audio_embeds: torch.FloatTensor | None = None encoder_logits: torch.Tensor | None = None encoder_preds: list[torch.Tensor] | None = None encoder_loss: torch.Tensor | None = None @@ -513,46 +515,53 @@ def _build_flat_inputs( def forward( self, - input_features: torch.Tensor, + input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, output_encoder_logits: bool = False, + audio_embeds: torch.Tensor | None = None, + ctc_token_ids: list[torch.Tensor] | None = None, **kwargs, ) -> GraniteSpeechNarModelOutput: - encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None - enc_out = self.encoder( - input_features=input_features, - attention_mask=attention_mask, - output_hidden_states=True, - labels=encoder_labels, - label_lengths=label_lengths if encoder_labels is not None else None, - ) - if attention_mask is None: attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) - encoder_lengths = attention_mask.sum(dim=1) + encoder_loss = None + encoder_logits = None + + if audio_embeds is None: + encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None + enc_out = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + labels=encoder_labels, + label_lengths=label_lengths if encoder_labels is not None else None, + ) - pool_window = self.encoder.config.bpe_pooling_window - bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() - ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) + encoder_lengths = attention_mask.sum(dim=1) - multilayer_features = torch.cat( - [enc_out.all_hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 - ) + pool_window = self.encoder.config.bpe_pooling_window + bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() + ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) + + multilayer_features = torch.cat( + [enc_out.all_hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 + ) - encoder_loss = enc_out.loss - encoder_logits = enc_out.logits if output_encoder_logits else None - del enc_out + encoder_loss = enc_out.loss + encoder_logits = enc_out.logits if output_encoder_logits else None + del enc_out - audio_embeds = self.projector(multilayer_features) - del multilayer_features - if self.config.scale_projected_embeddings: - embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) - audio_embeds = audio_embeds / embedding_multiplier - audio_embeds = audio_embeds.to(self.language_model.embed_tokens.weight.dtype) + audio_embeds = self.projector(multilayer_features) + del multilayer_features + if self.config.scale_projected_embeddings: + embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) + audio_embeds = audio_embeds / embedding_multiplier + audio_embeds = audio_embeds.to(self.language_model.embed_tokens.weight.dtype) + encoder_lengths = attention_mask.sum(dim=1) audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( @@ -566,6 +575,7 @@ def forward( return GraniteSpeechNarModelOutput( last_hidden_state=llm_out.last_hidden_state.squeeze(0), + audio_embeds=audio_embeds, ctc_token_ids=ctc_token_ids, text_lengths=text_lengths, audio_lengths=audio_lengths, @@ -595,11 +605,13 @@ def __init__(self, config: GraniteSpeechNarConfig): def forward( self, *, - input_features: torch.Tensor, + input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, output_encoder_logits: bool = False, + audio_embeds: torch.Tensor | None = None, + ctc_token_ids: list[torch.Tensor] | None = None, **kwargs, ) -> GraniteSpeechNarOutput: r""" @@ -615,6 +627,12 @@ def forward( output_encoder_logits (`bool`, *optional*, defaults to `False`): Whether to return encoder BPE logits. When False, the large logits tensor is freed early to reduce peak memory. + audio_embeds (`torch.Tensor`, *optional*): + Pre-computed projected audio embeddings. When provided, encoder and + projector are skipped (used for multi-step editing). + ctc_token_ids (`list[torch.Tensor]`, *optional*): + Pre-computed CTC token predictions to use as text input. When provided + with `audio_embeds`, replaces encoder CTC predictions. Returns: [`GraniteSpeechNarOutput`] @@ -625,6 +643,8 @@ def forward( labels=labels, label_lengths=label_lengths, output_encoder_logits=output_encoder_logits, + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, ) all_logits = self.lm_head(model_out.last_hidden_state) @@ -659,6 +679,7 @@ def forward( return GraniteSpeechNarOutput( loss=loss, logits=logits_per_sample, + audio_embeds=model_out.audio_embeds, encoder_logits=model_out.encoder_logits, encoder_preds=model_out.ctc_token_ids, encoder_loss=encoder_loss, @@ -671,26 +692,41 @@ def generate( input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, output_encoder_logits: bool = False, + num_editing_steps: int = 1, ) -> GraniteSpeechNarOutput: - """Single-pass non-autoregressive inference: forward + CTC collapse on LLM output. + """Non-autoregressive inference with iterative editing. + + Each editing step collapses the LLM output via CTC and feeds it back + as input for refinement, reusing cached audio embeddings. Returns token ID tensors in `preds`. Use `GraniteSpeechNarProcessor.batch_decode()` to convert to strings. """ - output = self.forward( - input_features=input_features, - attention_mask=attention_mask, - output_encoder_logits=output_encoder_logits, - ) + audio_embeds = None + ctc_token_ids = None + encoder_preds = None + + for step in range(num_editing_steps): + output = self.forward( + input_features=input_features, + attention_mask=attention_mask, + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, + output_encoder_logits=(output_encoder_logits and step == 0), + ) - blank_id = self.config.blank_token_id - preds = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in output.logits] + blank_id = self.config.blank_token_id + ctc_token_ids = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in output.logits] + audio_embeds = output.audio_embeds + + if step == 0: + encoder_preds = output.encoder_preds return GraniteSpeechNarOutput( - preds=preds, + preds=ctc_token_ids, logits=output.logits, encoder_logits=output.encoder_logits, - encoder_preds=output.encoder_preds, + encoder_preds=encoder_preds, ) diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 28a6b466c9e4..fb2347825f85 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -258,6 +258,26 @@ def test_generate(self): assert len(output.preds) == 1 assert isinstance(output.preds[0], torch.Tensor) + def test_generate_multi_step(self): + config = _make_small_config() + model = GraniteSpeechNarForCTC(config).eval() + + features = torch.randn(2, 80, 160) + mask = torch.ones(2, 80, dtype=torch.bool) + mask[1, 60:] = False + + out1 = model.generate(input_features=features, attention_mask=mask, num_editing_steps=1) + out2 = model.generate(input_features=features, attention_mask=mask, num_editing_steps=3) + + assert out1.preds is not None + assert out2.preds is not None + assert len(out1.preds) == 2 + assert len(out2.preds) == 2 + # Multi-step should produce valid predictions (may or may not differ) + for pred in out2.preds: + assert isinstance(pred, torch.Tensor) + assert pred.ndim == 1 + def test_loss(self): config = _make_small_config() model = GraniteSpeechNarForCTC(config).train() From 17bc4b6dcc2b7eb93b55b6d34b7ce83724912978 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 27 May 2026 10:18:02 +0000 Subject: [PATCH 24/39] minor --- .../granite_speech_nar/modular_granite_speech_nar.py | 8 ++++++-- utils/check_repo.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 7218bc4db141..adc3a9c07781 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -436,7 +436,9 @@ def forward( logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) if labels is not None: - loss = _ctc_loss_from_flat_logits(logits, lengths_list, labels, label_lengths, self.config.blank_token_id) + loss = _ctc_loss_from_flat_logits( + logits, lengths_list, labels, label_lengths, self.config.blank_token_id + ) return GraniteSpeechNarEncoderOutput( loss=loss, @@ -661,7 +663,9 @@ def forward( encoder_loss = model_out.encoder_loss ce_loss = None if labels is not None: - loss = _ctc_loss_from_flat_logits(text_logits, text_lengths, labels, label_lengths, self.config.blank_token_id) + loss = _ctc_loss_from_flat_logits( + text_logits, text_lengths, labels, label_lengths, self.config.blank_token_id + ) if self.config.ce_loss_lambda > 0.0: ce_targets = torch.cat([self.model._add_insertion_slots(ids) for ids in model_out.ctc_token_ids]) diff --git a/utils/check_repo.py b/utils/check_repo.py index b17ebad3c70d..9a3cf828d7f0 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -520,7 +520,7 @@ "UVDocBridge", # Building part of a bigger model, tested implicitly through UVDocModel "Granite4VisionTextModel", # Building part of bigger (tested) model. "GraniteSpeechNarModel", # Building part of bigger (tested) model. - "GraniteSpeechNarLM", # Building part of bigger (tested) model. + "GraniteSpeechNarLanguageModel", # Building part of bigger (tested) model. ] From 77eb54348cc64f4c0ecd8485bcc60153b6926a07 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 27 May 2026 11:13:32 +0000 Subject: [PATCH 25/39] - resolve a depracation warning - frame stacking as a parameter, pad instead of truncating audio. --- .../feature_extraction_granite_speech_nar.py | 18 +++++++++++++----- .../processing_granite_speech_nar.py | 1 - 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py index 328ba2578112..c11791599aca 100644 --- a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py @@ -42,6 +42,7 @@ def __init__( win_length: int = 400, hop_length: int = 160, n_mels: int = 80, + frame_stacking: int = 2, **kwargs, ): requires_backends(self, ["torch", "torchaudio"]) @@ -51,6 +52,7 @@ def __init__( self.win_length = win_length self.hop_length = hop_length self.n_mels = n_mels + self.frame_stacking = frame_stacking self.mel_filters = torchaudio.transforms.MelSpectrogram( sample_rate=sampling_rate, n_fft=n_fft, @@ -59,16 +61,22 @@ def __init__( n_mels=n_mels, ) + def get_num_encoder_frames(self, num_raw_samples): + mel_frames = num_raw_samples // self.hop_length + 1 + return -(-mel_frames // self.frame_stacking) + def _extract_features(self, raw_audio: "torch.Tensor") -> "torch.Tensor": with torch.no_grad(): mel_filters = self.mel_filters.to(raw_audio.device) - B, T = raw_audio.shape - l = 2 * (T // (2 * self.hop_length)) - mel = mel_filters(raw_audio.float())[..., :l] + mel = mel_filters(raw_audio.float()) + num_frames = mel.shape[-1] + remainder = num_frames % self.frame_stacking + if remainder != 0: + mel = torch.nn.functional.pad(mel, (0, self.frame_stacking - remainder)) logmel = mel.transpose(-1, -2).clamp_min_(1e-10).log10_() mx = logmel.amax(dim=(-2, -1), keepdim=True) logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) - return logmel.reshape(B, -1, 2 * self.n_mels) + return logmel.reshape(logmel.shape[0], -1, self.frame_stacking * self.n_mels) def __call__( self, @@ -84,7 +92,7 @@ def __call__( raise ValueError(f"Expected 1-D or 2-D tensor, got {audios.ndim}-D") raw_lengths = [a.shape[-1] for a in audios] - encoder_frame_counts = [l // (2 * self.hop_length) for l in raw_lengths] + encoder_frame_counts = [self.get_num_encoder_frames(l) for l in raw_lengths] raw_audio = torch.nn.utils.rnn.pad_sequence( [a.squeeze(0) if a.ndim > 1 else a for a in audios], diff --git a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py index b8cfc95f5d3a..679b984d22f9 100644 --- a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py @@ -26,7 +26,6 @@ class GraniteSpeechNarProcessor(ProcessorMixin): """Processor combining audio feature extraction and tokenizer for GraniteSpeechNar.""" - feature_extractor_class = "GraniteSpeechNarFeatureExtractor" tokenizer_class = "AutoTokenizer" def __init__(self, feature_extractor: GraniteSpeechNarFeatureExtractor, tokenizer=None, **kwargs): From 36f274541300d61f1957ded964639ac136ff7a43 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 12:18:01 +0300 Subject: [PATCH 26/39] Update src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> --- .../models/granite_speech_nar/modular_granite_speech_nar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index adc3a9c07781..1cd6733f0755 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -464,7 +464,7 @@ def __init__(self, config: GraniteSpeechNarConfig): text_config = config.text_config if hasattr(config, "_attn_implementation"): text_config._attn_implementation = config._attn_implementation - self.language_model = GraniteSpeechNarLanguageModel._from_config(text_config) + self.language_model = GraniteSpeechNarLanguageModel(config.text_config) self.post_init() From 70c06b58e3fcb36352d0a186ad3e74b104d54c47 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 12:18:23 +0300 Subject: [PATCH 27/39] Update src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> --- .../models/granite_speech_nar/modular_granite_speech_nar.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 1cd6733f0755..04a9236e595a 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -461,9 +461,6 @@ def __init__(self, config: GraniteSpeechNarConfig): self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) self.projector = GraniteSpeechNarProjector(config.projector_config) - text_config = config.text_config - if hasattr(config, "_attn_implementation"): - text_config._attn_implementation = config._attn_implementation self.language_model = GraniteSpeechNarLanguageModel(config.text_config) self.post_init() From 3ffc779892e784afedc3b56a3017bc3be1cf1428 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 11:26:12 +0000 Subject: [PATCH 28/39] simplify: encoder always has bpe head --- .../configuration_granite_speech_nar.py | 6 +-- .../modeling_granite_speech_nar.py | 38 +++++++------------ .../modular_granite_speech_nar.py | 33 +++++++--------- .../test_modeling_granite_speech_nar.py | 22 +---------- 4 files changed, 31 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py index f6667c40532d..f7b64a715d1d 100644 --- a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py @@ -41,8 +41,8 @@ class GraniteSpeechNarEncoderConfig(PreTrainedConfig): self_conditioning_layer (`int`, *optional*): Layer index at which self-conditioning (mid-layer CTC feedback) is applied. Defaults to `num_layers // 2`. - bpe_output_dim (`int`, *optional*): - Vocabulary size for the BPE CTC head (same as LLM vocab, blank reuses eos_token_id). If None, BPE head is disabled. + bpe_output_dim (`int`, *optional*, defaults to 49153): + Vocabulary size for the BPE CTC head (same as LLM vocab, blank reuses eos_token_id). bpe_pooling_window (`int`, *optional*, defaults to 4): Window size for posterior-weighted pooling before the BPE CTC head. blank_token_id (`int`, *optional*): @@ -80,7 +80,7 @@ class GraniteSpeechNarEncoderConfig(PreTrainedConfig): conv_kernel_size: int = 15 conv_expansion_factor: int = 2 self_conditioning_layer: int | None = None - bpe_output_dim: int | None = None + bpe_output_dim: int = 49153 bpe_pooling_window: int = 4 blank_token_id: int | None = None initializer_range: float = 0.02 diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index a706a73188b6..79ec98203c05 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -837,9 +837,7 @@ def __init__(self, config: GraniteSpeechNarEncoderConfig): self.layers = nn.ModuleList([GraniteSpeechNarConformerBlock(config) for _ in range(config.num_layers)]) self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True) self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True) - self.out_bpe = None - if config.bpe_output_dim is not None: - self.out_bpe = nn.Linear(config.hidden_dim, config.bpe_output_dim, bias=True) + self.out_bpe = nn.Linear(config.hidden_dim, config.bpe_output_dim, bias=True) self.dropout = nn.Dropout(config.pred_dropout) self.post_init() @@ -857,7 +855,6 @@ def forward( hidden_states = self.input_linear(input_features.to(self.dtype)) all_hidden_states = (hidden_states,) if output_hidden_states else None - blank_probs = None context_size = self.config.context_size seq = torch.arange(context_size, device=hidden_states.device) @@ -878,23 +875,19 @@ def forward( hidden_states = self.dropout(hidden_states) - logits = None - loss = None - if self.out_bpe is not None and blank_probs is not None: - pool_window = self.config.bpe_pooling_window - importance = 1.0 - blank_probs - pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( - hidden_states.dtype - ) - encoder_lengths = attention_mask.sum(dim=1) - lengths = -(encoder_lengths // -pool_window) - lengths_list = lengths.tolist() - logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) + pool_window = self.config.bpe_pooling_window + importance = 1.0 - blank_probs + pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( + hidden_states.dtype + ) + encoder_lengths = attention_mask.sum(dim=1) + lengths = -(encoder_lengths // -pool_window) + lengths_list = lengths.tolist() + logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) - if labels is not None: - loss = _ctc_loss_from_flat_logits( - logits, lengths_list, labels, label_lengths, self.config.blank_token_id - ) + loss = None + if labels is not None: + loss = _ctc_loss_from_flat_logits(logits, lengths_list, labels, label_lengths, self.config.blank_token_id) return GraniteSpeechNarEncoderOutput( loss=loss, @@ -923,10 +916,7 @@ def __init__(self, config: GraniteSpeechNarConfig): self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) self.projector = GraniteSpeechNarProjector(config.projector_config) - text_config = config.text_config - if hasattr(config, "_attn_implementation"): - text_config._attn_implementation = config._attn_implementation - self.language_model = GraniteSpeechNarLanguageModel._from_config(text_config) + self.language_model = GraniteSpeechNarLanguageModel(config.text_config) self.post_init() diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 04a9236e595a..911145d71fa1 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -381,9 +381,7 @@ def __init__(self, config: GraniteSpeechNarEncoderConfig): self.layers = nn.ModuleList([GraniteSpeechNarConformerBlock(config) for _ in range(config.num_layers)]) self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True) self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True) - self.out_bpe = None - if config.bpe_output_dim is not None: - self.out_bpe = nn.Linear(config.hidden_dim, config.bpe_output_dim, bias=True) + self.out_bpe = nn.Linear(config.hidden_dim, config.bpe_output_dim, bias=True) self.dropout = nn.Dropout(config.pred_dropout) self.post_init() @@ -401,7 +399,6 @@ def forward( hidden_states = self.input_linear(input_features.to(self.dtype)) all_hidden_states = (hidden_states,) if output_hidden_states else None - blank_probs = None context_size = self.config.context_size seq = torch.arange(context_size, device=hidden_states.device) @@ -422,23 +419,19 @@ def forward( hidden_states = self.dropout(hidden_states) - logits = None - loss = None - if self.out_bpe is not None and blank_probs is not None: - pool_window = self.config.bpe_pooling_window - importance = 1.0 - blank_probs - pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( - hidden_states.dtype - ) - encoder_lengths = attention_mask.sum(dim=1) - lengths = -(encoder_lengths // -pool_window) - lengths_list = lengths.tolist() - logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) + pool_window = self.config.bpe_pooling_window + importance = 1.0 - blank_probs + pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( + hidden_states.dtype + ) + encoder_lengths = attention_mask.sum(dim=1) + lengths = -(encoder_lengths // -pool_window) + lengths_list = lengths.tolist() + logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) - if labels is not None: - loss = _ctc_loss_from_flat_logits( - logits, lengths_list, labels, label_lengths, self.config.blank_token_id - ) + loss = None + if labels is not None: + loss = _ctc_loss_from_flat_logits(logits, lengths_list, labels, label_lengths, self.config.blank_token_id) return GraniteSpeechNarEncoderOutput( loss=loss, diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index fb2347825f85..e9843fae0582 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -101,7 +101,7 @@ def test_encoder_config_defaults(self): assert config.num_layers == 16 assert config.hidden_dim == 1024 assert config.self_conditioning_layer == 8 - assert config.bpe_output_dim is None + assert config.bpe_output_dim == 49153 def test_projector_config_defaults(self): config = GraniteSpeechNarProjectorConfig() @@ -161,26 +161,6 @@ def test_output_shapes(self): assert out.all_hidden_states is not None assert len(out.all_hidden_states) == 5 # input + 4 layers - def test_no_bpe_head(self): - config = GraniteSpeechNarEncoderConfig( - num_layers=2, - hidden_dim=64, - num_heads=4, - dim_head=16, - input_dim=160, - output_dim=348, - context_size=50, - self_conditioning_layer=1, - bpe_output_dim=None, - ) - encoder = GraniteSpeechNarCTCEncoder(config).eval() - - features = torch.randn(1, 50, 160) - out = encoder(features, output_hidden_states=False) - - assert out.logits is None - assert out.all_hidden_states is None - # === Projector tests === From 0d6af4182f6a0e36edb39cea1e80905512e21263 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 12:00:00 +0000 Subject: [PATCH 29/39] use @capture_outputs for encoder hidden states Co-Authored-By: Claude Opus 4.6 (1M context) --- .../modeling_granite_speech_nar.py | 16 +++++++--------- .../modular_granite_speech_nar.py | 17 ++++++++--------- .../test_modeling_granite_speech_nar.py | 4 ++-- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index 79ec98203c05..385f6522b9ef 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -53,7 +53,7 @@ class GraniteSpeechNarEncoderOutput(ModelOutput): loss: torch.Tensor | None = None logits: torch.FloatTensor | None = None last_hidden_state: torch.FloatTensor | None = None - all_hidden_states: tuple[torch.FloatTensor, ...] | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None @dataclass @@ -830,6 +830,9 @@ class GraniteSpeechNarCTCEncoder(GraniteSpeechNarPreTrainedModel): """Conformer encoder with BPE CTC head and multi-layer output.""" config_class = GraniteSpeechNarEncoderConfig + _can_record_outputs = { + "hidden_states": GraniteSpeechNarConformerBlock, + } def __init__(self, config: GraniteSpeechNarEncoderConfig): super().__init__(config) @@ -841,11 +844,11 @@ def __init__(self, config: GraniteSpeechNarEncoderConfig): self.dropout = nn.Dropout(config.pred_dropout) self.post_init() + @capture_outputs def forward( self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, - output_hidden_states: bool | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, **kwargs, @@ -854,7 +857,6 @@ def forward( attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) hidden_states = self.input_linear(input_features.to(self.dtype)) - all_hidden_states = (hidden_states,) if output_hidden_states else None context_size = self.config.context_size seq = torch.arange(context_size, device=hidden_states.device) @@ -868,10 +870,7 @@ def forward( mid_logits = self.out(self.dropout(hidden_states)) mid_probs = torch.softmax(mid_logits.float(), dim=-1) blank_probs = mid_probs[:, :, 0] - hidden_states = hidden_states + self.out_mid(mid_probs.to(hidden_states.dtype)) - - if output_hidden_states: - all_hidden_states += (hidden_states,) + hidden_states += self.out_mid(mid_probs.to(hidden_states.dtype)) hidden_states = self.dropout(hidden_states) @@ -893,7 +892,6 @@ def forward( loss=loss, logits=logits, last_hidden_state=hidden_states, - all_hidden_states=all_hidden_states, ) @@ -1001,7 +999,7 @@ def forward( ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) multilayer_features = torch.cat( - [enc_out.all_hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 + [enc_out.hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 ) encoder_loss = enc_out.loss diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 911145d71fa1..5b80be15c670 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -28,6 +28,7 @@ from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ...utils.output_capturing import capture_outputs from ..granite.modeling_granite import GraniteAttention, GraniteDecoderLayer, GraniteModel from ..granite_speech.modeling_granite_speech import GraniteSpeechConformerBlock from .configuration_granite_speech_nar import ( @@ -47,7 +48,7 @@ class GraniteSpeechNarEncoderOutput(ModelOutput): loss: torch.Tensor | None = None logits: torch.FloatTensor | None = None last_hidden_state: torch.FloatTensor | None = None - all_hidden_states: tuple[torch.FloatTensor, ...] | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None @dataclass @@ -374,6 +375,9 @@ class GraniteSpeechNarCTCEncoder(GraniteSpeechNarPreTrainedModel): """Conformer encoder with BPE CTC head and multi-layer output.""" config_class = GraniteSpeechNarEncoderConfig + _can_record_outputs = { + "hidden_states": GraniteSpeechNarConformerBlock, + } def __init__(self, config: GraniteSpeechNarEncoderConfig): super().__init__(config) @@ -385,11 +389,11 @@ def __init__(self, config: GraniteSpeechNarEncoderConfig): self.dropout = nn.Dropout(config.pred_dropout) self.post_init() + @capture_outputs def forward( self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None, - output_hidden_states: bool | None = None, labels: torch.Tensor | None = None, label_lengths: torch.Tensor | None = None, **kwargs, @@ -398,7 +402,6 @@ def forward( attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) hidden_states = self.input_linear(input_features.to(self.dtype)) - all_hidden_states = (hidden_states,) if output_hidden_states else None context_size = self.config.context_size seq = torch.arange(context_size, device=hidden_states.device) @@ -412,10 +415,7 @@ def forward( mid_logits = self.out(self.dropout(hidden_states)) mid_probs = torch.softmax(mid_logits.float(), dim=-1) blank_probs = mid_probs[:, :, 0] - hidden_states = hidden_states + self.out_mid(mid_probs.to(hidden_states.dtype)) - - if output_hidden_states: - all_hidden_states += (hidden_states,) + hidden_states += self.out_mid(mid_probs.to(hidden_states.dtype)) hidden_states = self.dropout(hidden_states) @@ -437,7 +437,6 @@ def forward( loss=loss, logits=logits, last_hidden_state=hidden_states, - all_hidden_states=all_hidden_states, ) @@ -539,7 +538,7 @@ def forward( ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) multilayer_features = torch.cat( - [enc_out.all_hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 + [enc_out.hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 ) encoder_loss = enc_out.loss diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index e9843fae0582..8506728fba48 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -158,8 +158,8 @@ def test_output_shapes(self): assert out.logits is not None assert out.logits.shape[1] == 100 - assert out.all_hidden_states is not None - assert len(out.all_hidden_states) == 5 # input + 4 layers + assert out.hidden_states is not None + assert len(out.hidden_states) == 5 # input + 4 layers # === Projector tests === From 67b3303e84bba059851f8e7620207482e0f11c21 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 15:33:12 +0300 Subject: [PATCH 30/39] Update src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> --- .../models/granite_speech_nar/modular_granite_speech_nar.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 5b80be15c670..1c97ed1e0fc3 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -318,12 +318,6 @@ class GraniteSpeechNarLanguageModel(GraniteModel): attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. """ - def __init__(self, config): - super().__init__(config) - self.layers = nn.ModuleList( - [GraniteSpeechNarDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - def forward( self, input_ids: torch.LongTensor | None = None, From 7d03457193ba3448fd7982439a3439ef8be087cf Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 15:33:26 +0300 Subject: [PATCH 31/39] Update src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> --- .../models/granite_speech_nar/modular_granite_speech_nar.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 1c97ed1e0fc3..728bfea4fbe7 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -302,12 +302,6 @@ def __init__(self, config, layer_idx=None): self.is_causal = False -class GraniteSpeechNarDecoderLayer(GraniteDecoderLayer): - """GraniteDecoderLayer using bidirectional attention.""" - - def __init__(self, config, layer_idx: int): - super().__init__(config, layer_idx) - self.self_attn = GraniteSpeechNarAttention(config=config, layer_idx=layer_idx) class GraniteSpeechNarLanguageModel(GraniteModel): From fd8fa8af05aa4712cec8f728f7fc94ef5b915bd3 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 15:33:39 +0300 Subject: [PATCH 32/39] Update src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> --- .../models/granite_speech_nar/modular_granite_speech_nar.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 728bfea4fbe7..039a0b387143 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -295,8 +295,6 @@ class GraniteSpeechNarPreTrainedModel(PreTrainedModel): class GraniteSpeechNarAttention(GraniteAttention): """GraniteAttention with is_causal=False for bidirectional attention.""" - is_causal = False - def __init__(self, config, layer_idx=None): super().__init__(config, layer_idx=layer_idx) self.is_causal = False From 212668f2b6f84ef156476beee46878476f2659a3 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 12:39:50 +0000 Subject: [PATCH 33/39] simplify setting is_causal=False --- .../granite_speech_nar/modeling_granite_speech_nar.py | 8 ++------ .../granite_speech_nar/modular_granite_speech_nar.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index 385f6522b9ef..0ca6282a89b9 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -472,8 +472,6 @@ def eager_attention_forward( class GraniteSpeechNarAttention(nn.Module): """GraniteAttention with is_causal=False for bidirectional attention.""" - is_causal = False - def __init__(self, config, layer_idx=None): super().__init__() self.config = config @@ -576,9 +574,7 @@ def forward(self, x): class GraniteSpeechNarDecoderLayer(GradientCheckpointingLayer): - """GraniteDecoderLayer using bidirectional attention.""" - - def __init__(self, config, layer_idx: int): + def __init__(self, config: GraniteSpeechNarConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = GraniteSpeechNarAttention(config=config, layer_idx=layer_idx) @@ -715,7 +711,7 @@ class GraniteSpeechNarLanguageModel(GraniteSpeechNarPreTrainedModel): attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. """ - def __init__(self, config): + def __init__(self, config: GraniteSpeechNarConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 039a0b387143..3533db477b5a 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -29,7 +29,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from ...utils.output_capturing import capture_outputs -from ..granite.modeling_granite import GraniteAttention, GraniteDecoderLayer, GraniteModel +from ..granite.modeling_granite import GraniteAttention, GraniteModel from ..granite_speech.modeling_granite_speech import GraniteSpeechConformerBlock from .configuration_granite_speech_nar import ( GraniteSpeechNarConfig, From c1223d46abb7fe2d84eaf3b5a9f4d119e99dc9fb Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 13:01:07 +0000 Subject: [PATCH 34/39] use more standard tests --- .../test_modeling_granite_speech_nar.py | 674 ++++++++++-------- utils/check_repo.py | 3 +- 2 files changed, 389 insertions(+), 288 deletions(-) diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 8506728fba48..3e476097398a 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -11,41 +11,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for GraniteSpeechNar model.""" +"""Testing suite for the PyTorch GraniteSpeechNar model.""" -import math +import tempfile import unittest -import pytest -import torch - -from transformers import ( - AutoConfig, - AutoModel, - AutoProcessor, - GraniteConfig, - GraniteSpeechNarConfig, -) -from transformers.models.granite_speech_nar.configuration_granite_speech_nar import ( - GraniteSpeechNarEncoderConfig, - GraniteSpeechNarProjectorConfig, -) -from transformers.models.granite_speech_nar.modeling_granite_speech_nar import ( - GraniteSpeechNarCTCEncoder, - GraniteSpeechNarForCTC, - GraniteSpeechNarOutput, - GraniteSpeechNarProjector, -) -from transformers.testing_utils import require_torch, slow, torch_device -from transformers.utils import is_datasets_available +from transformers import is_datasets_available, is_torch_available +from transformers.testing_utils import cleanup, require_torch, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor if is_datasets_available(): from datasets import load_dataset +if is_torch_available(): + import torch + + from transformers import ( + AutoModel, + AutoProcessor, + GraniteConfig, + GraniteSpeechNarConfig, + ) + from transformers.models.granite_speech_nar.configuration_granite_speech_nar import ( + GraniteSpeechNarEncoderConfig, + GraniteSpeechNarProjectorConfig, + ) + from transformers.models.granite_speech_nar.modeling_granite_speech_nar import ( + GraniteSpeechNarCTCEncoder, + GraniteSpeechNarForCTC, + ) + -def _make_small_config(): - encoder_config = GraniteSpeechNarEncoderConfig( +class GraniteSpeechNarEncoderModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=100, + is_training=True, num_layers=4, hidden_dim=64, num_heads=4, @@ -56,334 +62,441 @@ def _make_small_config(): self_conditioning_layer=2, bpe_output_dim=51, bpe_pooling_window=4, - ) - projector_config = GraniteSpeechNarProjectorConfig( - encoder_dim=64, - llm_dim=128, - downsample_rate=5, - num_encoder_layers=4, - hidden_size=128, - num_heads=4, - num_layers=1, - block_size=15, - ) - text_config = GraniteConfig( + dropout=0.0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.dim_head = dim_head + self.input_dim = input_dim + self.output_dim = output_dim + self.context_size = context_size + self.self_conditioning_layer = self_conditioning_layer + self.bpe_output_dim = bpe_output_dim + self.bpe_pooling_window = bpe_pooling_window + self.dropout = dropout + + def get_config(self): + return GraniteSpeechNarEncoderConfig( + num_layers=self.num_layers, + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + dim_head=self.dim_head, + input_dim=self.input_dim, + output_dim=self.output_dim, + context_size=self.context_size, + self_conditioning_layer=self.self_conditioning_layer, + bpe_output_dim=self.bpe_output_dim, + bpe_pooling_window=self.bpe_pooling_window, + dropout=self.dropout, + ) + + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.seq_length, self.input_dim]) + attention_mask = torch.ones(self.batch_size, self.seq_length, dtype=torch.bool) + attention_mask[1, 80:] = False + config = self.get_config() + return config, input_features, attention_mask + + def create_and_check_model(self, config, input_features, attention_mask): + model = GraniteSpeechNarCTCEncoder(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features, attention_mask=attention_mask, output_hidden_states=True) + + self.parent.assertIsNotNone(result.logits) + self.parent.assertEqual(result.logits.shape[-1], self.bpe_output_dim) + self.parent.assertIsNotNone(result.hidden_states) + self.parent.assertEqual(len(result.hidden_states), self.num_layers + 1) + + def prepare_config_and_inputs_for_common(self): + config, input_features, attention_mask = self.prepare_config_and_inputs() + inputs_dict = { + "input_features": input_features, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +class GraniteSpeechNarForCTCModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=100, + is_training=True, vocab_size=51, - hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=256, - max_position_embeddings=512, - tie_word_embeddings=True, - embedding_multiplier=1.0, - attention_multiplier=1.0, - residual_multiplier=1.0, - logits_scaling=1.0, - ) - return GraniteSpeechNarConfig( - encoder_config=encoder_config, - projector_config=projector_config, - text_config=text_config.to_dict(), - encoder_layer_indices=[1, 2, 3, -1], - scale_projected_embeddings=False, - ) + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + + self.encoder_model_tester = GraniteSpeechNarEncoderModelTester( + parent, + batch_size=batch_size, + seq_length=seq_length, + bpe_output_dim=vocab_size, + ) + def get_config(self): + encoder_config = self.encoder_model_tester.get_config() + projector_config = GraniteSpeechNarProjectorConfig( + encoder_dim=self.encoder_model_tester.hidden_dim, + llm_dim=128, + downsample_rate=5, + num_encoder_layers=4, + hidden_size=128, + num_heads=4, + num_layers=1, + block_size=15, + ) + text_config = GraniteConfig( + vocab_size=self.vocab_size, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=256, + max_position_embeddings=512, + tie_word_embeddings=True, + embedding_multiplier=1.0, + attention_multiplier=1.0, + residual_multiplier=1.0, + logits_scaling=1.0, + ) + return GraniteSpeechNarConfig( + encoder_config=encoder_config, + projector_config=projector_config, + text_config=text_config.to_dict(), + encoder_layer_indices=[1, 2, 3, -1], + scale_projected_embeddings=False, + ) -# === Configuration tests === + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.seq_length, self.encoder_model_tester.input_dim]) + attention_mask = torch.ones(self.batch_size, self.seq_length, dtype=torch.bool) + attention_mask[1, 80:] = False + config = self.get_config() + return config, input_features, attention_mask + + def create_and_check_model(self, config, input_features, attention_mask): + model = GraniteSpeechNarForCTC(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features=input_features, attention_mask=attention_mask) + + self.parent.assertIsNotNone(result.logits) + self.parent.assertIsInstance(result.logits, list) + self.parent.assertEqual(len(result.logits), self.batch_size) + for logits in result.logits: + self.parent.assertEqual(logits.ndim, 2) + self.parent.assertEqual(logits.shape[1], self.vocab_size) + + def create_and_check_generate(self, config, input_features, attention_mask): + model = GraniteSpeechNarForCTC(config=config) + model.to(torch_device) + model.eval() + output = model.generate(input_features=input_features, attention_mask=attention_mask) + + self.parent.assertIsNotNone(output.preds) + self.parent.assertEqual(len(output.preds), self.batch_size) + for pred in output.preds: + self.parent.assertIsInstance(pred, torch.Tensor) + self.parent.assertEqual(pred.ndim, 1) + + def create_and_check_generate_multi_step(self, config, input_features, attention_mask): + model = GraniteSpeechNarForCTC(config=config) + model.to(torch_device) + model.eval() + output = model.generate(input_features=input_features, attention_mask=attention_mask, num_editing_steps=3) + + self.parent.assertIsNotNone(output.preds) + self.parent.assertEqual(len(output.preds), self.batch_size) + for pred in output.preds: + self.parent.assertIsInstance(pred, torch.Tensor) + self.parent.assertEqual(pred.ndim, 1) + + def prepare_config_and_inputs_for_common(self): + config, input_features, attention_mask = self.prepare_config_and_inputs() + inputs_dict = { + "input_features": input_features, + "attention_mask": attention_mask, + } + return config, inputs_dict -class TestConfiguration: - def test_encoder_config_defaults(self): - config = GraniteSpeechNarEncoderConfig() - assert config.model_type == "granite_speech_nar_encoder" - assert config.input_dim == 160 - assert config.num_layers == 16 - assert config.hidden_dim == 1024 - assert config.self_conditioning_layer == 8 - assert config.bpe_output_dim == 49153 +@require_torch +class GraniteSpeechNarEncoderModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (GraniteSpeechNarCTCEncoder,) if is_torch_available() else () - def test_projector_config_defaults(self): - config = GraniteSpeechNarProjectorConfig() - assert config.model_type == "granite_speech_nar_projector" - assert config.encoder_dim == 1024 - assert config.llm_dim == 2048 - assert config.downsample_rate == 5 + test_resize_embeddings = False + test_attention_outputs = False + has_attentions = False - def test_config_defaults(self): - config = GraniteSpeechNarConfig() - assert config.model_type == "granite_speech_nar" - assert config.encoder_layer_indices == [4, 8, 12, -1] - assert config.scale_projected_embeddings is True + @unittest.skip(reason="GraniteSpeechNarCTCEncoder does not use inputs_embeds") + def test_model_get_set_embeddings(self): + pass - def test_config_serialization_roundtrip(self): - config = _make_small_config() - d = config.to_dict() - restored = GraniteSpeechNarConfig(**d) - assert restored.encoder_config.num_layers == 4 - assert restored.encoder_config.bpe_output_dim == 51 - assert restored.projector_config.num_layers == 1 - assert restored.encoder_layer_indices == [1, 2, 3, -1] + @unittest.skip(reason="GraniteSpeechNarCTCEncoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass - def test_auto_config_resolution(self): - config = AutoConfig.for_model("granite_speech_nar") - assert isinstance(config, GraniteSpeechNarConfig) + @unittest.skip(reason="GraniteSpeechNarCTCEncoder does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + @unittest.skip(reason="Conformer encoder does not support standard hidden_states output interface") + def test_hidden_states_output(self): + pass -# === Encoder tests === + @unittest.skip(reason="Conformer encoder does not support standard hidden_states output interface") + def test_retain_grad_hidden_states_attentions(self): + pass + @unittest.skip(reason="Conformer encoder uses input_features, not input_ids") + def test_model_main_input_name(self): + pass -class TestEncoder: - def test_output_shapes(self): - config = GraniteSpeechNarEncoderConfig( - num_layers=4, - hidden_dim=64, - num_heads=4, - dim_head=16, - input_dim=160, - output_dim=348, - context_size=50, - self_conditioning_layer=2, - bpe_output_dim=100, - bpe_pooling_window=4, - ) - encoder = GraniteSpeechNarCTCEncoder(config).eval() + @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") + def test_training(self): + pass - B, T = 2, 100 - features = torch.randn(B, T, 160) - mask = torch.ones(B, T, dtype=torch.bool) - mask[1, 80:] = False + @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") + def test_training_gradient_checkpointing(self): + pass - out = encoder(features, mask, output_hidden_states=True) + @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass - assert out.logits is not None - assert out.logits.shape[1] == 100 - assert out.hidden_states is not None - assert len(out.hidden_states) == 5 # input + 4 layers + @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") + def test_training_gradient_checkpointing_use_reentrant_true(self): + pass + @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") + def test_gradient_checkpointing_backward_compatibility(self): + pass -# === Projector tests === + @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") + def test_gradient_checkpointing_enable_disable(self): + pass + @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") + def test_peft_gradient_checkpointing_enable_disable(self): + pass -class TestProjector: - def test_output_shape(self): - config = GraniteSpeechNarProjectorConfig( - encoder_dim=64, - llm_dim=128, - downsample_rate=5, - num_encoder_layers=2, - hidden_size=128, - num_heads=4, - num_layers=1, - block_size=15, - ) - projector = GraniteSpeechNarProjector(config) + def setUp(self): + self.model_tester = GraniteSpeechNarEncoderModelTester(self) + self.config_tester = ConfigTester(self, config_class=GraniteSpeechNarEncoderConfig, has_text_modality=False) - B, T = 2, 60 - x = torch.randn(B, T, 2 * 64) - out = projector(x) - expected_len = math.ceil(T / config.block_size) * (config.block_size // config.downsample_rate) - assert out.shape == (B, expected_len, 128) + def test_config(self): + self.config_tester.run_common_tests() - def test_handles_non_divisible_length(self): - config = GraniteSpeechNarProjectorConfig( - encoder_dim=64, - llm_dim=128, - downsample_rate=5, - num_encoder_layers=1, - hidden_size=64, - num_heads=4, - num_layers=1, - block_size=15, - ) - projector = GraniteSpeechNarProjector(config) + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) - x = torch.randn(1, 37, 64) - out = projector(x) - assert out.shape == (1, 9, 128) +@require_torch +class GraniteSpeechNarForCTCModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (GraniteSpeechNarForCTC,) if is_torch_available() else () -# === Full model tests === + test_resize_embeddings = False + test_attention_outputs = False + has_attentions = False + _is_composite = True + @unittest.skip(reason="GraniteSpeechNarForCTC does not use inputs_embeds directly") + def test_model_get_set_embeddings(self): + pass -class TestGraniteSpeechNarForCTC: - def test_forward(self): - config = _make_small_config() - model = GraniteSpeechNarForCTC(config).eval() + @unittest.skip(reason="GraniteSpeechNarForCTC does not use inputs_embeds directly") + def test_inputs_embeds(self): + pass - B, T = 2, 100 - features = torch.randn(B, T, 160) - mask = torch.ones(B, T, dtype=torch.bool) - mask[1, 80:] = False + @unittest.skip(reason="GraniteSpeechNarForCTC does not use inputs_embeds directly") + def test_inputs_embeds_matches_input_ids(self): + pass - with torch.no_grad(): - output = model(input_features=features, attention_mask=mask) + @unittest.skip(reason="GraniteSpeechNarForCTC has a custom generate method, not standard GenerationMixin") + def test_generation_tester_mixin_inheritance(self): + pass - assert isinstance(output, GraniteSpeechNarOutput) - assert output.logits is not None - assert isinstance(output.logits, list) - assert len(output.logits) == B - for logits in output.logits: - assert logits.ndim == 2 - assert logits.shape[1] == 51 + @unittest.skip(reason="Non-standard output format (logits is a list of tensors)") + def test_determinism(self): + pass - def test_generate(self): - config = _make_small_config() - model = GraniteSpeechNarForCTC(config).eval() + @unittest.skip(reason="Non-standard output format (logits is a list of tensors)") + def test_model_outputs_equivalence(self): + pass - features = torch.randn(1, 60, 160) - output = model.generate(input_features=features) + @unittest.skip(reason="Non-standard output format (logits is a list of tensors)") + def test_hidden_states_output(self): + pass - assert output.preds is not None - assert len(output.preds) == 1 - assert isinstance(output.preds[0], torch.Tensor) + @unittest.skip(reason="Non-standard output format (logits is a list of tensors)") + def test_retain_grad_hidden_states_attentions(self): + pass - def test_generate_multi_step(self): - config = _make_small_config() - model = GraniteSpeechNarForCTC(config).eval() + @unittest.skip(reason="Non-standard keyword-only forward signature") + def test_model_main_input_name(self): + pass + + @unittest.skip(reason="Non-standard keyword-only forward signature") + def test_model_is_small(self): + pass + + @unittest.skip(reason="Non-standard keyword-only forward signature") + def test_enable_input_require_grads_with_gradient_checkpointing(self): + pass + + @unittest.skip(reason="Projector query/window_positions nn.Parameters need _init_weights handling — TODO") + def test_can_init_all_missing_weights(self): + pass + + def setUp(self): + self.model_tester = GraniteSpeechNarForCTCModelTester(self) + self.config_tester = ConfigTester(self, config_class=GraniteSpeechNarConfig, has_text_modality=False) - features = torch.randn(2, 80, 160) - mask = torch.ones(2, 80, dtype=torch.bool) - mask[1, 60:] = False + @unittest.skip(reason="ConfigTester composite sub_config resolution incompatible with AutoConfig string ref") + def test_config(self): + self.config_tester.run_common_tests() - out1 = model.generate(input_features=features, attention_mask=mask, num_editing_steps=1) - out2 = model.generate(input_features=features, attention_mask=mask, num_editing_steps=3) + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) - assert out1.preds is not None - assert out2.preds is not None - assert len(out1.preds) == 2 - assert len(out2.preds) == 2 - # Multi-step should produce valid predictions (may or may not differ) - for pred in out2.preds: - assert isinstance(pred, torch.Tensor) - assert pred.ndim == 1 + def test_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate(*config_and_inputs) + + def test_generate_multi_step(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_multi_step(*config_and_inputs) def test_loss(self): - config = _make_small_config() - model = GraniteSpeechNarForCTC(config).train() - - B, T = 2, 100 - features = torch.randn(B, T, 160) - mask = torch.ones(B, T, dtype=torch.bool) - mask[1, 80:] = False - labels = torch.randint(0, 51, (B, 5)) + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = GraniteSpeechNarForCTC(config).to(torch_device).train() + + labels = torch.randint(0, self.model_tester.vocab_size, (self.model_tester.batch_size, 5)) label_lengths = torch.tensor([5, 3]) output = model( - input_features=features, - attention_mask=mask, + input_features=input_features, + attention_mask=attention_mask, labels=labels, label_lengths=label_lengths, ) - assert output.loss is not None - assert output.loss.ndim == 0 - assert output.loss.requires_grad + self.assertIsNotNone(output.loss) + self.assertEqual(output.loss.ndim, 0) + self.assertTrue(output.loss.requires_grad) output.loss.backward() def test_loss_with_ce(self): - config = _make_small_config() + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() config.ce_loss_lambda = 0.5 - model = GraniteSpeechNarForCTC(config).train() + model = GraniteSpeechNarForCTC(config).to(torch_device).train() - features = torch.randn(1, 60, 160) - labels = torch.randint(0, 51, (1, 4)) - label_lengths = torch.tensor([4]) + labels = torch.randint(0, self.model_tester.vocab_size, (self.model_tester.batch_size, 4)) + label_lengths = torch.tensor([4, 3]) output = model( - input_features=features, - labels=labels, - label_lengths=label_lengths, + input_features=input_features, attention_mask=attention_mask, labels=labels, label_lengths=label_lengths ) - - assert output.loss is not None - assert output.loss.requires_grad + self.assertIsNotNone(output.loss) + self.assertTrue(output.loss.requires_grad) output.loss.backward() def test_loss_with_encoder_ctc(self): - config = _make_small_config() + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() config.encoder_ctc_loss_lambda = 0.3 - model = GraniteSpeechNarForCTC(config).train() + model = GraniteSpeechNarForCTC(config).to(torch_device).train() - features = torch.randn(1, 60, 160) - labels = torch.randint(0, 51, (1, 4)) - label_lengths = torch.tensor([4]) + labels = torch.randint(0, self.model_tester.vocab_size, (self.model_tester.batch_size, 4)) + label_lengths = torch.tensor([4, 3]) output = model( - input_features=features, - labels=labels, - label_lengths=label_lengths, + input_features=input_features, attention_mask=attention_mask, labels=labels, label_lengths=label_lengths ) - - assert output.loss is not None - assert output.loss.requires_grad + self.assertIsNotNone(output.loss) + self.assertTrue(output.loss.requires_grad) output.loss.backward() def test_no_loss_without_labels(self): - config = _make_small_config() - model = GraniteSpeechNarForCTC(config).eval() + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = GraniteSpeechNarForCTC(config).to(torch_device).eval() - features = torch.randn(1, 60, 160) with torch.no_grad(): - output = model(input_features=features) + output = model(input_features=input_features, attention_mask=attention_mask) - assert output.loss is None + self.assertIsNone(output.loss) - def test_output_encoder_logits_flag(self): - config = _make_small_config() - model = GraniteSpeechNarForCTC(config).eval() - - features = torch.randn(1, 60, 160) - with torch.no_grad(): - out_no = model(input_features=features, output_encoder_logits=False) - out_yes = model(input_features=features, output_encoder_logits=True) - - assert out_no.encoder_logits is None - assert out_yes.encoder_logits is not None - assert out_no.encoder_preds is not None # always returned - - def test_automodel_resolves(self): - config = AutoConfig.for_model("granite_speech_nar") - assert isinstance(config, GraniteSpeechNarConfig) - assert config.model_type == "granite_speech_nar" - - -# === Bidirectional attention test === - - -class TestBidirectionalAttention: - def test_last_token_affects_first(self): - """Changing the last token must affect the first (bidirectional).""" - config = _make_small_config() - model = GraniteSpeechNarForCTC(config).eval() + def test_bidirectional_attention(self): + config = self.model_tester.get_config() + model = GraniteSpeechNarForCTC(config).to(torch_device).eval() granite_model = model.model.language_model - embeds_a = torch.randn(1, 10, 128) + embeds_a = torch.randn(1, 10, 128, device=torch_device) embeds_b = embeds_a.clone() - embeds_b[0, -1, :] = torch.randn(128) + embeds_b[0, -1, :] = torch.randn(128, device=torch_device) with torch.no_grad(): out_a = granite_model(inputs_embeds=embeds_a).last_hidden_state out_b = granite_model(inputs_embeds=embeds_b).last_hidden_state diff_first = (out_a[0, 0] - out_b[0, 0]).abs().max().item() - assert diff_first > 1e-5, f"First token unchanged (diff={diff_first}). Attention appears causal." + self.assertGreater(diff_first, 1e-5, "First token unchanged — attention appears causal.") def test_is_causal_false_on_layers(self): - config = _make_small_config() + config = self.model_tester.get_config() model = GraniteSpeechNarForCTC(config) for i, layer in enumerate(model.model.language_model.layers): - assert layer.self_attn.is_causal is False, f"Layer {i} is_causal is not False" + self.assertFalse(layer.self_attn.is_causal, f"Layer {i} is_causal is not False") + + def test_sdpa_can_dispatch_composite_models(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) -# === Integration tests === + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") @require_torch class GraniteSpeechNarIntegrationTest(unittest.TestCase): - model_path = "ibm-granite/granite-speech-4.1-2b-nar" + checkpoint_name = "ibm-granite/granite-speech-4.1-2b-nar" _dataset = None + def tearDown(self): + cleanup(torch_device, gc_collect=True) + @classmethod def _load_dataset(cls): if cls._dataset is None: @@ -397,9 +510,12 @@ def _load_datasamples(self, num_samples): @slow def test_single_sample_transcription(self): model = AutoModel.from_pretrained( - self.model_path, attn_implementation="flash_attention_2", device_map=torch_device, dtype=torch.bfloat16 + self.checkpoint_name, + attn_implementation="flash_attention_2", + device_map=torch_device, + dtype=torch.bfloat16, ).eval() - processor = AutoProcessor.from_pretrained(self.model_path) + processor = AutoProcessor.from_pretrained(self.checkpoint_name) waveforms = self._load_datasamples(1) inputs = processor(waveforms, device=torch_device) @@ -412,9 +528,12 @@ def test_single_sample_transcription(self): @slow def test_batch_transcription(self): model = AutoModel.from_pretrained( - self.model_path, attn_implementation="flash_attention_2", device_map=torch_device, dtype=torch.bfloat16 + self.checkpoint_name, + attn_implementation="flash_attention_2", + device_map=torch_device, + dtype=torch.bfloat16, ).eval() - processor = AutoProcessor.from_pretrained(self.model_path) + processor = AutoProcessor.from_pretrained(self.checkpoint_name) waveforms = self._load_datasamples(2) inputs = processor(waveforms, device=torch_device) @@ -427,22 +546,3 @@ def test_batch_transcription(self): ] self.assertEqual(len(transcriptions), 2) self.assertEqual(transcriptions, expected) - - @slow - @pytest.mark.skipif(not is_datasets_available(), reason="datasets not installed") - def test_processor_output_shapes(self): - processor = AutoProcessor.from_pretrained(self.model_path) - - waveforms = self._load_datasamples(2) - inputs = processor(waveforms, device="cpu") - - self.assertEqual(inputs["input_features"].ndim, 3) - self.assertEqual(inputs["input_features"].shape[0], 2) - self.assertEqual(inputs["input_features"].shape[2], 160) - - self.assertEqual(inputs["attention_mask"].shape, inputs["input_features"].shape[:2]) - - # Shorter sample should have False values at end - mask_sums = inputs["attention_mask"].sum(dim=1) - self.assertEqual(mask_sums[0].item(), inputs["input_features"].shape[1]) - self.assertLess(mask_sums[1].item(), mask_sums[0].item()) diff --git a/utils/check_repo.py b/utils/check_repo.py index 9a3cf828d7f0..b833622b7b4b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -290,6 +290,8 @@ "Sam3LiteTextTextModel", # Building part of a bigger model, tested implicitly through Sam3LiteTextModel "Exaone4_5_VisionModel", # Building part of a bigger model "Granite4VisionTextModel", # Building part of bigger (tested) model. Tested implicitly through Granite4VisionModel. + "GraniteSpeechNarLanguageModel", # Building part of bigger (tested) model. Tested implicitly through GraniteSpeechNarForCTC. + "GraniteSpeechNarModel", # Building part of bigger (tested) model. Tested implicitly through GraniteSpeechNarForCTC. ] ) @@ -313,7 +315,6 @@ "models/sam3_video/test_modeling_sam3_video.py", "models/edgetam_video/test_modeling_edgetam_video.py", "models/gemma4_assistant/test_modeling_gemma4_assistant.py", - "models/granite_speech_nar/test_modeling_granite_speech_nar.py", ] # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and From 032cc088e138f2a7001ad63f3f09057bdf3a64f9 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 13:26:31 +0000 Subject: [PATCH 35/39] minor fixes after refactoring the tests --- .../configuration_granite_speech_nar.py | 4 ++-- .../granite_speech_nar/modeling_granite_speech_nar.py | 9 +++++++++ .../granite_speech_nar/modular_granite_speech_nar.py | 11 +++++++++-- .../test_modeling_granite_speech_nar.py | 4 +--- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py index f7b64a715d1d..3c94b3faddd9 100644 --- a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py @@ -17,7 +17,7 @@ from ...configuration_utils import PreTrainedConfig from ...utils import auto_docstring -from ..auto import CONFIG_MAPPING +from ..auto import CONFIG_MAPPING, AutoConfig @auto_docstring(checkpoint="ibm-granite/granite-speech-4.1-2b-nar") @@ -182,7 +182,7 @@ class GraniteSpeechNarConfig(PreTrainedConfig): sub_configs = { "encoder_config": GraniteSpeechNarEncoderConfig, "projector_config": GraniteSpeechNarProjectorConfig, - "text_config": "AutoConfig", + "text_config": AutoConfig, } encoder_config: dict | PreTrainedConfig | None = None diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index 0ca6282a89b9..a3055c805f16 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -27,6 +27,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func @@ -397,6 +398,14 @@ class GraniteSpeechNarPreTrainedModel(PreTrainedModel): _no_split_modules = ["GraniteSpeechNarConformerBlock", "GraniteDecoderLayer"] input_modalities = ("audio",) + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GraniteSpeechNarProjector): + std = module.config.hidden_size**-0.5 + init.normal_(module.query, mean=0.0, std=std) + init.normal_(module.window_positions, mean=0.0, std=std) + def rotate_half(x): """Rotates half the hidden dims of the input.""" diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 3533db477b5a..8f6f79465691 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from torch import nn +from ... import initialization as init from ...masking_utils import ( create_bidirectional_mask, find_packed_sequence_indices, @@ -291,6 +292,14 @@ class GraniteSpeechNarPreTrainedModel(PreTrainedModel): _no_split_modules = ["GraniteSpeechNarConformerBlock", "GraniteDecoderLayer"] input_modalities = ("audio",) + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GraniteSpeechNarProjector): + std = module.config.hidden_size**-0.5 + init.normal_(module.query, mean=0.0, std=std) + init.normal_(module.window_positions, mean=0.0, std=std) + class GraniteSpeechNarAttention(GraniteAttention): """GraniteAttention with is_causal=False for bidirectional attention.""" @@ -300,8 +309,6 @@ def __init__(self, config, layer_idx=None): self.is_causal = False - - class GraniteSpeechNarLanguageModel(GraniteModel): """GraniteModel with bidirectional (non-causal) attention. diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 3e476097398a..841bc0ce2402 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -359,15 +359,13 @@ def test_model_is_small(self): def test_enable_input_require_grads_with_gradient_checkpointing(self): pass - @unittest.skip(reason="Projector query/window_positions nn.Parameters need _init_weights handling — TODO") def test_can_init_all_missing_weights(self): - pass + super().test_can_init_all_missing_weights() def setUp(self): self.model_tester = GraniteSpeechNarForCTCModelTester(self) self.config_tester = ConfigTester(self, config_class=GraniteSpeechNarConfig, has_text_modality=False) - @unittest.skip(reason="ConfigTester composite sub_config resolution incompatible with AutoConfig string ref") def test_config(self): self.config_tester.run_common_tests() From dc7901b83a91046fd7c635bff9a961d93e7aff84 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 13:26:58 +0000 Subject: [PATCH 36/39] minor --- docs/source/en/model_doc/granite_speech_nar.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/granite_speech_nar.md b/docs/source/en/model_doc/granite_speech_nar.md index a6cf43a57551..cccd1bf844d6 100644 --- a/docs/source/en/model_doc/granite_speech_nar.md +++ b/docs/source/en/model_doc/granite_speech_nar.md @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. rendered properly in your Markdown viewer. --> -*This model was released on 2026-03-09 and added to Hugging Face Transformers on 2026-05-26.* +*This model was released on 2026-03-09 and added to Hugging Face Transformers on 2026-06-03.* # GraniteSpeechNar From 28fed5e3ec4d12c02e961c607f3c02c93c2fc64b Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 14:04:44 +0000 Subject: [PATCH 37/39] resolve some issues in the tests: support positional arguments main_input_name enable encoder gradient checkpointing --- .../modeling_granite_speech_nar.py | 6 +- .../modular_granite_speech_nar.py | 5 +- .../test_modeling_granite_speech_nar.py | 75 ++++++------------- 3 files changed, 28 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index a3055c805f16..382fcaf70ec9 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -238,7 +238,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class GraniteSpeechNarConformerBlock(nn.Module): +class GraniteSpeechNarConformerBlock( + GradientCheckpointingLayer, +): """Conformer block, consisting largely of linear layers, attention, and convolutional layers.""" def __init__(self, config: GraniteSpeechNarEncoderConfig): @@ -391,6 +393,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GraniteSpeechNarPreTrainedModel(PreTrainedModel): config_class = GraniteSpeechNarConfig base_model_prefix = "model" + main_input_name = "input_features" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_flash_attn_2 = True @@ -1061,7 +1064,6 @@ def __init__(self, config: GraniteSpeechNarConfig): def forward( self, - *, input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index 8f6f79465691..dbf1be674cfe 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -26,6 +26,7 @@ find_packed_sequence_indices, packed_sequence_mask_function, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -98,7 +99,7 @@ class GraniteSpeechNarOutput(ModelOutput): ce_loss: torch.Tensor | None = None -class GraniteSpeechNarConformerBlock(GraniteSpeechConformerBlock): +class GraniteSpeechNarConformerBlock(GradientCheckpointingLayer, GraniteSpeechConformerBlock): pass @@ -285,6 +286,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GraniteSpeechNarPreTrainedModel(PreTrainedModel): config_class = GraniteSpeechNarConfig base_model_prefix = "model" + main_input_name = "input_features" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_flash_attn_2 = True @@ -588,7 +590,6 @@ def __init__(self, config: GraniteSpeechNarConfig): def forward( self, - *, input_features: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 841bc0ce2402..7d2dac0f76c2 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -52,14 +52,14 @@ def __init__( batch_size=2, seq_length=100, is_training=True, - num_layers=4, + num_layers=2, hidden_dim=64, num_heads=4, dim_head=16, input_dim=160, output_dim=10, context_size=50, - self_conditioning_layer=2, + self_conditioning_layer=1, bpe_output_dim=51, bpe_pooling_window=4, dropout=0.0, @@ -70,7 +70,9 @@ def __init__( self.is_training = is_training self.num_layers = num_layers + self.num_hidden_layers = num_layers self.hidden_dim = hidden_dim + self.hidden_size = hidden_dim self.num_heads = num_heads self.dim_head = dim_head self.input_dim = input_dim @@ -94,6 +96,7 @@ def get_config(self): bpe_output_dim=self.bpe_output_dim, bpe_pooling_window=self.bpe_pooling_window, dropout=self.dropout, + blank_token_id=0, ) def prepare_config_and_inputs(self): @@ -152,7 +155,7 @@ def get_config(self): encoder_dim=self.encoder_model_tester.hidden_dim, llm_dim=128, downsample_rate=5, - num_encoder_layers=4, + num_encoder_layers=2, hidden_size=128, num_heads=4, num_layers=1, @@ -164,7 +167,7 @@ def get_config(self): num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, - intermediate_size=256, + intermediate_size=128, max_position_embeddings=512, tie_word_embeddings=True, embedding_multiplier=1.0, @@ -176,7 +179,7 @@ def get_config(self): encoder_config=encoder_config, projector_config=projector_config, text_config=text_config.to_dict(), - encoder_layer_indices=[1, 2, 3, -1], + encoder_layer_indices=[1, -1], scale_projected_embeddings=False, ) @@ -254,45 +257,21 @@ def test_inputs_embeds(self): def test_inputs_embeds_matches_input_ids(self): pass - @unittest.skip(reason="Conformer encoder does not support standard hidden_states output interface") - def test_hidden_states_output(self): - pass - - @unittest.skip(reason="Conformer encoder does not support standard hidden_states output interface") + @unittest.skip(reason="Conformer encoder does not expose attention outputs") def test_retain_grad_hidden_states_attentions(self): pass - @unittest.skip(reason="Conformer encoder uses input_features, not input_ids") - def test_model_main_input_name(self): - pass - - @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") - def test_training(self): - pass - - @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") - def test_training_gradient_checkpointing(self): - pass - - @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") - def test_training_gradient_checkpointing_use_reentrant_false(self): - pass - - @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") - def test_training_gradient_checkpointing_use_reentrant_true(self): - pass - - @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") - def test_gradient_checkpointing_backward_compatibility(self): - pass - - @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") - def test_gradient_checkpointing_enable_disable(self): + @unittest.skip(reason="Self-conditioning injection between layers causes hidden_states mismatch in tuple vs dict") + def test_model_outputs_equivalence(self): pass - @unittest.skip(reason="Conformer encoder does not support gradient checkpointing") - def test_peft_gradient_checkpointing_enable_disable(self): - pass + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if return_labels: + batch_size = self.model_tester.batch_size + inputs_dict["labels"] = torch.randint(0, self.model_tester.bpe_output_dim, (batch_size, 5)) + inputs_dict["label_lengths"] = torch.tensor([5] * batch_size) + return inputs_dict def setUp(self): self.model_tester = GraniteSpeechNarEncoderModelTester(self) @@ -315,15 +294,11 @@ class GraniteSpeechNarForCTCModelTest(ModelTesterMixin, unittest.TestCase): has_attentions = False _is_composite = True - @unittest.skip(reason="GraniteSpeechNarForCTC does not use inputs_embeds directly") - def test_model_get_set_embeddings(self): - pass - - @unittest.skip(reason="GraniteSpeechNarForCTC does not use inputs_embeds directly") + @unittest.skip(reason="GraniteSpeechNarForCTC takes audio input_features, not input_ids/inputs_embeds") def test_inputs_embeds(self): pass - @unittest.skip(reason="GraniteSpeechNarForCTC does not use inputs_embeds directly") + @unittest.skip(reason="GraniteSpeechNarForCTC takes audio input_features, not input_ids/inputs_embeds") def test_inputs_embeds_matches_input_ids(self): pass @@ -347,15 +322,7 @@ def test_hidden_states_output(self): def test_retain_grad_hidden_states_attentions(self): pass - @unittest.skip(reason="Non-standard keyword-only forward signature") - def test_model_main_input_name(self): - pass - - @unittest.skip(reason="Non-standard keyword-only forward signature") - def test_model_is_small(self): - pass - - @unittest.skip(reason="Non-standard keyword-only forward signature") + @unittest.skip(reason="Encoder does not have standard embedding layer for gradient checkpointing") def test_enable_input_require_grads_with_gradient_checkpointing(self): pass From 77473e55ec355f721ef9c645cf4c54c6e471edf3 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 14:23:28 +0000 Subject: [PATCH 38/39] stack logits to pass test_determinism --- .../modeling_granite_speech_nar.py | 16 +++++++++----- .../modular_granite_speech_nar.py | 16 +++++++++----- .../test_modeling_granite_speech_nar.py | 21 ++++++++----------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py index 382fcaf70ec9..3ba49279a2c4 100644 --- a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -85,8 +85,11 @@ class GraniteSpeechNarOutput(ModelOutput): Attributes: loss: Combined CTC + auxiliary losses (only when labels provided). + logits: Flat (concatenated) LLM logits of shape `(sum(text_lengths), vocab_size)`. + Split with `logits.split(text_lengths)` to get per-sample tensors. + text_lengths: Per-sample text sequence lengths for splitting logits. preds: List of predicted token ID tensors per sample (after CTC collapse, inference only). - logits: List of per-sample logit tensors from the LLM head. + audio_embeds: Projected audio embeddings (cached for multi-step editing). encoder_logits: Flat BPE CTC logits from the encoder. encoder_preds: List of CTC-collapsed encoder predictions per sample. encoder_loss: Encoder BPE CTC loss component (for logging). @@ -94,8 +97,9 @@ class GraniteSpeechNarOutput(ModelOutput): """ loss: torch.Tensor | None = None + logits: torch.FloatTensor | None = None + text_lengths: list[int] | None = None preds: list[torch.Tensor] | None = None - logits: list[torch.Tensor] | None = None audio_embeds: torch.FloatTensor | None = None encoder_logits: torch.Tensor | None = None encoder_preds: list[torch.Tensor] | None = None @@ -1114,7 +1118,6 @@ def forward( text_lengths = model_out.text_lengths segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) - logits_per_sample = list(text_logits.split(text_lengths)) loss = None encoder_loss = model_out.encoder_loss @@ -1139,7 +1142,8 @@ def forward( return GraniteSpeechNarOutput( loss=loss, - logits=logits_per_sample, + logits=text_logits, + text_lengths=text_lengths, audio_embeds=model_out.audio_embeds, encoder_logits=model_out.encoder_logits, encoder_preds=model_out.ctc_token_ids, @@ -1177,7 +1181,8 @@ def generate( ) blank_id = self.config.blank_token_id - ctc_token_ids = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in output.logits] + logits_per_sample = output.logits.split(output.text_lengths) + ctc_token_ids = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in logits_per_sample] audio_embeds = output.audio_embeds if step == 0: @@ -1186,6 +1191,7 @@ def generate( return GraniteSpeechNarOutput( preds=ctc_token_ids, logits=output.logits, + text_lengths=output.text_lengths, encoder_logits=output.encoder_logits, encoder_preds=encoder_preds, ) diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py index dbf1be674cfe..e73517d2bbd7 100644 --- a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -81,8 +81,11 @@ class GraniteSpeechNarOutput(ModelOutput): Attributes: loss: Combined CTC + auxiliary losses (only when labels provided). + logits: Flat (concatenated) LLM logits of shape `(sum(text_lengths), vocab_size)`. + Split with `logits.split(text_lengths)` to get per-sample tensors. + text_lengths: Per-sample text sequence lengths for splitting logits. preds: List of predicted token ID tensors per sample (after CTC collapse, inference only). - logits: List of per-sample logit tensors from the LLM head. + audio_embeds: Projected audio embeddings (cached for multi-step editing). encoder_logits: Flat BPE CTC logits from the encoder. encoder_preds: List of CTC-collapsed encoder predictions per sample. encoder_loss: Encoder BPE CTC loss component (for logging). @@ -90,8 +93,9 @@ class GraniteSpeechNarOutput(ModelOutput): """ loss: torch.Tensor | None = None + logits: torch.FloatTensor | None = None + text_lengths: list[int] | None = None preds: list[torch.Tensor] | None = None - logits: list[torch.Tensor] | None = None audio_embeds: torch.FloatTensor | None = None encoder_logits: torch.Tensor | None = None encoder_preds: list[torch.Tensor] | None = None @@ -640,7 +644,6 @@ def forward( text_lengths = model_out.text_lengths segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) - logits_per_sample = list(text_logits.split(text_lengths)) loss = None encoder_loss = model_out.encoder_loss @@ -665,7 +668,8 @@ def forward( return GraniteSpeechNarOutput( loss=loss, - logits=logits_per_sample, + logits=text_logits, + text_lengths=text_lengths, audio_embeds=model_out.audio_embeds, encoder_logits=model_out.encoder_logits, encoder_preds=model_out.ctc_token_ids, @@ -703,7 +707,8 @@ def generate( ) blank_id = self.config.blank_token_id - ctc_token_ids = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in output.logits] + logits_per_sample = output.logits.split(output.text_lengths) + ctc_token_ids = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in logits_per_sample] audio_embeds = output.audio_embeds if step == 0: @@ -712,6 +717,7 @@ def generate( return GraniteSpeechNarOutput( preds=ctc_token_ids, logits=output.logits, + text_lengths=output.text_lengths, encoder_logits=output.encoder_logits, encoder_preds=encoder_preds, ) diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 7d2dac0f76c2..9d0f9b6a7470 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -198,11 +198,12 @@ def create_and_check_model(self, config, input_features, attention_mask): result = model(input_features=input_features, attention_mask=attention_mask) self.parent.assertIsNotNone(result.logits) - self.parent.assertIsInstance(result.logits, list) - self.parent.assertEqual(len(result.logits), self.batch_size) - for logits in result.logits: - self.parent.assertEqual(logits.ndim, 2) - self.parent.assertEqual(logits.shape[1], self.vocab_size) + self.parent.assertIsInstance(result.logits, torch.Tensor) + self.parent.assertEqual(result.logits.ndim, 2) + self.parent.assertEqual(result.logits.shape[1], self.vocab_size) + self.parent.assertIsNotNone(result.text_lengths) + self.parent.assertEqual(len(result.text_lengths), self.batch_size) + self.parent.assertEqual(sum(result.text_lengths), result.logits.shape[0]) def create_and_check_generate(self, config, input_features, attention_mask): model = GraniteSpeechNarForCTC(config=config) @@ -306,19 +307,15 @@ def test_inputs_embeds_matches_input_ids(self): def test_generation_tester_mixin_inheritance(self): pass - @unittest.skip(reason="Non-standard output format (logits is a list of tensors)") - def test_determinism(self): - pass - - @unittest.skip(reason="Non-standard output format (logits is a list of tensors)") + @unittest.skip(reason="text_lengths (list[int]) in output breaks recursive tuple/dict comparison") def test_model_outputs_equivalence(self): pass - @unittest.skip(reason="Non-standard output format (logits is a list of tensors)") + @unittest.skip(reason="Composite model with flat packed sequences; hidden_states not piped to top-level output") def test_hidden_states_output(self): pass - @unittest.skip(reason="Non-standard output format (logits is a list of tensors)") + @unittest.skip(reason="Composite model with flat packed sequences; hidden_states not piped to top-level output") def test_retain_grad_hidden_states_attentions(self): pass From 8d089f5d14f215e5f80c7057d60393e13c3e2d41 Mon Sep 17 00:00:00 2001 From: Avihu Dekel Date: Wed, 3 Jun 2026 14:54:04 +0000 Subject: [PATCH 39/39] make labels test consistent --- .../granite_speech_nar/test_modeling_granite_speech_nar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py index 9d0f9b6a7470..df67b7c735bd 100644 --- a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -270,7 +270,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) if return_labels: batch_size = self.model_tester.batch_size - inputs_dict["labels"] = torch.randint(0, self.model_tester.bpe_output_dim, (batch_size, 5)) + inputs_dict["labels"] = torch.arange(1, 6).unsqueeze(0).expand(batch_size, -1) inputs_dict["label_lengths"] = torch.tensor([5] * batch_size) return inputs_dict