This repository contains a comprehensive PyTorch implementation of the Wav2Vec2 architecture for Automatic Speech Recognition (ASR). The project demonstrates both the self-supervised pre-training phase—using contrastive learning on unlabelled audio—and the supervised fine-tuning phase using Connectionist Temporal Classification (CTC) on the LibriSpeech dataset.
Wav2Vec2 is a milestone self-supervised framework for speech representation learning. This repository provides:
- A highly modular and readable implementation of the fundamental Wav2Vec2 components, including the Convolutional Feature Encoder, Transformer blocks, and the Gumbel Vector Quantizer.
- Pre-training Pipeline: Distributed training utilizing Hugging Face
accelerate. Features calculation of contrastive loss, diversity metrics, and continuous Gumbel-Softmax temperature annealing. - Fine-tuning Pipeline: CTC-based ASR training via the Hugging Face
TrainerAPI for state-of-the-art downstream adaptation. - Data Utilities: Custom PyTorch dataset implementations for LibriSpeech (
LibriSpeechDataset) with dynamic audio duration filtering and masking mechanisms.
The model code (found in models.py) strictly adheres to the original Wav2Vec2 specifications. The architecture leverages localized feature extraction followed by a deep contextualizing Transformer.
A multi-layer 1D convolutional feature extractor mapping raw audio waveforms directly to latent representations.
- Composed of several stacked
Wav2Vec2LayerNormConvLayerblocks, each containing aConv1d,LayerNorm, and aGELUactivation. - Reduces the temporal dimensionality of the raw audio, mapping 16kHz audio to overlapping 20ms frames (computed every 10ms).
- A normalization and mapping layer consisting of
LayerNormand aLinearprojection. - Bridges the convolutional encoder's output dimension to the Transformer's expected embedding dimension.
Unlike standard NLP Transformers that rely on absolute sinusoidal embeddings, Wav2Vec2 uses a local, learned relative positional embedding.
- Implemented as a grouped 1D convolution (
Conv1d) applied to the projected features. - Enables the model to generalize effectively to variable-length audio input without strict positional constraints.
The core sequence modeling engine that processes the position-embedded representations.
- Consists of stacked layers of Multi-Head Self-Attention (
Wav2Vec2Attention) and Feed-Forward Networks. - Produces dense, highly contextualized representations of the audio frames incorporating information spanning the entire audio sequence.
A highly specialized component used exclusively during the pre-training phase.
- Acts on the direct uncontextualized output of the Feature Encoder.
- Employs a Gumbel-Softmax operation to select discrete codebook entries in a fully differentiable manner.
- Creates the "targets" (a sequence of vectors) that the Transformer output must learn to identify among distractors.
Depending on the training phase, the backbone's representations are passed into specialized heads:
-
Wav2Vec2ForPreTraining:- Manages the masking of incoming features in the time domain.
- Computes the Contrastive Loss: requiring the model to confidently distinguish the true quantized latent representation for a masked time-step from a set of randomly sampled negative distractors.
- Computes the Diversity Loss: encouraging uniform usage across all available codebook vectors to prevent index collapse.
-
Wav2Vec2ForCTC:- Adapts the initialized/pre-trained model for downstream tasks.
- Applies a final
Linearclassification layer on top of the contextualized representation to map to individual vocabulary tokens, permitting CTC loss calculation against unaligned text transcripts.
Script: pretrain_wav2vec2.py
The initial phase trains the representation space using self-supervised learning on unlabelled speech data.
- Objective: Contrastive loss (identifying the correct quantized representation amongst distractors) and Diversity loss (encouraging uniform usage of codebook vectors).
- Masking Mechanism: Spans of feature encoder outputs are randomly masked. The model predicts the quantized representation of the masked frames.
- Hardware Acceleration: Scales to multiple GPUs effortlessly using Hugging Face
accelerate. - Convergence Details: Pre-training exhibits smooth convergence over 20,000 steps, actively monitored via contrastive and diversity loss trajectories, alongside a tightly controlled temperature decay schedule for the Gumbel-Softmax operation.
Script: finetune_wav2vec2.py
The pre-trained backbone is refined on transcribed speech to perform speech-to-text decoding.
- Backbone Setup: Feature encoder components can be iteratively frozen based on optimization targets to prevent catastrophic forgetting.
- Objective: CTC loss aligns the context-rich output sequence to the raw text transcript.
- Results: Monitored on the validation split. Reached a final Word Error Rate (WER) of 13.8% and evaluation loss of
0.178.
Comprehensive system monitoring and evaluation metrics were recorded using Weights and Biases (W&B) to ensure training stability and measure network convergence.
After 5000 global steps of fine-tuning, the model achieved a highly competitive 13.8% WER, demonstrating a successful transfer of learned acoustic features to the complex downstream transcription task.
Observe the sharp convergence in Word Error Rate and Loss below:
The fine-tuning trajectory shows consistent minimization of both training and validation losses, tracking the computational footprint effectively through steps processing.
Robust infrastructure tracking guarantees training health over prolonged multi-GPU intervals. The recorded metrics indicate highly stable compute utilization and flawless memory management without hardware throttling.
Install the complete set of dependencies required for distributed execution and monitoring:
pip install torch torchaudio transformers accelerate datasets wandbTo commence a pre-training run utilizing all visibly configured GPUs through accelerate:
accelerate launch pretrain_wav2vec2.pyTo adapt the model to downstream ASR, execute the fine-tuning script utilizing the HF Trainer with frozen feature extraction elements:
python finetune_wav2vec2.py




