Skip to content

Simple and easy to understand PyTorch implementation of Vision Transformer (ViT) from scratch, with detailed steps. Tested on common datasets like MNIST, CIFAR10, and more.

License

Notifications You must be signed in to change notification settings

s-chh/PyTorch-Scratch-Vision-Transformer-ViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision Transformer from Scratch in PyTorch

Simplified Scratch Pytorch Implementation of Vision Transformer (ViT) with detailed steps (code at model.py)

Overview:

  • The default network is a Scaled-down of the original Vision Transformer (ViT) architecture from the ViT Paper.
  • Has only 200k-800k parameters depending upon the embedding dimension (Original ViT-Base has 86 million).
  • Tested on Common Datasets: MNIST, FashionMNIST, SVHN, CIFAR10, and CIFAR100.
    • Uses 4×4 patch size for creating longer sequences for small image sizes.
  • For using with bigger datasets, increase the model parameters and patch size.
  • Option to switch between PyTorch’s inbuilt transformer layers and implemented layers one to define the ViT.

Usage

Run the following commands to train the model on supported datasets:

# Train on MNIST
python main.py --dataset mnist --epochs 100

# Train on CIFAR10 with custom embedding size
python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128
  • View more commands in scripts.sh.
  • Adjust configurations for datasets, image size, and embedding dimensions as needed.
  • --use_torch_transformer_layers argument switches between PyTorch’s inbuilt transformer layers and the implemented layers (code here).

Datasets and Performance

The model has been tested on multiple datasets with the following results:

Dataset Run Command Test Accuracy
MNIST python main.py --dataset mnist --epochs 100 99.5
FashionMNIST python main.py --dataset fmnist 92.3
SVHN python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 96.2
CIFAR10 python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 86.3 (82.5 w/o RandAug)
CIFAR100 python main.py --dataset cifar100 --n_channels 3 --image_size 32 --embed_dim 128 59.6 (55.8 w/o RandAug)

The following curves show the training and validation accuracy and loss for MNIST.

Accuracy Curve Loss Curve

For the accuracy and loss curves of all other datasets, refer to the outputs folder.


Model Configurations

Below are the key configurations for the Vision Transformer:

Parameter MNIST / FMNIST SVHN / CIFAR
Input Size 1 × 28 × 28 3 × 32 × 32
Patch Size 4 4
Sequence Length 49 64
Embedding Size 64 128
Parameters 210k 820k
Number of Layers 6 6
Number of Heads 4 4
Forward Multiplier 2 2
Dropout 0.1 0.1

About

Simple and easy to understand PyTorch implementation of Vision Transformer (ViT) from scratch, with detailed steps. Tested on common datasets like MNIST, CIFAR10, and more.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published