PyTorch implementation of Stacked Capsule Auto-Encoders [1].
Ported from official implementation with TensorFlow v1. The architecture of model and hyper-parameters are kept same. However, some parts are refactored for ease of use.
Please, open an issue for bugs and inconsistencies with original implementation.
# clone project
git clone https://github.com/bdsaglam/torch-scae
# install project
cd torch-scae
pip install -e .
PyTorch Lightning is used for training.
python -m torch_scae_experiments.mnist.train --batch_size 32 --learning_rate 1e-4
After training for 5 epochs
Fig 1. Rows: original image, bottom-up reconstructions and top-down reconstructions
For a custom model, create a parameter dictionary similar to the one at
torch_scae_experiments.mnist.hparams.model_params
- Kosiorek, A. R., Sabour, S., Teh, Y. W., & Hinton, G. E. (2019). Stacked Capsule Autoencoders. NeurIPS. http://arxiv.org/abs/1906.06818