diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d7b89784f4be..f7ef8d279189 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1078,6 +1078,8 @@ title: GLM-ASR - local: model_doc/granite_speech title: GraniteSpeech + - local: model_doc/granite_speech_nar + title: GraniteSpeechNar - local: model_doc/granite_speech_plus title: GraniteSpeechPlus - local: model_doc/higgs_audio_v2 diff --git a/docs/source/en/model_doc/granite_speech_nar.md b/docs/source/en/model_doc/granite_speech_nar.md new file mode 100644 index 000000000000..cccd1bf844d6 --- /dev/null +++ b/docs/source/en/model_doc/granite_speech_nar.md @@ -0,0 +1,72 @@ + +*This model was released on 2026-03-09 and added to Hugging Face Transformers on 2026-06-03.* + +# GraniteSpeechNar + +## Overview + +GraniteSpeechNar is a non-autoregressive (NAR) speech recognition model based on [NLE: Non-autoregressive LLM-based ASR by Transcript Editing](https://huggingface.co/papers/2603.08397). It formulates ASR as conditional transcript editing, achieving fully parallel prediction with significant speedups over autoregressive baselines. + +The model consists of: + +1. **Conformer Encoder**: A conformer encoder trained with CTC on BPE targets, using block-attention and self-conditioned CTC from the middle layer. + +2. **QFormer Projector**: A windowed query-transformer that maps multi-layer encoder features to the LLM embedding space with temporal downsampling. + +3. **Bidirectional Granite LLM**: A Granite language model with bidirectional (non-causal) attention that refines CTC predictions in a single forward pass. + +The model performs inference in a single pass: the encoder produces initial CTC predictions, which are interleaved with blank insertion slots (exploiting the identity mapping bias of Transformers) and fed alongside projected audio embeddings to the bidirectional LLM for refinement via a latent alignment objective. + +This model was contributed by [Avihu Dekel](https://huggingface.co/Avihu). + +## GraniteSpeechNarConfig + +[[autodoc]] GraniteSpeechNarConfig + +## GraniteSpeechNarEncoderConfig + +[[autodoc]] GraniteSpeechNarEncoderConfig + +## GraniteSpeechNarProjectorConfig + +[[autodoc]] GraniteSpeechNarProjectorConfig + +## GraniteSpeechNarProcessor + +[[autodoc]] GraniteSpeechNarProcessor + - __call__ + - batch_decode + +## GraniteSpeechNarFeatureExtractor + +[[autodoc]] GraniteSpeechNarFeatureExtractor + +## GraniteSpeechNarModel + +[[autodoc]] GraniteSpeechNarModel + - forward + +## GraniteSpeechNarLanguageModel + +[[autodoc]] GraniteSpeechNarLanguageModel + - forward + +## GraniteSpeechNarForCTC + +[[autodoc]] GraniteSpeechNarForCTC + - forward + - generate diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 4b4a464caf05..7b513b34215e 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -92,6 +92,7 @@ "audioflamingo3": "qwen2_audio", "glmasr": "qwen2_audio", "musicflamingo": "qwen2_audio", + "granite_speech_nar": "granite_speech", "granite_speech_plus": "granite_speech", "gemma3n_text": "qwen3_5_text", "qwen3_5_moe_text": "qwen3_5_text", @@ -116,6 +117,7 @@ "AudioFlamingo3Model": "Qwen2AudioModel", "GlmAsrModel": "Qwen2AudioModel", "MusicFlamingoModel": "Qwen2AudioModel", + "GraniteSpeechNarModel": "GraniteSpeechModel", "GraniteSpeechPlusModel": "GraniteSpeechModel", "MaskFormerDetrDecoder": "DetrModel", "Qwen2_5_VLForConditionalGeneration": "Qwen2VLForConditionalGeneration", diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index d2308ae8d75a..5381b29fcd11 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -184,6 +184,7 @@ from .granite import * from .granite4_vision import * from .granite_speech import * + from .granite_speech_nar import * from .granite_speech_plus import * from .granitemoe import * from .granitemoehybrid import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 754c8c8a9c43..cb69de52c45d 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -244,6 +244,9 @@ ("granite4_vision_text", "Granite4VisionTextConfig"), ("granite_speech", "GraniteSpeechConfig"), ("granite_speech_encoder", "GraniteSpeechEncoderConfig"), + ("granite_speech_nar", "GraniteSpeechNarConfig"), + ("granite_speech_nar_encoder", "GraniteSpeechNarEncoderConfig"), + ("granite_speech_nar_projector", "GraniteSpeechNarProjectorConfig"), ("granite_speech_plus", "GraniteSpeechPlusConfig"), ("granite_speech_plus_encoder", "GraniteSpeechPlusEncoderConfig"), ("granitemoe", "GraniteMoeConfig"), @@ -733,6 +736,8 @@ ("glmasr_encoder", "glmasr"), ("granite4_vision_text", "granite4_vision"), ("granite_speech_encoder", "granite_speech"), + ("granite_speech_nar_encoder", "granite_speech_nar"), + ("granite_speech_nar_projector", "granite_speech_nar"), ("granite_speech_plus_encoder", "granite_speech_plus"), ("grounding-dino", "grounding_dino"), ("groupvit_text_model", "groupvit"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 953a8e7cf742..3d193f3cb7c5 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -49,6 +49,7 @@ ("gemma4", "Gemma4AudioFeatureExtractor"), ("glmasr", "WhisperFeatureExtractor"), ("granite_speech", "GraniteSpeechFeatureExtractor"), + ("granite_speech_nar", "GraniteSpeechNarFeatureExtractor"), ("granite_speech_plus", "GraniteSpeechFeatureExtractor"), ("higgs_audio_v2_tokenizer", "DacFeatureExtractor"), ("hubert", "Wav2Vec2FeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7c6c87b10e88..ee47aac3103e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -217,6 +217,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("granite", "GraniteModel"), ("granite4_vision", "Granite4VisionModel"), ("granite_speech", "GraniteSpeechModel"), + ("granite_speech_nar", "GraniteSpeechNarForCTC"), ("granite_speech_plus", "GraniteSpeechPlusModel"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), @@ -1672,6 +1673,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): [ # Model for Connectionist temporal classification (CTC) mapping ("data2vec-audio", "Data2VecAudioForCTC"), + ("granite_speech_nar", "GraniteSpeechNarForCTC"), ("hubert", "HubertForCTC"), ("lasr_ctc", "LasrForCTC"), ("parakeet_ctc", "ParakeetForCTC"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 02784deae71c..e7ec95b76c85 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -93,6 +93,7 @@ ("got_ocr2", "GotOcr2Processor"), ("granite4_vision", "Granite4VisionProcessor"), ("granite_speech", "GraniteSpeechProcessor"), + ("granite_speech_nar", "GraniteSpeechNarProcessor"), ("granite_speech_plus", "GraniteSpeechProcessor"), ("grounding-dino", "GroundingDinoProcessor"), ("groupvit", "CLIPProcessor"), diff --git a/src/transformers/models/granite_speech_nar/__init__.py b/src/transformers/models/granite_speech_nar/__init__.py new file mode 100644 index 000000000000..419097100ba6 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_granite_speech_nar import * + from .feature_extraction_granite_speech_nar import * + from .modeling_granite_speech_nar import * + from .processing_granite_speech_nar import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py new file mode 100644 index 000000000000..3c94b3faddd9 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/configuration_granite_speech_nar.py @@ -0,0 +1,224 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Config classes for Granite Speech NAR (Non-Autoregressive ASR).""" + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring +from ..auto import CONFIG_MAPPING, AutoConfig + + +@auto_docstring(checkpoint="ibm-granite/granite-speech-4.1-2b-nar") +@strict +class GraniteSpeechNarEncoderConfig(PreTrainedConfig): + r""" + Configuration for the conformer encoder component of GraniteSpeechNar. + + feedforward_mult (`int`, *optional*, defaults to 4): + Multiplier for the feedforward layers; intermediate dim = `hidden_dim * feedforward_mult`. + output_dim (`int`, *optional*, defaults to 348): + Output dimension of the mid-layer CTC prediction head. + context_size (`int`, *optional*, defaults to 200): + Context size for block-wise conformer attention. + max_pos_emb (`int`, *optional*, defaults to 512): + Maximum relative positional embedding index (Shaw's relative positional encoding). + pred_dropout (`float`, *optional*, defaults to 0.25): + Dropout applied to encoder hidden states before prediction heads. + conv_expansion_factor (`int`, *optional*, defaults to 2): + Expansion factor for conformer convolution module. + self_conditioning_layer (`int`, *optional*): + Layer index at which self-conditioning (mid-layer CTC feedback) is applied. + Defaults to `num_layers // 2`. + bpe_output_dim (`int`, *optional*, defaults to 49153): + Vocabulary size for the BPE CTC head (same as LLM vocab, blank reuses eos_token_id). + bpe_pooling_window (`int`, *optional*, defaults to 4): + Window size for posterior-weighted pooling before the BPE CTC head. + blank_token_id (`int`, *optional*): + Token ID used as the CTC blank symbol. Defaults to the language model's `eos_token_id` if not set. + + Example: + + ```python + >>> from transformers import GraniteSpeechNarEncoderConfig + + >>> configuration = GraniteSpeechNarEncoderConfig() + >>> print(configuration.hidden_dim) + 1024 + ```""" + + model_type = "granite_speech_nar_encoder" + attribute_map = { + "hidden_size": "hidden_dim", + "num_hidden_layers": "num_layers", + "num_attention_heads": "num_heads", + "num_mel_bins": "input_dim", + } + + input_dim: int = 160 + num_layers: int = 16 + hidden_dim: int = 1024 + feedforward_mult: int = 4 + num_heads: int = 8 + dim_head: int | None = None + output_dim: int = 348 + context_size: int = 200 + max_pos_emb: int = 512 + dropout: float = 0.1 + pred_dropout: float = 0.25 + conv_kernel_size: int = 15 + conv_expansion_factor: int = 2 + self_conditioning_layer: int | None = None + bpe_output_dim: int = 49153 + bpe_pooling_window: int = 4 + blank_token_id: int | None = None + initializer_range: float = 0.02 + + def __post_init__(self, **kwargs): + super().__post_init__(**kwargs) + if self.dim_head is None: + self.dim_head = self.hidden_dim // self.num_heads + if self.self_conditioning_layer is None: + self.self_conditioning_layer = self.num_layers // 2 + + +@auto_docstring(checkpoint="ibm-granite/granite-speech-4.1-2b-nar") +@strict +class GraniteSpeechNarProjectorConfig(PreTrainedConfig): + r""" + Configuration for the QFormer-based audio projector in GraniteSpeechNar. + + encoder_dim (`int`, *optional*, defaults to 1024): + Hidden dimension of the encoder (per layer). + llm_dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the language model. + downsample_rate (`int`, *optional*, defaults to 5): + Temporal downsampling rate within each window block. + num_encoder_layers (`int`, *optional*, defaults to 4): + Number of encoder layers concatenated as projector input. + block_size (`int`, *optional*, defaults to 15): + Window size for blocked cross-attention in the projector. + layernorm_eps (`float`, *optional*, defaults to 1e-6): + Epsilon for layer normalization. + attn_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in attention projections. + + Example: + + ```python + >>> from transformers import GraniteSpeechNarProjectorConfig + + >>> configuration = GraniteSpeechNarProjectorConfig() + >>> print(configuration.hidden_size) + 2048 + ```""" + + model_type = "granite_speech_nar_projector" + + encoder_dim: int = 1024 + llm_dim: int = 2048 + downsample_rate: int = 5 + num_encoder_layers: int = 4 + hidden_size: int = 2048 + num_heads: int = 32 + num_layers: int = 2 + dropout_prob: float = 0.1 + block_size: int = 15 + mlp_ratio: int = 2 + layernorm_eps: float = 1e-6 + attn_bias: bool = True + mlp_bias: bool = True + + +@auto_docstring(checkpoint="ibm-granite/granite-speech-4.1-2b-nar") +@strict +class GraniteSpeechNarConfig(PreTrainedConfig): + r""" + Configuration for the GraniteSpeechNar non-autoregressive ASR model. + + This model uses a conformer encoder with BPE CTC head, a QFormer-based projector, + and a bidirectional Granite LLM backbone for single-pass speech recognition. + + projector_config (`GraniteSpeechNarProjectorConfig` or `dict`, *optional*): + Configuration for the QFormer-based audio projector. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the LLM's input and output word embeddings should be tied. + encoder_layer_indices (`list[int]`, *optional*, defaults to `[4, 8, 12, -1]`): + Indices of encoder layers whose hidden states are concatenated as projector input. + scale_projected_embeddings (`bool`, *optional*, defaults to `True`): + Whether to divide projected audio embeddings by the LLM's embedding multiplier. + blank_token_id (`int`, *optional*): + Token ID used as the CTC blank symbol. Defaults to `text_config.eos_token_id`. + min_edit_sequence_length (`int`, *optional*, defaults to 8): + Minimum length of the edit sequence (CTC tokens + insertion slots) fed to the LLM. + ce_loss_lambda (`float`, *optional*, defaults to 0.0): + Weight for auxiliary cross-entropy loss on the LLM output. + encoder_ctc_loss_lambda (`float`, *optional*, defaults to 0.0): + Weight for auxiliary encoder BPE CTC loss. + + Example: + + ```python + >>> from transformers import GraniteSpeechNarConfig, GraniteSpeechNarForASR + + >>> configuration = GraniteSpeechNarConfig() + >>> model = GraniteSpeechNarForASR(configuration) + >>> print(configuration.model_type) + 'granite_speech_nar' + ```""" + + model_type = "granite_speech_nar" + sub_configs = { + "encoder_config": GraniteSpeechNarEncoderConfig, + "projector_config": GraniteSpeechNarProjectorConfig, + "text_config": AutoConfig, + } + + encoder_config: dict | PreTrainedConfig | None = None + projector_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + tie_word_embeddings: bool = True + encoder_layer_indices: list[int] | None = None + scale_projected_embeddings: bool = True + blank_token_id: int | None = None + min_edit_sequence_length: int = 8 + ce_loss_lambda: float = 0.0 + encoder_ctc_loss_lambda: float = 0.0 + + def __post_init__(self, **kwargs): + if isinstance(self.text_config, dict): + self.text_config["model_type"] = self.text_config.get("model_type", "granite") + self.text_config = CONFIG_MAPPING[self.text_config["model_type"]](**self.text_config) + elif self.text_config is None: + self.text_config = CONFIG_MAPPING["granite"]() + + if not isinstance(self.encoder_config, GraniteSpeechNarEncoderConfig): + self.encoder_config = GraniteSpeechNarEncoderConfig(**(self.encoder_config or {})) + + if not isinstance(self.projector_config, GraniteSpeechNarProjectorConfig): + self.projector_config = GraniteSpeechNarProjectorConfig(**(self.projector_config or {})) + + if self.encoder_layer_indices is None: + self.encoder_layer_indices = [4, 8, 12, -1] + + if self.blank_token_id is None: + self.blank_token_id = self.text_config.eos_token_id + + # Propagate blank_token_id to encoder config + self.encoder_config.blank_token_id = self.blank_token_id + + super().__post_init__(**kwargs) + + +__all__ = ["GraniteSpeechNarEncoderConfig", "GraniteSpeechNarProjectorConfig", "GraniteSpeechNarConfig"] diff --git a/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py new file mode 100644 index 000000000000..c11791599aca --- /dev/null +++ b/src/transformers/models/granite_speech_nar/feature_extraction_granite_speech_nar.py @@ -0,0 +1,121 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extraction for Granite Speech NAR.""" + +from ...feature_extraction_utils import FeatureExtractionMixin +from ...tokenization_utils_base import AudioInput +from ...utils import is_torch_available, is_torchaudio_available +from ...utils.import_utils import requires_backends + + +if is_torch_available(): + import torch + +if is_torchaudio_available(): + import torchaudio + + +class GraniteSpeechNarFeatureExtractor(FeatureExtractionMixin): + """Extracts log-mel spectrogram features for GraniteSpeechNar. + + Produces stacked pairs of 80-band mel frames, yielding 160-dim features + at half the original frame rate. + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + sampling_rate: int = 16000, + n_fft: int = 512, + win_length: int = 400, + hop_length: int = 160, + n_mels: int = 80, + frame_stacking: int = 2, + **kwargs, + ): + requires_backends(self, ["torch", "torchaudio"]) + super().__init__(**kwargs) + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.n_mels = n_mels + self.frame_stacking = frame_stacking + self.mel_filters = torchaudio.transforms.MelSpectrogram( + sample_rate=sampling_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mels, + ) + + def get_num_encoder_frames(self, num_raw_samples): + mel_frames = num_raw_samples // self.hop_length + 1 + return -(-mel_frames // self.frame_stacking) + + def _extract_features(self, raw_audio: "torch.Tensor") -> "torch.Tensor": + with torch.no_grad(): + mel_filters = self.mel_filters.to(raw_audio.device) + mel = mel_filters(raw_audio.float()) + num_frames = mel.shape[-1] + remainder = num_frames % self.frame_stacking + if remainder != 0: + mel = torch.nn.functional.pad(mel, (0, self.frame_stacking - remainder)) + logmel = mel.transpose(-1, -2).clamp_min_(1e-10).log10_() + mx = logmel.amax(dim=(-2, -1), keepdim=True) + logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) + return logmel.reshape(logmel.shape[0], -1, self.frame_stacking * self.n_mels) + + def __call__( + self, + audios: AudioInput, + device: str | None = None, + ) -> dict: + if isinstance(audios, torch.Tensor): + if audios.ndim == 1: + audios = [audios] + elif audios.ndim == 2: + audios = [audios[i] for i in range(audios.shape[0])] + else: + raise ValueError(f"Expected 1-D or 2-D tensor, got {audios.ndim}-D") + + raw_lengths = [a.shape[-1] for a in audios] + encoder_frame_counts = [self.get_num_encoder_frames(l) for l in raw_lengths] + + raw_audio = torch.nn.utils.rnn.pad_sequence( + [a.squeeze(0) if a.ndim > 1 else a for a in audios], + batch_first=True, + padding_value=0.0, + ) + if device is not None: + raw_audio = raw_audio.to(device) + + input_features = self._extract_features(raw_audio) + + max_enc_frames = input_features.shape[1] + x_sizes = torch.tensor(encoder_frame_counts, dtype=torch.long) + attention_mask = torch.arange(max_enc_frames).unsqueeze(0) < x_sizes.unsqueeze(1) + + if device is not None: + input_features = input_features.to(device) + attention_mask = attention_mask.to(device) + + return { + "input_features": input_features, + "attention_mask": attention_mask, + } + + +__all__ = ["GraniteSpeechNarFeatureExtractor"] diff --git a/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py new file mode 100644 index 000000000000..3ba49279a2c4 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/modeling_granite_speech_nar.py @@ -0,0 +1,1206 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granite_speech_nar.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...masking_utils import create_bidirectional_mask, find_packed_sequence_indices, packed_sequence_mask_function +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, auto_docstring +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_granite_speech_nar import ( + GraniteSpeechNarConfig, + GraniteSpeechNarEncoderConfig, + GraniteSpeechNarProjectorConfig, +) + + +@dataclass +class GraniteSpeechNarEncoderOutput(ModelOutput): + """Output of the GraniteSpeechNar encoder.""" + + loss: torch.Tensor | None = None + logits: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +class GraniteSpeechNarModelOutput(ModelOutput): + """Output of GraniteSpeechNarModel (backbone without lm_head). + + Attributes: + last_hidden_state: Hidden states from the LLM backbone (flat, batch dim squeezed). + ctc_token_ids: List of CTC-collapsed encoder predictions per sample. + text_lengths: Per-sample text sequence lengths (after insertion slots). + audio_lengths: Per-sample projected audio lengths. + encoder_loss: Encoder BPE CTC loss (when encoder_ctc_loss_lambda > 0 and labels provided). + encoder_logits: Flat BPE CTC logits from the encoder (when output_encoder_logits=True). + """ + + last_hidden_state: torch.FloatTensor | None = None + audio_embeds: torch.FloatTensor | None = None + ctc_token_ids: list[torch.Tensor] | None = None + text_lengths: list[int] | None = None + audio_lengths: list[int] | None = None + encoder_loss: torch.Tensor | None = None + encoder_logits: torch.Tensor | None = None + + +@dataclass +class GraniteSpeechNarOutput(ModelOutput): + """Output of the GraniteSpeechNarForCTC model. + + Attributes: + loss: Combined CTC + auxiliary losses (only when labels provided). + logits: Flat (concatenated) LLM logits of shape `(sum(text_lengths), vocab_size)`. + Split with `logits.split(text_lengths)` to get per-sample tensors. + text_lengths: Per-sample text sequence lengths for splitting logits. + preds: List of predicted token ID tensors per sample (after CTC collapse, inference only). + audio_embeds: Projected audio embeddings (cached for multi-step editing). + encoder_logits: Flat BPE CTC logits from the encoder. + encoder_preds: List of CTC-collapsed encoder predictions per sample. + encoder_loss: Encoder BPE CTC loss component (for logging). + ce_loss: Cross-entropy auxiliary loss component (for logging). + """ + + loss: torch.Tensor | None = None + logits: torch.FloatTensor | None = None + text_lengths: list[int] | None = None + preds: list[torch.Tensor] | None = None + audio_embeds: torch.FloatTensor | None = None + encoder_logits: torch.Tensor | None = None + encoder_preds: list[torch.Tensor] | None = None + encoder_loss: torch.Tensor | None = None + ce_loss: torch.Tensor | None = None + + +### Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git +class GraniteSpeechNarConformerFeedForward(nn.Module): + """Feedforward module for conformer encoder blocks.""" + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__() + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.up_proj = nn.Linear(config.hidden_dim, config.hidden_dim * config.feedforward_mult) + self.silu = nn.SiLU() + self.dropout = nn.Dropout(config.dropout) + self.down_proj = nn.Linear(config.hidden_dim * config.feedforward_mult, config.hidden_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + hidden_states = self.up_proj(hidden_states) + hidden_states = self.dropout(self.silu(hidden_states)) + hidden_states = self.down_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GraniteSpeechNarConformerAttention(nn.Module): + """Attention for conformer blocks using Shaw's relative positional embeddings. + See the following [paper](https://huggingface.co/papers/1803.02155) for more details. + """ + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__() + + inner_dim = config.dim_head * config.num_heads + self.max_pos_emb = config.max_pos_emb + self.context_size = config.context_size + self.num_heads = config.num_heads + self.dim_head = config.dim_head + self.scale = self.dim_head**-0.5 + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, config.hidden_dim) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head) + self.dropout = nn.Dropout(config.dropout) + + if self.context_size <= 0 or self.context_size > self.max_pos_emb: + raise ValueError("Context size is either less than 0 or exceeds the max_pos_emb") + + def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + bsz, num_features, _ = hidden_states.shape + + num_blocks = math.ceil(num_features / self.context_size) + remainder = num_features % self.context_size + if remainder > 0: + # right padding to reach block size + hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, self.context_size - remainder)) + + query_states = self.to_q(hidden_states) + key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) + + query_states = query_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + key_states = key_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + value_states = value_states.reshape(bsz, num_blocks, self.context_size, self.num_heads, -1).transpose(2, 3) + + # shaw's relative positional embedding + rel_pos_emb = self.rel_pos_emb(attention_dists) + # alternative computation of `pos_attn` - for readability + # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) + # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale + # einsum implementation of pos_attn - gives x30 speedup over the alternative + # TODO (@avihu111) find a fast alternative to einsum + pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale + + if remainder > 0: + # masked attention in the extended block + mask = torch.ones(self.context_size, self.context_size, dtype=bool, device=hidden_states.device) + mask[:remainder, :remainder] = 0 + mask_value = -torch.finfo(pos_attn.dtype).max + pos_attn[:, -1, :].masked_fill_(mask, mask_value) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + out = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=pos_attn, scale=self.scale + ) + out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) + out = self.to_out(out[:, :num_features, :]) + return self.dropout(out) + + +class GraniteSpeechNarConformerDepthWiseConv1d(nn.Module): + """Wrapper for padded 1D pointwise convolution.""" + + def __init__(self, chan_in: int, chan_out: int, kernel_size: int): + super().__init__() + # Padding for the 1D conv is symmetric or close (i.e., offset by one). + pad = kernel_size // 2 + pad_offset = (kernel_size + 1) % 2 + self.padding = (pad, pad - pad_offset) + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.padding) + return self.conv(hidden_states) + + +class GraniteSpeechNarConformerConvModule(nn.Module): + """Conformer conv module consisting of several 1D/depthwise 1D convolutional layers.""" + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__() + inner_dim = config.hidden_dim * config.conv_expansion_factor + + self.norm = nn.LayerNorm(config.hidden_dim) + self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) + self.glu = nn.GLU(dim=1) + self.depth_conv = GraniteSpeechNarConformerDepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=config.conv_kernel_size, + ) + self.silu = nn.SiLU() + self.batch_norm = nn.BatchNorm1d(inner_dim) + self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) + hidden_states = self.glu(hidden_states) + hidden_states = self.depth_conv(hidden_states) + hidden_states = self.silu(self.batch_norm(hidden_states)) + hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GraniteSpeechNarConformerBlock( + GradientCheckpointingLayer, +): + """Conformer block, consisting largely of linear layers, attention, and convolutional layers.""" + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__() + self.ff1 = GraniteSpeechNarConformerFeedForward(config) + self.attn = GraniteSpeechNarConformerAttention(config) + self.conv = GraniteSpeechNarConformerConvModule(config) + self.ff2 = GraniteSpeechNarConformerFeedForward(config) + self.post_norm = nn.LayerNorm(config.hidden_dim) + + def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states + hidden_states = self.attn(hidden_states, attention_dists=attention_dists) + hidden_states + hidden_states = self.conv(hidden_states) + hidden_states + hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states + hidden_states = self.post_norm(hidden_states) + return hidden_states + + +class GraniteSpeechNarQFormerCrossAttention(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.num_heads = config.num_heads + self.head_dim = config.hidden_size // config.num_heads + self.hidden_size = config.hidden_size + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, query_len, _ = hidden_states.shape + encoder_len = encoder_hidden_states.shape[1] + + query_states = ( + self.q_proj(hidden_states).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + key_states = ( + self.k_proj(encoder_hidden_states) + .view(batch_size, encoder_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(encoder_hidden_states) + .view(batch_size, encoder_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=False) + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.hidden_size) + return self.o_proj(attn_output) + + +class GraniteSpeechNarQFormerMLP(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + mlp_hidden_size = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(config.hidden_size, mlp_hidden_size, bias=config.mlp_bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(mlp_hidden_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.fc2(self.act(self.fc1(hidden_states))) + + +class GraniteSpeechNarQFormerLayer(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.cross_attention = GraniteSpeechNarQFormerCrossAttention(config) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.mlp = GraniteSpeechNarQFormerMLP(config) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states + self.cross_attention(self.attn_norm(hidden_states), encoder_hidden_states) + hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states)) + return hidden_states + + +class GraniteSpeechNarQFormer(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.layers = nn.ModuleList([GraniteSpeechNarQFormerLayer(config) for _ in range(config.num_layers)]) + + def forward(self, query_embeds: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = query_embeds + for layer in self.layers: + hidden_states = layer(hidden_states, encoder_hidden_states) + return hidden_states + + +class GraniteSpeechNarProjector(nn.Module): + """Windowed QFormer projector that maps multi-layer encoder features to LLM embedding space.""" + + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.config = config + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(config.encoder_dim, eps=config.layernorm_eps) for _ in range(config.num_encoder_layers)] + ) + self.layer_projector = nn.Linear(config.encoder_dim * config.num_encoder_layers, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_prob) + self.projector_act = nn.GELU() + self.qformer = GraniteSpeechNarQFormer(config) + + query_length = config.block_size // config.downsample_rate + embed_std = config.hidden_size**-0.5 + self.query = nn.Parameter(torch.randn(1, query_length, config.hidden_size) * embed_std) + self.window_positions = nn.Parameter(torch.randn(1, config.block_size, config.hidden_size) * embed_std) + self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.out_linear = nn.Linear(config.hidden_size, config.llm_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = hidden_states.size() + + hidden_states = hidden_states.view( + batch_size, seq_len, self.config.num_encoder_layers, self.config.encoder_dim + ) + normalized_layers = [] + for i, layer_norm in enumerate(self.layer_norms): + normalized_layers.append(layer_norm(hidden_states[:, :, i])) + hidden_states = torch.cat(normalized_layers, dim=-1) + + hidden_states = self.projector_act(self.layer_projector(hidden_states)) + + block_size = self.config.block_size + nblocks = seq_len // block_size + rest = seq_len % block_size + if rest > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, block_size - rest), "constant", 0) + nblocks += 1 + + hidden_states = hidden_states.view(batch_size * nblocks, block_size, self.config.hidden_size) + query_length = self.query.shape[1] + mean_pool = hidden_states.view( + batch_size * nblocks, query_length, self.config.downsample_rate, self.config.hidden_size + ).mean(dim=-2) + + hidden_states = self.qformer( + query_embeds=self.dropout(self.query + mean_pool), + encoder_hidden_states=self.dropout(hidden_states + self.window_positions), + ) + + hidden_states = hidden_states.view(batch_size, nblocks * query_length, -1) + hidden_states = self.dropout(self.out_norm(hidden_states)) + return self.out_linear(hidden_states) + + +@auto_docstring +class GraniteSpeechNarPreTrainedModel(PreTrainedModel): + config_class = GraniteSpeechNarConfig + base_model_prefix = "model" + main_input_name = "input_features" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["GraniteSpeechNarConformerBlock", "GraniteDecoderLayer"] + input_modalities = ("audio",) + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GraniteSpeechNarProjector): + std = module.config.hidden_size**-0.5 + init.normal_(module.query, mean=0.0, std=std) + init.normal_(module.window_positions, mean=0.0, std=std) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@use_kernelized_func(apply_rotary_pos_emb) +class GraniteSpeechNarAttention(nn.Module): + """GraniteAttention with is_causal=False for bidirectional attention.""" + + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.attention_multiplier + self.attention_dropout = config.attention_dropout + self.is_causal = False + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class GraniteSpeechNarRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + GraniteSpeechNarRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GraniteSpeechNarMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class GraniteSpeechNarDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GraniteSpeechNarConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GraniteSpeechNarAttention(config=config, layer_idx=layer_idx) + + self.mlp = GraniteSpeechNarMLP(config) + self.input_layernorm = GraniteSpeechNarRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GraniteSpeechNarRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +class GraniteSpeechNarRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GraniteSpeechNarConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: GraniteSpeechNarConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class GraniteSpeechNarLanguageModel(GraniteSpeechNarPreTrainedModel): + """GraniteModel with bidirectional (non-causal) attention. + + Uses GraniteSpeechNarDecoderLayer which sets is_causal=False, + and replaces create_causal_mask() with create_bidirectional_mask() so all + attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. + """ + + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GraniteSpeechNarDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GraniteSpeechNarRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GraniteSpeechNarRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.embedding_multiplier = config.embedding_multiplier + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + **kwargs, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + if position_ids is None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + packed_seq_mask = find_packed_sequence_indices(position_ids) + and_mask_fn = packed_sequence_mask_function(packed_seq_mask) if packed_seq_mask is not None else None + bidirectional_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + and_mask_function=and_mask_fn, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # KV cache is not needed in a non-autoregressive model + kwargs["use_cache"] = False + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=bidirectional_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, window_size: int = 4) -> torch.Tensor: + batch_size, seq_len, hidden_dim = hidden.shape + pad_len = (window_size - seq_len % window_size) % window_size + if pad_len > 0: + hidden = F.pad(hidden, (0, 0, 0, pad_len)) + importance = F.pad(importance, (0, pad_len)) + num_windows = hidden.shape[1] // window_size + hidden = hidden.view(batch_size, num_windows, window_size, hidden_dim) + importance = importance.view(batch_size, num_windows, window_size) + weights = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8) + pooled = (hidden * weights.unsqueeze(-1)).sum(dim=2) + return pooled + + +def _ctc_loss_from_flat_logits( + flat_logits: torch.Tensor, + lengths: list[int], + labels: torch.Tensor, + label_lengths: torch.Tensor, + blank_id: int, +) -> torch.Tensor: + """Compute mean CTC loss from flat (concatenated) logits with per-sample lengths.""" + log_probs = torch.log_softmax(flat_logits.float(), dim=-1) + log_probs_padded = log_probs.new_zeros(len(lengths), max(lengths), log_probs.shape[-1]) + offset = 0 + for i, length in enumerate(lengths): + log_probs_padded[i, :length] = log_probs[offset : offset + length] + offset += length + + lengths_t = torch.tensor(lengths, device=flat_logits.device) + return ( + F.ctc_loss( + log_probs_padded.transpose(0, 1), + labels, + lengths_t, + label_lengths, + blank=blank_id, + reduction="sum", + zero_infinity=True, + ) + / lengths_t.sum() + ) + + +class GraniteSpeechNarCTCEncoder(GraniteSpeechNarPreTrainedModel): + """Conformer encoder with BPE CTC head and multi-layer output.""" + + config_class = GraniteSpeechNarEncoderConfig + _can_record_outputs = { + "hidden_states": GraniteSpeechNarConformerBlock, + } + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__(config) + self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) + self.layers = nn.ModuleList([GraniteSpeechNarConformerBlock(config) for _ in range(config.num_layers)]) + self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True) + self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True) + self.out_bpe = nn.Linear(config.hidden_dim, config.bpe_output_dim, bias=True) + self.dropout = nn.Dropout(config.pred_dropout) + self.post_init() + + @capture_outputs + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + **kwargs, + ) -> GraniteSpeechNarEncoderOutput: + if attention_mask is None: + attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) + + hidden_states = self.input_linear(input_features.to(self.dtype)) + + context_size = self.config.context_size + seq = torch.arange(context_size, device=hidden_states.device) + relpos_dist = seq.view(-1, 1) - seq.view(1, -1) + attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + self.config.max_pos_emb + + for layer_idx, layer in enumerate(self.layers, start=1): + hidden_states = layer(hidden_states, attention_dists=attention_dists) + + if layer_idx == self.config.self_conditioning_layer: + mid_logits = self.out(self.dropout(hidden_states)) + mid_probs = torch.softmax(mid_logits.float(), dim=-1) + blank_probs = mid_probs[:, :, 0] + hidden_states += self.out_mid(mid_probs.to(hidden_states.dtype)) + + hidden_states = self.dropout(hidden_states) + + pool_window = self.config.bpe_pooling_window + importance = 1.0 - blank_probs + pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( + hidden_states.dtype + ) + encoder_lengths = attention_mask.sum(dim=1) + lengths = -(encoder_lengths // -pool_window) + lengths_list = lengths.tolist() + logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) + + loss = None + if labels is not None: + loss = _ctc_loss_from_flat_logits(logits, lengths_list, labels, label_lengths, self.config.blank_token_id) + + return GraniteSpeechNarEncoderOutput( + loss=loss, + logits=logits, + last_hidden_state=hidden_states, + ) + + +def _ctc_greedy_decode(logits: torch.Tensor, blank_id: int) -> torch.Tensor: + """CTC greedy decode a single sequence: argmax -> unique_consecutive -> remove blank.""" + pred = torch.unique_consecutive(logits.argmax(-1)) + return pred[pred != blank_id] + + +@auto_docstring( + custom_intro=""" + The GraniteSpeechNar base model consisting of a conformer encoder, QFormer projector, + and a bidirectional Granite language model backbone. + """ +) +class GraniteSpeechNarModel(GraniteSpeechNarPreTrainedModel): + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + + self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechNarProjector(config.projector_config) + + self.language_model = GraniteSpeechNarLanguageModel(config.text_config) + + self.post_init() + + def _ctc_collapse_decode( + self, + bpe_logits_flat: torch.Tensor, + bpe_lengths: list[int], + ) -> list[torch.Tensor]: + """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank.""" + blank_id = self.config.blank_token_id + per_sample = bpe_logits_flat.split(bpe_lengths) + return [_ctc_greedy_decode(seq_logits, blank_id) for seq_logits in per_sample] + + def _add_insertion_slots(self, token_ids: torch.Tensor) -> torch.Tensor: + """Insert blank tokens between each CTC token as editing slots for the LLM.""" + blank_id = self.config.blank_token_id + n = token_ids.numel() + total_len = max(2 * n + 1, self.config.min_edit_sequence_length) + idx = torch.arange(n, device=token_ids.device) + out_idx = 2 * idx + 1 + out = torch.full((total_len,), fill_value=blank_id, dtype=token_ids.dtype, device=token_ids.device) + out[out_idx] = token_ids + return out + + def _build_flat_inputs( + self, + ctc_token_ids: list[torch.Tensor], + audio_embeds: torch.Tensor, + audio_lengths: list[int], + ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """Build flat (pad-free) LLM input: [audio_0, text_0, audio_1, text_1, ...]""" + embed_tokens = self.language_model.embed_tokens + + embeds_list = [] + position_ids_list = [] + text_lengths = [] + + for i, audio_len in enumerate(audio_lengths): + audio_emb = audio_embeds[i, :audio_len] + text_ids_with_slots = self._add_insertion_slots(ctc_token_ids[i]) + text_emb = embed_tokens(text_ids_with_slots) + sample_embeds = torch.cat([audio_emb, text_emb], dim=0) + embeds_list.append(sample_embeds) + position_ids_list.append(torch.arange(sample_embeds.shape[0], device=audio_embeds.device)) + text_lengths.append(text_ids_with_slots.shape[0]) + + flat_embeds = torch.cat(embeds_list, dim=0).unsqueeze(0) + flat_position_ids = torch.cat(position_ids_list, dim=0).unsqueeze(0) + return flat_embeds, flat_position_ids, text_lengths + + def forward( + self, + input_features: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + output_encoder_logits: bool = False, + audio_embeds: torch.Tensor | None = None, + ctc_token_ids: list[torch.Tensor] | None = None, + **kwargs, + ) -> GraniteSpeechNarModelOutput: + if attention_mask is None: + attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) + + encoder_loss = None + encoder_logits = None + + if audio_embeds is None: + encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None + enc_out = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + labels=encoder_labels, + label_lengths=label_lengths if encoder_labels is not None else None, + ) + + encoder_lengths = attention_mask.sum(dim=1) + + pool_window = self.encoder.config.bpe_pooling_window + bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() + ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) + + multilayer_features = torch.cat( + [enc_out.hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 + ) + + encoder_loss = enc_out.loss + encoder_logits = enc_out.logits if output_encoder_logits else None + del enc_out + + audio_embeds = self.projector(multilayer_features) + del multilayer_features + if self.config.scale_projected_embeddings: + embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) + audio_embeds = audio_embeds / embedding_multiplier + audio_embeds = audio_embeds.to(self.language_model.embed_tokens.weight.dtype) + + encoder_lengths = attention_mask.sum(dim=1) + audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() + + flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( + ctc_token_ids, audio_embeds, audio_lengths + ) + + llm_out = self.language_model( + inputs_embeds=flat_embeds, + position_ids=flat_position_ids, + ) + + return GraniteSpeechNarModelOutput( + last_hidden_state=llm_out.last_hidden_state.squeeze(0), + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, + text_lengths=text_lengths, + audio_lengths=audio_lengths, + encoder_loss=encoder_loss, + encoder_logits=encoder_logits, + ) + + +@auto_docstring( + custom_intro=""" + The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. + Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, + and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. + """ +) +class GraniteSpeechNarForCTC(GraniteSpeechNarPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + + self.model = GraniteSpeechNarModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def forward( + self, + input_features: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + output_encoder_logits: bool = False, + audio_embeds: torch.Tensor | None = None, + ctc_token_ids: list[torch.Tensor] | None = None, + **kwargs, + ) -> GraniteSpeechNarOutput: + r""" + Args: + input_features (`torch.Tensor` of shape `(batch_size, seq_len, input_dim)`): + Mel spectrogram features. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Encoder attention mask (1 for valid frames, 0 for padding). + labels (`torch.Tensor` of shape `(batch_size, max_label_len)`, *optional*): + Ground truth LLM token IDs for training. + label_lengths (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Number of valid tokens per sample in `labels`. + output_encoder_logits (`bool`, *optional*, defaults to `False`): + Whether to return encoder BPE logits. When False, the large logits + tensor is freed early to reduce peak memory. + audio_embeds (`torch.Tensor`, *optional*): + Pre-computed projected audio embeddings. When provided, encoder and + projector are skipped (used for multi-step editing). + ctc_token_ids (`list[torch.Tensor]`, *optional*): + Pre-computed CTC token predictions to use as text input. When provided + with `audio_embeds`, replaces encoder CTC predictions. + + Returns: + [`GraniteSpeechNarOutput`] + """ + model_out = self.model( + input_features=input_features, + attention_mask=attention_mask, + labels=labels, + label_lengths=label_lengths, + output_encoder_logits=output_encoder_logits, + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, + ) + + all_logits = self.lm_head(model_out.last_hidden_state) + if hasattr(self.config.text_config, "logits_scaling"): + all_logits = all_logits / self.config.text_config.logits_scaling + + audio_lengths = model_out.audio_lengths + text_lengths = model_out.text_lengths + segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] + text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) + + loss = None + encoder_loss = model_out.encoder_loss + ce_loss = None + if labels is not None: + loss = _ctc_loss_from_flat_logits( + text_logits, text_lengths, labels, label_lengths, self.config.blank_token_id + ) + + if self.config.ce_loss_lambda > 0.0: + ce_targets = torch.cat([self.model._add_insertion_slots(ids) for ids in model_out.ctc_token_ids]) + ce_loss = F.cross_entropy( + text_logits, + ce_targets.long(), + reduction="mean", + ignore_index=-100, + ) + loss = loss + self.config.ce_loss_lambda * ce_loss + + if encoder_loss is not None: + loss = loss + self.config.encoder_ctc_loss_lambda * encoder_loss + + return GraniteSpeechNarOutput( + loss=loss, + logits=text_logits, + text_lengths=text_lengths, + audio_embeds=model_out.audio_embeds, + encoder_logits=model_out.encoder_logits, + encoder_preds=model_out.ctc_token_ids, + encoder_loss=encoder_loss, + ce_loss=ce_loss, + ) + + @torch.inference_mode() + def generate( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_encoder_logits: bool = False, + num_editing_steps: int = 1, + ) -> GraniteSpeechNarOutput: + """Non-autoregressive inference with iterative editing. + + Each editing step collapses the LLM output via CTC and feeds it back + as input for refinement, reusing cached audio embeddings. + + Returns token ID tensors in `preds`. Use `GraniteSpeechNarProcessor.batch_decode()` + to convert to strings. + """ + audio_embeds = None + ctc_token_ids = None + encoder_preds = None + + for step in range(num_editing_steps): + output = self.forward( + input_features=input_features, + attention_mask=attention_mask, + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, + output_encoder_logits=(output_encoder_logits and step == 0), + ) + + blank_id = self.config.blank_token_id + logits_per_sample = output.logits.split(output.text_lengths) + ctc_token_ids = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in logits_per_sample] + audio_embeds = output.audio_embeds + + if step == 0: + encoder_preds = output.encoder_preds + + return GraniteSpeechNarOutput( + preds=ctc_token_ids, + logits=output.logits, + text_lengths=output.text_lengths, + encoder_logits=output.encoder_logits, + encoder_preds=encoder_preds, + ) + + +__all__ = [ + "GraniteSpeechNarCTCEncoder", + "GraniteSpeechNarForCTC", + "GraniteSpeechNarLanguageModel", + "GraniteSpeechNarModel", + "GraniteSpeechNarPreTrainedModel", +] diff --git a/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py new file mode 100644 index 000000000000..e73517d2bbd7 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/modular_granite_speech_nar.py @@ -0,0 +1,732 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GraniteSpeechNar: Non-autoregressive ASR with conformer encoder, QFormer projector, +and bidirectional Granite LLM backbone.""" + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...masking_utils import ( + create_bidirectional_mask, + find_packed_sequence_indices, + packed_sequence_mask_function, +) +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.output_capturing import capture_outputs +from ..granite.modeling_granite import GraniteAttention, GraniteModel +from ..granite_speech.modeling_granite_speech import GraniteSpeechConformerBlock +from .configuration_granite_speech_nar import ( + GraniteSpeechNarConfig, + GraniteSpeechNarEncoderConfig, + GraniteSpeechNarProjectorConfig, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class GraniteSpeechNarEncoderOutput(ModelOutput): + """Output of the GraniteSpeechNar encoder.""" + + loss: torch.Tensor | None = None + logits: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +class GraniteSpeechNarModelOutput(ModelOutput): + """Output of GraniteSpeechNarModel (backbone without lm_head). + + Attributes: + last_hidden_state: Hidden states from the LLM backbone (flat, batch dim squeezed). + ctc_token_ids: List of CTC-collapsed encoder predictions per sample. + text_lengths: Per-sample text sequence lengths (after insertion slots). + audio_lengths: Per-sample projected audio lengths. + encoder_loss: Encoder BPE CTC loss (when encoder_ctc_loss_lambda > 0 and labels provided). + encoder_logits: Flat BPE CTC logits from the encoder (when output_encoder_logits=True). + """ + + last_hidden_state: torch.FloatTensor | None = None + audio_embeds: torch.FloatTensor | None = None + ctc_token_ids: list[torch.Tensor] | None = None + text_lengths: list[int] | None = None + audio_lengths: list[int] | None = None + encoder_loss: torch.Tensor | None = None + encoder_logits: torch.Tensor | None = None + + +@dataclass +class GraniteSpeechNarOutput(ModelOutput): + """Output of the GraniteSpeechNarForCTC model. + + Attributes: + loss: Combined CTC + auxiliary losses (only when labels provided). + logits: Flat (concatenated) LLM logits of shape `(sum(text_lengths), vocab_size)`. + Split with `logits.split(text_lengths)` to get per-sample tensors. + text_lengths: Per-sample text sequence lengths for splitting logits. + preds: List of predicted token ID tensors per sample (after CTC collapse, inference only). + audio_embeds: Projected audio embeddings (cached for multi-step editing). + encoder_logits: Flat BPE CTC logits from the encoder. + encoder_preds: List of CTC-collapsed encoder predictions per sample. + encoder_loss: Encoder BPE CTC loss component (for logging). + ce_loss: Cross-entropy auxiliary loss component (for logging). + """ + + loss: torch.Tensor | None = None + logits: torch.FloatTensor | None = None + text_lengths: list[int] | None = None + preds: list[torch.Tensor] | None = None + audio_embeds: torch.FloatTensor | None = None + encoder_logits: torch.Tensor | None = None + encoder_preds: list[torch.Tensor] | None = None + encoder_loss: torch.Tensor | None = None + ce_loss: torch.Tensor | None = None + + +class GraniteSpeechNarConformerBlock(GradientCheckpointingLayer, GraniteSpeechConformerBlock): + pass + + +def _posterior_weighted_pool(hidden: torch.Tensor, importance: torch.Tensor, window_size: int = 4) -> torch.Tensor: + batch_size, seq_len, hidden_dim = hidden.shape + pad_len = (window_size - seq_len % window_size) % window_size + if pad_len > 0: + hidden = F.pad(hidden, (0, 0, 0, pad_len)) + importance = F.pad(importance, (0, pad_len)) + num_windows = hidden.shape[1] // window_size + hidden = hidden.view(batch_size, num_windows, window_size, hidden_dim) + importance = importance.view(batch_size, num_windows, window_size) + weights = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8) + pooled = (hidden * weights.unsqueeze(-1)).sum(dim=2) + return pooled + + +def _ctc_greedy_decode(logits: torch.Tensor, blank_id: int) -> torch.Tensor: + """CTC greedy decode a single sequence: argmax -> unique_consecutive -> remove blank.""" + pred = torch.unique_consecutive(logits.argmax(-1)) + return pred[pred != blank_id] + + +def _ctc_loss_from_flat_logits( + flat_logits: torch.Tensor, + lengths: list[int], + labels: torch.Tensor, + label_lengths: torch.Tensor, + blank_id: int, +) -> torch.Tensor: + """Compute mean CTC loss from flat (concatenated) logits with per-sample lengths.""" + log_probs = torch.log_softmax(flat_logits.float(), dim=-1) + log_probs_padded = log_probs.new_zeros(len(lengths), max(lengths), log_probs.shape[-1]) + offset = 0 + for i, length in enumerate(lengths): + log_probs_padded[i, :length] = log_probs[offset : offset + length] + offset += length + + lengths_t = torch.tensor(lengths, device=flat_logits.device) + return ( + F.ctc_loss( + log_probs_padded.transpose(0, 1), + labels, + lengths_t, + label_lengths, + blank=blank_id, + reduction="sum", + zero_infinity=True, + ) + / lengths_t.sum() + ) + + +class GraniteSpeechNarQFormerCrossAttention(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.num_heads = config.num_heads + self.head_dim = config.hidden_size // config.num_heads + self.hidden_size = config.hidden_size + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_bias) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, query_len, _ = hidden_states.shape + encoder_len = encoder_hidden_states.shape[1] + + query_states = ( + self.q_proj(hidden_states).view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + key_states = ( + self.k_proj(encoder_hidden_states) + .view(batch_size, encoder_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(encoder_hidden_states) + .view(batch_size, encoder_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=False) + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.hidden_size) + return self.o_proj(attn_output) + + +class GraniteSpeechNarQFormerMLP(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + mlp_hidden_size = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(config.hidden_size, mlp_hidden_size, bias=config.mlp_bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(mlp_hidden_size, config.hidden_size, bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.fc2(self.act(self.fc1(hidden_states))) + + +class GraniteSpeechNarQFormerLayer(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.cross_attention = GraniteSpeechNarQFormerCrossAttention(config) + self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.mlp = GraniteSpeechNarQFormerMLP(config) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states + self.cross_attention(self.attn_norm(hidden_states), encoder_hidden_states) + hidden_states = hidden_states + self.mlp(self.mlp_norm(hidden_states)) + return hidden_states + + +class GraniteSpeechNarQFormer(nn.Module): + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.layers = nn.ModuleList([GraniteSpeechNarQFormerLayer(config) for _ in range(config.num_layers)]) + + def forward(self, query_embeds: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = query_embeds + for layer in self.layers: + hidden_states = layer(hidden_states, encoder_hidden_states) + return hidden_states + + +class GraniteSpeechNarProjector(nn.Module): + """Windowed QFormer projector that maps multi-layer encoder features to LLM embedding space.""" + + def __init__(self, config: GraniteSpeechNarProjectorConfig): + super().__init__() + self.config = config + self.layer_norms = nn.ModuleList( + [nn.LayerNorm(config.encoder_dim, eps=config.layernorm_eps) for _ in range(config.num_encoder_layers)] + ) + self.layer_projector = nn.Linear(config.encoder_dim * config.num_encoder_layers, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_prob) + self.projector_act = nn.GELU() + self.qformer = GraniteSpeechNarQFormer(config) + + query_length = config.block_size // config.downsample_rate + embed_std = config.hidden_size**-0.5 + self.query = nn.Parameter(torch.randn(1, query_length, config.hidden_size) * embed_std) + self.window_positions = nn.Parameter(torch.randn(1, config.block_size, config.hidden_size) * embed_std) + self.out_norm = nn.LayerNorm(config.hidden_size, eps=config.layernorm_eps) + self.out_linear = nn.Linear(config.hidden_size, config.llm_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = hidden_states.size() + + hidden_states = hidden_states.view( + batch_size, seq_len, self.config.num_encoder_layers, self.config.encoder_dim + ) + normalized_layers = [] + for i, layer_norm in enumerate(self.layer_norms): + normalized_layers.append(layer_norm(hidden_states[:, :, i])) + hidden_states = torch.cat(normalized_layers, dim=-1) + + hidden_states = self.projector_act(self.layer_projector(hidden_states)) + + block_size = self.config.block_size + nblocks = seq_len // block_size + rest = seq_len % block_size + if rest > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, block_size - rest), "constant", 0) + nblocks += 1 + + hidden_states = hidden_states.view(batch_size * nblocks, block_size, self.config.hidden_size) + query_length = self.query.shape[1] + mean_pool = hidden_states.view( + batch_size * nblocks, query_length, self.config.downsample_rate, self.config.hidden_size + ).mean(dim=-2) + + hidden_states = self.qformer( + query_embeds=self.dropout(self.query + mean_pool), + encoder_hidden_states=self.dropout(hidden_states + self.window_positions), + ) + + hidden_states = hidden_states.view(batch_size, nblocks * query_length, -1) + hidden_states = self.dropout(self.out_norm(hidden_states)) + return self.out_linear(hidden_states) + + +@auto_docstring +class GraniteSpeechNarPreTrainedModel(PreTrainedModel): + config_class = GraniteSpeechNarConfig + base_model_prefix = "model" + main_input_name = "input_features" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["GraniteSpeechNarConformerBlock", "GraniteDecoderLayer"] + input_modalities = ("audio",) + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, GraniteSpeechNarProjector): + std = module.config.hidden_size**-0.5 + init.normal_(module.query, mean=0.0, std=std) + init.normal_(module.window_positions, mean=0.0, std=std) + + +class GraniteSpeechNarAttention(GraniteAttention): + """GraniteAttention with is_causal=False for bidirectional attention.""" + + def __init__(self, config, layer_idx=None): + super().__init__(config, layer_idx=layer_idx) + self.is_causal = False + + +class GraniteSpeechNarLanguageModel(GraniteModel): + """GraniteModel with bidirectional (non-causal) attention. + + Uses GraniteSpeechNarDecoderLayer which sets is_causal=False, + and replaces create_causal_mask() with create_bidirectional_mask() so all + attention backends (SDPA, FA2, eager, flex) get a proper non-causal mask. + """ + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + **kwargs, + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + if position_ids is None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + packed_seq_mask = find_packed_sequence_indices(position_ids) + and_mask_fn = packed_sequence_mask_function(packed_seq_mask) if packed_seq_mask is not None else None + bidirectional_mask = create_bidirectional_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + and_mask_function=and_mask_fn, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + # KV cache is not needed in a non-autoregressive model + kwargs["use_cache"] = False + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=bidirectional_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + +class GraniteSpeechNarCTCEncoder(GraniteSpeechNarPreTrainedModel): + """Conformer encoder with BPE CTC head and multi-layer output.""" + + config_class = GraniteSpeechNarEncoderConfig + _can_record_outputs = { + "hidden_states": GraniteSpeechNarConformerBlock, + } + + def __init__(self, config: GraniteSpeechNarEncoderConfig): + super().__init__(config) + self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) + self.layers = nn.ModuleList([GraniteSpeechNarConformerBlock(config) for _ in range(config.num_layers)]) + self.out = nn.Linear(config.hidden_dim, config.output_dim, bias=True) + self.out_mid = nn.Linear(config.output_dim, config.hidden_dim, bias=True) + self.out_bpe = nn.Linear(config.hidden_dim, config.bpe_output_dim, bias=True) + self.dropout = nn.Dropout(config.pred_dropout) + self.post_init() + + @capture_outputs + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + **kwargs, + ) -> GraniteSpeechNarEncoderOutput: + if attention_mask is None: + attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) + + hidden_states = self.input_linear(input_features.to(self.dtype)) + + context_size = self.config.context_size + seq = torch.arange(context_size, device=hidden_states.device) + relpos_dist = seq.view(-1, 1) - seq.view(1, -1) + attention_dists = torch.clamp(relpos_dist, -context_size, context_size) + self.config.max_pos_emb + + for layer_idx, layer in enumerate(self.layers, start=1): + hidden_states = layer(hidden_states, attention_dists=attention_dists) + + if layer_idx == self.config.self_conditioning_layer: + mid_logits = self.out(self.dropout(hidden_states)) + mid_probs = torch.softmax(mid_logits.float(), dim=-1) + blank_probs = mid_probs[:, :, 0] + hidden_states += self.out_mid(mid_probs.to(hidden_states.dtype)) + + hidden_states = self.dropout(hidden_states) + + pool_window = self.config.bpe_pooling_window + importance = 1.0 - blank_probs + pooled = _posterior_weighted_pool(hidden_states.float(), importance, window_size=pool_window).to( + hidden_states.dtype + ) + encoder_lengths = attention_mask.sum(dim=1) + lengths = -(encoder_lengths // -pool_window) + lengths_list = lengths.tolist() + logits = self.out_bpe(torch.cat([pooled[i, :length] for i, length in enumerate(lengths_list)])) + + loss = None + if labels is not None: + loss = _ctc_loss_from_flat_logits(logits, lengths_list, labels, label_lengths, self.config.blank_token_id) + + return GraniteSpeechNarEncoderOutput( + loss=loss, + logits=logits, + last_hidden_state=hidden_states, + ) + + +@auto_docstring( + custom_intro=""" + The GraniteSpeechNar base model consisting of a conformer encoder, QFormer projector, + and a bidirectional Granite language model backbone. + """ +) +class GraniteSpeechNarModel(GraniteSpeechNarPreTrainedModel): + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + + self.encoder = GraniteSpeechNarCTCEncoder(config.encoder_config) + self.projector = GraniteSpeechNarProjector(config.projector_config) + + self.language_model = GraniteSpeechNarLanguageModel(config.text_config) + + self.post_init() + + def _ctc_collapse_decode( + self, + bpe_logits_flat: torch.Tensor, + bpe_lengths: list[int], + ) -> list[torch.Tensor]: + """GPU CTC greedy decode: argmax -> unique_consecutive -> remove blank.""" + blank_id = self.config.blank_token_id + per_sample = bpe_logits_flat.split(bpe_lengths) + return [_ctc_greedy_decode(seq_logits, blank_id) for seq_logits in per_sample] + + def _add_insertion_slots(self, token_ids: torch.Tensor) -> torch.Tensor: + """Insert blank tokens between each CTC token as editing slots for the LLM.""" + blank_id = self.config.blank_token_id + n = token_ids.numel() + total_len = max(2 * n + 1, self.config.min_edit_sequence_length) + idx = torch.arange(n, device=token_ids.device) + out_idx = 2 * idx + 1 + out = torch.full((total_len,), fill_value=blank_id, dtype=token_ids.dtype, device=token_ids.device) + out[out_idx] = token_ids + return out + + def _build_flat_inputs( + self, + ctc_token_ids: list[torch.Tensor], + audio_embeds: torch.Tensor, + audio_lengths: list[int], + ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """Build flat (pad-free) LLM input: [audio_0, text_0, audio_1, text_1, ...]""" + embed_tokens = self.language_model.embed_tokens + + embeds_list = [] + position_ids_list = [] + text_lengths = [] + + for i, audio_len in enumerate(audio_lengths): + audio_emb = audio_embeds[i, :audio_len] + text_ids_with_slots = self._add_insertion_slots(ctc_token_ids[i]) + text_emb = embed_tokens(text_ids_with_slots) + sample_embeds = torch.cat([audio_emb, text_emb], dim=0) + embeds_list.append(sample_embeds) + position_ids_list.append(torch.arange(sample_embeds.shape[0], device=audio_embeds.device)) + text_lengths.append(text_ids_with_slots.shape[0]) + + flat_embeds = torch.cat(embeds_list, dim=0).unsqueeze(0) + flat_position_ids = torch.cat(position_ids_list, dim=0).unsqueeze(0) + return flat_embeds, flat_position_ids, text_lengths + + def forward( + self, + input_features: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + output_encoder_logits: bool = False, + audio_embeds: torch.Tensor | None = None, + ctc_token_ids: list[torch.Tensor] | None = None, + **kwargs, + ) -> GraniteSpeechNarModelOutput: + if attention_mask is None: + attention_mask = torch.ones(input_features.shape[:-1], dtype=torch.bool, device=input_features.device) + + encoder_loss = None + encoder_logits = None + + if audio_embeds is None: + encoder_labels = labels if self.config.encoder_ctc_loss_lambda else None + enc_out = self.encoder( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + labels=encoder_labels, + label_lengths=label_lengths if encoder_labels is not None else None, + ) + + encoder_lengths = attention_mask.sum(dim=1) + + pool_window = self.encoder.config.bpe_pooling_window + bpe_lengths = (-(encoder_lengths // -pool_window)).tolist() + ctc_token_ids = self._ctc_collapse_decode(enc_out.logits, bpe_lengths) + + multilayer_features = torch.cat( + [enc_out.hidden_states[idx] for idx in self.config.encoder_layer_indices], dim=-1 + ) + + encoder_loss = enc_out.loss + encoder_logits = enc_out.logits if output_encoder_logits else None + del enc_out + + audio_embeds = self.projector(multilayer_features) + del multilayer_features + if self.config.scale_projected_embeddings: + embedding_multiplier = getattr(self.config.text_config, "embedding_multiplier", 1.0) + audio_embeds = audio_embeds / embedding_multiplier + audio_embeds = audio_embeds.to(self.language_model.embed_tokens.weight.dtype) + + encoder_lengths = attention_mask.sum(dim=1) + audio_lengths = (encoder_lengths // self.projector.config.downsample_rate).cpu().tolist() + + flat_embeds, flat_position_ids, text_lengths = self._build_flat_inputs( + ctc_token_ids, audio_embeds, audio_lengths + ) + + llm_out = self.language_model( + inputs_embeds=flat_embeds, + position_ids=flat_position_ids, + ) + + return GraniteSpeechNarModelOutput( + last_hidden_state=llm_out.last_hidden_state.squeeze(0), + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, + text_lengths=text_lengths, + audio_lengths=audio_lengths, + encoder_loss=encoder_loss, + encoder_logits=encoder_logits, + ) + + +@auto_docstring( + custom_intro=""" + The GraniteSpeechNar model for non-autoregressive CTC-based speech recognition. + Consists of a conformer encoder with BPE CTC head, a QFormer-based projector, + and a bidirectional Granite LLM backbone that refines CTC predictions in a single pass. + """ +) +class GraniteSpeechNarForCTC(GraniteSpeechNarPreTrainedModel): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: GraniteSpeechNarConfig): + super().__init__(config) + + self.model = GraniteSpeechNarModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def forward( + self, + input_features: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + label_lengths: torch.Tensor | None = None, + output_encoder_logits: bool = False, + audio_embeds: torch.Tensor | None = None, + ctc_token_ids: list[torch.Tensor] | None = None, + **kwargs, + ) -> GraniteSpeechNarOutput: + r""" + Args: + input_features (`torch.Tensor` of shape `(batch_size, seq_len, input_dim)`): + Mel spectrogram features. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): + Encoder attention mask (1 for valid frames, 0 for padding). + labels (`torch.Tensor` of shape `(batch_size, max_label_len)`, *optional*): + Ground truth LLM token IDs for training. + label_lengths (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Number of valid tokens per sample in `labels`. + output_encoder_logits (`bool`, *optional*, defaults to `False`): + Whether to return encoder BPE logits. When False, the large logits + tensor is freed early to reduce peak memory. + audio_embeds (`torch.Tensor`, *optional*): + Pre-computed projected audio embeddings. When provided, encoder and + projector are skipped (used for multi-step editing). + ctc_token_ids (`list[torch.Tensor]`, *optional*): + Pre-computed CTC token predictions to use as text input. When provided + with `audio_embeds`, replaces encoder CTC predictions. + + Returns: + [`GraniteSpeechNarOutput`] + """ + model_out = self.model( + input_features=input_features, + attention_mask=attention_mask, + labels=labels, + label_lengths=label_lengths, + output_encoder_logits=output_encoder_logits, + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, + ) + + all_logits = self.lm_head(model_out.last_hidden_state) + if hasattr(self.config.text_config, "logits_scaling"): + all_logits = all_logits / self.config.text_config.logits_scaling + + audio_lengths = model_out.audio_lengths + text_lengths = model_out.text_lengths + segment_lengths = [l for a, t in zip(audio_lengths, text_lengths) for l in (a, t)] + text_logits = torch.cat(list(all_logits.split(segment_lengths)[1::2])) + + loss = None + encoder_loss = model_out.encoder_loss + ce_loss = None + if labels is not None: + loss = _ctc_loss_from_flat_logits( + text_logits, text_lengths, labels, label_lengths, self.config.blank_token_id + ) + + if self.config.ce_loss_lambda > 0.0: + ce_targets = torch.cat([self.model._add_insertion_slots(ids) for ids in model_out.ctc_token_ids]) + ce_loss = F.cross_entropy( + text_logits, + ce_targets.long(), + reduction="mean", + ignore_index=-100, + ) + loss = loss + self.config.ce_loss_lambda * ce_loss + + if encoder_loss is not None: + loss = loss + self.config.encoder_ctc_loss_lambda * encoder_loss + + return GraniteSpeechNarOutput( + loss=loss, + logits=text_logits, + text_lengths=text_lengths, + audio_embeds=model_out.audio_embeds, + encoder_logits=model_out.encoder_logits, + encoder_preds=model_out.ctc_token_ids, + encoder_loss=encoder_loss, + ce_loss=ce_loss, + ) + + @torch.inference_mode() + def generate( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_encoder_logits: bool = False, + num_editing_steps: int = 1, + ) -> GraniteSpeechNarOutput: + """Non-autoregressive inference with iterative editing. + + Each editing step collapses the LLM output via CTC and feeds it back + as input for refinement, reusing cached audio embeddings. + + Returns token ID tensors in `preds`. Use `GraniteSpeechNarProcessor.batch_decode()` + to convert to strings. + """ + audio_embeds = None + ctc_token_ids = None + encoder_preds = None + + for step in range(num_editing_steps): + output = self.forward( + input_features=input_features, + attention_mask=attention_mask, + audio_embeds=audio_embeds, + ctc_token_ids=ctc_token_ids, + output_encoder_logits=(output_encoder_logits and step == 0), + ) + + blank_id = self.config.blank_token_id + logits_per_sample = output.logits.split(output.text_lengths) + ctc_token_ids = [_ctc_greedy_decode(sample_logits, blank_id) for sample_logits in logits_per_sample] + audio_embeds = output.audio_embeds + + if step == 0: + encoder_preds = output.encoder_preds + + return GraniteSpeechNarOutput( + preds=ctc_token_ids, + logits=output.logits, + text_lengths=output.text_lengths, + encoder_logits=output.encoder_logits, + encoder_preds=encoder_preds, + ) + + +__all__ = [ + "GraniteSpeechNarCTCEncoder", + "GraniteSpeechNarForCTC", + "GraniteSpeechNarLanguageModel", + "GraniteSpeechNarModel", + "GraniteSpeechNarPreTrainedModel", +] diff --git a/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py new file mode 100644 index 000000000000..679b984d22f9 --- /dev/null +++ b/src/transformers/models/granite_speech_nar/processing_granite_speech_nar.py @@ -0,0 +1,48 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor for Granite Speech NAR.""" + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import AudioInput +from ...utils import is_torch_available +from .feature_extraction_granite_speech_nar import GraniteSpeechNarFeatureExtractor + + +if is_torch_available(): + import torch + + +class GraniteSpeechNarProcessor(ProcessorMixin): + """Processor combining audio feature extraction and tokenizer for GraniteSpeechNar.""" + + tokenizer_class = "AutoTokenizer" + + def __init__(self, feature_extractor: GraniteSpeechNarFeatureExtractor, tokenizer=None, **kwargs): + super().__init__(feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs) + + def __call__( + self, + audios: AudioInput, + device: str | None = None, + **kwargs, + ) -> dict: + return self.feature_extractor(audios, device=device) + + def batch_decode(self, token_ids_list: list["torch.Tensor"], **kwargs) -> list[str]: + if self.tokenizer is None: + raise ValueError("Tokenizer not set. Pass tokenizer to GraniteSpeechNarProcessor.") + return [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in token_ids_list] + + +__all__ = ["GraniteSpeechNarProcessor"] diff --git a/tests/models/granite_speech_nar/__init__.py b/tests/models/granite_speech_nar/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py new file mode 100644 index 000000000000..df67b7c735bd --- /dev/null +++ b/tests/models/granite_speech_nar/test_modeling_granite_speech_nar.py @@ -0,0 +1,510 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch GraniteSpeechNar model.""" + +import tempfile +import unittest + +from transformers import is_datasets_available, is_torch_available +from transformers.testing_utils import cleanup, require_torch, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor + + +if is_datasets_available(): + from datasets import load_dataset + +if is_torch_available(): + import torch + + from transformers import ( + AutoModel, + AutoProcessor, + GraniteConfig, + GraniteSpeechNarConfig, + ) + from transformers.models.granite_speech_nar.configuration_granite_speech_nar import ( + GraniteSpeechNarEncoderConfig, + GraniteSpeechNarProjectorConfig, + ) + from transformers.models.granite_speech_nar.modeling_granite_speech_nar import ( + GraniteSpeechNarCTCEncoder, + GraniteSpeechNarForCTC, + ) + + +class GraniteSpeechNarEncoderModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=100, + is_training=True, + num_layers=2, + hidden_dim=64, + num_heads=4, + dim_head=16, + input_dim=160, + output_dim=10, + context_size=50, + self_conditioning_layer=1, + bpe_output_dim=51, + bpe_pooling_window=4, + dropout=0.0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + + self.num_layers = num_layers + self.num_hidden_layers = num_layers + self.hidden_dim = hidden_dim + self.hidden_size = hidden_dim + self.num_heads = num_heads + self.dim_head = dim_head + self.input_dim = input_dim + self.output_dim = output_dim + self.context_size = context_size + self.self_conditioning_layer = self_conditioning_layer + self.bpe_output_dim = bpe_output_dim + self.bpe_pooling_window = bpe_pooling_window + self.dropout = dropout + + def get_config(self): + return GraniteSpeechNarEncoderConfig( + num_layers=self.num_layers, + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + dim_head=self.dim_head, + input_dim=self.input_dim, + output_dim=self.output_dim, + context_size=self.context_size, + self_conditioning_layer=self.self_conditioning_layer, + bpe_output_dim=self.bpe_output_dim, + bpe_pooling_window=self.bpe_pooling_window, + dropout=self.dropout, + blank_token_id=0, + ) + + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.seq_length, self.input_dim]) + attention_mask = torch.ones(self.batch_size, self.seq_length, dtype=torch.bool) + attention_mask[1, 80:] = False + config = self.get_config() + return config, input_features, attention_mask + + def create_and_check_model(self, config, input_features, attention_mask): + model = GraniteSpeechNarCTCEncoder(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features, attention_mask=attention_mask, output_hidden_states=True) + + self.parent.assertIsNotNone(result.logits) + self.parent.assertEqual(result.logits.shape[-1], self.bpe_output_dim) + self.parent.assertIsNotNone(result.hidden_states) + self.parent.assertEqual(len(result.hidden_states), self.num_layers + 1) + + def prepare_config_and_inputs_for_common(self): + config, input_features, attention_mask = self.prepare_config_and_inputs() + inputs_dict = { + "input_features": input_features, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +class GraniteSpeechNarForCTCModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=100, + is_training=True, + vocab_size=51, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + + self.encoder_model_tester = GraniteSpeechNarEncoderModelTester( + parent, + batch_size=batch_size, + seq_length=seq_length, + bpe_output_dim=vocab_size, + ) + + def get_config(self): + encoder_config = self.encoder_model_tester.get_config() + projector_config = GraniteSpeechNarProjectorConfig( + encoder_dim=self.encoder_model_tester.hidden_dim, + llm_dim=128, + downsample_rate=5, + num_encoder_layers=2, + hidden_size=128, + num_heads=4, + num_layers=1, + block_size=15, + ) + text_config = GraniteConfig( + vocab_size=self.vocab_size, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=128, + max_position_embeddings=512, + tie_word_embeddings=True, + embedding_multiplier=1.0, + attention_multiplier=1.0, + residual_multiplier=1.0, + logits_scaling=1.0, + ) + return GraniteSpeechNarConfig( + encoder_config=encoder_config, + projector_config=projector_config, + text_config=text_config.to_dict(), + encoder_layer_indices=[1, -1], + scale_projected_embeddings=False, + ) + + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.seq_length, self.encoder_model_tester.input_dim]) + attention_mask = torch.ones(self.batch_size, self.seq_length, dtype=torch.bool) + attention_mask[1, 80:] = False + config = self.get_config() + return config, input_features, attention_mask + + def create_and_check_model(self, config, input_features, attention_mask): + model = GraniteSpeechNarForCTC(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_features=input_features, attention_mask=attention_mask) + + self.parent.assertIsNotNone(result.logits) + self.parent.assertIsInstance(result.logits, torch.Tensor) + self.parent.assertEqual(result.logits.ndim, 2) + self.parent.assertEqual(result.logits.shape[1], self.vocab_size) + self.parent.assertIsNotNone(result.text_lengths) + self.parent.assertEqual(len(result.text_lengths), self.batch_size) + self.parent.assertEqual(sum(result.text_lengths), result.logits.shape[0]) + + def create_and_check_generate(self, config, input_features, attention_mask): + model = GraniteSpeechNarForCTC(config=config) + model.to(torch_device) + model.eval() + output = model.generate(input_features=input_features, attention_mask=attention_mask) + + self.parent.assertIsNotNone(output.preds) + self.parent.assertEqual(len(output.preds), self.batch_size) + for pred in output.preds: + self.parent.assertIsInstance(pred, torch.Tensor) + self.parent.assertEqual(pred.ndim, 1) + + def create_and_check_generate_multi_step(self, config, input_features, attention_mask): + model = GraniteSpeechNarForCTC(config=config) + model.to(torch_device) + model.eval() + output = model.generate(input_features=input_features, attention_mask=attention_mask, num_editing_steps=3) + + self.parent.assertIsNotNone(output.preds) + self.parent.assertEqual(len(output.preds), self.batch_size) + for pred in output.preds: + self.parent.assertIsInstance(pred, torch.Tensor) + self.parent.assertEqual(pred.ndim, 1) + + def prepare_config_and_inputs_for_common(self): + config, input_features, attention_mask = self.prepare_config_and_inputs() + inputs_dict = { + "input_features": input_features, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class GraniteSpeechNarEncoderModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (GraniteSpeechNarCTCEncoder,) if is_torch_available() else () + + test_resize_embeddings = False + test_attention_outputs = False + has_attentions = False + + @unittest.skip(reason="GraniteSpeechNarCTCEncoder does not use inputs_embeds") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="GraniteSpeechNarCTCEncoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="GraniteSpeechNarCTCEncoder does not use inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Conformer encoder does not expose attention outputs") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Self-conditioning injection between layers causes hidden_states mismatch in tuple vs dict") + def test_model_outputs_equivalence(self): + pass + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if return_labels: + batch_size = self.model_tester.batch_size + inputs_dict["labels"] = torch.arange(1, 6).unsqueeze(0).expand(batch_size, -1) + inputs_dict["label_lengths"] = torch.tensor([5] * batch_size) + return inputs_dict + + def setUp(self): + self.model_tester = GraniteSpeechNarEncoderModelTester(self) + self.config_tester = ConfigTester(self, config_class=GraniteSpeechNarEncoderConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + +@require_torch +class GraniteSpeechNarForCTCModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (GraniteSpeechNarForCTC,) if is_torch_available() else () + + test_resize_embeddings = False + test_attention_outputs = False + has_attentions = False + _is_composite = True + + @unittest.skip(reason="GraniteSpeechNarForCTC takes audio input_features, not input_ids/inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="GraniteSpeechNarForCTC takes audio input_features, not input_ids/inputs_embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="GraniteSpeechNarForCTC has a custom generate method, not standard GenerationMixin") + def test_generation_tester_mixin_inheritance(self): + pass + + @unittest.skip(reason="text_lengths (list[int]) in output breaks recursive tuple/dict comparison") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip(reason="Composite model with flat packed sequences; hidden_states not piped to top-level output") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Composite model with flat packed sequences; hidden_states not piped to top-level output") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Encoder does not have standard embedding layer for gradient checkpointing") + def test_enable_input_require_grads_with_gradient_checkpointing(self): + pass + + def test_can_init_all_missing_weights(self): + super().test_can_init_all_missing_weights() + + def setUp(self): + self.model_tester = GraniteSpeechNarForCTCModelTester(self) + self.config_tester = ConfigTester(self, config_class=GraniteSpeechNarConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate(*config_and_inputs) + + def test_generate_multi_step(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_multi_step(*config_and_inputs) + + def test_loss(self): + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = GraniteSpeechNarForCTC(config).to(torch_device).train() + + labels = torch.randint(0, self.model_tester.vocab_size, (self.model_tester.batch_size, 5)) + label_lengths = torch.tensor([5, 3]) + + output = model( + input_features=input_features, + attention_mask=attention_mask, + labels=labels, + label_lengths=label_lengths, + ) + + self.assertIsNotNone(output.loss) + self.assertEqual(output.loss.ndim, 0) + self.assertTrue(output.loss.requires_grad) + output.loss.backward() + + def test_loss_with_ce(self): + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + config.ce_loss_lambda = 0.5 + model = GraniteSpeechNarForCTC(config).to(torch_device).train() + + labels = torch.randint(0, self.model_tester.vocab_size, (self.model_tester.batch_size, 4)) + label_lengths = torch.tensor([4, 3]) + + output = model( + input_features=input_features, attention_mask=attention_mask, labels=labels, label_lengths=label_lengths + ) + self.assertIsNotNone(output.loss) + self.assertTrue(output.loss.requires_grad) + output.loss.backward() + + def test_loss_with_encoder_ctc(self): + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + config.encoder_ctc_loss_lambda = 0.3 + model = GraniteSpeechNarForCTC(config).to(torch_device).train() + + labels = torch.randint(0, self.model_tester.vocab_size, (self.model_tester.batch_size, 4)) + label_lengths = torch.tensor([4, 3]) + + output = model( + input_features=input_features, attention_mask=attention_mask, labels=labels, label_lengths=label_lengths + ) + self.assertIsNotNone(output.loss) + self.assertTrue(output.loss.requires_grad) + output.loss.backward() + + def test_no_loss_without_labels(self): + config, input_features, attention_mask = self.model_tester.prepare_config_and_inputs() + model = GraniteSpeechNarForCTC(config).to(torch_device).eval() + + with torch.no_grad(): + output = model(input_features=input_features, attention_mask=attention_mask) + + self.assertIsNone(output.loss) + + def test_bidirectional_attention(self): + config = self.model_tester.get_config() + model = GraniteSpeechNarForCTC(config).to(torch_device).eval() + granite_model = model.model.language_model + + embeds_a = torch.randn(1, 10, 128, device=torch_device) + embeds_b = embeds_a.clone() + embeds_b[0, -1, :] = torch.randn(128, device=torch_device) + + with torch.no_grad(): + out_a = granite_model(inputs_embeds=embeds_a).last_hidden_state + out_b = granite_model(inputs_embeds=embeds_b).last_hidden_state + + diff_first = (out_a[0, 0] - out_b[0, 0]).abs().max().item() + self.assertGreater(diff_first, 1e-5, "First token unchanged — attention appears causal.") + + def test_is_causal_false_on_layers(self): + config = self.model_tester.get_config() + model = GraniteSpeechNarForCTC(config) + for i, layer in enumerate(model.model.language_model.layers): + self.assertFalse(layer.self_attn.is_causal, f"Layer {i} is_causal is not False") + + def test_sdpa_can_dispatch_composite_models(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + +@require_torch +class GraniteSpeechNarIntegrationTest(unittest.TestCase): + checkpoint_name = "ibm-granite/granite-speech-4.1-2b-nar" + _dataset = None + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @classmethod + def _load_dataset(cls): + if cls._dataset is None: + cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + def _load_datasamples(self, num_samples): + self._load_dataset() + samples = self._dataset.sort("id")[:num_samples]["audio"] + return [torch.tensor(x["array"], dtype=torch.float32) for x in samples] + + @slow + def test_single_sample_transcription(self): + model = AutoModel.from_pretrained( + self.checkpoint_name, + attn_implementation="flash_attention_2", + device_map=torch_device, + dtype=torch.bfloat16, + ).eval() + processor = AutoProcessor.from_pretrained(self.checkpoint_name) + + waveforms = self._load_datasamples(1) + inputs = processor(waveforms, device=torch_device) + output = model.generate(**inputs) + transcriptions = processor.batch_decode(output.preds) + + expected = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" + self.assertEqual(transcriptions[0], expected) + + @slow + def test_batch_transcription(self): + model = AutoModel.from_pretrained( + self.checkpoint_name, + attn_implementation="flash_attention_2", + device_map=torch_device, + dtype=torch.bfloat16, + ).eval() + processor = AutoProcessor.from_pretrained(self.checkpoint_name) + + waveforms = self._load_datasamples(2) + inputs = processor(waveforms, device=torch_device) + output = model.generate(**inputs) + transcriptions = processor.batch_decode(output.preds) + + expected = [ + "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", + "nor is mister quilter's manner less interesting than his matter", + ] + self.assertEqual(len(transcriptions), 2) + self.assertEqual(transcriptions, expected) diff --git a/utils/check_repo.py b/utils/check_repo.py index a4dd3f06bd92..b833622b7b4b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -290,6 +290,8 @@ "Sam3LiteTextTextModel", # Building part of a bigger model, tested implicitly through Sam3LiteTextModel "Exaone4_5_VisionModel", # Building part of a bigger model "Granite4VisionTextModel", # Building part of bigger (tested) model. Tested implicitly through Granite4VisionModel. + "GraniteSpeechNarLanguageModel", # Building part of bigger (tested) model. Tested implicitly through GraniteSpeechNarForCTC. + "GraniteSpeechNarModel", # Building part of bigger (tested) model. Tested implicitly through GraniteSpeechNarForCTC. ] ) @@ -518,6 +520,8 @@ "Ernie4_5_VL_MoeTextModel", # BC Alias "UVDocBridge", # Building part of a bigger model, tested implicitly through UVDocModel "Granite4VisionTextModel", # Building part of bigger (tested) model. + "GraniteSpeechNarModel", # Building part of bigger (tested) model. + "GraniteSpeechNarLanguageModel", # Building part of bigger (tested) model. ]