Comparison of CNN and ViT on a small cultural relic data set.
Developed with the software and tools below.
Table of Contents
The CNN-ViT project compares multiple cnn and transformer-based models including VGG, ResNet, ViT and MobileViT on a small cultrual relic data set with 2 thousand images, each labeled with the material type of the cultural relics. We found that - 1, Without wavelet decomposition, these images are hard to classify by the models; 2, ViT's performance on this small data set, what ever the depth, heads, and dimensions, is unfavourable, with accuracy less 0.3; and 3, MobileViT_xxs can improve accuracy but the performance is still worse than CNN. All these demonstrate the importance of the inducive bais introduced by convolution.
This project implements ViT and MobileViT from scratch, both of with are compatible with graph execution mode in tensorflow.
Note: All the images are preprocessed with wavelet decomposition. On training data set: All the models are trained with the same learning rate, batch size and number of epochs
On test data set:
model | number of parameters | loss | accuracy | weighted accuracy | F1 score | confidence score |
---|---|---|---|---|---|---|
mobilevit_xxs | 9.52e+05 | 0.992 | 0.570 | 0.548 | 0.551 | 0.597 |
mobilevit_xs | 1.93e+06 | 1.360 | 0.312 | 0.250 | 0.119 | 0.313 |
mobilevit_s | 4.94e+06 | 1.360 | 0.312 | 0.250 | 0.119 | 0.314 |
vgg16 | 6.51e+07 | 1.569 | 0.796 | 0.802 | 0.792 | 0.953 |
vgg19 | 7.04e+07 | 1.228 | 0.798 | 0.781 | 0.783 | 0.943 |
resnet18 | 1.12e+07 | 0.843 | 0.659 | 0.645 | 0.645 | 0.695 |
resnet34 | 2.13e+07 | 0.847 | 0.809 | 0.814 | 0.812 | 0.941 |
vit | 3.29e+06 | 1.317 | 0.349 | 0.289 | 0.204 | 0.350 |
└── CNN-ViT/
├── example-data/
│ ├── CQ
│ ├── QTQ
│ ├── TQ
│ └── YQ
├── for-readme/
│ └── losses and accuracies.png
├── models/
│ ├── __init__.py
│ ├── mobileViT.py
│ └── models.py
├── results/
│ ├── results_not_vit,json
│ └── results_vit.csv
├── config.py
├── data.py
├── experiments.py
├── main.py
├── README.md
├── requirements.txt
├── test.py
└── wavelet.py
The data must hold the following structure and is put in the project folder CNN-ViT
if not specified in running:
└── data/
├── classname1
│ ├── image1.jpg
│ ├── image2.jpg
│ └── ...
├── classname2
└── ...
, i.e., the data folder, (in this case, it has a name data
) has child folders with class names as names, each of which contains the images of the correponding class.
The data could be processed with wavelet decomposition. To do this, put the original data in the data_old
folder and run wavelet.py.
The data used in this project is collected from Palace Museum, Taipei, which contains 4 classes with 2000 around images of cultural relics.
.
File / Directory | Summary |
---|---|
config.py | Defines configurations for different architectures of the MobileViT model, allowing customization of parameters such as number of classes, image size, and dropout rates to adapt to various image classification tasks within the CNN-ViT projects architecture. |
data.py | DataLoader in CNN-ViT manages image dataset preprocessing by loading, normalizing, and partitioning data into training and testing sets, supporting image resizing and format adjustments for model compatibility, and including functionality for data shuffling. |
experiments.py | experiments.py studied two group of experients - 1, Classification performance of VGG, ResNet and MobileViT with different architectures on the data set. 2, Classification performave of ViT with different architectures on this small data set. This scirpt handles configuration setup, execution, and result storage. |
main.py | main.py organizes the model training and evaluation pipeline including configuration, data loading, training, and evaluation of various neural network models including MobileViT, VGG, ResNet, and Vision Transformer. |
requirements.txt | Contains the depandencies of the project. |
test.py | A script to test the repo. |
wavelet.py | A script to perform wavelet decomposition on the images. |
example-data | Contains example images. |
for-readme | Contains support files for README.md . |
models | Contains scripts defining modules and models. See details below. |
results | Contains the results. |
models
File | Summary |
---|---|
mobileViT.py | Integrates the MobileViT from scratch, including the rearranging of tensors. Compatible with graph execution mode. |
models.py | Introduces foundational components for building and operating complex neural network models including MLP and Transformer, alongside model architectures like VGG, ResNet, and Vision Transformer (ViT) whcih is also implemented from scratch. |
System Requirements:
- Python:
3.10.4
- Create a virtual environment:
Windows:
py -3.10 -m venv cnn-vit-venv cnn-vit-venv\Scripts\activateLinux:
python3.10 -m venv cnn-vit-venv source cnn-vit-venv/bin/activate
- Clone the repository:
git clone https://github.com/kangchengX/CNN-ViT.git
- Change to the project directory:
cd CNN-ViT
- Install the dependencies:
pip install -r requirements.txt
Put the data folder with the structure described in the above section Data.
Run experiments.py.
python main.py [config_arch] [OPTIONS]
Command Line Arguments:
Argument | Type | Description | Default Value |
---|---|---|---|
Positional | |||
config_arch |
String | Architecture of the model. Value should be mobilevit_xxs , mobilevit_xs , mobilevit_s , vgg16 , vgg19 , resnet50 , resnet101 , or vit . |
N/A |
Option | |||
--image_size |
Integer | Height or width of the input image. | 128 |
--image_channels |
Integer | Channels of the input image. | 3 |
--dropout |
Float | Dropout ratio. | 0.5 |
--vit_patch_size |
Integer | Patch size for the Vision Transformer. | 2 |
--vit_dim |
Integer | (Word) dimension of the Vision Transformer. | 256 |
--vit_depth |
Integer | Number of layers in the Vision Transformer. | 4 |
--vit_num_heads |
Integer | Number of attention heads in the Vision Transformer. | 4 |
--vit_mlp_dim |
Integer | Dimension of the MLP hidden layer in the Vision Transformer. | 512 |
--split_ratio |
Float | Ratio of the training set in the whole dataset. | 0.75 |
--data_folder |
String | Folder containing the data. | data |
--not_shuffle |
Boolean | Not to shuffle the dataset. If the value is Flase , i.e., present, the data will not be suffled. |
True |
--num_epochs |
Integer | Number of epochs. | 200 |
--batch_size |
Integer | Batch size. | 16 |
--learning_rate |
Float | Learning rate. | 1e-6 |
--results_filename |
String | Path to save the results. | results |
Example:
python main.py resnet50 --results_filename results