Skip to content

PyTorch implementation of Stacked Capsule Auto-Encoders

License

Notifications You must be signed in to change notification settings

karayanni/torch-scae

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torch-scae

PyTorch implementation of Stacked Capsule Auto-Encoders [1].

Ported from official implementation with TensorFlow v1. The architecture of model and hyper-parameters are kept same. However, some parts are refactored for ease of use.

Please, open an issue for bugs and inconsistencies with original implementation.


Installation

# clone project   
git clone https://github.com/bdsaglam/torch-scae   

# install project   
cd torch-scae
pip install -e .

Train with MNIST Open In Colab

PyTorch Lightning is used for training.

python -m torch_scae_experiments.mnist.train --batch_size 32 --learning_rate 1e-4

Results

Image reconstructions

After training for 5 epochs

logo

Fig 1. Rows: original image, bottom-up reconstructions and top-down reconstructions

Custom model

For a custom model, create a parameter dictionary similar to the one at torch_scae_experiments.mnist.hparams.model_params

References

  1. Kosiorek, A. R., Sabour, S., Teh, Y. W., & Hinton, G. E. (2019). Stacked Capsule Autoencoders. NeurIPS. http://arxiv.org/abs/1906.06818

About

PyTorch implementation of Stacked Capsule Auto-Encoders

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 74.3%
  • Jupyter Notebook 25.7%