Skip to content

ojayballer/bvaex

Repository files navigation

Convolutional β-VAE from Scratch in JAX

A Convolutional Beta Variational Autoencoder implemented in JAX, with all components written manually, including forward passes, backward propagation, and parameter updates.

This project focuses on building and training a complete generative model from scratch by working directly with jax.numpy and jax.lax, without relying on automatic differentiation or high level neural network libraries. The model was trained on the full CelebA dataset using GPU acceleration.

Full details can be found in the blog :Convolutional Beta VAE from Scratch in JAX


Results

Evaluation on 1,000 hold out images

Metric Value
SSIM 0.9364
PSNR 22.49 dB
Final Reconstruction Loss 95.38
Final KL Divergence 110.14
Channel Balance (R / G / B) 0.56 / 0.49 / 0.44

Reconstructions : Top row shows original input; bottom row shows the model's reconstructed output.

Reconstructed Faces

Generated Faces : Faces generated by sampling z ~ N(0, I) through the custom decoder.

Generated Faces

Latent Interpolation : Linear transitions between the latent representations of two distinct identities.

Latent Interpolation

Training Loss : Tracking ELBO, Reconstruction (MSE), and KL Divergence over 100 epochs.

Loss Curves


Features

Implementation of core components:

All gradients and parameter updates are explicitly derived and implemented without autograd.


Architecture

Encoder

Input (3, 64, 64)
  → Conv2D(3→32, kernel=4, stride=2, padding=1)   → ReLU  → (32, 32, 32)
  → Conv2D(32→64, kernel=4, stride=2, padding=1)  → ReLU  → (64, 16, 16)
  → Conv2D(64→128, kernel=4, stride=2, padding=1) → ReLU  → (128, 8, 8)
  → Flatten → (8192,)
  → Dense(8192→128) → μ
  → Dense(8192→128) → log σ²

Latent Space

  • 128-dimensional Gaussian
  • z = μ + σ ⊙ ε

Decoder

z (128,)
  → Dense(128→8192) → Reshape → (128, 8, 8)
  → TransposedConv2D(128→64, kernel=4, stride=2, padding=1) → ReLU    → (64, 16, 16)
  → TransposedConv2D(64→32, kernel=4, stride=2, padding=1)  → ReLU    → (32, 32, 32)
  → TransposedConv2D(32→3, kernel=4, stride=2, padding=1)   → Sigmoid → (3, 64, 64)

Objective

  • ELBO = Reconstruction Loss (MSE) + β × KL Divergence
  • β = 0.5

Dataset

  • CelebA dataset
  • 202,599 images (full dataset)
  • Resolution: 64 × 64
  • Format: CHW tensors

The model was trained on the complete CelebA dataset using a custom data loader that preloads all images into memory for fast GPU training.


Setup and Installation

git clone https://github.com/ojayballer/bvaex.git
cd bvaex
pip install jax jaxlib numpy matplotlib pillow

Training

python train.py
Setting Value
GPU NVIDIA Tesla P100 (16GB)
Epochs 100
Batch Size 512
Optimizer AdamW (α=0.001, β₁=0.9, β₂=0.999, weight decay=0.01)
Latent Dimension 128
KL Weight (β) 0.5
Training Time ~2 hours
Platform Kaggle

JAX was used to compile all numerical operations through XLA, allowing the manually implemented layers and gradients to execute efficiently on GPU.

Trained weights are stored in:

weights/epoch_100/

Evaluate

python evaluate.py
python metrics.py
python interpolate.py

Note: The reconstructions are slightly blurry due to MSE loss. More details in the blog.

Outputs are stored in:

results/

Project Structure

bvaex/
├── model/
│   ├── __init__.py
│   ├── Encoder.py
│   ├── Decoder.py
│   ├── Dense.py
│   ├── ELBO.py
│   ├── Adam.py
│   ├── Activation.py
│   ├── Reshape.py
│   ├── Reparameterize.py
│   └── model.py
├── train.py
├── evaluate.py
├── metrics.py
├── interpolate.py
├── plots.py
├── load_data.py
├── results/
│   ├── reconstruction_grid.png
│   ├── generated_faces.png
│   ├── latent_interpolation.png
│   └── loss_curves.png
└── weights/
    └── epoch_100/

References

Papers I read deeply to implement bvaex:


License

MIT

About

A generative model that learns to reconstruct and sample human faces, built on raw JAX primitives

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages