diff --git a/configs/benchmarks/classification/_base_/models/vit-base-p16_ft.py b/configs/benchmarks/classification/_base_/models/vit-base-p16_ft.py
new file mode 100644
index 000000000..87553f71f
--- /dev/null
+++ b/configs/benchmarks/classification/_base_/models/vit-base-p16_ft.py
@@ -0,0 +1,17 @@
+model = dict(
+ type='Classification',
+ backbone=dict(
+ type='MIMVisionTransformer',
+ arch='b',
+ patch_size=16,
+ drop_path_rate=0.1,
+ final_norm=False),
+ head=dict(
+ type='MAEFinetuneHead',
+ num_classes=1000,
+ embed_dim=768,
+ label_smooth_val=0.1),
+ train_cfg=dict(augments=[
+ dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
+ dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
+ ]))
diff --git a/configs/benchmarks/classification/_base_/models/vit-base-p16_linprobe.py b/configs/benchmarks/classification/_base_/models/vit-base-p16_linprobe.py
new file mode 100644
index 000000000..f14212baf
--- /dev/null
+++ b/configs/benchmarks/classification/_base_/models/vit-base-p16_linprobe.py
@@ -0,0 +1,9 @@
+model = dict(
+ type='Classification',
+ backbone=dict(
+ type='MIMVisionTransformer',
+ arch='b',
+ patch_size=16,
+ final_norm=True,
+ finetune=False),
+ head=dict(type='MAELinprobeHead', num_classes=1000, embed_dim=768))
diff --git a/configs/benchmarks/classification/_base_/schedules/adamw_coslr-100e_in1k.py b/configs/benchmarks/classification/_base_/schedules/adamw_coslr-100e_in1k.py
new file mode 100644
index 000000000..ab979139b
--- /dev/null
+++ b/configs/benchmarks/classification/_base_/schedules/adamw_coslr-100e_in1k.py
@@ -0,0 +1,14 @@
+# optimizer
+optimizer = dict(type='AdamW', lr=1e-3, betas=(0.9, 0.999), weight_decay=0.05)
+
+# learning policy
+lr_config = dict(
+ policy='CosineAnnealing',
+ min_lr=0.,
+ warmup='linear',
+ warmup_iters=5,
+ warmup_ratio=1e-4, # cannot be 0
+ warmup_by_epoch=True)
+
+# runtime settings
+runner = dict(type='EpochBasedRunner', max_epochs=100)
diff --git a/configs/benchmarks/classification/imagenet/vit-b-p16_ft-8xb128-coslr-100e_in1k.py b/configs/benchmarks/classification/imagenet/vit-b-p16_ft-8xb128-coslr-100e_in1k.py
new file mode 100644
index 000000000..aead1b430
--- /dev/null
+++ b/configs/benchmarks/classification/imagenet/vit-b-p16_ft-8xb128-coslr-100e_in1k.py
@@ -0,0 +1,67 @@
+_base_ = [
+ '../_base_/models/vit-base-p16_ft.py',
+ '../_base_/datasets/imagenet.py',
+ '../_base_/schedules/adamw_coslr-100e_in1k.py',
+ '../_base_/default_runtime.py',
+]
+
+# dataset
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+train_pipeline = [
+ dict(
+ type='RandomAug',
+ input_size=224,
+ color_jitter=None,
+ auto_augment='rand-m9-mstd0.5-inc1',
+ interpolation='bicubic',
+ re_prob=0.25,
+ re_mode='pixel',
+ re_count=1,
+ mean=(0.485, 0.456, 0.406),
+ std=(0.229, 0.224, 0.225))
+]
+test_pipeline = [
+ dict(type='Resize', size=256, interpolation=3),
+ dict(type='CenterCrop', size=224),
+ dict(type='ToTensor'),
+ dict(type='Normalize', **img_norm_cfg)
+]
+data = dict(
+ samples_per_gpu=128,
+ drop_last=False,
+ workers_per_gpu=32,
+ train=dict(pipeline=train_pipeline),
+ val=dict(pipeline=test_pipeline))
+
+# model
+model = dict(backbone=dict(init_cfg=dict()))
+
+# optimizer
+optimizer = dict(
+ lr=1e-3 * 1024 / 256,
+ paramwise_options={
+ 'norm': dict(weight_decay=0.),
+ 'bias': dict(weight_decay=0.),
+ 'pos_embed': dict(weight_decay=0.),
+ 'cls_token': dict(weight_decay=0.)
+ },
+ constructor='MAEFtOptimizerConstructor',
+ layer_decay=0.65)
+
+# learning policy
+lr_config = dict(
+ policy='StepFixCosineAnnealing',
+ min_lr=1e-6,
+ warmup='linear',
+ warmup_iters=5,
+ warmup_ratio=1e-4,
+ warmup_by_epoch=True,
+ by_epoch=False)
+
+# runtime
+checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
+persistent_workers = True
+log_config = dict(
+ interval=100, hooks=[
+ dict(type='TextLoggerHook'),
+ ])
diff --git a/configs/selfsup/_base_/datasets/imagenet_mae.py b/configs/selfsup/_base_/datasets/imagenet_mae.py
new file mode 100644
index 000000000..939fc1039
--- /dev/null
+++ b/configs/selfsup/_base_/datasets/imagenet_mae.py
@@ -0,0 +1,30 @@
+# dataset settings
+data_source = 'ImageNet'
+dataset_type = 'SingleViewDataset'
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+train_pipeline = [
+ dict(
+ type='RandomResizedCrop', size=224, scale=(0.2, 1.0), interpolation=3),
+ dict(type='RandomHorizontalFlip')
+]
+
+# prefetch
+prefetch = False
+if not prefetch:
+ train_pipeline.extend(
+ [dict(type='ToTensor'),
+ dict(type='Normalize', **img_norm_cfg)])
+
+# dataset summary
+data = dict(
+ imgs_per_gpu=128,
+ workers_per_gpu=8,
+ train=dict(
+ type=dataset_type,
+ data_source=dict(
+ type=data_source,
+ data_prefix='data/imagenet/train',
+ ann_file='data/imagenet/meta/train.txt',
+ ),
+ pipeline=train_pipeline,
+ prefetch=prefetch))
diff --git a/configs/selfsup/_base_/models/mae_vit-base-p16.py b/configs/selfsup/_base_/models/mae_vit-base-p16.py
new file mode 100644
index 000000000..82db9942c
--- /dev/null
+++ b/configs/selfsup/_base_/models/mae_vit-base-p16.py
@@ -0,0 +1,15 @@
+# model settings
+model = dict(
+ type='MAE',
+ backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75),
+ neck=dict(
+ type='MAEPretrainDecoder',
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ decoder_embed_dim=512,
+ decoder_depth=8,
+ decoder_num_heads=16,
+ mlp_ratio=4.,
+ ),
+ head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16))
diff --git a/configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py b/configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py
new file mode 100644
index 000000000..16adc742b
--- /dev/null
+++ b/configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py
@@ -0,0 +1,15 @@
+# optimizer
+optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05)
+optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb
+
+# learning policy
+lr_config = dict(
+ policy='CosineAnnealing',
+ min_lr=0.,
+ warmup='linear',
+ warmup_iters=40,
+ warmup_ratio=1e-4, # cannot be 0
+ warmup_by_epoch=True)
+
+# runtime settings
+runner = dict(type='EpochBasedRunner', max_epochs=300)
diff --git a/configs/selfsup/mae/README.md b/configs/selfsup/mae/README.md
new file mode 100644
index 000000000..2001ae5ae
--- /dev/null
+++ b/configs/selfsup/mae/README.md
@@ -0,0 +1,54 @@
+# MAE
+
+> [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
+
+
+
+## Abstract
+
+This paper shows that masked autoencoders (MAE) are
+scalable self-supervised learners for computer vision. Our
+MAE approach is simple: we mask random patches of the
+input image and reconstruct the missing pixels. It is based
+on two core designs. First, we develop an asymmetric
+encoder-decoder architecture, with an encoder that operates only on the
+visible subset of patches (without mask tokens), along with a lightweight
+decoder that reconstructs the original image from the latent representation
+and mask tokens. Second, we find that masking a high proportion
+of the input image, e.g., 75%, yields a nontrivial and
+meaningful self-supervisory task. Coupling these two designs enables us to
+train large models efficiently and effectively: we accelerate
+training (by 3× or more) and improve accuracy. Our scalable approach allows
+for learning high-capacity models that generalize well: e.g., a vanilla
+ViT-Huge model achieves the best accuracy (87.8%) among
+methods that use only ImageNet-1K data. Transfer performance in downstream tasks outperforms supervised pretraining and shows promising scaling behavior.
+
+
+
+
+
+
+## Models and Benchmarks
+
+Here, we report the results of the model, which is pre-trained on ImageNet1K
+for 400 epochs, the details are below:
+
+
+
+| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download |
+| :------: | :-------------: | :---------------: | :-------------------------------------------------: | :---------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| ViT-B/16 | 400 | 83.1 | [config](./mae_vit-b-p16_8xb512-coslr-400e_in1k.py) | [config](../../benchmarks/classification/imagenet/vit-b-p16_ft-8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) | [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
+
+
+## Citation
+
+```bibtex
+@article{He2021MaskedAA,
+ title={Masked Autoencoders Are Scalable Vision Learners},
+ author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and
+ Piotr Doll'ar and Ross B. Girshick},
+ journal={ArXiv},
+ year={2021},
+ volume={abs/2111.06377}
+}
+```
diff --git a/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-1600e_in1k.py b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-1600e_in1k.py
new file mode 100644
index 000000000..b783e74fe
--- /dev/null
+++ b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-1600e_in1k.py
@@ -0,0 +1,4 @@
+_base_ = 'mae_vit-base-16_8xb512-coslr-400e_in1k.py'
+
+# schedule
+runner = dict(max_epochs=1600)
diff --git a/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py
new file mode 100644
index 000000000..33bfcb2bd
--- /dev/null
+++ b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py
@@ -0,0 +1,42 @@
+_base_ = [
+ '../_base_/models/mae_vit-base-p16.py',
+ '../_base_/datasets/imagenet_mae.py',
+ '../_base_/schedules/adamw_coslr-200e_in1k.py',
+ '../_base_/default_runtime.py',
+]
+
+# dataset
+data = dict(samples_per_gpu=512, workers_per_gpu=32)
+
+# optimizer
+optimizer = dict(
+ lr=1.5e-4 * 4096 / 256,
+ paramwise_options={
+ 'norm': dict(weight_decay=0.),
+ 'bias': dict(weight_decay=0.),
+ 'pos_embed': dict(weight_decay=0.),
+ 'mask_token': dict(weight_decay=0.),
+ 'cls_token': dict(weight_decay=0.)
+ })
+optimizer_config = dict()
+
+# learning policy
+lr_config = dict(
+ policy='StepFixCosineAnnealing',
+ min_lr=0.0,
+ warmup='linear',
+ warmup_iters=40,
+ warmup_ratio=1e-4,
+ warmup_by_epoch=True,
+ by_epoch=False)
+
+# schedule
+runner = dict(max_epochs=400)
+
+# runtime
+checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
+persistent_workers = True
+log_config = dict(
+ interval=100, hooks=[
+ dict(type='TextLoggerHook'),
+ ])
diff --git a/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-800e_in1k.py b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-800e_in1k.py
new file mode 100644
index 000000000..8dc9b3e74
--- /dev/null
+++ b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-800e_in1k.py
@@ -0,0 +1,4 @@
+_base_ = 'mae_vit-base-16_8xb512-coslr-400e_in1k.py'
+
+# schedule
+runner = dict(max_epochs=800)
diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md
index b10ac33a1..1349b25c0 100644
--- a/docs/en/model_zoo.md
+++ b/docs/en/model_zoo.md
@@ -20,6 +20,7 @@ All models and part of benchmark results are recorded below.
| | [simsiam_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220225-2f488143.pth) | [log](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220210_195402.log.json) |
| [SwAV](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/README.md) | [swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96.py) | [model](https://download.openmmlab.com/mmselfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96_20220225-0497dd5d.pth) | [log](https://download.openmmlab.com/mmselfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96_20220211_061131.log.json) |
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) | [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
+| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) | [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
Remarks:
@@ -52,6 +53,12 @@ If not specified, we use linear evaluation setting from [MoCo](http://openaccess
| SwAV | [swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96.py) | SwAV paper setting | 70.47 |
| MoCo v3 | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | MoCo v3 paper setting | 73.19 |
+
+### ImageNet Fine-tuning
+| Algorithm | Config | Remarks | Top-1 (%) |
+| --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | --------- |
+| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
+
### COCO17 Object Detection
In COCO17 Object detection task, we choose the evluation protocol from [MoCo](http://openaccess.thecvf.com/content_CVPR_2020/papers/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.pdf), with Mask-RCNN architecture, the results below are trained with the same [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/mmdetection/coco/mask_rcnn_r50_fpn_mstrain_1x_coco.py).
diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md
index ba2a1143b..3751fe651 100644
--- a/docs/zh_cn/model_zoo.md
+++ b/docs/zh_cn/model_zoo.md
@@ -20,6 +20,7 @@
| | [simsiam_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220225-2f488143.pth) | [log](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220210_195402.log.json) |
| [SwAV](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/README.md) | [swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96.py) | [model](https://download.openmmlab.com/mmselfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96_20220225-0497dd5d.pth) | [log](https://download.openmmlab.com/mmselfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96_20220211_061131.log.json) |
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) | [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
+| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) | [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
备注:
@@ -52,6 +53,11 @@
| SwAV | [swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/swav/swav_resnet50_8xb32-mcrop-2-6-coslr-200e_in1k-224-96.py) | SwAV 论文设置 | 70.47 |
| MoCo v3 | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | MoCo v3 论文设置 | 73.19 |
+
+### ImageNet 微调
+| 算法 | 配置文件 | 备注 | Top-1 (%) |
+| ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---- | --------- |
+| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
### COCO17 目标检测
在 COCO17 数据集的目标检测任务中,我们选用 [MoCo](http://openaccess.thecvf.com/content_CVPR_2020/papers/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.pdf) 的评估设置,基于 Mask-RCNN 网络架构,下列结果通过同样的 [配置文件](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/mmdetection/coco/mask_rcnn_r50_fpn_mstrain_1x_coco.py) 训练得到。
diff --git a/mmselfsup/core/hooks/__init__.py b/mmselfsup/core/hooks/__init__.py
index 288b3e486..3c769b66a 100644
--- a/mmselfsup/core/hooks/__init__.py
+++ b/mmselfsup/core/hooks/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .cosineAnnealing_hook import StepFixCosineAnnealingLrUpdaterHook
from .deepcluster_hook import DeepClusterHook
from .densecl_hook import DenseCLHook
from .momentum_update_hook import MomentumUpdateHook
@@ -10,5 +11,5 @@
__all__ = [
'MomentumUpdateHook', 'DeepClusterHook', 'DenseCLHook', 'ODCHook',
'DistOptimizerHook', 'GradAccumFp16OptimizerHook', 'SimSiamHook',
- 'SwAVHook'
+ 'SwAVHook', 'StepFixCosineAnnealingLrUpdaterHook'
]
diff --git a/mmselfsup/core/hooks/cosineAnnealing_hook.py b/mmselfsup/core/hooks/cosineAnnealing_hook.py
new file mode 100644
index 000000000..e55866058
--- /dev/null
+++ b/mmselfsup/core/hooks/cosineAnnealing_hook.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.runner import HOOKS
+from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
+ annealing_cos)
+
+
+@HOOKS.register_module()
+class StepFixCosineAnnealingLrUpdaterHook(CosineAnnealingLrUpdaterHook):
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+
+ # Delete warmup epochs
+ if self.warmup is not None:
+ progress = progress - self.warmup_iters // len(
+ runner.data_loader)
+ max_progress = max_progress - self.warmup_iters // len(
+ runner.data_loader)
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+
+ # Delete warmup iters
+ if self.warmup is not None:
+ progress = progress - self.warmup_iters
+ max_progress = max_progress - self.warmup_iters
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
diff --git a/mmselfsup/core/optimizer/__init__.py b/mmselfsup/core/optimizer/__init__.py
index 9f25370c2..3378fa0fa 100644
--- a/mmselfsup/core/optimizer/__init__.py
+++ b/mmselfsup/core/optimizer/__init__.py
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_optimizer
from .constructor import DefaultOptimizerConstructor
+from .mae_finetune_constructor import MAEFtOptimizerConstructor
from .optimizers import LARS
-__all__ = ['LARS', 'build_optimizer', 'DefaultOptimizerConstructor']
+__all__ = [
+ 'LARS', 'build_optimizer', 'DefaultOptimizerConstructor',
+ 'MAEFtOptimizerConstructor'
+]
diff --git a/mmselfsup/core/optimizer/constructor.py b/mmselfsup/core/optimizer/constructor.py
index 42d3b1a36..2010f2300 100644
--- a/mmselfsup/core/optimizer/constructor.py
+++ b/mmselfsup/core/optimizer/constructor.py
@@ -22,7 +22,7 @@ class DefaultOptimizerConstructor:
- any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options.
- Defaults to None
+ Defaults to None.
Example 1:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
@@ -47,6 +47,7 @@ def __call__(self, model):
model = model.module
optimizer_cfg = self.optimizer_cfg.copy()
paramwise_options = self.paramwise_cfg
+
# if no paramwise option is specified, just use the global setting
if paramwise_options is None:
optimizer_cfg['params'] = model.parameters()
diff --git a/mmselfsup/core/optimizer/mae_finetune_constructor.py b/mmselfsup/core/optimizer/mae_finetune_constructor.py
new file mode 100644
index 000000000..e0674fd59
--- /dev/null
+++ b/mmselfsup/core/optimizer/mae_finetune_constructor.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import re
+
+import torch.distributed as dist
+from mmcv.runner.optimizer.builder import OPTIMIZER_BUILDERS, OPTIMIZERS
+from mmcv.utils import build_from_cfg, print_log
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class MAEFtOptimizerConstructor:
+ """Rewrote default constructor for optimizers. By default each parameter
+ share the same optimizer settings, and we provide an argument
+ ``paramwise_cfg`` to specify parameter-wise settings and set layer-wise
+ learning rate. It is a dict and may contain the following fields:
+
+ Args:
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
+ optimizer_cfg (dict): The config dict of the optimizer.
+ Positional fields are
+ - `type`: class name of the optimizer.
+ Optional fields are
+ - any arguments of the corresponding optimizer type, e.g.,
+ lr, weight_decay, momentum, etc.
+ paramwise_cfg (dict, optional): Parameter-wise options.
+ Defaults to None
+ layer_decay (float): base value for layer wise learning rate decay.
+ Defaults to 0.0
+
+ Example 1:
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
+ >>> weight_decay=0.0001)
+ >>> paramwise_cfg = dict('bias': dict(weight_decay=0., \
+ lars_exclude=True))
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+ """
+
+ def __init__(self, optimizer_cfg, paramwise_cfg=None):
+ if not isinstance(optimizer_cfg, dict):
+ raise TypeError('optimizer_cfg should be a dict',
+ f'but got {type(optimizer_cfg)}')
+ self.optimizer_cfg = optimizer_cfg
+ self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
+ self.layer_decay = self.optimizer_cfg.pop('layer_decay', 0.0)
+
+ def __call__(self, model):
+ if hasattr(model, 'module'):
+ model = model.module
+ optimizer_cfg = self.optimizer_cfg.copy()
+ paramwise_options = self.paramwise_cfg
+
+ # generate layer-wise lr decay
+ if self.layer_decay > 0:
+ self._generate_layer_wise_lr_decay(model, paramwise_options)
+
+ # if no paramwise option is specified, just use the global setting
+ if paramwise_options is None:
+ optimizer_cfg['params'] = model.parameters()
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
+ else:
+ assert isinstance(paramwise_options, dict)
+ params = []
+ for name, param in model.named_parameters():
+ param_group = {'params': [param]}
+ if not param.requires_grad:
+ params.append(param_group)
+ continue
+
+ for regexp, options in paramwise_options.items():
+ if re.search(regexp, name):
+ for key, value in options.items():
+ if key.endswith('_mult'): # is a multiplier
+ key = key[:-5]
+ assert key in optimizer_cfg, \
+ f'{key} not in optimizer_cfg'
+ value = optimizer_cfg[key] * value
+ param_group[key] = value
+ if not dist.is_initialized() or \
+ dist.get_rank() == 0:
+ print_log(f'paramwise_options -- \
+ {name}: {key}={value}')
+
+ # otherwise use the global settings
+ params.append(param_group)
+
+ optimizer_cfg['params'] = params
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
+
+ def _generate_layer_wise_lr_decay(self, model, paramwise_options):
+ """Currently, we follow the same layer-wise lr decay schedule as
+ MAE."""
+ num_layers = len(model.backbone.layers) + 1
+ layer_scales = list(self.layer_decay**(num_layers - i)
+ for i in range(num_layers + 1))
+
+ if 'pos_embed' in paramwise_options:
+ paramwise_options['pos_embed'].update(
+ dict(lr_mult=layer_scales[0]))
+ else:
+ paramwise_options['pos_embed'] = dict(lr_mult=layer_scales[0])
+
+ if 'cls_token' in paramwise_options:
+ paramwise_options['cls_token'].update(
+ dict(lr_mult=layer_scales[0]))
+ else:
+ paramwise_options['cls_token'] = dict(lr_mult=layer_scales[0])
+
+ if 'patch_embed' in paramwise_options:
+ paramwise_options['patch_embed'].update(
+ dict(lr_mult=layer_scales[0]))
+ else:
+ paramwise_options['patch_embed'] = dict(lr_mult=layer_scales[0])
+
+ for i in range(num_layers - 1):
+ paramwise_options[f'backbone\\.layers\\.{i}\\.'] = dict(
+ lr_mult=layer_scales[i + 1])
diff --git a/mmselfsup/datasets/builder.py b/mmselfsup/datasets/builder.py
index 4ec8cf239..512ed46b0 100644
--- a/mmselfsup/datasets/builder.py
+++ b/mmselfsup/datasets/builder.py
@@ -131,7 +131,6 @@ def build_dataloader(dataset,
img_norm_cfg = kwargs.pop('img_norm_cfg')
else:
prefetch = False
-
data_loader = DataLoader(
dataset,
batch_size=batch_size,
diff --git a/mmselfsup/datasets/data_sources/base.py b/mmselfsup/datasets/data_sources/base.py
index 03d74efdc..6de429ffe 100644
--- a/mmselfsup/datasets/data_sources/base.py
+++ b/mmselfsup/datasets/data_sources/base.py
@@ -102,6 +102,9 @@ def get_img(self, idx):
else:
img = self.data_infos[idx]['img']
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(
+ img_bytes, flag=self.color_type, channel_order=self.channel_order)
img = img.astype(np.uint8)
return Image.fromarray(img)
diff --git a/mmselfsup/datasets/pipelines/__init__.py b/mmselfsup/datasets/pipelines/__init__.py
index 215c5c42e..ec9c94446 100644
--- a/mmselfsup/datasets/pipelines/__init__.py
+++ b/mmselfsup/datasets/pipelines/__init__.py
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from .transforms import (GaussianBlur, Lighting, RandomAppliedTrans,
+from .transforms import (GaussianBlur, Lighting, RandomAppliedTrans, RandomAug,
Solarization)
-__all__ = ['GaussianBlur', 'Lighting', 'RandomAppliedTrans', 'Solarization']
+__all__ = [
+ 'GaussianBlur', 'Lighting', 'RandomAppliedTrans', 'Solarization',
+ 'RandomAug'
+]
diff --git a/mmselfsup/datasets/pipelines/transforms.py b/mmselfsup/datasets/pipelines/transforms.py
index 517c86072..26c400b9a 100644
--- a/mmselfsup/datasets/pipelines/transforms.py
+++ b/mmselfsup/datasets/pipelines/transforms.py
@@ -5,6 +5,7 @@
import torch
from mmcv.utils import build_from_cfg
from PIL import Image, ImageFilter
+from timm.data import create_transform
from torchvision import transforms as _transforms
from ..builder import PIPELINES
@@ -16,6 +17,48 @@
PIPELINES.register_module(m[1])
+@PIPELINES.register_module()
+class RandomAug(object):
+ """RandAugment data augmentation method based on
+ `"RandAugment: Practical automated data augmentation
+ with a reduced search space"
+ `_.
+
+ This code is borrowed from
+ """
+
+ def __init__(self,
+ input_size=None,
+ color_jitter=None,
+ auto_augment=None,
+ interpolation=None,
+ re_prob=None,
+ re_mode=None,
+ re_count=None,
+ mean=None,
+ std=None):
+
+ self.trans = create_transform(
+ input_size=input_size,
+ is_training=True,
+ color_jitter=color_jitter,
+ auto_augment=auto_augment,
+ interpolation=interpolation,
+ re_prob=re_prob,
+ re_mode=re_mode,
+ re_count=re_count,
+ mean=mean,
+ std=std,
+ )
+
+ def __call__(self, img):
+ return self.trans(img)
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
@PIPELINES.register_module()
class RandomAppliedTrans(object):
"""Randomly applied transformations.
diff --git a/mmselfsup/models/algorithms/__init__.py b/mmselfsup/models/algorithms/__init__.py
index c2c75093f..fbe2f7620 100644
--- a/mmselfsup/models/algorithms/__init__.py
+++ b/mmselfsup/models/algorithms/__init__.py
@@ -4,6 +4,7 @@
from .classification import Classification
from .deepcluster import DeepCluster
from .densecl import DenseCL
+from .mae import MAE
from .moco import MoCo
from .mocov3 import MoCoV3
from .npid import NPID
@@ -16,6 +17,6 @@
__all__ = [
'BaseModel', 'BYOL', 'Classification', 'DeepCluster', 'DenseCL', 'MoCo',
- 'MoCoV3', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR',
- 'SimSiam', 'SwAV'
+ 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam', 'SwAV',
+ 'MAE', 'MoCoV3'
]
diff --git a/mmselfsup/models/algorithms/classification.py b/mmselfsup/models/algorithms/classification.py
index d1395ee94..5bcc27499 100644
--- a/mmselfsup/models/algorithms/classification.py
+++ b/mmselfsup/models/algorithms/classification.py
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from mmcls.models.utils import Augments
+
from ..builder import ALGORITHMS, build_backbone, build_head
from ..utils import Sobel
from .base import BaseModel
@@ -16,7 +18,12 @@ class Classification(BaseModel):
Defaults to None.
"""
- def __init__(self, backbone, with_sobel=False, head=None, init_cfg=None):
+ def __init__(self,
+ backbone,
+ with_sobel=False,
+ head=None,
+ train_cfg=None,
+ init_cfg=None):
super(Classification, self).__init__(init_cfg)
self.with_sobel = with_sobel
if with_sobel:
@@ -25,6 +32,11 @@ def __init__(self, backbone, with_sobel=False, head=None, init_cfg=None):
assert head is not None
self.head = build_head(head)
+ self.augments = None
+ if train_cfg is not None:
+ augments_cfg = train_cfg.get('augments', None)
+ self.augments = Augments(augments_cfg)
+
def extract_feat(self, img):
"""Function to extract features from backbone.
@@ -52,6 +64,8 @@ def forward_train(self, img, label, **kwargs):
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
+ if self.augments is not None:
+ img, label = self.augments(img, label)
x = self.extract_feat(img)
outs = self.head(x)
loss_inputs = (outs, label)
diff --git a/mmselfsup/models/algorithms/mae.py b/mmselfsup/models/algorithms/mae.py
new file mode 100644
index 000000000..a51e116f6
--- /dev/null
+++ b/mmselfsup/models/algorithms/mae.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
+from .base import BaseModel
+
+
+@ALGORITHMS.register_module()
+class MAE(BaseModel):
+ """MAE.
+
+ Implementation of `Masked Autoencoders Are Scalable Vision Learners
+ `_.
+ Args:
+ backbone (dict): Config dict for encoder. Defaults to None.
+ neck (dict): Config dict for encoder. Defaults to None.
+ head (dict): Config dict for loss functions. Defaults to None.
+ init_cfg (dict): Config dict for weight initialization.
+ Defaults to None.
+ """
+
+ def __init__(self, backbone=None, neck=None, head=None, init_cfg=None):
+ super(MAE, self).__init__(init_cfg)
+ assert backbone is not None
+ self.backbone = build_backbone(backbone)
+ assert neck is not None
+ self.neck = build_neck(neck)
+ self.neck.num_patches = self.backbone.patch_embed.num_patches
+ assert head is not None
+ self.head = build_head(head)
+
+ def init_weights(self):
+ super(MAE, self).init_weights()
+
+ def extract_feat(self, img):
+ """Function to extract features from backbone.
+
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+
+ Returns:
+ tuple[Tensor]: backbone outputs.
+ """
+ return self.backbone(img)
+
+ def forward_train(self, img, **kwargs):
+ """Forward computation during training.
+
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ kwargs: Any keyword arguments to be used to forward.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ latent, mask, ids_restore = self.backbone(img)
+ pred = self.neck(latent, ids_restore)
+ losses = self.head(img, pred, mask)
+
+ return losses
diff --git a/mmselfsup/models/backbones/__init__.py b/mmselfsup/models/backbones/__init__.py
index 32d0ed3c1..b0ae1942f 100644
--- a/mmselfsup/models/backbones/__init__.py
+++ b/mmselfsup/models/backbones/__init__.py
@@ -1,6 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .mae_pretrain_vit import MAEViT
+from .mim_cls_vit import MIMVisionTransformer
from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt
from .vision_transformer import VisionTransformer
-__all__ = ['ResNet', 'ResNetV1d', 'ResNeXt', 'VisionTransformer']
+__all__ = [
+ 'ResNet', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MIMVisionTransformer',
+ 'VisionTransformer'
+]
diff --git a/mmselfsup/models/backbones/mae_pretrain_vit.py b/mmselfsup/models/backbones/mae_pretrain_vit.py
new file mode 100644
index 000000000..75b7d8a10
--- /dev/null
+++ b/mmselfsup/models/backbones/mae_pretrain_vit.py
@@ -0,0 +1,156 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcls.models import VisionTransformer
+from torch import nn
+
+from ..builder import BACKBONES
+from ..utils import build_2d_sincos_position_embedding
+
+
+@BACKBONES.register_module()
+class MAEViT(VisionTransformer):
+ """Vision Transformer for MAE pre-training.
+
+ A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers
+ for Image Recognition at Scale `_
+
+ Args:
+ arch (str | dict): Vision Transformer architecture
+ Default: 'b'
+ img_size (int | tuple): Input image size
+ patch_size (int | tuple): The patch size
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Defaults to True.
+ output_cls_token (bool): Whether output the cls_token. If set True,
+ `with_cls_token` must be True. Defaults to True.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Defaults to "bicubic".
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
+ layer_cfgs (Sequence | dict): Configs of each transformer layer in
+ encoder. Defaults to an empty dict.
+ mask_ratio (bool): The ratio of total number of patches to be masked.
+ Defaults to 0.75.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ arch='b',
+ img_size=224,
+ patch_size=16,
+ out_indices=-1,
+ drop_rate=0,
+ drop_path_rate=0,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ final_norm=True,
+ output_cls_token=True,
+ interpolate_mode='bicubic',
+ patch_cfg=dict(),
+ layer_cfgs=dict(),
+ mask_ratio=0.75,
+ init_cfg=None):
+ super().__init__(arch, img_size, patch_size, out_indices, drop_rate,
+ drop_path_rate, norm_cfg, final_norm,
+ output_cls_token, interpolate_mode, patch_cfg,
+ layer_cfgs, init_cfg)
+
+ self.pos_embed.requires_grad = False
+ self.mask_ratio = mask_ratio
+
+ def init_weights(self):
+ super(MAEViT, self).init_weights()
+ if not (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # initialize position embedding in backbone
+ pos_embed = build_2d_sincos_position_embedding(
+ int(self.patch_embed.num_patches**.5),
+ self.pos_embed.shape[-1],
+ cls_token=True)
+ self.pos_embed.data.copy_(pos_embed.float())
+
+ w = self.patch_embed.projection.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ torch.nn.init.normal_(self.cls_token, std=.02)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+
+ if isinstance(m, nn.Linear):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def random_masking(self, x, mask_ratio=0.75):
+ """Generate the mask for MAE Pre-training.
+
+ Args:
+ x (torch.tensor): Image with data augmentation applied.
+ mask_ratio (float): The mask ratio of total patches.
+ Defaults to 0.75.
+
+ Returns:
+ tuple[Tensor, Tensor, Tensor]: masked image, mask and the ids
+ to restore original image.
+
+ - x_masked (Tensor): masked image.
+ - mask (Tensor): mask used to mask image.
+ - ids_restore (Tensor): ids to restore original image.
+ """
+ N, L, D = x.shape # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(
+ noise, dim=1) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(
+ x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore
+
+ def forward(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ # add pos embed w/o cls token
+ x = x + self.pos_embed[:, 1:, :]
+
+ # masking: length -> length * mask_ratio
+ x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
+
+ # append cls token
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_tokens = cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = self.drop_after_pos(x)
+
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+
+ if i == len(self.layers) - 1 and self.final_norm:
+ x = self.norm1(x)
+
+ return (x, mask, ids_restore)
diff --git a/mmselfsup/models/backbones/mim_cls_vit.py b/mmselfsup/models/backbones/mim_cls_vit.py
new file mode 100644
index 000000000..b29807cee
--- /dev/null
+++ b/mmselfsup/models/backbones/mim_cls_vit.py
@@ -0,0 +1,103 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcls.models import VisionTransformer
+from mmcv.cnn import build_norm_layer
+
+from ..builder import BACKBONES
+
+
+@BACKBONES.register_module()
+class MIMVisionTransformer(VisionTransformer):
+ """Vision Transformer for MIM-style model (Mask Image Modeling)
+ classification (fine-tuning or linear probe).
+
+ A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
+ for Image Recognition at Scale `_
+
+ Args:
+ arch (str | dict): Vision Transformer architecture
+ Default: 'b'
+ img_size (int | tuple): Input image size
+ patch_size (int | tuple): The patch size
+ out_indices (Sequence | int): Output from which stages.
+ Defaults to -1, means the last stage.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Defaults to True.
+ output_cls_token (bool): Whether output the cls_token. If set True,
+ `with_cls_token` must be True. Defaults to True.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Defaults to "bicubic".
+ patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
+ layer_cfgs (Sequence | dict): Configs of each transformer layer in
+ encoder. Defaults to an empty dict.
+ finetune (bool): Whether or not do fine-tuning. Defaults to True.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ arch='b',
+ img_size=224,
+ patch_size=16,
+ out_indices=-1,
+ drop_rate=0,
+ drop_path_rate=0,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ final_norm=True,
+ output_cls_token=True,
+ interpolate_mode='bicubic',
+ patch_cfg=dict(),
+ layer_cfgs=dict(),
+ finetune=True,
+ init_cfg=None):
+ super().__init__(arch, img_size, patch_size, out_indices, drop_rate,
+ drop_path_rate, norm_cfg, final_norm,
+ output_cls_token, interpolate_mode, patch_cfg,
+ layer_cfgs, init_cfg)
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ if not self.final_norm:
+ _, self.fc_norm = build_norm_layer(
+ norm_cfg, self.embed_dims, postfix=1)
+
+ self.finetune = finetune
+ if not self.finetune:
+ self._freeze_stages()
+
+ def train(self, mode=True):
+ super(MIMVisionTransformer, self).train(mode)
+ if not self.finetune:
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ """Freeze params in backbone when linear probing."""
+ for _, param in self.named_parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ # stole cls_tokens impl from Phil Wang, thanks
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.pos_embed
+ x = self.drop_after_pos(x)
+
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+
+ if i == len(self.layers) - 1 and self.final_norm:
+ x = self.norm1(x)
+
+ if not self.final_norm:
+ x = x[:, 1:, :].mean(dim=1)
+ outcome = self.fc_norm(x)
+ else:
+ outcome = x[:, 0]
+ return outcome
diff --git a/mmselfsup/models/heads/__init__.py b/mmselfsup/models/heads/__init__.py
index 754c1a874..db2e2e30a 100644
--- a/mmselfsup/models/heads/__init__.py
+++ b/mmselfsup/models/heads/__init__.py
@@ -2,11 +2,13 @@
from .cls_head import ClsHead
from .contrastive_head import ContrastiveHead
from .latent_pred_head import LatentClsHead, LatentPredictHead
+from .mae_head import MAEFinetuneHead, MAEPretrainHead
from .mocov3_head import MoCoV3Head
from .multi_cls_head import MultiClsHead
from .swav_head import SwAVHead
__all__ = [
'ContrastiveHead', 'ClsHead', 'LatentPredictHead', 'LatentClsHead',
- 'MoCoV3Head', 'MultiClsHead', 'SwAVHead'
+ 'MultiClsHead', 'SwAVHead', 'MAEFinetuneHead', 'MAEPretrainHead',
+ 'MoCoV3Head'
]
diff --git a/mmselfsup/models/heads/mae_head.py b/mmselfsup/models/heads/mae_head.py
new file mode 100644
index 000000000..cb8b566c1
--- /dev/null
+++ b/mmselfsup/models/heads/mae_head.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcls.models import LabelSmoothLoss
+from mmcv.cnn.utils.weight_init import trunc_normal_
+from mmcv.runner import BaseModule
+from torch import nn
+
+from ..builder import HEADS
+
+
+@HEADS.register_module()
+class MAEPretrainHead(BaseModule):
+ """Pre-training head for MAE.
+
+ Args:
+ norm_pix_loss (bool): Whether or not normalize target.
+ Defaults to False.
+ patch_size (int): Patch size. Defaults to 16.
+ """
+
+ def __init__(self, norm_pix=False, patch_size=16):
+ super(MAEPretrainHead, self).__init__()
+ self.norm_pix = norm_pix
+ self.patch_size = patch_size
+
+ def patchify(self, imgs):
+
+ p = self.patch_size
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum('nchpwq->nhwpqc', x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
+ return x
+
+ def forward(self, x, pred, mask):
+ losses = dict()
+ target = self.patchify(x)
+ if self.norm_pix:
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.e-6)**.5
+
+ loss = (pred - target)**2
+ loss = loss.mean(dim=-1)
+
+ loss = (loss * mask).sum() / mask.sum()
+ losses['loss'] = loss
+ return losses
+
+
+@HEADS.register_module()
+class MAEFinetuneHead(BaseModule):
+ """Fine-tuning head for MAE.
+
+ Args:
+ embed_dim (int): The dim of the feature before the classifier head.
+ num_classes (int): The total classes. Defaults to 1000.
+ """
+
+ def __init__(self, embed_dim, num_classes=1000, label_smooth_val=0.1):
+ super(MAEFinetuneHead, self).__init__()
+ self.head = nn.Linear(embed_dim, num_classes)
+ self.criterion = LabelSmoothLoss(label_smooth_val, num_classes)
+
+ def init_weights(self):
+ nn.init.constant_(self.head.bias, 0)
+ trunc_normal_(self.head.weight, std=2e-5)
+
+ def forward(self, x):
+ """"Get the logits."""
+ outputs = self.head(x)
+
+ return [outputs]
+
+ def loss(self, outputs, labels):
+ """Compute the loss."""
+ losses = dict()
+ losses['loss'] = self.criterion(outputs[0], labels)
+
+ return losses
diff --git a/mmselfsup/models/necks/__init__.py b/mmselfsup/models/necks/__init__.py
index 21346e35e..52e1e3343 100644
--- a/mmselfsup/models/necks/__init__.py
+++ b/mmselfsup/models/necks/__init__.py
@@ -2,6 +2,7 @@
from .avgpool2d_neck import AvgPool2dNeck
from .densecl_neck import DenseCLNeck
from .linear_neck import LinearNeck
+from .mae_neck import MAEPretrainDecoder
from .mocov2_neck import MoCoV2Neck
from .nonlinear_neck import NonLinearNeck
from .odc_neck import ODCNeck
@@ -10,5 +11,6 @@
__all__ = [
'AvgPool2dNeck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck',
- 'NonLinearNeck', 'ODCNeck', 'RelativeLocNeck', 'SwAVNeck'
+ 'NonLinearNeck', 'ODCNeck', 'RelativeLocNeck', 'SwAVNeck',
+ 'MAEPretrainDecoder'
]
diff --git a/mmselfsup/models/necks/mae_neck.py b/mmselfsup/models/necks/mae_neck.py
new file mode 100644
index 000000000..15954f775
--- /dev/null
+++ b/mmselfsup/models/necks/mae_neck.py
@@ -0,0 +1,136 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
+from mmcv.cnn import build_norm_layer
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+from ..utils import build_2d_sincos_position_embedding
+
+
+@NECKS.register_module()
+class MAEPretrainDecoder(BaseModule):
+ """Decoder for MAE Pre-training.
+
+ Args:
+ num_patches (int): The number of total patches. Defaults to 196.
+ patch_size (int): Image patch size. Defaults to 16.
+ in_chans (int): The channel of input image. Defaults to 3.
+ embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
+ decoder_embed_dim (int): Decoder's embedding dimension.
+ Defaults to 512.
+ decoder_depth (int): The depth of decoder. Defaults to 8.
+ decoder_num_heads (int): Number of attention heads of decoder.
+ Defaults to 16.
+ mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
+ Defaults to 4.
+ norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
+
+ Some of the code is borrowed from
+ `https://github.com/facebookresearch/mae`.
+
+ Example:
+ >>> from mmselfsup.models import MAEPretrainDecoder
+ >>> import torch
+ >>> self = MAEPretrainDecoder()
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 50, 1024)
+ >>> ids_restore = torch.arange(0, 196).unsqueeze(0)
+ >>> level_outputs = self.forward(inputs, ids_restore)
+ >>> print(tuple(level_outputs.shape))
+ (1, 196, 768)
+ """
+
+ def __init__(self,
+ num_patches=196,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=1024,
+ decoder_embed_dim=512,
+ decoder_depth=8,
+ decoder_num_heads=16,
+ mlp_ratio=4.,
+ norm_cfg=dict(type='LN', eps=1e-6)):
+ super(MAEPretrainDecoder, self).__init__()
+ self.num_patches = num_patches
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
+
+ self.decoder_pos_embed = nn.Parameter(
+ torch.zeros(1, self.num_patches + 1, decoder_embed_dim),
+ requires_grad=False)
+
+ self.decoder_blocks = nn.ModuleList([
+ TransformerEncoderLayer(
+ decoder_embed_dim,
+ decoder_num_heads,
+ int(mlp_ratio * decoder_embed_dim),
+ qkv_bias=True,
+ norm_cfg=norm_cfg) for _ in range(decoder_depth)
+ ])
+
+ self.decoder_norm_name, decoder_norm = build_norm_layer(
+ norm_cfg, decoder_embed_dim, postfix=1)
+ self.add_module(self.decoder_norm_name, decoder_norm)
+ self.decoder_pred = nn.Linear(
+ decoder_embed_dim, patch_size**2 * in_chans, bias=True)
+
+ def init_weights(self):
+ super(MAEPretrainDecoder, self).init_weights()
+
+ # initialize position embedding of MAE decoder
+ decoder_pos_embed = build_2d_sincos_position_embedding(
+ int(self.num_patches**.5),
+ self.decoder_pos_embed.shape[-1],
+ cls_token=True)
+ self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
+
+ torch.nn.init.normal_(self.mask_token, std=.02)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+
+ if isinstance(m, nn.Linear):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @property
+ def decoder_norm(self):
+ return getattr(self, self.decoder_norm_name)
+
+ def forward(self, x, ids_restore):
+ # embed tokens
+ x = self.decoder_embed(x)
+
+ # append mask tokens to sequence
+ mask_tokens = self.mask_token.repeat(
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
+ x_ = torch.gather(
+ x_,
+ dim=1,
+ index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
+ x = torch.cat([x[:, :1, :], x_], dim=1)
+
+ # add pos embed
+ x = x + self.decoder_pos_embed
+
+ # apply Transformer blocks
+ for blk in self.decoder_blocks:
+ x = blk(x)
+ x = self.decoder_norm(x)
+
+ # predictor projection
+ x = self.decoder_pred(x)
+
+ # remove cls token
+ x = x[:, 1:, :]
+
+ return x
diff --git a/mmselfsup/models/utils/__init__.py b/mmselfsup/models/utils/__init__.py
index 4e5b64ab6..e9a5ebf1c 100644
--- a/mmselfsup/models/utils/__init__.py
+++ b/mmselfsup/models/utils/__init__.py
@@ -10,6 +10,6 @@
__all__ = [
'Accuracy', 'accuracy', 'ExtractProcess', 'GatherLayer', 'MultiPooling',
- 'MultiPrototypes', 'build_2d_sincos_position_embedding', 'ResLayer',
- 'Sobel'
+ 'MultiPrototypes', 'ResLayer', 'Sobel',
+ 'build_2d_sincos_position_embedding', 'Mixup'
]
diff --git a/mmselfsup/utils/__init__.py b/mmselfsup/utils/__init__.py
index 27a872e68..26c00528c 100644
--- a/mmselfsup/utils/__init__.py
+++ b/mmselfsup/utils/__init__.py
@@ -14,6 +14,6 @@
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
'distributed_sinkhorn', 'Extractor', 'concat_all_gather', 'gather_tensors',
- 'gather_tensors_batch', 'get_root_logger', 'setup_multi_processes',
- 'multi_gpu_test', 'single_gpu_test'
+ 'gather_tensors_batch', 'get_root_logger', 'multi_gpu_test',
+ 'single_gpu_test', 'setup_multi_processes'
]
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index e286fe0c2..179263445 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -7,4 +7,5 @@ scipy
six
sklearn
tensorboard
+timm
tqdm
diff --git a/setup.cfg b/setup.cfg
index 2e22906e2..f7cec1f73 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -8,7 +8,7 @@ line_length = 79
multi_line_output = 0
extra_standard_library = setuptools
known_first_party = mmselfsup
-known_third_party = PIL,cv2,detectron2,faiss,matplotlib,mmcls,mmcv,mmdet,numpy,packaging,pytest,pytorch_sphinx_theme,scipy,seaborn,six,sklearn,svm_helper,torch,torchvision,tqdm
+known_third_party = PIL,detectron2,faiss,matplotlib,mmcls,mmcv,mmdet,numpy,packaging,pytest,pytorch_sphinx_theme,scipy,seaborn,six,sklearn,svm_helper,timm,torch,torchvision,tqdm
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
diff --git a/tests/test_data/test_pipelines.py b/tests/test_data/test_pipeline.py
similarity index 83%
rename from tests/test_data/test_pipelines.py
rename to tests/test_data/test_pipeline.py
index b67b21bf0..1bacc9e65 100644
--- a/tests/test_data/test_pipelines.py
+++ b/tests/test_data/test_pipeline.py
@@ -95,3 +95,26 @@ def test_solarization():
res = module(img)
assert img.size == res.size
+
+
+def test_randomaug():
+ transform = dict(
+ type='RandomAug',
+ input_size=224,
+ color_jitter=None,
+ auto_augment='rand-m9-mstd0.5-inc1',
+ interpolation='bicubic',
+ re_prob=0.25,
+ re_mode='pixel',
+ re_count=1,
+ mean=(0.485, 0.456, 0.406),
+ std=(0.229, 0.224, 0.225))
+
+ img = Image.fromarray(np.uint8(np.ones((224, 224, 3))))
+
+ module = build_from_cfg(transform, PIPELINES)
+ res = module(img)
+
+ assert list(res.shape) == [3, 224, 224]
+
+ assert isinstance(str(module), str)
diff --git a/tests/test_models/test_algorithms/test_mae.py b/tests/test_models/test_algorithms/test_mae.py
new file mode 100644
index 000000000..d985f44f9
--- /dev/null
+++ b/tests/test_models/test_algorithms/test_mae.py
@@ -0,0 +1,37 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+
+import pytest
+import torch
+
+from mmselfsup.models.algorithms import MAE
+
+backbone = dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75)
+neck = dict(
+ type='MAEPretrainDecoder',
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ decoder_embed_dim=512,
+ decoder_depth=8,
+ decoder_num_heads=16,
+ mlp_ratio=4.,
+)
+head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16)
+
+
+@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
+def test_mae():
+ with pytest.raises(AssertionError):
+ alg = MAE(backbone=backbone, neck=None, head=head)
+ with pytest.raises(AssertionError):
+ alg = MAE(backbone=backbone, neck=neck, head=None)
+ with pytest.raises(AssertionError):
+ alg = MAE(backbone=None, neck=neck, head=head)
+ alg = MAE(backbone=backbone, neck=neck, head=head)
+
+ fake_input = torch.randn((16, 3, 224, 224))
+ fake_loss = alg.forward_train(fake_input)
+ fake_feature = alg.extract_feat(fake_input)
+ assert isinstance(fake_loss['loss'].item(), float)
+ assert list(fake_feature[0].shape) == [16, 50, 768]
diff --git a/tests/test_models/test_backbones/test_mae_pretrain_vit.py b/tests/test_models/test_backbones/test_mae_pretrain_vit.py
new file mode 100644
index 000000000..7a772962e
--- /dev/null
+++ b/tests/test_models/test_backbones/test_mae_pretrain_vit.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+
+import pytest
+import torch
+
+from mmselfsup.models.backbones import MAEViT
+
+backbone = dict(arch='b', patch_size=16, mask_ratio=0.75)
+
+
+@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
+def test_mae_pretrain_vit():
+ mae_pretrain_backbone = MAEViT(**backbone)
+ mae_pretrain_backbone.init_weights()
+ fake_inputs = torch.randn((2, 3, 224, 224))
+ fake_outputs = mae_pretrain_backbone(fake_inputs)[0]
+
+ assert list(fake_outputs.shape) == [2, 50, 768]
diff --git a/tests/test_models/test_backbones/test_mim_cls_vit.py b/tests/test_models/test_backbones/test_mim_cls_vit.py
new file mode 100644
index 000000000..005ea1cfe
--- /dev/null
+++ b/tests/test_models/test_backbones/test_mim_cls_vit.py
@@ -0,0 +1,32 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+
+import pytest
+import torch
+
+from mmselfsup.models.backbones import MIMVisionTransformer
+
+finetune_backbone = dict(
+ arch='b', patch_size=16, drop_path_rate=0.1, final_norm=False)
+
+finetune_backbone_norm = dict(
+ arch='b', patch_size=16, drop_path_rate=0.1, final_norm=True)
+
+linprobe_backbone = dict(
+ arch='b', patch_size=16, finetune=False, final_norm=False)
+
+
+@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
+def test_mae_cls_vit():
+ mae_finetune_backbone = MIMVisionTransformer(**finetune_backbone)
+ mae_finetune_backbone_norm = MIMVisionTransformer(**finetune_backbone_norm)
+ mae_linprobe_backbone = MIMVisionTransformer(**linprobe_backbone)
+ mae_linprobe_backbone.train()
+
+ fake_inputs = torch.randn((2, 3, 224, 224))
+ fake_finetune_outputs = mae_finetune_backbone(fake_inputs)
+ fake_finetune_outputs_norm = mae_finetune_backbone_norm(fake_inputs)
+ fake_linprobe_outputs = mae_linprobe_backbone(fake_inputs)
+ assert list(fake_finetune_outputs.shape) == [2, 768]
+ assert list(fake_linprobe_outputs.shape) == [2, 768]
+ assert list(fake_finetune_outputs_norm.shape) == [2, 768]
diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py
index c9548164e..69b3bf8b8 100644
--- a/tests/test_models/test_heads.py
+++ b/tests/test_models/test_heads.py
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
+import torch.nn.functional as F
from mmselfsup.models.heads import (ClsHead, ContrastiveHead, LatentClsHead,
- LatentPredictHead, MultiClsHead, SwAVHead)
+ LatentPredictHead, MAEFinetuneHead,
+ MAEPretrainHead, MultiClsHead, SwAVHead)
def test_cls_head():
@@ -73,3 +75,34 @@ def test_swav_head():
loss = head.forward(fake_input)
assert loss['loss'].item() > 0
+
+
+def test_mae_pretrain_head():
+ head = MAEPretrainHead(norm_pix=False, patch_size=16)
+ fake_input = torch.rand((2, 3, 224, 224))
+ fake_mask = torch.ones((2, 196))
+ fake_pred = torch.rand((2, 196, 768))
+
+ loss = head.forward(fake_input, fake_pred, fake_mask)
+
+ assert loss['loss'].item() > 0
+
+ head_norm_pixel = MAEPretrainHead(norm_pix=True, patch_size=16)
+
+ loss_norm_pixel = head_norm_pixel.forward(fake_input, fake_pred, fake_mask)
+
+ assert loss_norm_pixel['loss'].item() > 0
+
+
+def test_mae_finetune_head():
+
+ head = MAEFinetuneHead(num_classes=1000, embed_dim=768)
+ fake_input = torch.rand((2, 768))
+ fake_labels = F.normalize(torch.rand((2, 1000)), dim=-1)
+ fake_features = head.forward(fake_input)
+
+ assert list(fake_features[0].shape) == [2, 1000]
+
+ loss = head.loss(fake_features, fake_labels)
+
+ assert loss['loss'].item() > 0
diff --git a/tests/test_models/test_necks/test_mae_neck.py b/tests/test_models/test_necks/test_mae_neck.py
new file mode 100644
index 000000000..230e32b04
--- /dev/null
+++ b/tests/test_models/test_necks/test_mae_neck.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmselfsup.models.necks import MAEPretrainDecoder
+
+
+def test_linear_neck():
+ decoder = MAEPretrainDecoder()
+ decoder.eval()
+ inputs = torch.rand(1, 50, 1024)
+ ids_restore = torch.arange(0, 196).unsqueeze(0)
+ level_outputs = decoder.forward(inputs, ids_restore)
+ assert tuple(level_outputs.shape) == (1, 196, 768)