This repository is the official implementation of Mixture of Experts Meets Prompt-Based Continual Learning
(NeurIPS 2024).
Exploiting the power of pre-trained models, prompt-based approaches stand out compared to other continual learning solutions in effectively preventing catastrophic forgetting, even with very few learnable parameters and without the need for a memory buffer. While existing prompt-based continual learning methods excel in leveraging prompts for state-of-the-art performance, they often lack a theoretical explanation for the effectiveness of prompting. This paper conducts a theoretical analysis to unravel how prompts bestow such advantages in continual learning, thus offering a new perspective on prompt design. We first show that the attention block of pre-trained models like Vision Transformers inherently encodes a mixture of experts architecture, characterized by linear experts and quadratic gating score functions. This realization drives us to provide a novel view on prefix tuning, reframing it as the addition of new task-specific experts, thereby inspiring the design of a novel gating mechanism termed Non-linear Residual Gates (NoRGa). Through the incorporation of non-linear activation and residual connection, NoRGa enhances continual learning performance while preserving parameter efficiency. The effectiveness of NoRGa is substantiated both theoretically and empirically across diverse benchmarks and pretraining paradigms.
- Python 3.10.5
pip install -r requirements.txt
Our code has been tested on four datasets: CIFAR-100, ImageNet-R, 5-Datasets, and CUB-200:
- CIFAR-100
- Imagenet-R
- 5-Datasets (including SVHN, MNIST, CIFAR10, NotMNIST, FashionMNIST)
- CUB-200
We incorporated the following supervised and self-supervised checkpoints as backbones:
Please download the self-supervised checkpoints and put them in the /checkpoints/{checkpoint_name}
directory, excecpt Sup-21K.
NOTE: For iBOT, please rename the checkpoint file to ibot_vitbase16_pretrain.pth
.
To reproduce the results mentioned in our paper, execute the training script in /scripts/{dataset}_{backbone}_{method}.sh
. e.g.
NoRGa: If you want to train with Sup-21K backbone, run the following command:
- Split CIFAR-100:
bash scripts/cifar100_Sup21k_NoRGa.sh
- Split CUB-200:
bash scripts/cub_Sup21k_NoRGa.sh
- Split ImageNet-R:
bash scripts/imr_Sup21k_NoRGa.sh
- 5-datasets:
bash scripts/5datasets_Sup21k_NoRGa.sh
If you encounter any issues or have any questions, please let us know.
This repository is developed mainly based on the PyTorch implementation of HiDe-Prompt. Many thanks to its contributors!