Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: MAE official #221

Merged
merged 151 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
151 commits
Select commit Hold shift + click to select a range
bd2832c
[Feature]: MAE single image pre-training
YuanLiuuuuuu Dec 17, 2021
a4abe91
[Fix]: Fix config
Dec 18, 2021
3bdaee7
[Fix]: Fix dataset link
YuanLiuuuuuu Dec 19, 2021
9fa79de
[Feature]: Add run
YuanLiuuuuuu Dec 19, 2021
ff5eb75
[Refactor]: Delete spot
YuanLiuuuuuu Dec 19, 2021
7635f8d
[Feature]: ignore nohup output file
YuanLiuuuuuu Dec 20, 2021
249c67e
[Feature]: Add auto script to generate run cmd
YuanLiuuuuuu Dec 20, 2021
85456a0
[Refactor]: Refactor mae config file
YuanLiuuuuuu Dec 20, 2021
00a3ccd
[Feature]: sz20 settings
YuanLiuuuuuu Dec 20, 2021
54ef7de
[Feature]: Add auto resume
YuanLiuuuuuu Dec 20, 2021
69ba6cd
[Fix]: Fix lint
YuanLiuuuuuu Dec 20, 2021
df01dc3
[Feature]: Make git ignore txt
YuanLiuuuuuu Dec 20, 2021
a81aedd
[Refactor]: Delete gpus in script
YuanLiuuuuuu Dec 20, 2021
bcf24dc
[Fix]: Make generate_cmd to add --async
YuanLiuuuuuu Dec 20, 2021
e32845a
[Feature]: Initial version of Vit fine-tune
YuanLiuuuuuu Dec 23, 2021
cddc03e
[Fix]: Add 1424 specific settings
YuanLiuuuuuu Dec 20, 2021
d147286
[Fix]: Fix missing file client bug for 1424
YuanLiuuuuuu Dec 21, 2021
d543567
[Feature]: 1424 customized settings
YuanLiuuuuuu Dec 24, 2021
e4fbbac
[Fix]: Make drop in eval to False
YuanLiuuuuuu Dec 25, 2021
f34e1a4
[Feature]: Change the finetune and pre-training settings
YuanLiuuuuuu Dec 27, 2021
0ff72bd
[Feature]: Add debug setting
YuanLiuuuuuu Dec 27, 2021
c5d08c6
[Refactor]: Refactor the model
YuanLiuuuuuu Dec 28, 2021
d07ba7b
[Feature]: Customized settings
YuanLiuuuuuu Dec 29, 2021
cc12ab6
[Feature]: Add A100 settings
YuanLiuuuuuu Dec 31, 2021
35c590a
[Fix]: Change mae to imagenet
YuanLiuuuuuu Dec 31, 2021
0fda414
[Feature]: Change mae pretrain num workers to 32
YuanLiuuuuuu Dec 31, 2021
f4e24f4
[Feature]: Change num workers to 16
YuanLiuuuuuu Jan 4, 2022
3b21c8b
[Feature]: Add A100 setting for pre_release ft version
YuanLiuuuuuu Jan 5, 2022
02eaeb4
[Feature]: Add img_norm_cfg
YuanLiuuuuuu Jan 5, 2022
6e7f077
[Fix]: Fix mae cls test missing logits bug
YuanLiuuuuuu Jan 5, 2022
735b118
[Fix]: Fix mae cls head bias initialize to zero
YuanLiuuuuuu Jan 6, 2022
daf29b5
[Feature]: Rename mae config name
YuanLiuuuuuu Jan 6, 2022
964a46e
[Feature]: Add MAE README.md
YuanLiuuuuuu Jan 7, 2022
53f3cef
[Fix]: Fix lint
YuanLiuuuuuu Jan 7, 2022
18037f5
[Feature]: Fix typo
YuanLiuuuuuu Jan 7, 2022
0cc25ea
[Fix]: Fix typo
YuanLiuuuuuu Jan 7, 2022
54ac07d
[Feature]: Fix invalid link
YuanLiuuuuuu Jan 7, 2022
2250554
[Fix]: Fix finetune config file name
YuanLiuuuuuu Jan 7, 2022
bbb792a
[Feature]: Official pretrain v1
YuanLiuuuuuu Jan 8, 2022
6cddae6
[Feature]: Change log interval to 100
YuanLiuuuuuu Jan 8, 2022
d4ac05e
[Feature]: pretrain 1600 epochs
YuanLiuuuuuu Jan 8, 2022
f65083b
[Fix]: Change encoder num head to 12
YuanLiuuuuuu Jan 10, 2022
0356ac1
[Feature]: Mix precision
YuanLiuuuuuu Jan 10, 2022
951ae49
[Feature]: Add default value to random masking
YuanLiuuuuuu Jan 10, 2022
dec4e12
[Feature]: Official MAE finetune
YuanLiuuuuuu Jan 10, 2022
c67a398
[Feature]: Finetune img per gpu 32
YuanLiuuuuuu Jan 10, 2022
e1509b8
[Feature]: Add multi machine training for lincls
YuanLiuuuuuu Jan 10, 2022
850afeb
[Fix]: Fix lincls master port master addr
YuanLiuuuuuu Jan 10, 2022
8faf83b
[Feature]: Change img per gpu to 128
YuanLiuuuuuu Jan 10, 2022
230e30e
[Feature]: Add linear eval and Refactor
YuanLiuuuuuu Jan 11, 2022
6f000f4
[Fix]: Fix debug mode
YuanLiuuuuuu Jan 11, 2022
ecf499b
[Fix]: Delete MAE dataset in __init__.py
YuanLiuuuuuu Jan 11, 2022
df106df
[Feature]: normalize pixel for mae
YuanLiuuuuuu Jan 11, 2022
a79dbd2
[Fix]: Fix lint
YuanLiuuuuuu Jan 13, 2022
d8cf7f7
[Feature]: LARS for linear eval
YuanLiuuuuuu Jan 13, 2022
f27b98b
[Feature]: Add lars for mae linear eval
YuanLiuuuuuu Jan 13, 2022
f85888e
[Feature]: Change mae linear lars num workers to 32
YuanLiuuuuuu Jan 13, 2022
ce96371
[Feature]: Change mae linear lars num workers to 8
YuanLiuuuuuu Jan 13, 2022
7a458c0
[Feature]: log every 25 iter for mae linear eval lars
YuanLiuuuuuu Jan 13, 2022
c29c969
[Feature]: Add 1600 epoch and 800 epoch pretraining
YuanLiuuuuuu Jan 13, 2022
f38b023
[Fix]: Change linear eval to 902
YuanLiuuuuuu Jan 14, 2022
8b265f5
[Fix]: Add random flip to linear eval
YuanLiuuuuuu Jan 14, 2022
05b5d1a
[Fix]: delete fp16 in mae
YuanLiuuuuuu Jan 17, 2022
1a0c391
[Refactor]: Change backbone to mmcls
YuanLiuuuuuu Jan 18, 2022
4cc9675
[Fix]: Align finetune settings
YuanLiuuuuuu Jan 18, 2022
8e5224f
[Fix]: replace timm trunc_normal with mmcv trunc_normal
YuanLiuuuuuu Jan 18, 2022
3007827
[Fix]: Change finetune layer_decay to 0.65
YuanLiuuuuuu Jan 18, 2022
fa993d4
[Fix]: Delete pretrain last norm when global_pooling
YuanLiuuuuuu Jan 20, 2022
e41873a
[Fix]: set requires_grad of norm1 to False
YuanLiuuuuuu Jan 20, 2022
a93bf4d
[Fix]: delete norm1
YuanLiuuuuuu Jan 20, 2022
e2b5a46
[Fix]: Fix docstring bug
YuanLiuuuuuu Jan 20, 2022
bfcede7
[Fix]: Fix lint
YuanLiuuuuuu Jan 20, 2022
2292097
[Fix]: Add external link
YuanLiuuuuuu Jan 20, 2022
06eb87f
[Fix]: Delete auto_resume and reformat config readme.
YuanLiuuuuuu Jan 20, 2022
92cb93a
[Fix]: Fix pytest bug
YuanLiuuuuuu Jan 20, 2022
eb1f1e7
[Fix]: Fix lint
YuanLiuuuuuu Jan 20, 2022
0e0b329
[Refactor]: Rename filename
YuanLiuuuuuu Jan 21, 2022
a6bcb17
[Feature]: Add docstring
YuanLiuuuuuu Jan 21, 2022
819d757
[Fix]: Rename config file name
YuanLiuuuuuu Jan 24, 2022
4daf8ea
[Fix]: Fix name inconsistency bug
YuanLiuuuuuu Jan 24, 2022
4431e22
[Fix]: Change the default value of persistent_worker in builder to True
YuanLiuuuuuu Jan 24, 2022
3da7da1
[Fix]: Change the default value of CPUS_PER_TASK to 5
YuanLiuuuuuu Jan 24, 2022
740aaa3
[Fix]: Add a blank line to line136 in tools/train.py
YuanLiuuuuuu Jan 24, 2022
ad0a0cd
[Fix]: Fix MAE algorithm docstring format and add paper name and url
YuanLiuuuuuu Jan 24, 2022
32c3034
[Feature]: Add MAE paper name and link, and store mae teaser on github
YuanLiuuuuuu Jan 24, 2022
8546415
[Refactor]: Delete mae.png
YuanLiuuuuuu Jan 24, 2022
431837c
[Fix]: Fix config file name”
YuanLiuuuuuu Jan 24, 2022
9bffe84
[Fix]: Fix name bug
YuanLiuuuuuu Jan 24, 2022
67fbd40
[Refactor]: Change default GPUS to 8
YuanLiuuuuuu Jan 24, 2022
ebbe995
[Fix]: Abandon change to drop_last
YuanLiuuuuuu Jan 24, 2022
7b81385
[Fix]: Fix docstring in mae algorithm
YuanLiuuuuuu Jan 24, 2022
c37c9a3
[Fix]: Fix lint
YuanLiuuuuuu Jan 24, 2022
43f2e8b
[Fix]: Fix lint
YuanLiuuuuuu Jan 24, 2022
f895c79
[Fix]: Fix mae finetune algo type bug
YuanLiuuuuuu Jan 24, 2022
0d8a008
[Feature]: Add unit test for algorithm
YuanLiuuuuuu Jan 24, 2022
145772a
[Feature]: Add unit test for remaining parts
YuanLiuuuuuu Jan 25, 2022
2668560
[Fix]: Fix lint
YuanLiuuuuuu Jan 25, 2022
e1a18a5
[Fix]: Fix typo
YuanLiuuuuuu Jan 25, 2022
967666b
[Fix]: Delete some unnecessary modification in gitignore
YuanLiuuuuuu Jan 26, 2022
9f48a8b
[Feature]: Change finetune setting in mae algo to mixup setting
YuanLiuuuuuu Jan 26, 2022
8c70139
[Fix]: Change norm_pix_loss to norm_pix in pretrain head
YuanLiuuuuuu Jan 26, 2022
f672434
[Fix]: Delete modification in dist_train_linear.sh
YuanLiuuuuuu Jan 26, 2022
e817133
[Refactor]: Delete global pool in mae_cls_vit.py
YuanLiuuuuuu Jan 26, 2022
cd72bed
[Fix]: Change finetune param to mixup in test_mae_classification
YuanLiuuuuuu Jan 26, 2022
aa49693
[Fix]: Change norm_pix_loss to norm_pix of mae_pretrain_head in unit …
YuanLiuuuuuu Jan 26, 2022
6e7f4a4
[Fix]: Change norm_pix_loss to norm_pix in unit test
YuanLiuuuuuu Jan 26, 2022
bebf727
[Refactor]: Create init_weights for mae_finetune_head and mae_linprob…
YuanLiuuuuuu Jan 26, 2022
b3a2401
[Refactor]: Construct 2d sin-cosine position embedding using torch
YuanLiuuuuuu Jan 27, 2022
7b80971
[Refactor]: Using classification and using mixup from mmcls
YuanLiuuuuuu Jan 27, 2022
6d764cc
[Fix]: Fix lint
YuanLiuuuuuu Jan 27, 2022
7664df6
[Fix]: Add False to finetune mae linprobe‘
YuanLiuuuuuu Jan 27, 2022
9ae1536
[Fix]: Set drop_last to False
YuanLiuuuuuu Jan 28, 2022
7643e24
[Fix]: Fix MAE finetune layerwise lr bug
YuanLiuuuuuu Feb 7, 2022
89096de
[Refactor]: Delete redundant MAE when registering MAE
YuanLiuuuuuu Feb 7, 2022
b76ce88
[Refactor]: Split initialize_weights in MAE to submodules
YuanLiuuuuuu Feb 7, 2022
771682d
[Fix]: Change the min_lr of mae pretrain to 0.0
YuanLiuuuuuu Feb 7, 2022
a1f1bc0
[Refactor]: Delete unused _init_weights in mae_cls_vit
YuanLiuuuuuu Feb 8, 2022
ec36e4b
[Refactor]: Change MAE cls vit to a more general name
YuanLiuuuuuu Feb 8, 2022
9dbd6ef
[Feature]: Add Epoch Fix cosine annealing lr updater
YuanLiuuuuuu Feb 8, 2022
b686452
[Fix]: Fix lint
YuanLiuuuuuu Feb 8, 2022
95091a7
[Feature]: Add layer wise lr decay in optimizer constructor
YuanLiuuuuuu Feb 8, 2022
4120bd2
[Fix]: Fix lint
YuanLiuuuuuu Feb 8, 2022
5e00b4e
[Fix]: Fix set layer wise lr decay bug
YuanLiuuuuuu Feb 8, 2022
dca3e1f
[Fix]: Fix UT for MAE
YuanLiuuuuuu Feb 9, 2022
74f33d9
[Fix]: Fix lint
YuanLiuuuuuu Feb 9, 2022
f844b16
[Fix]: update algorithm readme format for MAE
YuanLiuuuuuu Feb 9, 2022
cd36793
[Fix]: Fix isort
YuanLiuuuuuu Feb 9, 2022
97d5410
[Fix]: Add Returns inmae_pretrain_vit
YuanLiuuuuuu Feb 9, 2022
722354a
[Fix]: Change bgr to rgb
YuanLiuuuuuu Feb 9, 2022
f7796a9
[Fix]: Change norm pix to True
YuanLiuuuuuu Feb 9, 2022
102be55
[Fix]: Use cls_token to linear prob
YuanLiuuuuuu Feb 17, 2022
6d06091
[Fix]: Delete mixup.py
YuanLiuuuuuu Feb 23, 2022
e9c5ec3
[Fix]: Fix MAE readme
YuanLiuuuuuu Feb 23, 2022
35b6767
[Feature]: Delete linprobe
YuanLiuuuuuu Feb 23, 2022
78c877b
[Refactor]: Merge MAE head into one file
YuanLiuuuuuu Feb 24, 2022
7d80bdc
[Fix]: Fix lint
YuanLiuuuuuu Feb 24, 2022
1042cd8
[Fix]: rename mae_pretrain_head to mae_head
YuanLiuuuuuu Feb 24, 2022
439834b
[Fix]: Fix import error in __init__.py
YuanLiuuuuuu Feb 24, 2022
880b423
[Feature]: skip MAE algo UT when running on windows
YuanLiuuuuuu Feb 24, 2022
03c6a73
[Fix]: Fix UT bug
YuanLiuuuuuu Feb 24, 2022
7a17ac0
[Feature]: Update model_zoo
YuanLiuuuuuu Feb 24, 2022
92d5b38
[Fix]: Rename MAE pretrain model name
YuanLiuuuuuu Feb 24, 2022
7e6c0e4
[Fix]: Delete mae ft prefix
YuanLiuuuuuu Feb 25, 2022
9e285fc
[Feature]: Change b to base
YuanLiuuuuuu Feb 25, 2022
337de6e
[Refactor]: Change b in MAE pt config to base
YuanLiuuuuuu Feb 25, 2022
e3323b5
[Fix]: Fix typo in docstring
YuanLiuuuuuu Feb 25, 2022
29521cd
[Fix]: Fix name bug
YuanLiuuuuuu Feb 25, 2022
e084a43
[Feature]: Add new constructor for MAE finetune
YuanLiuuuuuu Feb 25, 2022
3f1f8b8
[Fix]: Fix model_zoo link
YuanLiuuuuuu Mar 2, 2022
c8c8560
[Fix]: Skip UT for MAE
YuanLiuuuuuu Mar 2, 2022
cfb3bca
[Fix]: Change fixed channel order to param
YuanLiuuuuuu Mar 2, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions configs/benchmarks/classification/_base_/models/vit-base-p16_ft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model = dict(
type='Classification',
backbone=dict(
type='MIMVisionTransformer',
arch='b',
patch_size=16,
drop_path_rate=0.1,
final_norm=False),
head=dict(
type='MAEFinetuneHead',
num_classes=1000,
embed_dim=768,
label_smooth_val=0.1),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model = dict(
type='Classification',
backbone=dict(
type='MIMVisionTransformer',
arch='b',
patch_size=16,
final_norm=True,
finetune=False),
head=dict(type='MAELinprobeHead', num_classes=1000, embed_dim=768))
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# optimizer
optimizer = dict(type='AdamW', lr=1e-3, betas=(0.9, 0.999), weight_decay=0.05)

# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=0.,
warmup='linear',
warmup_iters=5,
warmup_ratio=1e-4, # cannot be 0
warmup_by_epoch=True)

# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=100)
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
_base_ = [
'../_base_/models/vit-base-p16_ft.py',
'../_base_/datasets/imagenet.py',
'../_base_/schedules/adamw_coslr-100e_in1k.py',
'../_base_/default_runtime.py',
]

# dataset
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(
type='RandomAug',
input_size=224,
color_jitter=None,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
]
test_pipeline = [
dict(type='Resize', size=256, interpolation=3),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg)
]
data = dict(
samples_per_gpu=128,
drop_last=False,
workers_per_gpu=32,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline))

# model
model = dict(backbone=dict(init_cfg=dict()))

# optimizer
optimizer = dict(
lr=1e-3 * 1024 / 256,
paramwise_options={
'norm': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
'pos_embed': dict(weight_decay=0.),
'cls_token': dict(weight_decay=0.)
},
constructor='MAEFtOptimizerConstructor',
layer_decay=0.65)

# learning policy
lr_config = dict(
policy='StepFixCosineAnnealing',
min_lr=1e-6,
warmup='linear',
warmup_iters=5,
warmup_ratio=1e-4,
warmup_by_epoch=True,
by_epoch=False)

# runtime
checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
persistent_workers = True
log_config = dict(
interval=100, hooks=[
dict(type='TextLoggerHook'),
])
30 changes: 30 additions & 0 deletions configs/selfsup/_base_/datasets/imagenet_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# dataset settings
data_source = 'ImageNet'
dataset_type = 'SingleViewDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(
type='RandomResizedCrop', size=224, scale=(0.2, 1.0), interpolation=3),
dict(type='RandomHorizontalFlip')
]

