This repository contains the official Pytorch implementation of Similarity Contrastive Estimation for Image and Video Soft Contrastive Self-Supervised Learning (SCE) that has been published in the journal Machine Vision and Applications (2023).
Data preparation details are available here.
Doc is available here.
We launched our experiments on computational clusters configured via SLURM using up to 16 A100-80G GPUs depending on the experiments.
We provide below the commands using the srun command from SLURM that was inside a SLURM script. Pytorch-Lightning directly detects SLURM is used and configures accordingly the distributed training. We strongly suggest you refer to Pytorch-Lightning's documentation to correctly set up a command line without srun if you do not have access to a slurm cluster.
We launched our experiments on a computational cluster configured via SLURM.
Results obtained on Kinetics 400. We provide the encoder checkpoints.
Frames | K400 |
UCF101 |
HMDB51 |
ckpt | ||
---|---|---|---|---|---|---|
Acc 1 | Retrieval 1 | Acc 1 |
Retrieval 1 | |||
8 | 67.6 | 94.1 | 81.5 | 70.5 | 43.0 | Download |
16 | 69.6 | 95.3 | 83.9 | 74.7 | 45.9 | Download |
Define the output directory, experiment and datasets directory as well as the seed for all experiments.
output_dir=...
exp_dir=...
dataset_dir=...
seed=42
cd eztorch/run
Can be launched on 2 A100-80G GPUs.
config_path="../eztorch/configs/run/pretrain/sce/resnet3d18"
config_name="resnet3d18_kinetics200"
srun --kill-on-bad-exit=1 python pretrain.py \
-cp $config_path -cn $config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="pretrain" \
seed.seed=$seed \
datamodule.train.loader.num_workers=4 \
datamodule.val.loader.num_workers=4 \
trainer.devices=2
Can be launched on 4 A100-80G GPUs.
config_path="../eztorch/configs/run/pretrain/sce/resnet3d50"
config_name="resnet3d50_kinetics400"
srun --kill-on-bad-exit=1 python pretrain.py \
-cp $config_path -cn $config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="pretrain" \
seed.seed=$seed \
datamodule.train.loader.num_workers=4 \
datamodule.val.loader.num_workers=4 \
trainer.devices=8
Can be launched on 8 A100-80G GPUs.
config_path="../eztorch/configs/run/pretrain/sce/resnet3d50"
config_name="resnet3d50_kinetics400"
srun --kill-on-bad-exit=1 python pretrain.py \
-cp $config_path -cn $config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="pretrain" \
seed.seed=$seed \
datamodule.train.transform.transform.transforms.1.num_samples=16 \
datamodule.train.loader.num_workers=4 \
datamodule.val.loader.num_workers=4 \
trainer.devices=8 \
trainer.num_nodes=2
For downstream tasks, we consider by default you use checkpoints you pretrained yourselves.
If this is not the case and you downloaded the checkpoints we provided, do not forget to change the model.trunk_pattern
config that searches the trunk pattern in the state dict:
srun --kill-on-bad-exit=1 python downstream_script.py
...
model.trunk_pattern="" \
...
eval_config_path="../eztorch/configs/run/evaluation/linear_classifier/sce/resnet3d18"
eval_config_name="resnet3d18_kinetics200_frame"
pretrain_checkpoint=...
srun --kill-on-bad-exit=1 python linear_classifier_evaluation.py \
-cp $eval_config_path -cn $eval_config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="linear_classifier_evaluation" \
model.pretrained_trunk_path=$pretrain_checkpoint \
datamodule.train.loader.num_workers=4 \
datamodule.val.loader.num_workers=4 \
seed.seed=$seed \
trainer.devices=-1
eval_config_path="../eztorch/configs/run/evaluation/linear_classifier/sce/resnet3d50"
eval_config_name="resnet3d50_kinetics400"
pretrain_checkpoint=...
srun --kill-on-bad-exit=1 python linear_classifier_evaluation.py \
-cp $eval_config_path -cn $eval_config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="linear_classifier_evaluation" \
model.pretrained_trunk_path=$pretrain_checkpoint \
datamodule.train.loader.num_workers=4 \
datamodule.val.loader.num_workers=4 \
seed.seed=$seed \
trainer.devices=-1
eval_config_path="../eztorch/configs/run/evaluation/linear_classifier/sce/resnet3d18"
eval_config_name="resnet3d50_kinetics400"
pretrain_checkpoint=...
srun --kill-on-bad-exit=1 python linear_classifier_evaluation.py \
-cp $eval_config_path -cn $eval_config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="linear_classifier_evaluation" \
model.pretrained_trunk_path=$pretrain_checkpoint \
datamodule.train.transform.transform.transforms.0.num_samples=16 \
datamodule.train.loader.num_workers=5 \
datamodule.val.loader.num_workers=5 \
seed.seed=$seed \
trainer.devices=-1
Validation can be quite long, in the code we evaluate only every 5 epochs. Two steps can speed things up:
- Speed training:
- removes validation and only saves the last checkpoint
- performs a validation with only one crop instead of 30
- Perform testing afterward.
To perform this, change the config for validation and launch a test after training (example for Kinetics400 R3D50 16 frames):
eval_config_path="../eztorch/configs/run/evaluation/linear_classifier/sce/resnet3d18"
eval_config_name="resnet3d50_kinetics400"
pretrain_checkpoint=...
srun --kill-on-bad-exit=1 python test.py.py \
-cp $eval_config_path -cn $eval_config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="linear_classifier_evaluation" \
model.pretrained_trunk_path=$pretrain_checkpoint \
model.optimizer.batch_size=512 \
datamodule.train=null \
datamodule.val=null \
datamodule.test.loader.num_workers=3 \
datamodule.test.global_batch_size=2 \
datamodule.test.transform.transform.transforms.0.num_samples=16 \
seed.seed=$seed \
trainer=gpu \
trainer.devices=1 \
test.ckpt_by_callback_mode=best
We give here the configurations for fine-tuning a ResNet3d50 with 16 frames, but configs for other networks are available.
config_path="../eztorch/configs/run/finetuning/resnet3d50"
config_name="resnet3d50_hmdb51_frame"
pretrain_checkpoint=...
split=1
srun --kill-on-bad-exit=1 python supervised.py \
-cp $config_path -cn $config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="finetuning_hmdb51_split${split}" \
model.pretrained_trunk_path=$pretrain_checkpoint \
datamodule.split_id=$split \
datamodule.train.loader.num_workers=4 \
datamodule.val.loader.num_workers=4 \
datamodule.decoder_args.frame_filter.num_samples=16 \
seed.seed=$seed \
trainer.devices=-1 \
test=null \
config_path="../eztorch/configs/run/finetuning/resnet3d50"
config_name="resnet3d50_ucf101_frame"
pretrain_checkpoint=...
split=1
srun --kill-on-bad-exit=1 python supervised.py \
-cp $config_path -cn $config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="finetuning_ucf101_split${split}" \
model.pretrained_trunk_path=$pretrain_checkpoint \
datamodule.split_id=$split \
datamodule.train.loader.num_workers=4 \
datamodule.val.loader.num_workers=4 \
datamodule.decoder_args.frame_filter.num_samples=16 \
seed.seed=$seed \
trainer.devices=-1 \
test=null \
Validation can be quite long, in the code we evaluate only every 5 epochs. Two steps can speed things up:
- Speed training:
- removes validation and only saves the last checkpoint
- performs a validation with only one crop instead of 30
- Perform testing afterward.
To perform this, change the config for validation and launch a test after training (example for UCF101):
config_path="../eztorch/configs/run/finetuning/resnet3d50"
config_name="resnet3d50_ucf101_frame"
pretrain_checkpoint=...
split=1
srun --kill-on-bad-exit=1 python test.py \
-cp $config_path -cn $config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="finetuning_hmdb51_split${split}" \
model.pretrained_trunk_path=$pretrain_checkpoint \
datamodule.train=null \
datamodule.val=null \
model.optimizer.batch_size=64 \
datamodule.test.global_batch_size=2 \
datamodule.test.loader.num_workers=4 \
datamodule.decoder_args.frame_filter.num_samples=16 \
trainer=gpu \
seed.seed=$seed \
trainer.devices=1 \
test.ckpt_by_callback_mode=best
We give here the configurations for video retrieval using a ResNet3d50 with 16 frames, but configs for other networks are available.
It has two steps:
- Features extraction
- Retrieval
extract_config_path="../eztorch/configs/run/evaluation/feature_extractor/resnet3d50"
extract_config_name="resnet3d50_hmdb51_frame"
retrieval_config_path="../eztorch/configs/run/evaluation/retrieval_from_bank"
retrieval_config_name="default"
split=1
pretrain_checkpoint=...
# Extraction
srun --kill-on-bad-exit=1 python extract_features.py \
-cp $extract_config_path -cn $extract_config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="features_extraction_split${split}" \
model.pretrained_trunk_path=$pretrain_checkpoint \
datamodule.decoder_args.frame_filter.num_samples=16 \
datamodule.train.loader.num_workers=3 \
datamodule.val.loader.num_workers=3 \
datamodule.train.global_batch_size=2 \
datamodule.val.global_batch_size=2 \
seed.seed=$seed \
trainer.num_nodes=$SLURM_NNODES \
datamodule.split_id=$split \
trainer.max_epochs=1
# Retrieval
query_features="${output_dir}/features_extraction_split${split}/val_features.pth"
bank_features="${output_dir}/features_extraction_split${split}/train_features.pth"
query_labels="${output_dir}/features_extraction_split${split}/val_labels.pth"
bank_labels="${output_dir}/features_extraction_split${split}/train_labels.pth"
srun --kill-on-bad-exit=1 python retrieval_from_bank.py \
-cp $retrieval_config_path -cn $retrieval_config_name \
dir.root=$output_dir \
dir.exp="retrieval_split${split}" \
query.features_path=$query_features \
query.labels_path=$query_labels \
bank.features_path=$bank_features \
bank.labels_path=$bank_labels
extract_config_path="../eztorch/configs/run/evaluation/feature_extractor/resnet3d50"
extract_config_name="resnet3d50_ucf101_frame"
retrieval_config_path="../eztorch/configs/run/evaluation/retrieval_from_bank"
retrieval_config_name="default"
split=1
pretrain_checkpoint=...
# Extraction
srun --kill-on-bad-exit=1 python extract_features.py \
-cp $extract_config_path -cn $extract_config_name \
dir.data=$dataset_dir \
dir.root=$output_dir \
dir.exp="features_extraction_split${split}" \
model.pretrained_trunk_path=$pretrain_checkpoint \
datamodule.decoder_args.frame_filter.num_samples=16 \
datamodule.train.loader.num_workers=3 \
datamodule.val.loader.num_workers=3 \
datamodule.train.global_batch_size=2 \
datamodule.val.global_batch_size=2 \
seed.seed=$seed \
trainer.num_nodes=$SLURM_NNODES \
datamodule.split_id=$split \
trainer.max_epochs=1
# Retrieval
query_features="${output_dir}/features_extraction_split${split}/val_features.pth"
bank_features="${output_dir}/features_extraction_split${split}/train_features.pth"
query_labels="${output_dir}/features_extraction_split${split}/val_labels.pth"
bank_labels="${output_dir}/features_extraction_split${split}/train_labels.pth"
srun --kill-on-bad-exit=1 python retrieval_from_bank.py \
-cp $retrieval_config_path -cn $retrieval_config_name \
dir.root=$output_dir \
dir.exp="retrieval_split${split}" \
query.features_path=$query_features \
query.labels_path=$query_labels \
bank.features_path=$bank_features \
bank.labels_path=$bank_labels
Generalization to Action Localization on AVA and Action Recognition on SSV2 was performed thanks to the SlowFast repository. This repository supports the use of pytorchvideo models which we used as backbones.
If you found an error, have trouble making this work or have any questions, please open an issue to describe your problem.
This publication was made possible by the use of the Factory-AI supercomputer, financially supported by the Ile-de-France Regional Council and the HPC resources of IDRIS under the allocation 2022-AD011013575 made by GENCI.
If you found our work useful, please consider citing us:
@article{Denize_2023_MVAP,
author={Denize, Julien and Rabarisoa, Jaonary and Orcesi, Astrid and H{\'e}rault, Romain},
title={Similarity contrastive estimation for image and video soft contrastive self-supervised learning},
journal={Machine Vision and Applications},
year={2023},
volume={34},
number={6},
doi={10.1007/s00138-023-01444-9},
}