Makhzani, A., Shlens, J., Jaitly, N., Goodfellow, I., & Frey, B. (2016).* Adversarial Autoencoders. arXiv preprint arXiv:1511.05644. https://arxiv.org/abs/1511.05644
A PyTorch implementation of a semi-supervised Adversarial Autoencoder (AAE), with example experiments on MNIST.
The AAE learns a low-dimensional latent representation split into:
- a categorical part for class labels, and
- a continuous part for “style.”
The repository includes the following code and documents:
- The python implemntation of the adversarial autoencoder can be found here: aae.py
- A jupyer notebook with example usage: experiments.ipynb
- A report covering the reproduction results and additional experiments: AAE Semi-Supervised Report (PDF)
The model was trained on the MNIST dataset of handwritten digits. Below you can find a comparison of images generated using the model when trained semi-supervised and fully supervised.
Trained for 2000 epochs, using 2000 labeled samples:
![]() |
![]() |
We can see variation in the sampled images for each digit class. The disentanglement plot shows how fixing other dimensions and sampling over two of the latent style dimensions on a grid leads to smooth variation across the styles of the digits.
Trained for 500 epochs, with all samples labeled:
![]() |
![]() |
With all labeled samples, the fidelity of generation is higher, but the style variation seems to be slightly less relevant than in the semi-supervised case - the model relies more on labels and less on learning meaningful style representations. Note that with cycling of unlabeled samples for the reconstruction phase, using all labeled samples for 500 epochs means the model actually sees more data than in the 2000 epochs experiment.
Clone the repo:
git clone https://github.com/flatala/adverserial-autoencoders-semi-supervised.git
cd adverserial-autoencoders-semi-supervisedInstall requirements with:
pip install -r requirements.txtKey dependencies:
torch >= 2.6.0torchvision >= 0.21.0numpy >= 2.2.5pandas >= 2.2.3matplotlib >= 3.10.1tensorboard >= 2.19.0scikit-learn == 1.6.1notebook >= 7.4.1, < 8.0.0jupyter >= 1.1.1, < 2.0.0ipykernel >= 6.29.5, < 7.0.0tqdm
from aae import SemiSupervisedAutoEncoderOptions, SemiSupervisedAdversarialAutoencoder
opts = SemiSupervisedAutoEncoderOptions(
input_dim=784,
ae_hidden_dim=1024,
disc_hidden_dim=512,
latent_dim_categorical=10, # 10 classes
latent_dim_style=16, # style vector size
recon_loss_fn=torch.nn.MSELoss(),
init_recon_lr=1e-3,
semi_supervised_loss_fn=torch.nn.CrossEntropyLoss(),
init_semi_sup_lr=1e-3,
init_gen_lr=1e-4,
init_disc_categorical_lr=1e-4,
init_disc_style_lr=1e-4,
use_decoder_sigmoid=True
)
model = SemiSupervisedAdversarialAutoencoder(opts)Use standard PyTorch DataLoaders for MNIST (or your own dataset), splitting into labeled and unlabeled sets.
model.train_mbgd(
train_labeled_loader,
val_loader,
epochs=2000,
result_folder="results_2000_epochs_2000_samples",
prior_std=5.0,
add_gaussian_noise=False,
train_unlabeled_loader=train_unlabeled_loader,
save_interval=100
)This codebase is the part of the code that I implemended for a group project at TU Delft for the Generative Modelling - DSAIT4030 course. The collaborators in the project were:
- Goos, Rowdey (TU Delft)
- Łodziński, Maciej (TU Delft)
- Latała, Franciszek (Me) (TU Delft)
- Page, Henry (TU Delft)
- Savvidi, Danae (TU Delft)



