A Convolutional Beta Variational Autoencoder implemented in JAX, with all components written manually, including forward passes, backward propagation, and parameter updates.
This project focuses on building and training a complete generative model from scratch by working directly with jax.numpy and jax.lax, without relying on automatic differentiation or high level neural network libraries. The model was trained on the full CelebA dataset using GPU acceleration.
Full details can be found in the blog :Convolutional Beta VAE from Scratch in JAX
Evaluation on 1,000 hold out images
| Metric | Value |
|---|---|
| SSIM | 0.9364 |
| PSNR | 22.49 dB |
| Final Reconstruction Loss | 95.38 |
| Final KL Divergence | 110.14 |
| Channel Balance (R / G / B) | 0.56 / 0.49 / 0.44 |
Reconstructions : Top row shows original input; bottom row shows the model's reconstructed output.
Generated Faces : Faces generated by sampling z ~ N(0, I) through the custom decoder.
Latent Interpolation : Linear transitions between the latent representations of two distinct identities.
Training Loss : Tracking ELBO, Reconstruction (MSE), and KL Divergence over 100 epochs.
Implementation of core components:
-
Conv2D (forward and backward) view implementation
-
TransposedConv2D (forward and backward) view implementation
-
Dense layers with manual gradients view implementation
-
Activation functions (ReLU, Sigmoid) view implementation
-
Reparameterization trick (μ, σ, ε) view implementation
-
ELBO loss (Reconstruction + KL Divergence) view implementation
-
Custom AdamW optimizer (decoupled weight decay) view implementation
-
Xavier/Glorot weight initialization
-
Gradient clipping
All gradients and parameter updates are explicitly derived and implemented without autograd.
Encoder
Input (3, 64, 64)
→ Conv2D(3→32, kernel=4, stride=2, padding=1) → ReLU → (32, 32, 32)
→ Conv2D(32→64, kernel=4, stride=2, padding=1) → ReLU → (64, 16, 16)
→ Conv2D(64→128, kernel=4, stride=2, padding=1) → ReLU → (128, 8, 8)
→ Flatten → (8192,)
→ Dense(8192→128) → μ
→ Dense(8192→128) → log σ²
Latent Space
- 128-dimensional Gaussian
- z = μ + σ ⊙ ε
Decoder
z (128,)
→ Dense(128→8192) → Reshape → (128, 8, 8)
→ TransposedConv2D(128→64, kernel=4, stride=2, padding=1) → ReLU → (64, 16, 16)
→ TransposedConv2D(64→32, kernel=4, stride=2, padding=1) → ReLU → (32, 32, 32)
→ TransposedConv2D(32→3, kernel=4, stride=2, padding=1) → Sigmoid → (3, 64, 64)
Objective
- ELBO = Reconstruction Loss (MSE) + β × KL Divergence
- β = 0.5
- CelebA dataset
- 202,599 images (full dataset)
- Resolution: 64 × 64
- Format: CHW tensors
The model was trained on the complete CelebA dataset using a custom data loader that preloads all images into memory for fast GPU training.
git clone https://github.com/ojayballer/bvaex.git
cd bvaex
pip install jax jaxlib numpy matplotlib pillowpython train.py| Setting | Value |
|---|---|
| GPU | NVIDIA Tesla P100 (16GB) |
| Epochs | 100 |
| Batch Size | 512 |
| Optimizer | AdamW (α=0.001, β₁=0.9, β₂=0.999, weight decay=0.01) |
| Latent Dimension | 128 |
| KL Weight (β) | 0.5 |
| Training Time | ~2 hours |
| Platform | Kaggle |
JAX was used to compile all numerical operations through XLA, allowing the manually implemented layers and gradients to execute efficiently on GPU.
Trained weights are stored in:
weights/epoch_100/
python evaluate.py
python metrics.py
python interpolate.pyNote: The reconstructions are slightly blurry due to MSE loss. More details in the blog.
Outputs are stored in:
results/
bvaex/
├── model/
│ ├── __init__.py
│ ├── Encoder.py
│ ├── Decoder.py
│ ├── Dense.py
│ ├── ELBO.py
│ ├── Adam.py
│ ├── Activation.py
│ ├── Reshape.py
│ ├── Reparameterize.py
│ └── model.py
├── train.py
├── evaluate.py
├── metrics.py
├── interpolate.py
├── plots.py
├── load_data.py
├── results/
│ ├── reconstruction_grid.png
│ ├── generated_faces.png
│ ├── latent_interpolation.png
│ └── loss_curves.png
└── weights/
└── epoch_100/
Papers I read deeply to implement bvaex:
- Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. ICLR 2014.
- Kingma, D. P., & Welling, M. (2019). An Introduction to Variational Autoencoders. Foundations and Trends in Machine Learning.
- Kingma, D. P., & Ba, J. (2015). Adam: A Method for Stochastic Optimization. ICLR 2015.
- Higgins, I., et al. (2017). β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework. ICLR 2017.
MIT