# prefetch
prefetch = False
if not prefetch:
train_pipeline.extend(
[dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg)])

# dataset summary
data = dict(
imgs_per_gpu=128,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/imagenet/train',
ann_file='data/imagenet/meta/train.txt',
),
pipeline=train_pipeline,
prefetch=prefetch))
15 changes: 15 additions & 0 deletions configs/selfsup/_base_/models/mae_vit-base-p16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# model settings
model = dict(
type='MAE',
backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75),
neck=dict(
type='MAEPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=768,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.,
),
head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16))
15 changes: 15 additions & 0 deletions configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# optimizer
optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05)
optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb

# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=0.,
warmup='linear',
warmup_iters=40,
warmup_ratio=1e-4, # cannot be 0
warmup_by_epoch=True)

# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=300)
54 changes: 54 additions & 0 deletions configs/selfsup/mae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# MAE

> [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)

<!-- [ALGORITHM] -->

## Abstract

This paper shows that masked autoencoders (MAE) are
scalable self-supervised learners for computer vision. Our
MAE approach is simple: we mask random patches of the
input image and reconstruct the missing pixels. It is based
on two core designs. First, we develop an asymmetric
encoder-decoder architecture, with an encoder that operates only on the
visible subset of patches (without mask tokens), along with a lightweight
decoder that reconstructs the original image from the latent representation
and mask tokens. Second, we find that masking a high proportion
of the input image, e.g., 75%, yields a nontrivial and
meaningful self-supervisory task. Coupling these two designs enables us to
train large models efficiently and effectively: we accelerate
training (by 3× or more) and improve accuracy. Our scalable approach allows
for learning high-capacity models that generalize well: e.g., a vanilla
ViT-Huge model achieves the best accuracy (87.8%) among
methods that use only ImageNet-1K data. Transfer performance in downstream tasks outperforms supervised pretraining and shows promising scaling behavior.

<div align="center">
<img src="https://user-images.githubusercontent.com/30762564/150733959-2959852a-c7bd-4d3f-911f-3e8d8839fe67.png" width="40%"/>
</div>


## Models and Benchmarks

Here, we report the results of the model, which is pre-trained on ImageNet1K
for 400 epochs, the details are below:



| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download |
| :------: | :-------------: | :---------------: | :-------------------------------------------------: | :---------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| ViT-B/16 | 400 | 83.1 | [config](./mae_vit-b-p16_8xb512-coslr-400e_in1k.py) | [config](../../benchmarks/classification/imagenet/vit-b-p16_ft-8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |


## Citation

```bibtex
@article{He2021MaskedAA,
title={Masked Autoencoders Are Scalable Vision Learners},
author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and
Piotr Doll'ar and Ross B. Girshick},
journal={ArXiv},
year={2021},
volume={abs/2111.06377}
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = 'mae_vit-base-16_8xb512-coslr-400e_in1k.py'

# schedule
runner = dict(max_epochs=1600)
42 changes: 42 additions & 0 deletions configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
_base_ = [
'../_base_/models/mae_vit-base-p16.py',
'../_base_/datasets/imagenet_mae.py',
'../_base_/schedules/adamw_coslr-200e_in1k.py',
'../_base_/default_runtime.py',
]

# dataset
data = dict(samples_per_gpu=512, workers_per_gpu=32)

# optimizer
optimizer = dict(
lr=1.5e-4 * 4096 / 256,
paramwise_options={
'norm': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
'pos_embed': dict(weight_decay=0.),
'mask_token': dict(weight_decay=0.),
'cls_token': dict(weight_decay=0.)
})
optimizer_config = dict()

# learning policy
lr_config = dict(
policy='StepFixCosineAnnealing',
min_lr=0.0,
warmup='linear',
warmup_iters=40,
warmup_ratio=1e-4,
warmup_by_epoch=True,
by_epoch=False)

# schedule
runner = dict(max_epochs=400)

# runtime
checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
persistent_workers = True
log_config = dict(
interval=100, hooks=[
dict(type='TextLoggerHook'),
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = 'mae_vit-base-16_8xb512-coslr-400e_in1k.py'

# schedule
runner = dict(max_epochs=800)
7 changes: 7 additions & 0 deletions docs/en/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ All models and part of benchmark results are recorded below.
| | [simsiam_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220225-2f488143.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220210_195402.log.json) |
| [SwAV](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/README.md) | [swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96.py) | [model](https://download.openmmlab.com/mmselfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96_20220225-0497dd5d.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96_20220211_061131.log.json) |
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |

Remarks:

Expand Down Expand Up @@ -52,6 +53,12 @@ If not specified, we use linear evaluation setting from [MoCo](http://openaccess
| SwAV | [swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96.py) | SwAV paper setting | 70.47 |
| MoCo v3 | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | MoCo v3 paper setting | 73.19 |


### ImageNet Fine-tuning
| Algorithm | Config | Remarks | Top-1 (%) |
| --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | --------- |
| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |

### COCO17 Object Detection

In COCO17 Object detection task, we choose the evluation protocol from [MoCo](http://openaccess.thecvf.com/content_CVPR_2020/papers/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.pdf), with Mask-RCNN architecture, the results below are trained with the same [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/mmdetection/coco/mask_rcnn_r50_fpn_mstrain_1x_coco.py).
Expand Down
6 changes: 6 additions & 0 deletions docs/zh_cn/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
| | [simsiam_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220225-2f488143.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220210_195402.log.json) |
| [SwAV](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/README.md) | [swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96.py) | [model](https://download.openmmlab.com/mmselfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96_20220225-0497dd5d.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96_20220211_061131.log.json) |
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |

备注:

Expand Down Expand Up @@ -52,6 +53,11 @@
| SwAV | [swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96.py) | SwAV 论文设置 | 70.47 |
| MoCo v3 | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | MoCo v3 论文设置 | 73.19 |


### ImageNet 微调
| 算法 | 配置文件 | 备注 | Top-1 (%) |
| ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---- | --------- |
| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
### COCO17 目标检测

在 COCO17 数据集的目标检测任务中,我们选用 [MoCo](http://openaccess.thecvf.com/content_CVPR_2020/papers/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.pdf) 的评估设置,基于 Mask-RCNN 网络架构,下列结果通过同样的 [配置文件](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/mmdetection/coco/mask_rcnn_r50_fpn_mstrain_1x_coco.py) 训练得到。
Expand Down
3 changes: 2 additions & 1 deletion mmselfsup/core/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cosineAnnealing_hook import StepFixCosineAnnealingLrUpdaterHook
from .deepcluster_hook import DeepClusterHook
from .densecl_hook import DenseCLHook
from .momentum_update_hook import MomentumUpdateHook
Expand All @@ -10,5 +11,5 @@
__all__ = [
'MomentumUpdateHook', 'DeepClusterHook', 'DenseCLHook', 'ODCHook',
'DistOptimizerHook', 'GradAccumFp16OptimizerHook', 'SimSiamHook',
'SwAVHook'
'SwAVHook', 'StepFixCosineAnnealingLrUpdaterHook'
]
35 changes: 35 additions & 0 deletions mmselfsup/core/hooks/cosineAnnealing_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import HOOKS
from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
annealing_cos)


@HOOKS.register_module()
class StepFixCosineAnnealingLrUpdaterHook(CosineAnnealingLrUpdaterHook):

def get_lr(self, runner, base_lr):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs

# Delete warmup epochs
if self.warmup is not None:
progress = progress - self.warmup_iters // len(
runner.data_loader)
max_progress = max_progress - self.warmup_iters // len(
runner.data_loader)
else:
progress = runner.iter
max_progress = runner.max_iters

# Delete warmup iters
if self.warmup is not None:
progress = progress - self.warmup_iters
max_progress = max_progress - self.warmup_iters

if self.min_lr_ratio is not None:
target_lr = base_lr * self.min_lr_ratio
else:
target_lr = self.min_lr

return annealing_cos(base_lr, target_lr, progress / max_progress)
6 changes: 5 additions & 1 deletion mmselfsup/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_optimizer
from .constructor import DefaultOptimizerConstructor
from .mae_finetune_constructor import MAEFtOptimizerConstructor
from .optimizers import LARS

__all__ = ['LARS', 'build_optimizer', 'DefaultOptimizerConstructor']
__all__ = [
'LARS', 'build_optimizer', 'DefaultOptimizerConstructor',
'MAEFtOptimizerConstructor'
]
Loading