-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
116 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Sharpness-aware Quantization for Deep Neural Networks | ||
|
||
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) | ||
|
||
## Recent Update | ||
|
||
**`2021.11.23`**: We release the source code of SAQ. | ||
|
||
## Setup the environments | ||
|
||
1. Clone the repository locally: | ||
|
||
``` | ||
git clone https://github.com/zhuang-group/SAQ | ||
``` | ||
|
||
2. Install pytorch 1.8+, tensorboard and prettytable | ||
|
||
``` | ||
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch | ||
pip install tensorboard | ||
pip install prettytable | ||
``` | ||
|
||
## Data preparation | ||
|
||
### ImageNet | ||
|
||
Download the ImageNet 2012 dataset from [here](http://image-net.org/), and prepare the dataset based on this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4). | ||
|
||
### CIFAR-100 | ||
|
||
Download the CIFAR-100 dataset from [here](https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz). | ||
|
||
After downloading ImageNet and CIFAR-100, the file structure should look like: | ||
|
||
``` | ||
dataset | ||
├── imagenet | ||
├── train | ||
│ ├── class1 | ||
│ │ ├── img1.jpeg | ||
│ │ ├── img2.jpeg | ||
│ │ └── ... | ||
│ ├── class2 | ||
│ │ ├── img3.jpeg | ||
│ │ └── ... | ||
│ └── ... | ||
└── val | ||
├── class1 | ||
│ ├── img4.jpeg | ||
│ ├── img5.jpeg | ||
│ └── ... | ||
├── class2 | ||
│ ├── img6.jpeg | ||
│ └── ... | ||
└── ... | ||
├── cifar100 | ||
├── cifar-100-python | ||
│ ├── meta | ||
│ ├── test | ||
│ ├── train | ||
│ └── ... | ||
└── ... | ||
``` | ||
|
||
|
||
## Training | ||
|
||
### Fixed-precision quantization | ||
|
||
1. Download the pre-trained full-precision models from the [model zoo](https://github.com/zhuang-group/SAQ/wiki/Model-Zoo). | ||
|
||
2. Train low-precision models. | ||
|
||
To train low-precision ResNet-20 on CIFAR-100, run: | ||
|
||
```bash | ||
sh script/train_qsam_cifar_r20.sh | ||
``` | ||
|
||
To train low-precision ResNet-18 on ImageNet, run: | ||
|
||
```bash | ||
sh script/train_qsam_imagenet_r18.sh | ||
``` | ||
|
||
### Mixed-precision quantization | ||
|
||
1. Download the pre-trained full-precision models from the [model zoo](https://github.com/zhuang-group/SAQ/wiki/Model-Zoo). | ||
|
||
2. Train the configuration generator. | ||
|
||
To train the configuration generator of ResNet-20 on CIFAR-100, run: | ||
|
||
```bash | ||
sh script/train_generator_cifar_r20.sh | ||
``` | ||
|
||
To train the configuration generator on ImageNet, run: | ||
|
||
```bash | ||
sh script/train_generator_imagenet_r18.sh | ||
``` | ||
|
||
## License | ||
|
||
This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file. | ||
|
||
## Acknowledgement | ||
|
||
This repository has adopted codes from [SAM](https://github.com/davda54/sam), [ASAM](https://github.com/SamsungLabs/ASAM) and [ESAM](https://github.com/dydjw9/efficient_sam), we thank the authors for their open-sourced code. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python train_sam.py --save_path ./output/cifar100/qresnet20/w4a4/ --data_path XXXXX --dataset cifar100 --lr 0.01 --clip_lr 0.01 --opt_type QSAM_SGD --network qsampreresnet20 --rho 0.4 --pretrained XXXXX --qw 4.0 --qa 4.0 --quan_type LIQ_wn_qsam --experiment_id 01 --seed 01 --gpu 0 --include_aclip True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python train_controller.py --save_path ./output/cifar100/generator/r20/ --data_path path_of_dataset --dataset cifar100 --lr 0.01 --clip_lr 0.01 --opt_type QSAM_SGD --network qsamspreresnet20 --rho 0.4 --pretrained path_of_pretrained_model --qw 3.0 --qa 3.0 --quan_type switchable_LIQ_wn_qsam --gpu 0 --lr_scheduler_type multi_step --n_epochs 100 --loss_lambda 1e-4 --suffix generator_01 --c_lr 5e-4 --entropy_coeff 5e-3 --target_bops 674 --include_aclip True --bits_choice 2,3,4,5 --bit_warmup_epochs 10 --wa_same_bit True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=66630 --use_env train_controller.py --save_path ./output/imagenet/generator/r18/ --data_path path_of_dataset --dataset imagenet100 --lr 0.01 --clip_lr 0.01 --opt_type QSAM_SGD --network qsamsresnet18 --rho 0.3 --pretrained /home/liujing/models/mobilenet_v2-convert.pth --qw 3.0 --qa 3.0 --quan_type switchable_LIQ_wn_qsam --gpu 4,5,6,7 --lr_scheduler_type multi_step --n_epochs 100 --loss_lambda 5e-3 --suffix controller_rho0.3_unshare_include_aclip_tb5.32_multi_step_lr0.01_warmup10 --c_lr 5e-4 --entropy_coeff 5e-3 --target_bops 5.32 --include_aclip True --bits_choice 2,3,4,5 --bit_warmup_epochs 10 --batch_size 64 --val_num 50000 --wa_same_bit True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python -m torch.distributed.launch --nproc_per_node=4 --master_port=65535 --use_env train_sam.py --save_path ./output/imagenet/qresnet18/w4a4/ --data_path path_of_dataset --dataset imagenet --lr 0.02 --clip_lr 0.02 --opt_type QSAM_SGD --network qsamresnet18 --rho 0.3 --pretrained path_of_pretrained_model --qw 4.0 --qa 4.0 --quan_type LIQ_wn_qsam --seed 01 --gpu 0,1,2,3 --include_aclip True --batch_size 128 --n_epochs 90 --lr_scheduler_type cosine --n_threads 8 --experiment_id 01 |