Skip to content

A PyTorch implementation of CapsNet based on Hinton's Dynamic Routing Between Capsules using PyTorch's Visdom and Ignite.

License

Notifications You must be signed in to change notification settings

mavanb/capsule_network_pytorch

Repository files navigation

Capsule Networks

A PyTorch implementation of CapsNet based on Hinton's Dynamic Routing Between Capsules using PyTorch's Visdom and Ignite.

Requirements

Getting Started

# Make sure a recent Python 3 version is installed 

# clone this repository. 
git clone git@github.com:mavanb/capsnet_pytorch.git

# install the requirements 
pip install -r requirements.txt 

# train the capsnet using the default settings 
python train_capsnet.py

Project Overview

The main modules in this project are:

To handle the PyTorch training process, we use ignite. All supporting modules are in ignite_features.

  • trainer.py contains the abstract Trainer class that adds all commonly used handlers and supports a train, validation and test step. The CapsuleTrainer extends this class and implement the train, valid and test functions.
  • plot_handlers.py handles to make standard visdom plots
  • metric.py custom ignite metrics
  • log_handlers.py all handlers used for logging

The default configuration file are in default.conf. The data is downloaded to the data folder.

Run a new experiment

To run a new experiment.

# make folder in the experiments folder
mkdir experiments/newexp

# copy the default configs 
cp configurations/default.conf  experiments/newexp/

Change the configurations files to the desired settings. Make sure in general.conf the experiment name points to the right experiment:

exp_name = newexp

Some relevant settings:

  • Log the test accuracy

If save_best = True test accuracy on the best validation epoch and the model name are logged to a csv in the experiment folder. Change the filename using score_file_name = best_acc.

  • Change the architecture

The architecture of the capsule layers can be changed in the config. The default architecture is architecture = 32,8;10,16. The layers are seperated by a semi-column. Each layer constist of two numbers seperated by a comma. The number of capsule is the first number, the vector length the second. The primary capsule layer are arranged in a 6x6 grid, so 32 means 6x6x32 = 1152 capsules. Example of an extra layer: architecture = 32,8;14,12;10,16.

  • Change the dataset

dataset = mnist. Project currently support mnist, fashionmnist and cifar10. The train data is split into a train and validation set. Change the size using valid_size = 0.1.

  • Debug mode

If debug = True the dataloader uses only one worker and only a few images are loaded into the dataset.

Using Visdom

Visdom is used to plot and log the metrics. To use visdom, make sure that the general.conf file contains:

start_visdom = True
use_visdom = True

Or set start_visdom to False and start visdom manually:

python -m visdom.server -env_path ./experiments/newexp

During training navigate to http://localhost:8097 to follow the training process. All visdom files are written to the env_path (generally the experiment folder) for later analysis.

About

A PyTorch implementation of CapsNet based on Hinton's Dynamic Routing Between Capsules using PyTorch's Visdom and Ignite.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages