Skip to content

Aashutoshh01/wav2vec2-pytorch

Repository files navigation

Wav2Vec2 Implementation & Training

This repository contains a comprehensive PyTorch implementation of the Wav2Vec2 architecture for Automatic Speech Recognition (ASR). The project demonstrates both the self-supervised pre-training phase—using contrastive learning on unlabelled audio—and the supervised fine-tuning phase using Connectionist Temporal Classification (CTC) on the LibriSpeech dataset.

🚀 Project Overview

Wav2Vec2 is a milestone self-supervised framework for speech representation learning. This repository provides:

  • A highly modular and readable implementation of the fundamental Wav2Vec2 components, including the Convolutional Feature Encoder, Transformer blocks, and the Gumbel Vector Quantizer.
  • Pre-training Pipeline: Distributed training utilizing Hugging Face accelerate. Features calculation of contrastive loss, diversity metrics, and continuous Gumbel-Softmax temperature annealing.
  • Fine-tuning Pipeline: CTC-based ASR training via the Hugging Face Trainer API for state-of-the-art downstream adaptation.
  • Data Utilities: Custom PyTorch dataset implementations for LibriSpeech (LibriSpeechDataset) with dynamic audio duration filtering and masking mechanisms.

🏗️ Detailed Architecture

The model code (found in models.py) strictly adheres to the original Wav2Vec2 specifications. The architecture leverages localized feature extraction followed by a deep contextualizing Transformer.

1. Convolutional Feature Encoder (Wav2Vec2FeatureEncoder)

A multi-layer 1D convolutional feature extractor mapping raw audio waveforms directly to latent representations.

  • Composed of several stacked Wav2Vec2LayerNormConvLayer blocks, each containing a Conv1d, LayerNorm, and a GELU activation.
  • Reduces the temporal dimensionality of the raw audio, mapping 16kHz audio to overlapping 20ms frames (computed every 10ms).

2. Feature Projection (Wav2Vec2FeatureProjection)

  • A normalization and mapping layer consisting of LayerNorm and a Linear projection.
  • Bridges the convolutional encoder's output dimension to the Transformer's expected embedding dimension.

3. Positional Convolution Embedding (Wav2Vec2PositionalConvEmbedding)

Unlike standard NLP Transformers that rely on absolute sinusoidal embeddings, Wav2Vec2 uses a local, learned relative positional embedding.

  • Implemented as a grouped 1D convolution (Conv1d) applied to the projected features.
  • Enables the model to generalize effectively to variable-length audio input without strict positional constraints.

4. Transformer Contextualizer (Wav2Vec2Encoder)

The core sequence modeling engine that processes the position-embedded representations.

  • Consists of stacked layers of Multi-Head Self-Attention (Wav2Vec2Attention) and Feed-Forward Networks.
  • Produces dense, highly contextualized representations of the audio frames incorporating information spanning the entire audio sequence.

5. Quantization Module (Wav2Vec2GumbelVectorQuantizer)

A highly specialized component used exclusively during the pre-training phase.

  • Acts on the direct uncontextualized output of the Feature Encoder.
  • Employs a Gumbel-Softmax operation to select discrete codebook entries in a fully differentiable manner.
  • Creates the "targets" (a sequence of vectors) that the Transformer output must learn to identify among distractors.

6. Task-Specific Heads

Depending on the training phase, the backbone's representations are passed into specialized heads:

  • Wav2Vec2ForPreTraining:

    • Manages the masking of incoming features in the time domain.
    • Computes the Contrastive Loss: requiring the model to confidently distinguish the true quantized latent representation for a masked time-step from a set of randomly sampled negative distractors.
    • Computes the Diversity Loss: encouraging uniform usage across all available codebook vectors to prevent index collapse.
  • Wav2Vec2ForCTC:

    • Adapts the initialized/pre-trained model for downstream tasks.
    • Applies a final Linear classification layer on top of the contextualized representation to map to individual vocabulary tokens, permitting CTC loss calculation against unaligned text transcripts.

🏋️ Training Workflows

1. Pre-training

Script: pretrain_wav2vec2.py The initial phase trains the representation space using self-supervised learning on unlabelled speech data.

  • Objective: Contrastive loss (identifying the correct quantized representation amongst distractors) and Diversity loss (encouraging uniform usage of codebook vectors).
  • Masking Mechanism: Spans of feature encoder outputs are randomly masked. The model predicts the quantized representation of the masked frames.
  • Hardware Acceleration: Scales to multiple GPUs effortlessly using Hugging Face accelerate.
  • Convergence Details: Pre-training exhibits smooth convergence over 20,000 steps, actively monitored via contrastive and diversity loss trajectories, alongside a tightly controlled temperature decay schedule for the Gumbel-Softmax operation.

2. Fine-tuning (ASR)

Script: finetune_wav2vec2.py The pre-trained backbone is refined on transcribed speech to perform speech-to-text decoding.

  • Backbone Setup: Feature encoder components can be iteratively frozen based on optimization targets to prevent catastrophic forgetting.
  • Objective: CTC loss aligns the context-rich output sequence to the raw text transcript.
  • Results: Monitored on the validation split. Reached a final Word Error Rate (WER) of 13.8% and evaluation loss of 0.178.

📊 Evaluation & Metrics

Comprehensive system monitoring and evaluation metrics were recorded using Weights and Biases (W&B) to ensure training stability and measure network convergence.

Performance Results

After 5000 global steps of fine-tuning, the model achieved a highly competitive 13.8% WER, demonstrating a successful transfer of learned acoustic features to the complex downstream transcription task.

Observe the sharp convergence in Word Error Rate and Loss below:

Training Dynamics

The fine-tuning trajectory shows consistent minimization of both training and validation losses, tracking the computational footprint effectively through steps processing.

System Health Diagnostics

Robust infrastructure tracking guarantees training health over prolonged multi-GPU intervals. The recorded metrics indicate highly stable compute utilization and flawless memory management without hardware throttling.


🛠️ Usage Runbook

Environment Setup

Install the complete set of dependencies required for distributed execution and monitoring:

pip install torch torchaudio transformers accelerate datasets wandb

Running Training

To commence a pre-training run utilizing all visibly configured GPUs through accelerate:

accelerate launch pretrain_wav2vec2.py

To adapt the model to downstream ASR, execute the fine-tuning script utilizing the HF Trainer with frozen feature extraction elements:

python finetune_wav2vec2.py

About

End-to-end PyTorch implementation of Wav2Vec2 with convolutional feature encoding, Transformer contextualization, Gumbel vector quantization, and full pipelines for contrastive pre-training and CTC fine-tuning.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages