diff --git a/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb256-coslr-100e_in1k.py b/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb256-coslr-100e_in1k.py new file mode 100644 index 000000000..574fe3844 --- /dev/null +++ b/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb256-coslr-100e_in1k.py @@ -0,0 +1,76 @@ +_base_ = [ + '../_base_/models/vit-base-p16_ft.py', + '../_base_/datasets/imagenet.py', + '../_base_/schedules/adamw_coslr-100e_in1k.py', + '../_base_/default_runtime.py', +] +# maskfeat fine-tuning setting + +# dataset +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict( + type='RandomAug', + input_size=224, + color_jitter=0.4, + auto_augment='rand-m9-mstd0.5-inc1', + interpolation='bicubic', + re_prob=0.25, + re_mode='pixel', + re_count=1, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225)) +] +test_pipeline = [ + dict(type='Resize', size=256, interpolation=3), + dict(type='CenterCrop', size=224), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg) +] +data = dict( + samples_per_gpu=256, + drop_last=False, + workers_per_gpu=32, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline)) + +# model +model = dict( + backbone=dict(init_cfg=dict()), + head=dict( + type='MaskFeatFinetuneHead', + num_classes=1000, + embed_dim=768, + label_smooth_val=0.1)) + +# optimizer +optimizer = dict( + lr=0.002 * 8 / 2, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_options={ + 'ln': dict(weight_decay=0.), + 'bias': dict(weight_decay=0.), + 'pos_embed': dict(weight_decay=0.), + 'cls_token': dict(weight_decay=0.), + }, + constructor='TransformerFinetuneConstructor', + model_type='vit', + layer_decay=0.65) + +# learning policy +lr_config = dict( + policy='CosineAnnealing', + min_lr=1e-6, + warmup='linear', + warmup_iters=20, + warmup_ratio=1e-08, + warmup_by_epoch=True) + +# runtime +checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='') +persistent_workers = True +log_config = dict( + interval=100, hooks=[ + dict(type='TextLoggerHook'), + ]) diff --git a/configs/selfsup/_base_/datasets/imagenet_maskfeat.py b/configs/selfsup/_base_/datasets/imagenet_maskfeat.py new file mode 100644 index 000000000..9114a62a2 --- /dev/null +++ b/configs/selfsup/_base_/datasets/imagenet_maskfeat.py @@ -0,0 +1,35 @@ +# dataset settings +data_source = 'ImageNet' +dataset_type = 'SingleViewDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict( + type='RandomResizedCropAndInterpolationWithTwoPic', + size=224, + scale=(0.5, 1.0), + ratio=(0.75, 1.3333), + interpolation='bicubic'), + dict(type='RandomHorizontalFlip') +] + +# prefetch +prefetch = False +if not prefetch: + train_pipeline.extend( + [dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg)]) + +train_pipeline.append(dict(type='MaskFeatMaskGenerator', mask_ratio=0.4)) + +# dataset summary +data = dict( + samples_per_gpu=256, + workers_per_gpu=8, + train=dict( + type=dataset_type, + data_source=dict( + type=data_source, + data_prefix='data/imagenet/train', + ann_file='data/imagenet/meta/train.txt'), + pipeline=train_pipeline, + prefetch=prefetch)) diff --git a/configs/selfsup/_base_/models/maskfeat_vit-base-p16.py b/configs/selfsup/_base_/models/maskfeat_vit-base-p16.py new file mode 100644 index 000000000..e1de9ae11 --- /dev/null +++ b/configs/selfsup/_base_/models/maskfeat_vit-base-p16.py @@ -0,0 +1,15 @@ +# model settings +model = dict( + type='MaskFeat', + backbone=dict( + type='MaskFeatViT', + arch='b', + patch_size=16, + drop_path_rate=0, + ), + head=dict(type='MaskFeatPretrainHead', hog_dim=108), + hog_para=dict( + nbins=9, # Number of bin. Defaults to 9. + pool=8, # Number of cell. Defaults to 8. + gaussian_window=16 # Size of gaussian kernel. Defaults to 16. + )) diff --git a/configs/selfsup/maskfeat/README.md b/configs/selfsup/maskfeat/README.md new file mode 100644 index 000000000..4ffeb9b78 --- /dev/null +++ b/configs/selfsup/maskfeat/README.md @@ -0,0 +1,34 @@ +# MaskFeat + +> [Masked Feature Prediction for Self-Supervised Visual Pre-Training](https://arxiv.org/abs/2112.09133v1) + + + +## Abstract + +We present Masked Feature Prediction (MaskFeat) for self-supervised pre-training of video models. Our approach first randomly masks out a portion of the input sequence and then predicts the feature of the masked regions. We study five different types of features and find Histograms of Oriented Gradients (HOG), a hand-crafted feature descriptor, works particularly well in terms of both performance and efficiency. We observe that the local contrast normalization in HOG is essential for good results, which is in line with earlier work using HOG for visual recognition. Our approach can learn abundant visual knowledge and drive large-scale Transformer-based models. Without using extra model weights or supervision, MaskFeat pre-trained on unlabeled videos achieves unprecedented results of 86.7% with MViT-L on Kinetics-400, 88.3% on Kinetics-600, 80.4% on Kinetics-700, 38.8 mAP on AVA, and 75.0% on SSv2. MaskFeat further generalizes to image input, which can be interpreted as a video with a single frame and obtains competitive results on ImageNet. + +
+ +
+ +## Models and Benchmarks + +Here, we report the results of the model, which is pre-trained on ImageNet-1k +for 400 epochs, the details are below: + +| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download | +| :------: | :-------------: | :---------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ViT-B/16 | 300 | 83.5 | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) | + +## Citation + +```bibtex +@article{He2021MaskedAA, + title={Masked Autoencoders Are Scalable Vision Learners}, + author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and + Piotr Doll'ar and Ross B. Girshick}, + journal={ArXiv}, + year={2021} +} +``` diff --git a/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py b/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py new file mode 100644 index 000000000..660ae0e80 --- /dev/null +++ b/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py @@ -0,0 +1,40 @@ +_base_ = [ + '../_base_/models/maskfeat_vit-base-p16.py', + '../_base_/datasets/imagenet_maskfeat.py', + '../_base_/schedules/adamw_coslr-300e_in1k.py', + '../_base_/default_runtime.py', +] + +# dataset +data = dict(samples_per_gpu=256, workers_per_gpu=32) + +# optimizer +optimizer = dict( + lr=2e-4 * 8, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_options={ + 'ln': dict(weight_decay=0.), + 'bias': dict(weight_decay=0.), + }) +optimizer_config = dict(grad_clip=dict(max_norm=0.02)) + +# learning policy +lr_config = dict( + policy='CosineAnnealing', + min_lr=1e-6, + warmup='linear', + warmup_iters=30, + warmup_ratio=1e-06, + warmup_by_epoch=True) + +# schedule +runner = dict(max_epochs=300) + +# runtime +checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='') +persistent_workers = True +log_config = dict( + interval=100, hooks=[ + dict(type='TextLoggerHook'), + ]) diff --git a/configs/selfsup/maskfeat/metafile.yaml b/configs/selfsup/maskfeat/metafile.yaml new file mode 100644 index 000000000..65823d072 --- /dev/null +++ b/configs/selfsup/maskfeat/metafile.yaml @@ -0,0 +1,27 @@ +Collections: + - Name: MaskFeat + Metadata: + Training Data: ImageNet-1k + Training Techniques: + - AdamW + Training Resources: 8x A100-80G GPUs + Architecture: + - ViT + Paper: + URL: https://arxiv.org/abs/2112.09133v1 + Title: "Masked Feature Prediction for Self-Supervised Visual Pre-Training" + README: configs/selfsup/maskfeat/README.md + +Models: + - Name: maskfeat_vit-base-p16_8xb256-coslr-300e_in1k + In Collection: MaskFeat + Metadata: + Epochs: 300 + Batch Size: 2048 + Results: + - Task: Self-Supervised Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.5 + Config: configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py + Weights: https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220913-591d4c4b.pth diff --git a/docs/en/algorithms/maskfeat.md b/docs/en/algorithms/maskfeat.md new file mode 100644 index 000000000..4ffeb9b78 --- /dev/null +++ b/docs/en/algorithms/maskfeat.md @@ -0,0 +1,34 @@ +# MaskFeat + +> [Masked Feature Prediction for Self-Supervised Visual Pre-Training](https://arxiv.org/abs/2112.09133v1) + + + +## Abstract + +We present Masked Feature Prediction (MaskFeat) for self-supervised pre-training of video models. Our approach first randomly masks out a portion of the input sequence and then predicts the feature of the masked regions. We study five different types of features and find Histograms of Oriented Gradients (HOG), a hand-crafted feature descriptor, works particularly well in terms of both performance and efficiency. We observe that the local contrast normalization in HOG is essential for good results, which is in line with earlier work using HOG for visual recognition. Our approach can learn abundant visual knowledge and drive large-scale Transformer-based models. Without using extra model weights or supervision, MaskFeat pre-trained on unlabeled videos achieves unprecedented results of 86.7% with MViT-L on Kinetics-400, 88.3% on Kinetics-600, 80.4% on Kinetics-700, 38.8 mAP on AVA, and 75.0% on SSv2. MaskFeat further generalizes to image input, which can be interpreted as a video with a single frame and obtains competitive results on ImageNet. + +
+ +
+ +## Models and Benchmarks + +Here, we report the results of the model, which is pre-trained on ImageNet-1k +for 400 epochs, the details are below: + +| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download | +| :------: | :-------------: | :---------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ViT-B/16 | 300 | 83.5 | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) | + +## Citation + +```bibtex +@article{He2021MaskedAA, + title={Masked Autoencoders Are Scalable Vision Learners}, + author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and + Piotr Doll'ar and Ross B. Girshick}, + journal={ArXiv}, + year={2021} +} +``` diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index 4338a2395..8f073dae5 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -25,7 +25,8 @@ All models and part of benchmark results are recorded below. | [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) \| [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) | | [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) | | [SimMIM](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | [model](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.pth) \| [log](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.log.json) | -| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/RAEDME.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) | +| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/README.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) | +| [MaskFeat](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/README.md) | [maskfeat_vit-base-p16_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220913-591d4c4b.pth) \| [log](https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220829_225552.log.json) | Remarks: @@ -63,11 +64,12 @@ If not specified, we use linear evaluation setting from [MoCo](http://openaccess ### ImageNet Fine-tuning -| Algorithm | Config | Remarks | Top-1 (%) | -| --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | --------- | -| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 | -| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 | -| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 | +| Algorithm | Config | Remarks | Top-1 (%) | +| --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------- | --------- | +| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 | +| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 | +| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 | +| MaskFeat | [maskfeat_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | | 83.5 | ### COCO17 Object Detection and Instance Segmentation diff --git a/docs/zh_cn/algorithms/maskfeat.md b/docs/zh_cn/algorithms/maskfeat.md new file mode 100644 index 000000000..4ffeb9b78 --- /dev/null +++ b/docs/zh_cn/algorithms/maskfeat.md @@ -0,0 +1,34 @@ +# MaskFeat + +> [Masked Feature Prediction for Self-Supervised Visual Pre-Training](https://arxiv.org/abs/2112.09133v1) + + + +## Abstract + +We present Masked Feature Prediction (MaskFeat) for self-supervised pre-training of video models. Our approach first randomly masks out a portion of the input sequence and then predicts the feature of the masked regions. We study five different types of features and find Histograms of Oriented Gradients (HOG), a hand-crafted feature descriptor, works particularly well in terms of both performance and efficiency. We observe that the local contrast normalization in HOG is essential for good results, which is in line with earlier work using HOG for visual recognition. Our approach can learn abundant visual knowledge and drive large-scale Transformer-based models. Without using extra model weights or supervision, MaskFeat pre-trained on unlabeled videos achieves unprecedented results of 86.7% with MViT-L on Kinetics-400, 88.3% on Kinetics-600, 80.4% on Kinetics-700, 38.8 mAP on AVA, and 75.0% on SSv2. MaskFeat further generalizes to image input, which can be interpreted as a video with a single frame and obtains competitive results on ImageNet. + +
+ +
+ +## Models and Benchmarks + +Here, we report the results of the model, which is pre-trained on ImageNet-1k +for 400 epochs, the details are below: + +| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download | +| :------: | :-------------: | :---------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ViT-B/16 | 300 | 83.5 | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) | + +## Citation + +```bibtex +@article{He2021MaskedAA, + title={Masked Autoencoders Are Scalable Vision Learners}, + author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and + Piotr Doll'ar and Ross B. Girshick}, + journal={ArXiv}, + year={2021} +} +``` diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md index 742046da6..ec988ff24 100644 --- a/docs/zh_cn/model_zoo.md +++ b/docs/zh_cn/model_zoo.md @@ -25,7 +25,8 @@ | [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) \| [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) | | [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) | | [SimMIM](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | [model](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.pth) \| [log](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.log.json) | -| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/RAEDME.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) | +| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/README.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) | +| [MaskFeat](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/README.md) | [maskfeat_vit-base-p16_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220913-591d4c4b.pth) \| [log](https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220829_225552.log.json) | 备注: @@ -63,11 +64,12 @@ ### ImageNet 微调 -| 算法 | 配置文件 | 备注 | Top-1 (%) | -| ------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---- | --------- | -| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 | -| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 | -| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 | +| 算法 | 配置文件 | 备注 | Top-1 (%) | +| -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---- | --------- | +| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 | +| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 | +| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 | +| MaskFeat | [maskfeat_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | | 83.5 | ### COCO17 目标检测和实例分割 diff --git a/mmselfsup/datasets/pipelines/__init__.py b/mmselfsup/datasets/pipelines/__init__.py index f4cf1e2a9..671abdd97 100644 --- a/mmselfsup/datasets/pipelines/__init__.py +++ b/mmselfsup/datasets/pipelines/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .transforms import (BEiTMaskGenerator, GaussianBlur, Lighting, - RandomAppliedTrans, RandomAug, SimMIMMaskGenerator, - Solarization, ToTensor) + MaskFeatMaskGenerator, RandomAppliedTrans, RandomAug, + SimMIMMaskGenerator, Solarization, ToTensor) __all__ = [ 'GaussianBlur', 'Lighting', 'RandomAppliedTrans', 'Solarization', - 'RandomAug', 'SimMIMMaskGenerator', 'ToTensor', 'BEiTMaskGenerator' + 'RandomAug', 'SimMIMMaskGenerator', 'ToTensor', 'BEiTMaskGenerator', + 'MaskFeatMaskGenerator' ] diff --git a/mmselfsup/datasets/pipelines/transforms.py b/mmselfsup/datasets/pipelines/transforms.py index dd4411d2b..13afed8f4 100644 --- a/mmselfsup/datasets/pipelines/transforms.py +++ b/mmselfsup/datasets/pipelines/transforms.py @@ -482,3 +482,113 @@ def __repr__(self): repr_str += f'threshold = {self.threshold}, ' repr_str += f'prob = {self.prob}' return repr_str + + +@PIPELINES.register_module() +class MaskFeatMaskGenerator(object): + """Generate random block mask for each image. + + This module is borrowed from + https://github.com/facebookresearch/SlowFast/blob/main/slowfast/datasets/transform.py + Args: + mask_window_size (int): Size of input image. Defaults to 14. + mask_ratio (float): The mask ratio of image. Defaults to 0.4. + min_num_patches (int): Minimum number of patches that require masking. + Defaults to 15. + max_num_patches (int, optional): Maximum number of patches that + require masking. Defaults to None. + min_aspect (int): Minimum aspect of patches. Defaults to 0.3. + max_aspect (float, optional): Maximum aspect of patches. + Defaults to None. + """ + + def __init__( + self, + mask_window_size: int = 14, + mask_ratio: float = 0.4, + min_num_patches: int = 15, + max_num_patches: Optional[int] = None, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None, + ) -> None: + self.height, self.width = mask_window_size, mask_window_size + + self.num_patches = self.height * self.width + self.num_masking_patches = int(mask_window_size**2 * mask_ratio) + + self.min_num_patches = min_num_patches + self.max_num_patches = ( + self.num_masking_patches + if max_num_patches is None else max_num_patches) + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(height={self.height}, ' + repr_str += f'width={self.width}, ' + repr_str += f'num_patches={self.num_patches}, ' + repr_str += f'num_masking_patches={self.num_masking_patches}, ' + repr_str += f'min_num_patches={self.min_num_patches}, ' + repr_str += f'max_num_patches={self.max_num_patches}, ' + repr_str += f'log_aspect_ratio={self.log_aspect_ratio})' + return repr_str + + def get_shape(self) -> Tuple[int, int]: + return self.height, self.width + + def _random_masking(self, mask: np.array, max_mask_patches: int) -> int: + """Generate random block masks for each image up to 10 times. + + Args: + mask (np.array): Initial mask of shape (mask_window_size, + mask_window_size). + max_mask_patches (int): Maximum number of masked patches required. + Returns: + int: Number of masking patches. + """ + delta = 0 + for _ in range(10): + target_area = random.uniform(self.min_num_patches, + max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top:top + h, left:left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate random block mask for each image. + + Args: + img (torch.Tensor): Input image of shape (C, H, W). + Returns: + Tuple[torch.Tensor, torch.Tensor]: Input image and mask. + """ + mask = np.zeros(shape=self.get_shape(), dtype=np.int) + mask_count = 0 + while mask_count < self.num_masking_patches: + max_mask_patches = self.num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._random_masking(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + return img, torch.Tensor(mask).bool() diff --git a/mmselfsup/models/algorithms/__init__.py b/mmselfsup/models/algorithms/__init__.py index 8de1b2539..8f8992f3b 100644 --- a/mmselfsup/models/algorithms/__init__.py +++ b/mmselfsup/models/algorithms/__init__.py @@ -7,6 +7,7 @@ from .deepcluster import DeepCluster from .densecl import DenseCL from .mae import MAE +from .maskfeat import MaskFeat from .mmcls_classifier_wrapper import MMClsImageClassifierWrapper from .moco import MoCo from .mocov3 import MoCoV3 @@ -23,5 +24,5 @@ 'BaseModel', 'BarlowTwins', 'BYOL', 'Classification', 'DeepCluster', 'DenseCL', 'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam', 'SwAV', 'MAE', 'MoCoV3', 'SimMIM', - 'MMClsImageClassifierWrapper', 'CAE' + 'MMClsImageClassifierWrapper', 'CAE', 'MaskFeat' ] diff --git a/mmselfsup/models/algorithms/maskfeat.py b/mmselfsup/models/algorithms/maskfeat.py new file mode 100644 index 000000000..769179b4f --- /dev/null +++ b/mmselfsup/models/algorithms/maskfeat.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch + +from ..builder import ALGORITHMS, build_backbone, build_head +from ..utils.hog_layer import HOGLayerC +from .base import BaseModel + + +@ALGORITHMS.register_module() +class MaskFeat(BaseModel): + """MaskFeat. + + Implementation of `Masked Feature Prediction for + Self-Supervised Visual Pre-Training `_. + Args: + backbone (dict): Config dict for encoder. + head (dict): Config dict for loss functions. + hog_para (dict): Config dict for hog layer. + dict['nbins', int]: Number of bin. Defaults to 9. + dict['pool', float]: Number of cell. Defaults to 8. + dict['gaussian_window', int]: Size of gaussian kernel. + Defaults to 16. + init_cfg (dict): Config dict for weight initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + head: dict, + hog_para: dict, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg) + assert backbone is not None + self.backbone = build_backbone(backbone) + assert head is not None + self.head = build_head(head) + assert hog_para is not None + self.hog_layer = HOGLayerC(**hog_para) + + def extract_feat(self, input: List[torch.Tensor]) -> torch.Tensor: + """Function to extract features from backbone. + + Args: + input (List[torch.Tensor, torch.Tensor]): Input images and masks. + Returns: + tuple[Tensor]: backbone outputs. + """ + img = input[0] + mask = input[1] + return self.backbone(img, mask) + + def forward_train(self, input: List[torch.Tensor], **kwargs) -> dict: + """Forward computation during training. + + Args: + input (List[torch.Tensor, torch.Tensor]): Input images and masks. + kwargs: Any keyword arguments to be used to forward. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + img = input[0] + mask = input[1] + + hog = self.hog_layer(img) + latent = self.backbone(img, mask) + losses = self.head(latent, hog, mask) + + return losses diff --git a/mmselfsup/models/backbones/__init__.py b/mmselfsup/models/backbones/__init__.py index dbc78426c..0d2a678bb 100644 --- a/mmselfsup/models/backbones/__init__.py +++ b/mmselfsup/models/backbones/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cae_vit import CAEViT from .mae_vit import MAEViT +from .maskfeat_vit import MaskFeatViT from .mim_cls_vit import MIMVisionTransformer from .resnet import ResNet, ResNetV1d from .resnext import ResNeXt @@ -9,5 +10,5 @@ __all__ = [ 'ResNet', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MIMVisionTransformer', - 'VisionTransformer', 'SimMIMSwinTransformer', 'CAEViT' + 'VisionTransformer', 'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT' ] diff --git a/mmselfsup/models/backbones/maskfeat_vit.py b/mmselfsup/models/backbones/maskfeat_vit.py new file mode 100644 index 000000000..c50949915 --- /dev/null +++ b/mmselfsup/models/backbones/maskfeat_vit.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +from mmcls.models import VisionTransformer +from mmcv.cnn.utils.weight_init import trunc_normal_ +from torch import nn + +from ..builder import BACKBONES + + +@BACKBONES.register_module() +class MaskFeatViT(VisionTransformer): + """Vision Transformer for MaskFeat pre-training. + + A PyTorch implement of: `Masked Feature Prediction for Self-Supervised + Visual Pre-Training `_. + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Defaults to True. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: Union[Tuple[int, int], int] = 224, + patch_size: int = 16, + out_indices: int = -1, + drop_rate: float = 0., + drop_path_rate: float = 0., + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + output_cls_token: bool = True, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + init_cfg: Optional[dict] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + output_cls_token=output_cls_token, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.parameter.Parameter( + torch.zeros(1, 1, self.embed_dims)) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + def init_weights(self) -> None: + super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.mask_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m: torch.nn.Module) -> None: + if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate features for masked images. + + Args: + x (torch.Tensor): Input images. + mask (torch.Tensor): Input masks. + Returns: + torch.Tensor: Features with cls_tokens. + """ + B = x.shape[0] + x = self.patch_embed(x)[0] + + # masking: length -> length * mask_ratio + B, L, _ = x.shape + mask_tokens = self.mask_token.expand(B, L, -1) + mask = mask.flatten(1).unsqueeze(-1) + x = x * (1 - mask.int()) + mask_tokens * mask + + # append cls token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.drop_after_pos(x) + + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + return x diff --git a/mmselfsup/models/heads/__init__.py b/mmselfsup/models/heads/__init__.py index 360701b17..9fcb3d0bf 100644 --- a/mmselfsup/models/heads/__init__.py +++ b/mmselfsup/models/heads/__init__.py @@ -5,6 +5,7 @@ from .latent_pred_head import (LatentClsHead, LatentCrossCorrelationHead, LatentPredictHead) from .mae_head import MAEFinetuneHead, MAELinprobeHead, MAEPretrainHead +from .maskfeat_head import MaskFeatFinetuneHead, MaskFeatPretrainHead from .mocov3_head import MoCoV3Head from .multi_cls_head import MultiClsHead from .simmim_head import SimMIMHead @@ -14,5 +15,6 @@ 'ContrastiveHead', 'ClsHead', 'LatentPredictHead', 'LatentClsHead', 'LatentCrossCorrelationHead', 'MultiClsHead', 'SwAVHead', 'MAEFinetuneHead', 'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', - 'CAEHead', 'MAELinprobeHead' + 'CAEHead', 'MAELinprobeHead', 'MaskFeatFinetuneHead', + 'MaskFeatPretrainHead' ] diff --git a/mmselfsup/models/heads/maskfeat_head.py b/mmselfsup/models/heads/maskfeat_head.py new file mode 100644 index 000000000..90cf84995 --- /dev/null +++ b/mmselfsup/models/heads/maskfeat_head.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcls.models import LabelSmoothLoss +from mmcv.cnn.utils.weight_init import trunc_normal_ +from mmcv.runner import BaseModule +from torch import nn + +from ..builder import HEADS + + +@HEADS.register_module() +class MaskFeatPretrainHead(BaseModule): + """Pre-training head for MaskFeat. + + Args: + embed_dim (int): The dim of the feature before the classifier head. + Defaults to 768. + hog_dim (int): The dim of the hog feature. Defaults to 108. + """ + + def __init__(self, embed_dim: int = 768, hog_dim: int = 108) -> None: + super().__init__() + self.head = nn.Linear(embed_dim, hog_dim) + + def init_weights(self) -> None: + nn.init.constant_(self.head.bias, 0) + trunc_normal_(self.head.weight, std=0.02) + + def loss(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> dict: + """Compute the loss. + + Args: + pred (torch.Tensor): Input prediction of shape (N, L, C). + target (torch.Tensor): Input target of shape (N, L, C). + mask (torch.Tensor): Input mask of shape (N, L, 1). + Returns: + dict[str, torch.Tensor]: A dictionary of loss components. + """ + losses = dict() + + pred = pred[mask] + target = target[mask] + loss = ((pred - target)**2).mean(-1).mean() + + losses['loss'] = loss + return losses + + def forward(self, latent: torch.Tensor, hog: torch.Tensor, + mask: torch.Tensor) -> dict: + """Pre-training head for MaskFeat. + + Args: + latent (torch.Tensor): Input latent of shape (N, 1+L, C). + hog (torch.Tensor): Input hog feature of shape (N, L, C). + mask (torch.Tensor): Input mask of shape (N, H, W). + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + latent = self.head(latent) + mask = mask.flatten(1).bool() + losses = self.loss(latent[:, 1:], hog, mask) + + return losses + + +@HEADS.register_module() +class MaskFeatFinetuneHead(BaseModule): + """Fine-tuning head for MaskFeat. + + Args: + embed_dim (int): The dim of the feature before the classifier head. + num_classes (int): The total classes. Defaults to 1000. + label_smooth_val (float): The degree of label smoothing. + Defaults to 0.1. + """ + + def __init__(self, + embed_dim: int, + num_classes: int = 1000, + label_smooth_val: float = 0.1) -> None: + super().__init__() + self.head = nn.Linear(embed_dim, num_classes, bias=True) + self.act = nn.Softmax(dim=1) + self.criterion = LabelSmoothLoss(label_smooth_val, num_classes) + + def init_weights(self) -> None: + nn.init.constant_(self.head.bias, 0) + trunc_normal_(self.head.weight, std=.02) + + def forward(self, x: torch.Tensor) -> list: + """"Get the logits.""" + outputs = self.head(x) + if not self.training: + outputs = self.act(outputs) + return [outputs] + + def loss(self, outputs: torch.Tensor, labels: torch.Tensor) -> dict: + """Compute the loss.""" + losses = dict() + losses['loss'] = self.criterion(outputs[0], labels) + + return losses diff --git a/mmselfsup/models/utils/hog_layer.py b/mmselfsup/models/utils/hog_layer.py new file mode 100644 index 000000000..958cbfeba --- /dev/null +++ b/mmselfsup/models/utils/hog_layer.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class HOGLayerC(nn.Module): + """Generate hog feature for each batch images. This module is used in + Maskfeat to generate hog feature. This code is borrowed from. + + + Args: + nbins (int): Number of bin. Defaults to 9. + pool (float): Number of cell. Defaults to 8. + gaussian_window (int): Size of gaussian kernel. Defaults to 16. + """ + + def __init__(self, + nbins: int = 9, + pool: int = 8, + gaussian_window: int = 16) -> None: + super().__init__() + self.nbins = nbins + self.pool = pool + self.pi = math.pi + weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) + weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1) + weight_y = weight_x.transpose(2, 3) + self.register_buffer('weight_x', weight_x) + self.register_buffer('weight_y', weight_y) + + self.gaussian_window = gaussian_window + if gaussian_window: + gkern = self.get_gkern(gaussian_window, gaussian_window // 2) + self.register_buffer('gkern', gkern) + + def get_gkern(self, kernlen: int, std: int) -> torch.Tensor: + """Returns a 2D Gaussian kernel array.""" + + def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor: + n = torch.arange(0, kernlen).float() + n -= n.mean() + n /= std + w = torch.exp(-0.5 * n**2) + return w + + gkern1d = _gaussian_fn(kernlen, std) + gkern2d = gkern1d[:, None] * gkern1d[None, :] + return gkern2d / gkern2d.sum() + + def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor: + hog_feat = hog_feat.flatten(1, 2) + unfold_size = hog_feat.shape[-1] // 14 + hog_feat = ( + hog_feat.permute(0, 2, 3, + 1).unfold(1, unfold_size, unfold_size).unfold( + 2, unfold_size, + unfold_size).flatten(1, 2).flatten(2)) + return hog_feat + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Generate hog feature for each batch images. + + Args: + x (torch.Tensor): Input images of shape (N, 3, H, W). + Returns: + torch.Tensor: Hog features. + """ + # input is RGB image with shape [B 3 H W] + x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect') + gx_rgb = F.conv2d( + x, self.weight_x, bias=None, stride=1, padding=0, groups=3) + gy_rgb = F.conv2d( + x, self.weight_y, bias=None, stride=1, padding=0, groups=3) + norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1) + phase = torch.atan2(gx_rgb, gy_rgb) + phase = phase / self.pi * self.nbins # [-9, 9] + + b, c, h, w = norm_rgb.shape + out = torch.zeros((b, c, self.nbins, h, w), + dtype=torch.float, + device=x.device) + phase = phase.view(b, c, 1, h, w) + norm_rgb = norm_rgb.view(b, c, 1, h, w) + if self.gaussian_window: + if h != self.gaussian_window: + assert h % self.gaussian_window == 0, 'h {} gw {}'.format( + h, self.gaussian_window) + repeat_rate = h // self.gaussian_window + temp_gkern = self.gkern.repeat([repeat_rate, repeat_rate]) + else: + temp_gkern = self.gkern + norm_rgb *= temp_gkern + + out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb) + + out = out.unfold(3, self.pool, self.pool) + out = out.unfold(4, self.pool, self.pool) + out = out.sum(dim=[-1, -2]) + + out = F.normalize(out, p=2, dim=2) + + return self._reshape(out) diff --git a/tests/test_data/test_pipeline.py b/tests/test_data/test_pipeline.py index b62dabb25..f46f71721 100644 --- a/tests/test_data/test_pipeline.py +++ b/tests/test_data/test_pipeline.py @@ -187,3 +187,13 @@ def test_random_resize_crop_with_two_pic(): fake_output = module(fake_input) assert list(fake_output[0].size) == [224, 224] assert list(fake_output[1].size) == [112, 112] + + +def test_maskfeat_mask_gen(): + transform = dict( + type='MaskFeatMaskGenerator', mask_window_size=14, mask_ratio=0.6) + + img = torch.rand((3, 224, 224)) + module = build_from_cfg(transform, PIPELINES) + res = module(img) + assert list(res[1].shape) == [14, 14] diff --git a/tests/test_models/test_algorithms/test_maskfeat.py b/tests/test_models/test_algorithms/test_maskfeat.py new file mode 100644 index 000000000..cc102d692 --- /dev/null +++ b/tests/test_models/test_algorithms/test_maskfeat.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmselfsup.models.algorithms import MaskFeat + +backbone = dict( + type='MaskFeatViT', + arch='b', + patch_size=16, + drop_path_rate=0, +) +head = dict(type='MaskFeatPretrainHead', hog_dim=108) +hog_para = dict(nbins=9, pool=8, gaussian_window=16) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_maskfeat(): + with pytest.raises(AssertionError): + alg = MaskFeat(backbone=backbone, head=None, hog_para=hog_para) + with pytest.raises(AssertionError): + alg = MaskFeat(backbone=None, head=head, hog_para=hog_para) + alg = MaskFeat(backbone=backbone, head=head, hog_para=hog_para) + + fake_img = torch.randn((2, 3, 224, 224)) + fake_mask = torch.randn((2, 14, 14)).bool() + fake_input = (fake_img, fake_mask) + fake_loss = alg.forward_train(fake_input) + fake_feature = alg.extract_feat(fake_input) + assert isinstance(fake_loss['loss'].item(), float) + assert list(fake_feature.shape) == [2, 197, 768] diff --git a/tests/test_models/test_backbones/test_maskfeat_pretrain_vit.py b/tests/test_models/test_backbones/test_maskfeat_pretrain_vit.py new file mode 100644 index 000000000..0d83be2fb --- /dev/null +++ b/tests/test_models/test_backbones/test_maskfeat_pretrain_vit.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmselfsup.models.backbones import MaskFeatViT + +backbone = dict(arch='b', patch_size=16) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_maskfeat_pretrain_vit(): + maskfeat_pretrain_backbone = MaskFeatViT(**backbone) + maskfeat_pretrain_backbone.init_weights() + fake_inputs = torch.randn((2, 3, 224, 224)) + fake_mask = torch.randn((2, 14, 14)) + fake_outputs = maskfeat_pretrain_backbone(fake_inputs, fake_mask) + + assert list(fake_outputs.shape) == [2, 197, 768] diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index a019a51d4..8e6311ff5 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -5,7 +5,9 @@ from mmselfsup.models.heads import (ClsHead, ContrastiveHead, LatentClsHead, LatentCrossCorrelationHead, LatentPredictHead, MAEFinetuneHead, - MAEPretrainHead, MultiClsHead, SwAVHead) + MAEPretrainHead, MaskFeatFinetuneHead, + MaskFeatPretrainHead, MultiClsHead, + SwAVHead) def test_cls_head(): @@ -120,3 +122,28 @@ def test_mae_finetune_head(): loss = head.loss(fake_features, fake_labels) assert loss['loss'].item() > 0 + + +def test_maskfeat_pretrain_head(): + head = MaskFeatPretrainHead(hog_dim=108) + fake_mask = torch.ones((2, 14, 14)).bool() + fake_pred = torch.rand((2, 197, 768)) + fake_hog = torch.rand((2, 196, 108)) + + loss = head.forward(fake_pred, fake_hog, fake_mask) + + assert loss['loss'].item() > 0 + + +def test_maskfeat_finetune_head(): + + head = MaskFeatFinetuneHead(num_classes=1000, embed_dim=768) + fake_input = torch.rand((2, 768)) + fake_labels = F.normalize(torch.rand((2, 1000)), dim=-1) + fake_features = head.forward(fake_input) + + assert list(fake_features[0].shape) == [2, 1000] + + loss = head.loss(fake_features, fake_labels) + + assert loss['loss'].item() > 0